def evaluate(epoch, dataloader, eval_type='valid', final_eval=False):
    global val_metric_best, lr, stop_training

    if eval_type == 'valid':
        print('\nVALIDATION : Epoch {0}'.format(epoch))

    vmetrics = Metrics(tok2i, i2tok, field=TRG)
    vmetrics.reset()
    model.eval()
    for i, batch in enumerate(dataloader, 0):
        scores, samples = predict_batch(batch)
        vmetrics.update(scores, samples, (batch.trg[0], None))
    model.train()

    kind = eval_type if not final_eval else 'final_' + eval_type
    ms = vmetrics.report(kind)
    eval_metric = ms['%s/%s' % (kind, args.eval_metric)]
    metrics_to_log = ['bleu', 'avg_span', 'f1', 'em', 'depth_score']
    if final_eval:
        print('final: ' + vmetrics.log(ms, kind, metrics_to_log))
        log_tensorboard(ms, step=args.logstep)
    else:
        print(('valid (epoch %d): ' % epoch) + vmetrics.log(ms, kind, metrics_to_log))
        log_tensorboard(ms, step=args.logstep)

    if eval_type == 'valid' and epoch <= args.n_epochs:
        if eval_metric >= val_metric_best:
            print('saving model at epoch {0}'.format(epoch))
            torch.save(model.state_dict(), os.path.join(args.log_directory, args.expr_name))
            val_metric_best = eval_metric
        if epoch > 1 and epoch % args.lrshrink_nepochs == 0:
            optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] / args.lrshrink
            print('Shrinking lr by : {0}. New lr = {1}'
                  .format(args.lrshrink, optimizer.param_groups[0]['lr']))
    return eval_metric
Exemple #2
0
def adjust(epoch):
    if epoch <= args.beta_burnin:
        return
    args.rollin_beta = max(args.rollin_beta - args.beta_step, args.beta_min)
    log_tensorboard({'sampler.beta': args.rollin_beta}, step=args.logstep)
    if args.self_teach_beta_step > 0 and 'self_teach_beta' in loss_flags:
        loss_flags['self_teach_beta'] = max(loss_flags['self_teach_beta'] - args.self_teach_beta_step, 0.0)
        log_tensorboard({'self_teach_beta': loss_flags['self_teach_beta']}, step=args.logstep)
def adjust(epoch, sampler):
    if epoch <= args.beta_burnin:
        return
    if hasattr(sampler, 'beta'):
        sampler.beta = max(sampler.beta - args.beta_step, 0.0)
        log_tensorboard({'sampler.beta': sampler.beta}, step=args.logstep)
    if args.self_teach_beta_step > 0 and 'self_teach_beta' in loss_flags:
          loss_flags['self_teach_beta'] = max(loss_flags['self_teach_beta'] - args.self_teach_beta_step, 0.0)
          log_tensorboard({'self_teach_beta': loss_flags['self_teach_beta']}, step=args.logstep)
Exemple #4
0
def adjust(sampler, epoch):
    if epoch > 1 and epoch % args.lrshrink_nepochs == 0:
        optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] / args.lrshrink
        print('Shrinking lr by : {0}. New lr = {1}'
              .format(args.lrshrink, optimizer.param_groups[0]['lr']))

    if epoch <= args.beta_burnin:
        return
    if hasattr(sampler, 'beta'):
        sampler.beta = max(sampler.beta - args.beta_step, 0.0)
        log_tensorboard({'sampler.beta': sampler.beta}, step=args.logstep)
    if args.self_teach_beta_step > 0 and 'self_teach_beta' in loss_flags:
        loss_flags['self_teach_beta'] = max(loss_flags['self_teach_beta'] - args.self_teach_beta_step, 0.0)
        log_tensorboard({'self_teach_beta': loss_flags['self_teach_beta']}, step=args.logstep)
