def eval_on_batch(model: Tree2Seq, criterion: nn.modules.loss, graph: dgl.BatchedDGLGraph, labels: List[str], device: torch.device) -> Tuple[Dict, torch.Tensor]: model.eval() root_indexes = get_root_indexes(graph).to(device) # Model step with torch.no_grad(): root_logits, ground_truth = model(graph, root_indexes, labels, 0.0, device) root_logits = root_logits[1:] ground_truth = ground_truth[1:] loss = criterion(root_logits.view(-1, root_logits.shape[-1]), ground_truth.view(-1)) # Calculate metrics prediction = model.predict(root_logits) batch_eval_info = { 'loss': loss.item(), 'statistics': calculate_batch_statistics(ground_truth, prediction, [ model.decoder.label_to_id[token] for token in [PAD, UNK, EOS] ]) } return batch_eval_info, prediction
def eval_on_batch( model: Tree2Seq, criterion: nn.modules.loss, graph: dgl.DGLGraph, labels: torch.Tensor ) -> Tuple[Dict, torch.Tensor]: model.eval() # Model step with torch.no_grad(): loss, prediction, batch_info = _forward_pass(model, graph, labels, criterion) del loss return batch_info, prediction
def train_on_dataset( train_dataset: Dataset, val_dataset, model: Tree2Seq, criterion: nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, clip_norm: int, logger: AbstractLogger, start_batch_id: int = 0, log_step: int = -1, eval_step: int = -1, save_step: int = -1 ): train_epoch_info = LearningInfo() batch_iterator_pb = tqdm(range(start_batch_id, len(train_dataset)), total=len(train_dataset)) batch_iterator_pb.update(start_batch_id) batch_iterator_pb.refresh() for batch_id in batch_iterator_pb: graph, labels = train_dataset[batch_id] batch_info = train_on_batch(model, criterion, optimizer, scheduler, graph, labels, clip_norm) train_epoch_info.accumulate_info(batch_info) if is_step_match(batch_id, log_step): logger.log(train_epoch_info.get_state_dict(), batch_id, is_train=True) train_epoch_info = LearningInfo() if is_step_match(batch_id, save_step): train_dump = { 'state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'batch_id': batch_id } logger.save_model(f'batch_{batch_id}.pt', train_dump) if is_step_match(batch_id, eval_step): eval_info = evaluate_on_dataset(val_dataset, model, criterion) logger.log(eval_info.get_state_dict(), batch_id, is_train=False) if train_epoch_info.batch_processed > 0: logger.log(train_epoch_info.get_state_dict(), len(train_dataset) - 1, is_train=True)
def save_model(self, model: Tree2Seq, output_name: str, configuration: Dict, **kwargs: Dict) -> str: saving_path = join_path(self.checkpoints_folder, output_name) output = { 'state_dict': model.state_dict(), 'configuration': configuration } output.update(kwargs) torch.save(output, saving_path) return saving_path
def train_on_batch( model: Tree2Seq, criterion: nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, graph: dgl.DGLGraph, labels: torch.Tensor, clip_norm: int ) -> Dict: model.train() # Model step model.zero_grad() loss, prediction, batch_info = _forward_pass(model, graph, labels, criterion) batch_info['learning_rate'] = scheduler.get_last_lr()[0] loss.backward() nn.utils.clip_grad_value_(model.parameters(), clip_norm) optimizer.step() scheduler.step() del loss del prediction torch.cuda.empty_cache() return batch_info
def train_on_batch(model: Tree2Seq, criterion: nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler, graph: dgl.BatchedDGLGraph, labels: List[str], params: Dict, device: torch.device) -> Dict: model.train() root_indexes = get_root_indexes(graph).to(device) # Model step model.zero_grad() root_logits, ground_truth = model(graph, root_indexes, labels, params['teacher_force'], device) root_logits = root_logits[1:] ground_truth = ground_truth[1:] loss = criterion(root_logits.view(-1, root_logits.shape[-1]), ground_truth.view(-1)) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), params['clip_norm']) optimizer.step() scheduler.step() # Calculate metrics prediction = model.predict(root_logits) batch_train_info = { 'loss': loss.item(), 'statistics': calculate_batch_statistics( ground_truth, prediction, [model.decoder.label_to_id[token] for token in [PAD, UNK, EOS]]) } return batch_train_info
def interactive(path_to_function: str, path_to_model: str): fix_seed() device = get_device() print(f"using {device} device") # load model print("loading model...") checkpoint = torch.load(path_to_model, map_location=device) model = Tree2Seq(**checkpoint['configuration']).to(device) model.load_state_dict(checkpoint['state_dict']) token_to_id = model.token_to_id type_to_id = model.type_to_id label_to_id = model.label_to_id id_to_label = {v: k for k, v in label_to_id.items()} # convert function to dgl format print("convert function to dgl format...") create_folder(TMP_FOLDER) build_asts(path_to_function, TMP_FOLDER, ASTMINER_PATH, *ASTMINER_PARAMS) project_folder = os.path.join(TMP_FOLDER, 'java') convert_project(project_folder, token_to_id, type_to_id, label_to_id, True, True, 5, 6, False, True, '|') # load function graph, labels = load_graphs(os.path.join(project_folder, 'converted.dgl')) labels = labels['labels'] assert len(labels) == 1, f"found {len('labels')} functions, instead of 1" ast = graph[0].reverse(share_ndata=True) ast.ndata['token'] = ast.ndata['token'].to(device) ast.ndata['type'] = ast.ndata['type'].to(device) labels = labels.t().to(device) root_indexes = torch.tensor([0], dtype=torch.long) # forward pass model.eval() with torch.no_grad(): logits = model(ast, root_indexes, labels, device) logits = logits[1:] prediction = model.predict(logits).reshape(-1) sublabels = [id_to_label[label_id.item()] for label_id in prediction] label = '|'.join(takewhile(lambda sl: sl != EOS, sublabels)) print(f"the predicted label is:\n{label}")
def _forward_pass( model: Tree2Seq, graph: dgl.DGLGraph, labels: torch.Tensor, criterion: nn.modules.loss ) -> Tuple[torch.Tensor, torch.Tensor, Dict]: """Make model step :param model: Tree2Seq model :param graph: batched dgl graph :param labels: [seq len; batch size] ground truth labels :param criterion: criterion to optimize :return: Tuple[ loss [1] torch tensor with loss information prediction [the longest sequence, batch size] batch info [Dict] dict with statistics ] """ # [seq len; batch size; vocab size] root_logits = model(graph, labels) # if seq len in labels equal to 1, then model solve classification task # for longer sequences we should remove <SOS> token, since it's always on the first place if labels.shape[0] > 1: # [seq len - 1; batch size; vocab size] root_logits = root_logits[1:] # [seq len - 1; batch size] labels = labels[1:] loss = criterion(root_logits.reshape(-1, root_logits.shape[-1]), labels.reshape(-1)) # [the longest sequence, batch size] prediction = model.predict(root_logits) # Calculate metrics skipping_tokens = [model.decoder.label_to_id[token] for token in [PAD, UNK, EOS] if token in model.decoder.label_to_id] batch_info = { 'loss': loss.item(), 'statistics': calculate_batch_statistics( labels.t(), prediction.t(), skipping_tokens ) } return loss, prediction, batch_info
def evaluate(params: Dict) -> None: fix_seed() device = get_device() print(f"using {device} device") checkpoint = torch.load(params['model'], map_location=device) print('model initializing...') # create model model = Tree2Seq(**checkpoint['configuration']).to(device) model.load_state_dict(checkpoint['state_dict']) evaluation_set = TreeDGLDataset(params['dataset'], params['batch_size'], device, True) # define loss function criterion = nn.CrossEntropyLoss(ignore_index=model.label_to_id[PAD]).to(device) # evaluation loop print("ok, let's evaluate it") eval_epoch_info = evaluate_on_dataset(evaluation_set, model, criterion) print(eval_epoch_info.get_state_dict())