Source code for tart.representation.train

import os
import json
from argparse import Namespace
from typing import Callable
from rich.console import Console

import torch
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
from deepsnap.batch import Batch

from tart.representation.dataset import Corpus
from tart.representation.encoders import get_feature_encoder
from tart.representation.test import validation
from tart.representation import config, models, dataset
from tart.utils.model_utils import build_model, build_optimizer, get_device
from tart.utils.train_utils import set_parser, init_logger, start_workers, make_validation_set
from tart.utils.config_utils import validate_feat_encoder
from tart.utils.tart_utils import print_header, summarize_tart_run

torch.multiprocessing.set_sharing_strategy("file_system")
console = Console()


[docs]def train(args: Namespace, model: nn.Module, corpus: Corpus, in_queue: mp.Queue, out_queue: mp.Queue): """Train a single iteration for the model on a corpus of graphs. NOTE: This function is called by each worker process. Args: args (Namespace): tart configs model (nn.Module): tart model to train corpus (Corpus): corpus generator for graphs in_queue (mp.Queue): multiprocessing queue for input out_queue (mp.Queue): multiprocessing queue for output """ scheduler, opt = build_optimizer(args, model.parameters()) clf_opt = optim.Adam(model.classifier.parameters(), lr=args.lr) done = False while not done: dataloader = corpus.gen_data_loader(args.batch_size, train=True) for batch in dataloader: msg, _ = in_queue.get() if msg == "done": done = True break model.train() model.zero_grad() pos_a, pos_b, neg_a, neg_b = zip(*batch) # convert to DeepSnap Batch pos_a = Batch.from_data_list(pos_a).to(get_device()) pos_b = Batch.from_data_list(pos_b).to(get_device()) neg_a = Batch.from_data_list(neg_a).to(get_device()) neg_b = Batch.from_data_list(neg_b).to(get_device()) # get embeddings emb_pos_a = model.encoder(pos_a) # pos target emb_pos_b = model.encoder(pos_b) # pos query emb_neg_a = model.encoder(neg_a) # neg target emb_neg_b = model.encoder(neg_b) # neg query # concatenate emb_as = torch.cat((emb_pos_a, emb_neg_a), dim=0) emb_bs = torch.cat((emb_pos_b, emb_neg_b), dim=0) labels = torch.tensor([1] * pos_a.num_graphs + [0] * neg_a.num_graphs).to(get_device()) # make predictions pred = model(emb_as, emb_bs) loss = model.criterion(pred, labels) # backward pass loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if scheduler: scheduler.step() with torch.no_grad(): pred = model.predict(pred) model.classifier.zero_grad() pred = model.classifier(pred.unsqueeze(1)) criterion = nn.NLLLoss() clf_loss = criterion(pred, labels) clf_loss.backward() clf_opt.step() # metrics pred = pred.argmax(dim=-1) acc = torch.mean((pred == labels).type(torch.float)) train_loss = loss.item() train_acc = acc.item() out_queue.put(("step", (train_loss, train_acc)))
[docs]def train_loop(args: Namespace, feat_encoder: Callable): """train the model for a number of iterations by spinning up a number of workers. Args: args (Namespace): tart configs feat_encoder (Callable): encoder function that converts string to torch.tensor """ if not os.path.exists(os.path.dirname(args.model_path)): os.makedirs(os.path.dirname(args.model_path)) if not os.path.exists("plots/"): os.makedirs("plots/") in_queue, out_queue = mp.Queue(), mp.Queue() # init logger logger = init_logger(args) # build model model = build_model(models.SubgraphEmbedder, args) model.share_memory() # print("Moving model to device:", get_device()) model = model.to(get_device()) # create a corpus for train and test corpus = dataset.Corpus(args, feat_encoder, train=(not args.test)) # create validation points loader = corpus.gen_data_loader(args.batch_size, train=False) summarize_tart_run(args) # ====== TRAINING ====== validation_pts = make_validation_set(loader) assert args.n_iters > 0, "Number of iterations must be greater than 0" assert args.n_batches // args.eval_interval > 0, "Number of epochs per iteration must be greater than 0" for iter in range(args.n_iters): console.print(f"\n[bright_green underline]Iteration #{iter}[/bright_green underline]\n") workers = start_workers(train, model, corpus, in_queue, out_queue, args) batch_n = 0 for epoch in range(args.n_batches // args.eval_interval): console.print(f"\n[bright_blue]=============== Epoch #{epoch} ===============[/ bright_blue]") for _ in range(args.eval_interval): in_queue.put(("step", None)) # loop over mini-batches in an epoch for _ in range(args.eval_interval): _, result = out_queue.get() train_loss, train_acc = result console.print(f"Batch {batch_n}. Loss: {train_loss:.4f}. Train acc: {train_acc:.4f}\n") logger.add_scalar("Loss(train)", train_loss, batch_n) logger.add_scalar("Acc(train)", train_acc, batch_n) batch_n += 1 # validation after an epoch validation(args, model, validation_pts, logger, batch_n, epoch) for _ in range(args.n_workers): in_queue.put(("done", None)) for worker in workers: worker.join()
[docs]def tart_train(user_config_file: str, tune: bool = False, trials: int = 0) -> None: """tart's train API Args: user_config_file (str): config file path tune (bool, optional): flag to perform hyperparam tuning. Defaults to False. trials (int, optional): #trials for hyperparam tuning . Defaults to None. """ print_header() parser = set_parser(tune) # reading user config from json file with open(user_config_file) as f: config_json = json.load(f) # build configs and their defaults config.build_optimizer_configs(parser) config.build_model_configs(parser) config.build_feature_configs(parser) if tune: config.make_tunable(parser, config_json["tunable"]) args = parser.parse_args() # set user defined configs args = config.init_user_configs(args, config_json) # set feature encoder feat_encoder = get_feature_encoder(args.feat_encoder) validate_feat_encoder(feat_encoder, config_json) args.n_train = args.n_batches * args.batch_size args.n_test = int(0.2 * args.n_train) if tune and trials > 0: for i, trial_args in enumerate(args.trials(trials)): console.print(f"\n[bright_green underline]Hyperparameter Tuning Trial #{i}[/bright_green underline]\n") trial_args = config.init_user_configs(trial_args, config_json, tune=True) trial_args.n_train = trial_args.n_batches * trial_args.batch_size trial_args.n_test = int(0.2 * trial_args.n_train) print(f"Trial args: {trial_args}") train_loop(trial_args, feat_encoder) else: train_loop(args, feat_encoder)