import os
import glob
import json
import os.path as osp
from argparse import Namespace
from typing import Callable, List
from rich.progress import track, Progress, TextColumn, SpinnerColumn
from rich.console import Console
import networkx as nx
import torch
import torch.nn as nn
import argparse
from deepsnap.batch import Batch
import torch.multiprocessing as mp
from tart.representation.encoders import get_feature_encoder
from tart.representation import config, models
from tart.utils.model_utils import build_model, get_device
from tart.utils.graph_utils import read_graph_from_json, featurize_graph
console = Console()
# ########## MULTI PROC ##########
[docs]def start_workers_process(in_queue: mp.Queue, out_queue: mp.Queue, args: Namespace) -> List[mp.Process]:
"""Starts worker processes for generating neighborhoods
Args:
in_queue (mp.Queue): multiprocessing queue for input
out_queue (mp.Queue): multiprocessing queue for output
args (Namespace): tart configs
Returns:
List[mp.Process]: list of worker processes
"""
workers = []
with Progress(
SpinnerColumn(),
TextColumn("Starting workers..{task.description}"),
transient=True,
) as progress:
progress.add_task("", total=None)
for _ in range(args.n_workers):
worker = mp.Process(target=generate_neighborhoods, args=(args, in_queue, out_queue))
worker.start()
workers.append(worker)
return workers
[docs]def start_workers_embed(model: nn.Module, in_queue: mp.Queue, out_queue: mp.Queue, args: Namespace) -> List[mp.Process]:
"""Starts worker processes for generating embeddings
Args:
model (nn.Module): tart model to embed graphs
in_queue (mp.Queue): multiprocessing queue for input
out_queue (mp.Queue): multiprocessing queue for output
args (Namespace): tart configs
Returns:
List[mp.Process]: list of worker processes
"""
workers = []
with Progress(
SpinnerColumn(),
TextColumn("Starting workers..{task.description}"),
transient=True,
) as progress:
progress.add_task("", total=None)
for _ in range(args.n_workers):
worker = mp.Process(target=generate_embeddings, args=(args, model, in_queue, out_queue))
worker.start()
workers.append(worker)
return workers
# ########## UTILITIES ##########
[docs]def get_neighborhoods(args: Namespace, graph: nx.Graph, feat_encoder: Callable) -> List:
"""Returns a featurized (sampled) radial neighborhood for all nodes in a graph
Args:
args (Namespace): tart configs
graph (nx.Graph): graph to find neighborhoods for
feat_encoder (Callable): feature encoder for graph nodes
Returns:
List: list of featurized neighborhoods
"""
neighs = []
# find each node's neighbors via SSSP
for j, node in enumerate(graph.nodes):
shortest_paths = sorted(nx.single_source_shortest_path_length(graph, node, cutoff=args.emb_sssp_radius).items(), key=lambda x: x[1])
neighbors = list(map(lambda x: x[0], shortest_paths))
if args.emb_subg_sample_size != 0:
# NOTE: random sampling of radius-hop neighbors,
# results in nodes w/o any edges between them!!
# Instead, sort neighbors by hops and chose top-K closest neighbors
neighbors = neighbors[: args.emb_subg_sample_size]
if len(neighbors) > 1:
# NOTE: G.subgraph([nodes]) returns the subG induced on [nodes]
# i.e., the subG containing the nodes in [nodes] and
# edges between these nodes => in this case, a (sampled) radial n'hood
neigh = graph.subgraph(neighbors)
neigh = featurize_graph(args, feat_encoder, neigh, anchor=0)
neighs.append(neigh)
return neighs
# ########## PIPELINE FUNCTIONS ##########
[docs]def generate_embeddings(args: Namespace, model: nn.Module, in_queue: mp.Queue, out_queue: mp.Queue):
"""Generates embeddings for each node in the graph.
NOTE: This function is called by each worker process.
Args:
arg (Namespace): tart configs
model (nn.Module): tart model to generate embeddings
in_queue (mp.Queue): multiprocessing queue for input
out_queue (mp.Queue): multiprocessing queue for output
"""
done = False
while not done:
msg, idx = in_queue.get()
if msg == "done":
done = True
break
# read only graphs of processed programs
try:
neighs = torch.load(osp.join(args.proc_dir, f"data_{idx}.pt"))
except:
out_queue.put(("complete"))
continue
with torch.no_grad():
emb = model.encoder(Batch.from_data_list(neighs).to(get_device()))
torch.save(emb, osp.join(args.emb_dir, f"emb_{idx}.pt"))
out_queue.put(("complete"))
[docs]def generate_neighborhoods(args: Namespace, in_queue: mp.Queue, out_queue: mp.Queue):
"""Generates neighborhoods for each node in the graph.
NOTE: This function is called by each worker process.
Args:
args (Namespace): tart configs
in_queue (mp.Queue): multiprocessing queue for input
out_queue (mp.Queue): multiprocessing queue for output
"""
done = False
feat_encoder = get_feature_encoder(args.feat_encoder)
while not done:
msg, idx = in_queue.get()
if msg == "done":
done = True
break
raw_path = osp.join(args.raw_dir, f"example_{idx}.json")
graph = read_graph_from_json(args, raw_path)
if graph is None:
out_queue.put(("complete"))
continue
# save graph object for future apps like search
torch.save(graph, osp.join(args.graph_dir, f"data_{idx}.pt"))
# get neighborhoods of each node in the graph
neighs = get_neighborhoods(args, graph, feat_encoder)
torch.save(neighs, osp.join(args.proc_dir, f"data_{idx}.pt"))
del graph
del neighs
out_queue.put(("complete"))
# ########## MAIN ##########
[docs]def embed_main(args: Namespace):
"""Pipeline to generate embeddings for a dataset of graphs
Args:
args (Namespace): tart configs
"""
assert osp.exists(osp.dirname(args.raw_dir)), "raw_dir does not exist!"
if not osp.exists(args.graph_dir):
os.makedirs(args.graph_dir)
if not osp.exists(args.proc_dir):
os.makedirs(args.proc_dir)
if not osp.exists(args.emb_dir):
os.makedirs(args.emb_dir)
raw_paths = sorted(glob.glob(osp.join(args.raw_dir, "*.json")))
# ######### PHASE1: PROCESS GRAPHS #########
# util: to rename .py files into a standard filename format
# TODO: write to a txt file the mapping between index and original filename
# TODO: should we write the renamed files to a tmp folder?
for idx, p in enumerate(raw_paths):
os.rename(p, osp.join(args.raw_dir, f"example_{idx}.json"))
in_queue, out_queue = mp.Queue(), mp.Queue()
workers = start_workers_process(in_queue, out_queue, args)
for i in range(0, len(raw_paths)):
in_queue.put(("idx", i))
for _ in track(range(0, len(raw_paths)), description="Processing graphs"):
msg = out_queue.get()
for _ in range(args.n_workers):
in_queue.put(("done", None))
for worker in workers:
worker.join()
# ######### EMBED GRAPHS #########
model = build_model(models.SubgraphEmbedder, args)
model.share_memory()
console.print(f"\n[bright_green]\[tart] [/bright_green] Moving model to device: [bright_blue]{get_device()}[/bright_blue]\n")
model = model.to(get_device())
model.eval()
in_queue, out_queue = mp.Queue(), mp.Queue()
workers = start_workers_embed(model, in_queue, out_queue, args)
for i in range(0, len(raw_paths)):
in_queue.put(("idx", i))
for _ in track(range(0, len(raw_paths)), description="Embedding graphs"):
msg = out_queue.get()
for _ in range(args.n_workers):
in_queue.put(("done", None))
for worker in workers:
worker.join()
[docs]def tart_embed(user_config_file: str):
"""tart's embed API
Args:
user_config_file (str): config file path
"""
console.print("[bright_green underline]Embedding Search Space[/ bright_green underline]\n")
parser = argparse.ArgumentParser()
# reading user config from json file
with open(user_config_file) as f:
config_json = json.load(f)
# build configs and their defaults
config.build_optimizer_configs(parser)
config.build_model_configs(parser)
config.build_feature_configs(parser)
args = parser.parse_args()
# set user defined configs
args = config.init_user_configs(args, config_json)
# set default file paths for results
root_dir = osp.join(args.data_dir, "embed")
args.raw_dir = osp.join(root_dir, "raw")
args.graph_dir = osp.join(root_dir, "graphs")
args.proc_dir = osp.join(root_dir, "processed")
args.emb_dir = osp.join(root_dir, "embs")
embed_main(args)