Source code for tart.representation.encoders
from typing import Callable
from transformers import RobertaTokenizer, RobertaModel
import torch
my_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
max_len = None
tokenizer = None
model = None
# ########## FEATURE ENCODERS ##########
[docs]def codebert_encoder(x: str) -> torch.tensor:
"""feature encoder using CodeBert model
Args:
x (str): string to be encoded
Returns:
torch.tensor: encoding of the string
"""
global tokenizer, model
tokens_ids = tokenizer.encode(x, truncation=True)
tokens_tensor = torch.tensor(tokens_ids, device=my_device)
with torch.no_grad():
context_embeddings = model(tokens_tensor[None, :])[0]
encoding = torch.mean(context_embeddings, dim=1)
return encoding
[docs]def codebert_bpe_encoder(x: str) -> torch.tensor:
"""feature encoder using CodeBert BPE tokenizer
Args:
x (str): string to be encoded
Returns:
torch.tensor: encoding of the string
"""
global tokenizer, max_len
encoded_input = tokenizer(
x,
return_tensors="pt",
max_length=max_len,
padding="max_length",
truncation=True,
)
encoding = encoded_input["input_ids"]
return encoding
# ########## FEATURE ENCODER FACTORY ##########
ENCODER_STRATEGY = {
"CodeBert": codebert_encoder,
"CodeBertBPE": codebert_bpe_encoder,
}
[docs]def get_feature_encoder(encoder_name: str, **kwargs) -> Callable[[str], torch.tensor]:
"""Factory function to get a feature encoder
Args:
encoder_name (str): name of the encoder to retrieve
Returns:
Callable[[str], torch.tensor]: callable feature encoder
"""
global tokenizer, model, max_len
if encoder_name == "CodeBert":
codebert_name = "microsoft/codebert-base"
tokenizer = RobertaTokenizer.from_pretrained(codebert_name)
model = RobertaModel.from_pretrained(codebert_name).to(my_device)
model.eval()
return ENCODER_STRATEGY[encoder_name]
elif encoder_name == "CodeBertBPE":
try:
max_len = kwargs["max_len"]
except KeyError:
raise ValueError("max_len is required for CodeBertBPE encoder; please provide it as a keyword argument.")
codebert_name = "microsoft/codebert-base"
tokenizer = RobertaTokenizer.from_pretrained(codebert_name)
return ENCODER_STRATEGY[encoder_name]
elif encoder_name == "Bert":
raise NotImplementedError
elif encoder_name == "BertBPE":
raise NotImplementedError
elif encoder_name == "GPT2":
raise NotImplementedError
elif encoder_name == "GPT3":
raise NotImplementedError
else:
raise ValueError(
f"Oops, {encoder_name} is not a default encoder!\
You can register it as a custom encoder in encoders.py!."
)