Source code for tart.utils.train_utils

import argparse
from typing import List, Tuple, Callable
from argparse import Namespace
from rich.progress import track, Progress, TextColumn, SpinnerColumn

import torch
import torch.multiprocessing as mp
from deepsnap.batch import Batch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from test_tube import HyperOptArgumentParser

from tart.representation.dataset import Corpus


[docs]def set_parser(tune): if tune: parser = HyperOptArgumentParser(strategy="grid_search") else: parser = argparse.ArgumentParser() return parser
[docs]def init_logger(args: Namespace) -> SummaryWriter: """Initialize tensorboard logger Args: args (Namespace): tart configs Returns: SummaryWriter: tensorboard logger """ log_keys = [ "conv_type", "n_layers", "hidden_dim", "margin", "dataset", "max_graph_size", "skip", ] log_str = ".".join(["{}={}".format(k, v) for k, v in sorted(vars(args).items()) if k in log_keys]) return SummaryWriter(comment=log_str)
[docs]def start_workers( train_func: Callable, model: torch.nn.Module, corpus: Corpus, in_queue: mp.Queue, out_queue: mp.Queue, args: Namespace, ) -> List[mp.Process]: """Start workers for training Args: train_func (Callable): train function model (torch.nn.Module): tart model to train corpus (Corpus): dataset to train the model on in_queue (mp.Queue): mp queue for input out_queue (mp.Queue): mp queue for output args (Namespace): tart configs Returns: List[mp.Process]: list of workers """ workers = [] with Progress( SpinnerColumn(), TextColumn("Starting workers..{task.description}"), transient=True, ) as progress: progress.add_task("", total=None) for _ in range(args.n_workers): worker = mp.Process(target=train_func, args=(args, model, corpus, in_queue, out_queue)) worker.start() workers.append(worker) return workers
[docs]def make_validation_set( dataloader: DataLoader, ) -> List[Tuple[Batch, Batch, Batch, Batch]]: """Make validation set from dataloader Args: dataloader (DataLoader): dataloader for validation set Returns: List[Tuple[Batch, Batch, Batch, Batch]]: list of validation batches, each batch is a tuple of (pos_q, pos_t, neg_q, neg_t) """ test_pts = [] for batch in track(dataloader, total=len(dataloader), description="TestBatches"): pos_q, pos_t, neg_q, neg_t = zip(*batch) pos_q = Batch.from_data_list(pos_q) pos_t = Batch.from_data_list(pos_t) neg_q = Batch.from_data_list(neg_q) neg_t = Batch.from_data_list(neg_t) test_pts.append((pos_q, pos_t, neg_q, neg_t)) return test_pts