Source code for tart.utils.model_utils

from argparse import Namespace
from typing import Tuple, Iterator

import torch
import torch.optim as optim

DEVICE_CACHE = None


[docs]def get_device() -> torch.device: """get device (cpu or gpu) Returns: torch.device: available device """ # get device (cpu or gpu) global DEVICE_CACHE if DEVICE_CACHE is None: if torch.cuda.is_available(): # print("GPU is available!!!") DEVICE_CACHE = torch.device("cuda") else: DEVICE_CACHE = torch.device("cpu") return DEVICE_CACHE
[docs]def get_torch_tensor_type() -> torch.FloatTensor or torch.cuda.FloatTensor: """get torch tensor type (cpu or gpu) Returns: torch.dtype: correct torch tensor type """ if torch.cuda.is_available(): return torch.cuda.FloatTensor else: return torch.FloatTensor
[docs]def build_model(model_type: torch.nn.Module, args: Namespace) -> torch.nn.Module: """build the user specified model Args: model_type (torch.nn.Module): model class args (Namespace): user defined configs Returns: torch.nn.Module: built model """ # build model model = model_type(1, args.hidden_dim, args) model.to(get_device()) if args.test and args.model_path: model.load_state_dict(torch.load(args.model_path, map_location=get_device())) return model
[docs]def build_optimizer(args: Namespace, params: Iterator) -> Tuple: """build optimizer and scheduler Args: args (Namespace): user defined configs params (Iterator): model parameters Returns: Tuple[optim.lr_scheduler.LRScheduler, optim.Optimizer]: _description_ """ # build optimizer weight_decay = args.weight_decay filter_fn = filter(lambda p: p.requires_grad, params) if args.opt == "adam": optimizer = optim.Adam(filter_fn, lr=args.lr, weight_decay=weight_decay) elif args.opt == "sgd": optimizer = optim.SGD(filter_fn, lr=args.lr, momentum=0.95, weight_decay=weight_decay) elif args.opt == "rmsprop": optimizer = optim.RMSprop(filter_fn, lr=args.lr, weight_decay=weight_decay) elif args.opt == "adagrad": optimizer = optim.Adagrad(filter_fn, lr=args.lr, weight_decay=weight_decay) if args.opt_scheduler == "none": return None, optimizer elif args.opt_scheduler == "step": scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.opt_decay_step, gamma=args.opt_decay_rate) elif args.opt_scheduler == "cos": scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.opt_restart) return scheduler, optimizer