Exemple #1
0
def eval_on_batch(model: Tree2Seq, criterion: nn.modules.loss,
                  graph: dgl.BatchedDGLGraph, labels: List[str],
                  device: torch.device) -> Tuple[Dict, torch.Tensor]:
    model.eval()

    root_indexes = get_root_indexes(graph).to(device)

    # Model step
    with torch.no_grad():
        root_logits, ground_truth = model(graph, root_indexes, labels, 0.0,
                                          device)
        root_logits = root_logits[1:]
        ground_truth = ground_truth[1:]
        loss = criterion(root_logits.view(-1, root_logits.shape[-1]),
                         ground_truth.view(-1))

        # Calculate metrics
        prediction = model.predict(root_logits)
        batch_eval_info = {
            'loss':
            loss.item(),
            'statistics':
            calculate_batch_statistics(ground_truth, prediction, [
                model.decoder.label_to_id[token] for token in [PAD, UNK, EOS]
            ])
        }
        return batch_eval_info, prediction
Exemple #2
0
def eval_on_batch(
        model: Tree2Seq, criterion: nn.modules.loss, graph: dgl.DGLGraph,
        labels: torch.Tensor
) -> Tuple[Dict, torch.Tensor]:
    model.eval()
    # Model step
    with torch.no_grad():
        loss, prediction, batch_info = _forward_pass(model, graph, labels, criterion)
        del loss

    return batch_info, prediction
def train_on_dataset(
        train_dataset: Dataset, val_dataset, model: Tree2Seq, criterion: nn.modules.loss, optimizer: torch.optim,
        scheduler: torch.optim.lr_scheduler, clip_norm: int, logger: AbstractLogger, start_batch_id: int = 0,
        log_step: int = -1, eval_step: int = -1, save_step: int = -1
):
    train_epoch_info = LearningInfo()

    batch_iterator_pb = tqdm(range(start_batch_id, len(train_dataset)), total=len(train_dataset))
    batch_iterator_pb.update(start_batch_id)
    batch_iterator_pb.refresh()

    for batch_id in batch_iterator_pb:
        graph, labels = train_dataset[batch_id]
        batch_info = train_on_batch(model, criterion, optimizer, scheduler, graph, labels, clip_norm)
        train_epoch_info.accumulate_info(batch_info)

        if is_step_match(batch_id, log_step):
            logger.log(train_epoch_info.get_state_dict(), batch_id, is_train=True)
            train_epoch_info = LearningInfo()

        if is_step_match(batch_id, save_step):
            train_dump = {
                'state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'batch_id': batch_id
            }
            logger.save_model(f'batch_{batch_id}.pt', train_dump)

        if is_step_match(batch_id, eval_step):
            eval_info = evaluate_on_dataset(val_dataset, model, criterion)
            logger.log(eval_info.get_state_dict(), batch_id, is_train=False)

    if train_epoch_info.batch_processed > 0:
        logger.log(train_epoch_info.get_state_dict(), len(train_dataset) - 1, is_train=True)
 def save_model(self, model: Tree2Seq, output_name: str,
                configuration: Dict, **kwargs: Dict) -> str:
     saving_path = join_path(self.checkpoints_folder, output_name)
     output = {
         'state_dict': model.state_dict(),
         'configuration': configuration
     }
     output.update(kwargs)
     torch.save(output, saving_path)
     return saving_path
Exemple #5
0
def train_on_batch(
        model: Tree2Seq, criterion: nn.modules.loss, optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
        graph: dgl.DGLGraph, labels: torch.Tensor, clip_norm: int
) -> Dict:
    model.train()

    # Model step
    model.zero_grad()
    loss, prediction, batch_info = _forward_pass(model, graph, labels, criterion)
    batch_info['learning_rate'] = scheduler.get_last_lr()[0]
    loss.backward()
    nn.utils.clip_grad_value_(model.parameters(), clip_norm)
    optimizer.step()
    scheduler.step()
    del loss
    del prediction
    torch.cuda.empty_cache()

    return batch_info
Exemple #6
0
def train_on_batch(model: Tree2Seq, criterion: nn.modules.loss,
                   optimizer: torch.optim, scheduler: torch.optim.lr_scheduler,
                   graph: dgl.BatchedDGLGraph, labels: List[str], params: Dict,
                   device: torch.device) -> Dict:
    model.train()

    root_indexes = get_root_indexes(graph).to(device)

    # Model step
    model.zero_grad()
    root_logits, ground_truth = model(graph, root_indexes, labels,
                                      params['teacher_force'], device)
    root_logits = root_logits[1:]
    ground_truth = ground_truth[1:]
    loss = criterion(root_logits.view(-1, root_logits.shape[-1]),
                     ground_truth.view(-1))
    loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), params['clip_norm'])
    optimizer.step()
    scheduler.step()

    # Calculate metrics
    prediction = model.predict(root_logits)
    batch_train_info = {
        'loss':
        loss.item(),
        'statistics':
        calculate_batch_statistics(
            ground_truth, prediction,
            [model.decoder.label_to_id[token] for token in [PAD, UNK, EOS]])
    }
    return batch_train_info
