Source code for tart.representation.config

"""Configs for model and optimizer"""

from argparse import ArgumentParser, Namespace
from test_tube import HyperOptArgumentParser
from typing import Dict, List


[docs]def make_tunable(parser: HyperOptArgumentParser, tunable: List[str]) -> None: """make select arguments tunable Args: parser (ArgumentParser): argparse parser tunable (List[str]): list of arguments to make tunable """ for arg in tunable: # remove arg from current parser parser._option_string_actions.pop("--" + arg) # add arg to test_tube parser if arg == "batch_size": parser.opt_list("--batch_size", type=int, help="Training batch size", tunable=True, default=64, options=[32, 64, 128]) elif arg == "agg_type": parser.opt_list( "--agg_type", type=str, help="type of aggregation", tunable=True, default="GINE", options=["GINE", "GIN", "GCN"] ) elif arg == "n_layers": parser.opt_list("--n_layers", type=int, help="Number of graph conv layers", tunable=True, default=7, options=[5, 7, 9, 11]) elif arg == "hidden_dim": parser.opt_list("--hidden_dim", type=int, help="Training hidden size", tunable=True, default=64, options=[32, 64, 128]) elif arg == "skip": parser.opt_list("--skip", type=str, help="skip connections", tunable=True, default="learnable", options=["all", "learnable"]) else: raise ValueError("Argument {} is not tunable.".format(arg))
[docs]def build_model_configs(parser: ArgumentParser) -> None: """build model config arguments Args: parser (ArgumentParser): argparse parser """ enc_args = parser.add_argument_group() enc_args.add_argument("--agg_type", type=str, help="type of aggregation/convolution") enc_args.add_argument("--batch_size", type=int, help="Training batch size") enc_args.add_argument("--n_layers", type=int, help="Number of graph conv layers") enc_args.add_argument("--hidden_dim", type=int, help="Training hidden size") enc_args.add_argument("--skip", type=str, help='"all" or "learnable"') enc_args.add_argument("--dropout", type=float, help="Dropout rate") enc_args.add_argument("--n_iters", type=int, help="Number of training iterations") enc_args.add_argument("--n_batches", type=int, help="Number of training minibatches") enc_args.add_argument("--margin", type=float, help="margin for loss") enc_args.add_argument("--dataset", type=str, help="Dataset name") enc_args.add_argument( "--data_dir", type=str, help="path to the root directory of the train/test sub-directories", ) enc_args.add_argument("--test_set", type=str, help="test set filename") enc_args.add_argument("--eval_interval", type=int, help="how often to eval during training") enc_args.add_argument("--val_size", type=int, help="validation set size") enc_args.add_argument("--model_path", type=str, help="path to save/load model") enc_args.add_argument("--opt_scheduler", type=str, help="scheduler name") enc_args.add_argument( "--node_anchored", action="store_true", help="whether to use node anchoring in training", ) enc_args.add_argument("--test", action="store_true") enc_args.add_argument("--n_workers", type=int) enc_args.add_argument("--tag", type=str, help="tag to identify the run") enc_args.set_defaults( agg_type="GINE", dataset="example", data_dir="../data/example", n_layers=7, batch_size=64, hidden_dim=64, skip="learnable", dropout=0.0, n_iters=5, n_batches=10000, opt="adam", opt_scheduler="none", opt_restart=100, weight_decay=0.0, lr=1e-4, margin=0.1, test_set="", eval_interval=1000, n_workers=4, model_path="ckpt/model.pt", tag="", val_size=4096, node_anchored=True, )
[docs]def build_optimizer_configs(parser: ArgumentParser) -> None: """build optimizer config arguments Args: parser (ArgumentParser): argparse parser """ opt_parser = parser.add_argument_group() opt_parser.add_argument("--opt", dest="opt", type=str, help="Type of optimizer") opt_parser.add_argument( "--opt-scheduler", dest="opt_scheduler", type=str, help="Type of optimizer scheduler. default: none", ) opt_parser.add_argument( "--opt-restart", dest="opt_restart", type=int, help="Number of epochs before restart, default: 0", ) opt_parser.add_argument( "--opt-decay-step", dest="opt_decay_step", type=int, help="Number of epochs before decay", ) opt_parser.add_argument( "--opt-decay-rate", dest="opt_decay_rate", type=float, help="Learning rate decay ratio", ) opt_parser.add_argument("--lr", dest="lr", type=float, help="Learning rate.") opt_parser.add_argument("--clip", dest="clip", type=float, help="Gradient clipping.") opt_parser.add_argument("--weight_decay", type=float, help="Optimizer weight decay.") opt_parser.set_defaults(opt="adam", opt_scheduler="none", opt_restart=100, weight_decay=0.0, lr=1e-4)
[docs]def build_feature_configs(parser: ArgumentParser) -> None: """build graph feature config arguments Args: parser (ArgumentParser): argparse parser """ feat_parser = parser.add_argument_group() feat_parser.add_argument("--node_feats", nargs="+", help="node features to use in training") feat_parser.add_argument("--edge_feats", nargs="+", help="edge features to use in training") feat_parser.add_argument("--node_feat_dims", nargs="+", help="node feature dimension") feat_parser.add_argument("--edge_feat_dims", nargs="+", help="edge feature dimension")
[docs]def init_user_configs(args: Namespace, configs_json: Dict, tune: bool = False) -> Namespace: """initialize user defined configs Args: args (Namespace): argparse namespace configs_json (Dict): user defined configs Raises: ValueError: node_feats not provided in configs.json ValueError: edge_feats not provided in configs.json ValueError: node and edge feats names overlap Returns: Namespace: updated argparse namespace """ # check if node_feats and edge_feats are provided if "node_feats" not in configs_json: raise ValueError("node_feats not provided in configs.json") if "edge_feats" not in configs_json: raise ValueError("edge_feats not provided in configs.json") # check if there is an overlap between node_feats and edge_feats feat_overlap = set(configs_json["node_feats"]) & set(configs_json["edge_feats"]) if len(feat_overlap) > 0: raise ValueError("node and edge feats overlap on features: {}! Please rename them. ".format(feat_overlap)) args.node_feats = configs_json["node_feats"] + [ "node_degree", "node_pagerank", "node_cc", ] args.edge_feats = configs_json["edge_feats"] args.node_feat_dims = configs_json["node_feat_dims"] + [1, 1, 1] args.edge_feat_dims = configs_json["edge_feat_dims"] args.node_feat_types = configs_json["node_feat_types"] + ["int", "int", "int"] args.edge_feat_types = configs_json["edge_feat_types"] # other (assumes non list) features that were provided: for feat in set(configs_json) - set( [ "node_feats", "edge_feats", "node_feat_dims", "edge_feat_dims", "node_feat_types", "edge_feat_types", ] ): if tune and feat in configs_json["tunable"]: raise ValueError(f"Feature {feat} is tunable. Please remove it from the config file.") setattr(args, feat, configs_json[feat]) return args