def train_epoch(epoch):
    print('\nTRAINING : Epoch ' + str(epoch))
    model.train()
    losses = []
    logs = []

    last_time = time.time()

    metrics = Metrics(tok2i, i2tok, field=TRG)
    for i, batch in enumerate(trainloader):
        # -- Actual Training
        gt.reset()
        gt.stamp("load_data")

        oracle = Oracle(batch.trg[0].detach(), model.n_classes, tok2i, i2tok, **oracle_flags)
        gt.stamp("create_oracle")
        max_steps = 2*batch.trg[0].detach().ne(tok2i[constants.PAD_WORD]).sum(1).max()+1
        scores, samples, p_oracle = model.forward(xs=batch.src, oracle=oracle, max_steps=max_steps, num_samples=len(batch),
                                                  return_p_oracle=True)
        gt.stamp("forward")
        loss = loss_fn(scores, samples, p_oracle, end_idx=tok2i['<end>'], **loss_flags)
        gt.stamp("loss")

        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), args.max_norm)
        optimizer.step()
        gt.stamp("backward")

        losses.append(loss.item())

        # -- Report metrics every `print_every` batches.
        if i % args.print_every == 0:
            # Only compute training metrics once here for efficiency.
            metrics.update(scores, samples, (batch.trg[0], None), kind='train')
            gt.stamp("metrics.update")
            # Training report computed over the last `print_every` batches.
            ms = metrics.report('train')
            ms['train/loss'] = round(np.mean(losses), 2)
            logs.append('{0} ; loss {1} ; sentence/s {2} ; {3} train {4} '.format(
                        i+1,
                        round(np.mean(losses), 2),
                        int(len(losses) * args.batch_size / (time.time() - last_time)),
                        args.eval_metric,
                        ms['train/%s' % args.eval_metric],
                        ))
            args.logstep += 1
            last_time = time.time()
            losses = []
            metrics.reset()

            # -- Validation report with a single batch.
            metrics.reset()
            model.eval()
            batch = next(iter(validloader))
            scores, samples = predict_batch(batch)
            model.train()
            metrics.update(scores, samples, (batch.trg[0], None))
            vms = metrics.report('valid_batch')
            logs[-1] = logs[-1] + metrics.log(vms, 'valid_batch', ['bleu', 'avg_span', 'f1', 'em', 'depth_score'])
            metrics.reset()

            print_samples(samples, (batch.trg[0], None), n=len(batch))
            gt.stamp("validation_batch")

            log_tensorboard(ms, step=args.logstep)
            log_tensorboard(vms, step=args.logstep)
            print(logs[-1])
            print(gt.report(include_itrs=False, format_options={'itr_name_width': 30}))

        # -- Checkpointing
        if i % args.save_every == 0:
            print('saving checkpoint at epoch {0} batch {1}'.format(epoch, i))
            print(os.path.join(args.log_directory, args.expr_name + '.checkpoint'))
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'optimizer_param': args.optimizer,
                'loss': loss.item()
            }, os.path.join(args.log_directory, args.expr_name + '.checkpoint'))

            model_config['longest_label'] = model.longest_label
            with open(os.path.join(args.log_directory, 'model_config.json'), 'w') as f:
                json.dump(model_config, f)

    print('end : epoch {0} '.format(epoch))
    log_tensorboard({'lr': optimizer.param_groups[0]['lr']}, step=args.logstep)
Exemple #6
0
def train_epoch(epoch):
    print('\nTRAINING : Epoch ' + str(epoch))
    model.train()
    losses = []
    logs = []
    sample_avgs = []
    update_avgs = []

    last_time = time.time()
    metrics = Metrics(tok2i, i2tok, field=TRG)

    trajectory_sampler = buffer.TrajectorySampler(trainloader)
    n_updates = 0
    oracle_samples_only = args.rollin_beta == 1.0
    while n_updates < updates_per_epoch:
        gt.reset()
        if oracle_samples_only:
            start = time.time()
            trajectory = trajectory_sampler.get_oracle_trajectory(model, Oracle, oracle_flags=oracle_flags)
            sample_time = (time.time() - start)
            start = time.time()
            loss = trajectory_sampler.get_loss(model, trajectory, loss_flags)
            update_time = (time.time() - start)
        else:
            start = time.time()
            loss = trajectory_sampler.get_mixed_trajectory_loss(model, Oracle,
                                                                oracle_flags=oracle_flags,
                                                                beta=args.rollin_beta,
                                                                loss_flags=loss_flags)
            sample_time = 0
            update_time = (time.time() - start)

        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), args.max_norm)
        losses.append(loss.item())
        optimizer.step()
        n_updates += 1

        sample_avgs.append(sample_time)
        update_avgs.append(update_time)

        gt.stamp("buffer updates")

        if n_updates % 20 == 0:
            print("%d|%d\t%.3f\tSample: %.3fs\tUpdate: %.3fs" % (epoch, n_updates, round(np.mean(losses), 3), np.mean(sample_avgs), np.mean(update_avgs)))
            log_tensorboard({'sample_avgs': np.mean(sample_avgs),
                             'update_avgs': np.mean(update_avgs)}, step=args.logstep)
            sample_avgs = []
            update_avgs = []

        # -- Report metrics every `print_every` batches.
        if n_updates % args.print_every == 0:
            gt.stamp("report")
            # Training report computed over the last `print_every` batches.
            ms = metrics.report('train')
            ms['train/loss'] = round(np.mean(losses), 2)
            logs.append('{0} ; loss {1} ; sentence/s {2} ; {3} train {4} '.format(
                        epoch,
                        round(np.mean(losses), 2),
                        int(len(losses) * args.batch_size / (time.time() - last_time)),
                        args.eval_metric,
                        ms.get('train/%s' % args.eval_metric, 0.0),
                        ))
            args.logstep += 1
            last_time = time.time()
            losses = []
            metrics.reset()

            # -- Validation report with a single batch.
            metrics.reset()
            model.eval()
            batch = next(iter(validloader))
            scores, samples = predict_batch(batch)
            model.train()
            metrics.update(scores, samples, (batch.trg[0], None))
            vms = metrics.report('valid_batch')
            logs[-1] = logs[-1] + metrics.log(vms, 'valid_batch', ['bleu', 'avg_span', 'f1', 'em', 'depth_score'])
            metrics.reset()

            print_samples(samples, (batch.trg[0], None), n=len(batch))
            gt.stamp("validation_batch")

            log_tensorboard(ms, step=args.logstep)
            log_tensorboard(vms, step=args.logstep)
            print(logs[-1])
            print(gt.report(include_itrs=False, format_options={'itr_name_width': 30}))

        # -- Checkpointing
        if n_updates % args.save_every == 0:
            print('saving checkpoint at epoch {0} batch {1}'.format(epoch, i))
            print(os.path.join(args.log_directory, args.expr_name + '.checkpoint'))
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'optimizer_param': args.optimizer,
                'loss': loss.item()
            }, os.path.join(args.log_directory, args.expr_name + '.checkpoint'))

            model_config.longest_label = model.longest_label
            with open(os.path.join(args.log_directory, 'model_config.pkl'), 'wb') as f:
                pickle.dump(model_config, f)

    print('end : epoch {0} '.format(epoch))
    log_tensorboard({'lr': optimizer.param_groups[0]['lr']}, step=args.logstep)