Exemple #7
0
def interactive(path_to_function: str, path_to_model: str):
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    # load model
    print("loading model...")
    checkpoint = torch.load(path_to_model, map_location=device)

    model = Tree2Seq(**checkpoint['configuration']).to(device)
    model.load_state_dict(checkpoint['state_dict'])

    token_to_id = model.token_to_id
    type_to_id = model.type_to_id
    label_to_id = model.label_to_id
    id_to_label = {v: k for k, v in label_to_id.items()}

    # convert function to dgl format
    print("convert function to dgl format...")
    create_folder(TMP_FOLDER)
    build_asts(path_to_function, TMP_FOLDER, ASTMINER_PATH, *ASTMINER_PARAMS)
    project_folder = os.path.join(TMP_FOLDER, 'java')
    convert_project(project_folder, token_to_id, type_to_id, label_to_id, True,
                    True, 5, 6, False, True, '|')

    # load function
    graph, labels = load_graphs(os.path.join(project_folder, 'converted.dgl'))
    labels = labels['labels']
    assert len(labels) == 1, f"found {len('labels')} functions, instead of 1"
    ast = graph[0].reverse(share_ndata=True)
    ast.ndata['token'] = ast.ndata['token'].to(device)
    ast.ndata['type'] = ast.ndata['type'].to(device)
    labels = labels.t().to(device)
    root_indexes = torch.tensor([0], dtype=torch.long)

    # forward pass
    model.eval()
    with torch.no_grad():
        logits = model(ast, root_indexes, labels, device)
    logits = logits[1:]
    prediction = model.predict(logits).reshape(-1)
    sublabels = [id_to_label[label_id.item()] for label_id in prediction]
    label = '|'.join(takewhile(lambda sl: sl != EOS, sublabels))
    print(f"the predicted label is:\n{label}")
Exemple #8
0
def _forward_pass(
        model: Tree2Seq, graph: dgl.DGLGraph, labels: torch.Tensor, criterion: nn.modules.loss
) -> Tuple[torch.Tensor, torch.Tensor, Dict]:
    """Make model step

    :param model: Tree2Seq model
    :param graph: batched dgl graph
    :param labels: [seq len; batch size] ground truth labels
    :param criterion: criterion to optimize
    :return: Tuple[
        loss [1] torch tensor with loss information
        prediction [the longest sequence, batch size]
        batch info [Dict] dict with statistics
    ]
    """
    # [seq len; batch size; vocab size]
    root_logits = model(graph, labels)

    # if seq len in labels equal to 1, then model solve classification task
    # for longer sequences we should remove <SOS> token, since it's always on the first place
    if labels.shape[0] > 1:
        # [seq len - 1; batch size; vocab size]
        root_logits = root_logits[1:]
        # [seq len - 1; batch size]
        labels = labels[1:]

    loss = criterion(root_logits.reshape(-1, root_logits.shape[-1]), labels.reshape(-1))
    # [the longest sequence, batch size]
    prediction = model.predict(root_logits)

    # Calculate metrics
    skipping_tokens = [model.decoder.label_to_id[token]
                       for token in [PAD, UNK, EOS]
                       if token in model.decoder.label_to_id]
    batch_info = {
        'loss': loss.item(),
        'statistics':
            calculate_batch_statistics(
                labels.t(), prediction.t(), skipping_tokens
            )
    }

    return loss, prediction, batch_info
def evaluate(params: Dict) -> None:
    fix_seed()
    device = get_device()
    print(f"using {device} device")

    checkpoint = torch.load(params['model'], map_location=device)

    print('model initializing...')
    # create model
    model = Tree2Seq(**checkpoint['configuration']).to(device)
    model.load_state_dict(checkpoint['state_dict'])

    evaluation_set = TreeDGLDataset(params['dataset'], params['batch_size'], device, True)

    # define loss function
    criterion = nn.CrossEntropyLoss(ignore_index=model.label_to_id[PAD]).to(device)

    # evaluation loop
    print("ok, let's evaluate it")
    eval_epoch_info = evaluate_on_dataset(evaluation_set, model, criterion)

    print(eval_epoch_info.get_state_dict())