Пример #1
0
def sample_text(model, length: int, conditional_files: List[str] = None, temperature: float = 1.0) -> Tuple[str, str]:
    if not conditional_files:
        context = prepare_git_context()
    else:
        context = prepare_git_context(conditional_files[-1],
                                      conditional_files[:-1] if len(conditional_files) > 1 else None)
    text = generate_text(unwrap_model(model), context, length, num_diversity_groups=1, tokenizer=g.corpus.vocab.tokenizer, verbose=False)[0][0]
    return context, text
Пример #2
0
def evaluate(model,
             eval_iter,
             label: str,
             max_eval_steps: int = 0,
             reset_mems_interval: int = None):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_len, total_count = 0, 0
    total_loss, total_top1, total_top5, MRR_total = 0., 0, 0, 0.
    with torch.no_grad():
        mems = tuple()
        bar = tqdm.tqdm(eval_iter, leave=False)
        for i, (data, target, seq_len) in enumerate(bar):
            if 0 < max_eval_steps <= i:
                break
            if reset_mems_interval is not None and i % reset_mems_interval == 0:
                mems = tuple()

            ret = model(data, target, *mems, return_hidden=True)
            pred_hidden, loss, mems = ret[0].half(), ret[1], ret[2:]
            pred_probas = unwrap_model(model).hidden_to_softmax(pred_hidden)
            # Loss calculation
            loss = loss.mean()
            total_loss += seq_len * loss.item()

            # Accuracy calculation
            _, pred_top = torch.topk(pred_probas, 5)
            true_pos = pred_top == target.unsqueeze(-1).expand_as(pred_top)
            true_top1 = true_pos[:, :, 0].sum()
            true_top5 = true_pos[:, :, :5].sum()
            total_top1 += true_top1
            total_top5 += true_top5

            # MRR calculation
            MRR_total += float(
                (true_pos.double() /
                 (torch.arange(end=true_pos.size(-1),
                               dtype=torch.double,
                               device=true_pos.device) + 1)).sum())

            total_len += seq_len
            total_count += seq_len * target.size(1)
            MRR_top5 = MRR_total / total_count
            accuracy_top1 = float(total_top1) / total_count
            accuracy_top5 = float(total_top5) / total_count
            bar.set_description(f'{label} '
                                f'| loss: {total_loss / total_len:.2f} '
                                f'| accuracy@1: {accuracy_top1:.2f} '
                                f'| accuracy@5: {accuracy_top5:.2f} '
                                f'| MRR@5: {MRR_top5:.2f} ')
        metrics = {
            "total_loss": total_loss,
            "accuracy_top1": accuracy_top1,
            "accuracy_top5": accuracy_top5,
            "MRR_top5": MRR_top5,
            "total_len": total_len,
        }
    return metrics
Пример #3
0
 def _load_model(model_path: str, device: torch.device) -> MemTransformerLM:
     with open(model_path, "rb") as f:
         model: MemTransformerLM = torch.load(f, map_location=device)
     model = unwrap_model(model)
     return model
Пример #4
0
def evaluate_and_log(model: torch.nn.Module,
                     eval_iter,
                     split: str,
                     generate_text: bool = True,
                     reset_mems_interval: int = None):
    args = g.args
    state = g.state
    optimizer = g.state.optimizer
    eval_start_time = time.time()

    model_to_reset = util.unwrap_model(model)
    # If the model does not use memory at all, make the ext_len longer.
    # Otherwise, make the mem_len longer and keep the ext_len the same.
    if g.args.mem_len == 0:
        model_to_reset.reset_length(
            args.eval_tgt_len, args.ext_len + args.tgt_len - args.eval_tgt_len,
            args.mem_len)
    else:
        model_to_reset.reset_length(
            args.eval_tgt_len, args.ext_len,
            args.mem_len + args.tgt_len - args.eval_tgt_len)

    # Calculate metrics
    ret = evaluate(model,
                   eval_iter,
                   split,
                   args.max_eval_steps,
                   reset_mems_interval=reset_mems_interval)
    total_loss, accuracy_top1, accuracy_top5, MRR, total_len = \
        ret["total_loss"], ret["accuracy_top1"], ret["accuracy_top5"], ret["MRR_top5"], ret["total_len"]
    # Switch back to the training mode
    model_to_reset.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    model.train()

    # Log all the things.
    loss = total_loss / total_len
    mean_loss = util.dist_mean(loss)
    mean_accuracy_top1 = util.dist_mean(accuracy_top1)
    mean_accuracy_top5 = util.dist_mean(accuracy_top5)
    mean_MRR = util.dist_mean(MRR)
    g.logger.info('-' * 100)
    log_str = (
        f'| Eval {g.state.train_step // args.eval_interval:3d} at step {g.state.train_step:>8d} | '
        f'time: {time.time() - eval_start_time:5.2f}s '
        f'| {split} loss {loss:5.2f}')
    log_tb(f'learning/{split}_loss', mean_loss)
    if args.dataset in ['enwik8', 'text8']:
        log_str += f' | bpc {loss / math.log(2):9.5f}'
        log_tb(f'learning/{split}_bpc', mean_loss / math.log(2))
    elif args.dataset == 'git':
        log_str += f' | accuracy@1 {accuracy_top1:.2f} ' \
                   f'| accuracy@5 {accuracy_top5:.2f} ' \
                   f'| MRR@5 {MRR:.2f}'
        log_tb(f'learning/{split}_acc@1', mean_accuracy_top1)
        log_tb(f'learning/{split}_acc@5', mean_accuracy_top5)
        log_tb(f'learning/{split}_MRR@5', mean_MRR)
    else:
        log_str += f' | {split} ppl {math.exp(loss):9.3f}'
        log_tb(f'learning/{split}_ppl', math.exp(mean_loss))
    g.logger.info(log_str)
    g.logger.info('-' * 100)

    # Update checkpoint if validation loss improved.
    if split == 'val' and (not state.best_val_loss
                           or mean_loss < state.best_val_loss):
        g.logger.info('Saving checkpoint for new best loss')
        util.dist_save_checkpoint(model, optimizer, args.logdir, suffix='best')
        state.best_val_loss = mean_loss