tart.utils#

tart.utils.config_utils module#

tart.utils.config_utils.validate_feat_encoder(user_feat_encoder, config_json)[source]#

validate user defined feature encoder

Parameters:
  • user_feat_encoder (Callable) – user defined feature encoder

  • config_json (Dict) – user defined configs

Return type:

Callable

Returns:

Callable – user defined feature encoder

tart.utils.graph_utils module#

tart.utils.graph_utils.featurize_graph(args, feat_encoder, g, anchor=None)[source]#

Featurize a networkx graph into a DeepSnap graph >> all features are converted to torch.tensor and added to the {feat}_t key >> string features are converted to torch.tensor by the encoder model

Parameters:
  • args (Namespace) – tart configs

  • feat_encoder (Callable) – encoder function that converts string to torch.tensor

  • g (nx.DiGraph) – networkx graph

  • anchor (_type_, optional) – anchor node id. Defaults to None.

Return type:

Graph

Returns:

DSGraph – DeepSnap graph

tart.utils.graph_utils.read_graph_from_json(args, path)[source]#

Read a graph from a json file

Parameters:
  • args (Namespace) – tart configs

  • path (str) – path to the json file

Return type:

Graph

Returns:

nx.Graph – networkx graph

tart.utils.infer_utils module#

tart.utils.model_utils module#

tart.utils.model_utils.build_model(model_type, args)[source]#

build the user specified model

Parameters:
  • model_type (torch.nn.Module) – model class

  • args (Namespace) – user defined configs

Return type:

Module

Returns:

torch.nn.Module – built model

tart.utils.model_utils.build_optimizer(args, params)[source]#

build optimizer and scheduler

Parameters:
  • args (Namespace) – user defined configs

  • params (Iterator) – model parameters

Return type:

Tuple

Returns:

Tuple[optim.lr_scheduler.LRScheduler, optim.Optimizer] – _description_

tart.utils.model_utils.get_device()[source]#

get device (cpu or gpu)

Return type:

device

Returns:

torch.device – available device

tart.utils.model_utils.get_torch_tensor_type()[source]#

get torch tensor type (cpu or gpu)

Return type:

FloatTensor

Returns:

torch.dtype – correct torch tensor type

tart.utils.tart_utils module#

tart.utils.tart_utils.print_header()[source]#

print tart header to console

Return type:

None

tart.utils.tart_utils.summarize_tart_run(args)[source]#

print tart run summary to console

Parameters:

args (Namespace) – tart configs

Return type:

None

tart.utils.train_utils module#

tart.utils.train_utils.init_logger(args)[source]#

Initialize tensorboard logger

Parameters:

args (Namespace) – tart configs

Return type:

SummaryWriter

Returns:

SummaryWriter – tensorboard logger

tart.utils.train_utils.make_validation_set(dataloader)[source]#

Make validation set from dataloader

Parameters:

dataloader (DataLoader) – dataloader for validation set

Return type:

List[Tuple[Batch, Batch, Batch, Batch]]

Returns:

List[Tuple[Batch, Batch, Batch, Batch]] – list of validation batches, each batch is a tuple of (pos_q, pos_t, neg_q, neg_t)

tart.utils.train_utils.set_parser(tune)[source]#
tart.utils.train_utils.start_workers(train_func, model, corpus, in_queue, out_queue, args)[source]#

Start workers for training

Parameters:
  • train_func (Callable) – train function

  • model (torch.nn.Module) – tart model to train

  • corpus (Corpus) – dataset to train the model on

  • in_queue (mp.Queue) – mp queue for input

  • out_queue (mp.Queue) – mp queue for output

  • args (Namespace) – tart configs

Return type:

List[Process]

Returns:

List[mp.Process] – list of workers

Module contents#