Source code for tart.inference.predict

# predict API for tart

import os.path as osp
import json
import glob
import argparse
from argparse import Namespace
from typing import List, Tuple, Optional
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn

import numpy as np
import torch
import torch.nn as nn
from deepsnap.batch import Batch

from tart.representation.encoders import get_feature_encoder
from tart.representation import config, models
from tart.inference.embed import get_neighborhoods
from tart.utils.model_utils import build_model, get_device
from tart.utils.graph_utils import read_graph_from_json, featurize_graph
from tart.utils.tart_utils import print_header


console = Console()


[docs]def search_space_sample(src_dir: str, k: Optional[int] = None, seed: int = 24) -> Tuple[List[str], List[str]]: """sample k graphs from search space Args: src_dir (str): directory containing search space graphs k (int, optional): number of graphs to sample. Defaults to None. seed (int, optional): random seed. Defaults to 24. Returns: Tuple[List[str], List[str]]: list of sampled files and their indices """ np.random.seed(seed) files = [f for f in sorted(glob.glob(osp.join(src_dir, "*.pt")))] if k is None: k = len(files) random_files = np.random.choice(files, min(len(files), k)) random_index = [f.split("_")[-1][:-3] for f in random_files] return random_files, random_index
[docs]def read_embedding(args: Namespace, idx: str) -> torch.Tensor: """read embedding of a graph from disk Args: args (Namespace): tart configs idx (str): index of graph Returns: torch.Tensor: embedding of graph """ emb_path = f"emb_{idx}.pt" emb_path = osp.join(args.emb_dir, emb_path) return torch.load(emb_path, map_location=torch.device("cpu"))
[docs]def load_search_space(args: Namespace, file_indices: List[str]) -> List[torch.Tensor]: """load embeddings of search space graphs into a list of batches Args: args (Namespace): tart configs file_indices (List[str]): list of indices of search space graphs Returns: List[torch.Tensor]: list of batched embeddings """ embs, batch_embs = [], [] count = 0 for i, idx in enumerate(file_indices): batch_embs.append(read_embedding(args, idx)) if i > 0 and i % args.batch_size == 0: embs.append(torch.cat(batch_embs, dim=0)) count += len(batch_embs) batch_embs = [] # add remaining embs as a batch if len(batch_embs) > 0: embs.append(torch.cat(batch_embs, dim=0)) count += len(batch_embs) assert count == len(file_indices) return embs
[docs]def predict_neighs_batched(model: nn.Module, search_embs: List[torch.Tensor], query_emb: torch.Tensor) -> int: """(batched) predict number of neighborhoods in which query graph (query_emb) is subgraph of search graphs (embs) Args: model (nn.Module): tart model search_embs (List[torch.Tensor]): list of batched embeddings of search space graphs query_emb (torch.Tensor): embedding of query graph Returns: int: score = number of neighborhoods in which query graph is subgraph of search graphs """ score = 0 with Progress( SpinnerColumn(), TextColumn("Searching for subgraphs..{task.description}"), transient=True, ) as progress: progress.add_task("", total=len(search_embs)) for emb_batch in search_embs: with torch.no_grad(): predictions, _ = model.predictv2((emb_batch.to(get_device()), query_emb)) score += torch.sum(predictions).item() return score
[docs]def tart_predict( user_config_file: str, query_json: str, search_space_path: str, search_sample: Optional[int] = None, outcome: str = "count_subgraphs", ): """predict API for tart Args: user_config_file (str): config file path query_json (str): path to query graph json file search_space_path (str): path to search space directory search_sample (Optional[int], optional): number of graphs to sample from search space. Defaults to None. outcome (str, optional): prediction outcome = {"count_subgraphs", "is_subgraph"}. Defaults to "count_subgraphs". """ print_header() console.print("[bright_green underline]Prediction API[/ bright_green underline]\n") parser = argparse.ArgumentParser() # is_subgraph expects search_space to be a single graph if outcome == "is_subgraph" and osp.isdir(search_space_path): raise ValueError("is_subgraph expects search_space_path to point to a single graph") # count_subgraph expects search_space to be a directory if outcome == "count_subgraphs" and not osp.isdir(search_space_path): raise ValueError("count_subgraphs expects search_space_path to point to a directory") # reading user config from json file with open(user_config_file) as f: config_json = json.load(f) # build configs and their defaults config.build_optimizer_configs(parser) config.build_model_configs(parser) config.build_feature_configs(parser) args = parser.parse_args() # set to test mode args.test = True # set user defined configs args = config.init_user_configs(args, config_json) # set feature encoder feat_encoder = get_feature_encoder(args.feat_encoder) # set search space embeddings directory args.emb_dir = search_space_path # featurize the query graph query_graph = read_graph_from_json(args, query_json) console.print(f"Query graph: {query_graph}") query_feat = featurize_graph(args, feat_encoder, query_graph, anchor=0) query_tensor = Batch.from_data_list([query_feat]).to(get_device()) # build model model = build_model(models.SubgraphEmbedder, args) # embed query graph query_emb = model.encoder(query_tensor) # ======= PREDICT ========= if outcome == "count_subgraphs": # load search space embeddings _, file_indices = search_space_sample(search_space_path, k=search_sample, seed=4) search_embs = load_search_space(args, file_indices) console.print(f"Search space: {len(file_indices)} graphs loaded.") # predict number of neighborhoods score = predict_neighs_batched(model, search_embs, query_emb) console.print(f"Number of subgraph (neighs): {score}") # NOTE: returns number of nodes that have subgraphs # rooted at the node that are isomorphic to query graph. # TODO: count number of graphs in which query was found as a subG. elif outcome == "is_subgraph": # predict if query graph is subgraph of search graph # featurize the search graph (like in embed.py) search_graph = read_graph_from_json(args, query_json) console.print(f"Search graph: {search_graph}") search_neighs = get_neighborhoods(args, search_graph, feat_encoder) search_embs = model.encoder(Batch.from_data_list(search_neighs).to(get_device())) # predict if subgraph score = predict_neighs_batched(model, [search_embs], query_emb) if score > 0.0: console.print(f"Query graph is a subgraph of search graph (score = {score}).") else: console.print(f"Query graph is not a subgraph of search graph!") else: raise ValueError("Invalid outcome. Please choose from: count_subgraphs, is_subgraph")