Exemple #7
0
def train_epoch(epoch):
    print('\nTRAINING : Epoch ' + str(epoch))
    model.train()
    losses = []
    logs = []

    last_time = time.time()

    metrics = Metrics(tok2i, i2tok)
    for i, data in enumerate(trainloader, 0):
        # -- Actual Training
        gt.reset()
        xs, annots = data
        xs = xs.to(args.device)
        gt.stamp("load_data")

        oracle = Oracle(xs, model.n_classes, tok2i, i2tok, **oracle_flags)
        gt.stamp("create_oracle")
        max_steps = 2*xs.ne(tok2i['<p>']).sum(1).max()+1
        scores, samples, p_oracle = model.forward(num_samples=args.batch_size, oracle=oracle, max_steps=max_steps, return_p_oracle=True)
        gt.stamp("forward")
        loss = loss_fn(scores, samples, p_oracle, tok2i['<end>'], **loss_flags)
        gt.stamp("loss")

        optimizer.zero_grad()
        loss.backward()
        clip_grad_norm_(model.parameters(), args.max_norm)
        optimizer.step()
        gt.stamp("backward")

        losses.append(loss.item())

        # -- Report metrics every `print_every` batches.
        if i % args.print_every == 0:
            # Training report; loss averaged over the last `print_every` batches.
            metrics.update(scores, samples, data)
            gt.stamp("metrics.update")
            ms = metrics.report('train')
            ms['train/loss'] = round(np.mean(losses), 2)
            logs.append('{0} ; loss {1} ; sentence/s {2} ; f1 train {3} '.format(
                        i+1,
                        round(np.mean(losses), 2),
                        int(len(losses) * args.batch_size / (time.time() - last_time)),
                        0,
                        ))
            args.logstep += 1
            last_time = time.time()
            losses = []
            metrics.reset()

            scores, samples = predict_batch(data)
            print_samples(samples, data)
            gt.stamp("validation_batch")

            log_tensorboard(ms, step=args.logstep)
            print(logs[-1])
            print(gt.report(include_itrs=False, format_options={'itr_name_width': 30}))

        # -- Checkpointing
        if i % args.save_every == 0:
            print('saving checkpoint at epoch {0} batch {1}'.format(epoch, i))
            torch.save(model.state_dict(), os.path.join(args.log_directory, args.expr_name + '.checkpoint'))
            model_config['longest_label'] = model.longest_label
            with open(os.path.join(args.log_directory, 'model_config.json'), 'w') as f:
                json.dump(model_config, f)

    print('end : epoch {0} '.format(epoch))
    log_tensorboard({'lr': optimizer.param_groups[0]['lr']}, step=args.logstep)