Source code for tart.utils.config_utils
from typing import Callable, Dict
[docs]def validate_feat_encoder(user_feat_encoder: Callable, config_json: Dict) -> Callable:
"""validate user defined feature encoder
Args:
user_feat_encoder (Callable): user defined feature encoder
config_json (Dict): user defined configs
Returns:
Callable: user defined feature encoder
"""
str_feat_idx = config_json["node_feat_types"].index("str")
# check if the function takes a string and returns a torch.tensor using a sample input
assert user_feat_encoder.__code__.co_argcount == 1, "feat_encoder must take a single (str) argument"
# check if the function takes a string and returns a torch.tensor
# expected shape = config_json['node_feat_dim'][0] (assumes all features use the same encoder)
exp_dim = config_json["node_feat_dims"][str_feat_idx]
recv_dim = user_feat_encoder("test").shape
assert recv_dim == (
1,
exp_dim,
), f"feat_encoder must return a torch.tensor of shape ({exp_dim},) got {recv_dim}"
return user_feat_encoder