tart.inference#
tart.inference.embed module#
- tart.inference.embed.embed_main(args)[source]#
Pipeline to generate embeddings for a dataset of graphs
- Parameters:
args (Namespace) – tart configs
- tart.inference.embed.generate_embeddings(args, model, in_queue, out_queue)[source]#
Generates embeddings for each node in the graph. NOTE: This function is called by each worker process.
- Parameters:
arg (Namespace) – tart configs
model (nn.Module) – tart model to generate embeddings
in_queue (mp.Queue) – multiprocessing queue for input
out_queue (mp.Queue) – multiprocessing queue for output
- tart.inference.embed.generate_neighborhoods(args, in_queue, out_queue)[source]#
Generates neighborhoods for each node in the graph. NOTE: This function is called by each worker process.
- Parameters:
args (Namespace) – tart configs
in_queue (mp.Queue) – multiprocessing queue for input
out_queue (mp.Queue) – multiprocessing queue for output
- tart.inference.embed.get_neighborhoods(args, graph, feat_encoder)[source]#
Returns a featurized (sampled) radial neighborhood for all nodes in a graph
- Parameters:
args (Namespace) – tart configs
graph (nx.Graph) – graph to find neighborhoods for
feat_encoder (Callable) – feature encoder for graph nodes
- Return type:
- Returns:
List – list of featurized neighborhoods
- tart.inference.embed.start_workers_embed(model, in_queue, out_queue, args)[source]#
Starts worker processes for generating embeddings
- Parameters:
model (nn.Module) – tart model to embed graphs
in_queue (mp.Queue) – multiprocessing queue for input
out_queue (mp.Queue) – multiprocessing queue for output
args (Namespace) – tart configs
- Return type:
List[Process]- Returns:
List[mp.Process] – list of worker processes
- tart.inference.embed.start_workers_process(in_queue, out_queue, args)[source]#
Starts worker processes for generating neighborhoods
- Parameters:
in_queue (mp.Queue) – multiprocessing queue for input
out_queue (mp.Queue) – multiprocessing queue for output
args (Namespace) – tart configs
- Return type:
List[Process]- Returns:
List[mp.Process] – list of worker processes
tart.inference.predict module#
- tart.inference.predict.load_search_space(args, file_indices)[source]#
load embeddings of search space graphs into a list of batches
- tart.inference.predict.predict_neighs_batched(model, search_embs, query_emb)[source]#
(batched) predict number of neighborhoods in which query graph (query_emb) is subgraph of search graphs (embs)
- Parameters:
model (nn.Module) – tart model
search_embs (List[torch.Tensor]) – list of batched embeddings of search space graphs
query_emb (torch.Tensor) – embedding of query graph
- Return type:
- Returns:
int – score = number of neighborhoods in which query graph is subgraph of search graphs
- tart.inference.predict.read_embedding(args, idx)[source]#
read embedding of a graph from disk
- Parameters:
args (Namespace) – tart configs
idx (str) – index of graph
- Return type:
Tensor- Returns:
torch.Tensor – embedding of graph
- tart.inference.predict.search_space_sample(src_dir, k=None, seed=24)[source]#
sample k graphs from search space
- Parameters:
- Return type:
- Returns:
Tuple[List[str], List[str]] – list of sampled files and their indices
- tart.inference.predict.tart_predict(user_config_file, query_json, search_space_path, search_sample=None, outcome='count_subgraphs')[source]#
predict API for tart
- Parameters:
user_config_file (str) – config file path
query_json (str) – path to query graph json file
search_space_path (str) – path to search space directory
search_sample (Optional[int], optional) – number of graphs to sample from search space. Defaults to None.
outcome (str, optional) – prediction outcome = {“count_subgraphs”, “is_subgraph”}. Defaults to “count_subgraphs”.