def evaluate(epoch, dataloader, eval_type='valid', final_eval=False): global val_metric_best, lr, stop_training if eval_type == 'valid': print('\nVALIDATION : Epoch {0}'.format(epoch)) vmetrics = Metrics(tok2i, i2tok, field=TRG) vmetrics.reset() model.eval() for i, batch in enumerate(dataloader, 0): scores, samples = predict_batch(batch) vmetrics.update(scores, samples, (batch.trg[0], None)) model.train() kind = eval_type if not final_eval else 'final_' + eval_type ms = vmetrics.report(kind) eval_metric = ms['%s/%s' % (kind, args.eval_metric)] metrics_to_log = ['bleu', 'avg_span', 'f1', 'em', 'depth_score'] if final_eval: print('final: ' + vmetrics.log(ms, kind, metrics_to_log)) log_tensorboard(ms, step=args.logstep) else: print(('valid (epoch %d): ' % epoch) + vmetrics.log(ms, kind, metrics_to_log)) log_tensorboard(ms, step=args.logstep) if eval_type == 'valid' and epoch <= args.n_epochs: if eval_metric >= val_metric_best: print('saving model at epoch {0}'.format(epoch)) torch.save(model.state_dict(), os.path.join(args.log_directory, args.expr_name)) val_metric_best = eval_metric if epoch > 1 and epoch % args.lrshrink_nepochs == 0: optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] / args.lrshrink print('Shrinking lr by : {0}. New lr = {1}' .format(args.lrshrink, optimizer.param_groups[0]['lr'])) return eval_metric
def adjust(epoch): if epoch <= args.beta_burnin: return args.rollin_beta = max(args.rollin_beta - args.beta_step, args.beta_min) log_tensorboard({'sampler.beta': args.rollin_beta}, step=args.logstep) if args.self_teach_beta_step > 0 and 'self_teach_beta' in loss_flags: loss_flags['self_teach_beta'] = max(loss_flags['self_teach_beta'] - args.self_teach_beta_step, 0.0) log_tensorboard({'self_teach_beta': loss_flags['self_teach_beta']}, step=args.logstep)
def adjust(epoch, sampler): if epoch <= args.beta_burnin: return if hasattr(sampler, 'beta'): sampler.beta = max(sampler.beta - args.beta_step, 0.0) log_tensorboard({'sampler.beta': sampler.beta}, step=args.logstep) if args.self_teach_beta_step > 0 and 'self_teach_beta' in loss_flags: loss_flags['self_teach_beta'] = max(loss_flags['self_teach_beta'] - args.self_teach_beta_step, 0.0) log_tensorboard({'self_teach_beta': loss_flags['self_teach_beta']}, step=args.logstep)
def adjust(sampler, epoch): if epoch > 1 and epoch % args.lrshrink_nepochs == 0: optimizer.param_groups[0]['lr'] = optimizer.param_groups[0]['lr'] / args.lrshrink print('Shrinking lr by : {0}. New lr = {1}' .format(args.lrshrink, optimizer.param_groups[0]['lr'])) if epoch <= args.beta_burnin: return if hasattr(sampler, 'beta'): sampler.beta = max(sampler.beta - args.beta_step, 0.0) log_tensorboard({'sampler.beta': sampler.beta}, step=args.logstep) if args.self_teach_beta_step > 0 and 'self_teach_beta' in loss_flags: loss_flags['self_teach_beta'] = max(loss_flags['self_teach_beta'] - args.self_teach_beta_step, 0.0) log_tensorboard({'self_teach_beta': loss_flags['self_teach_beta']}, step=args.logstep)
def train_epoch(epoch): print('\nTRAINING : Epoch ' + str(epoch)) model.train() losses = [] logs = [] last_time = time.time() metrics = Metrics(tok2i, i2tok, field=TRG) for i, batch in enumerate(trainloader): # -- Actual Training gt.reset() gt.stamp("load_data") oracle = Oracle(batch.trg[0].detach(), model.n_classes, tok2i, i2tok, **oracle_flags) gt.stamp("create_oracle") max_steps = 2*batch.trg[0].detach().ne(tok2i[constants.PAD_WORD]).sum(1).max()+1 scores, samples, p_oracle = model.forward(xs=batch.src, oracle=oracle, max_steps=max_steps, num_samples=len(batch), return_p_oracle=True) gt.stamp("forward") loss = loss_fn(scores, samples, p_oracle, end_idx=tok2i['<end>'], **loss_flags) gt.stamp("loss") optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), args.max_norm) optimizer.step() gt.stamp("backward") losses.append(loss.item()) # -- Report metrics every `print_every` batches. if i % args.print_every == 0: # Only compute training metrics once here for efficiency. metrics.update(scores, samples, (batch.trg[0], None), kind='train') gt.stamp("metrics.update") # Training report computed over the last `print_every` batches. ms = metrics.report('train') ms['train/loss'] = round(np.mean(losses), 2) logs.append('{0} ; loss {1} ; sentence/s {2} ; {3} train {4} '.format( i+1, round(np.mean(losses), 2), int(len(losses) * args.batch_size / (time.time() - last_time)), args.eval_metric, ms['train/%s' % args.eval_metric], )) args.logstep += 1 last_time = time.time() losses = [] metrics.reset() # -- Validation report with a single batch. metrics.reset() model.eval() batch = next(iter(validloader)) scores, samples = predict_batch(batch) model.train() metrics.update(scores, samples, (batch.trg[0], None)) vms = metrics.report('valid_batch') logs[-1] = logs[-1] + metrics.log(vms, 'valid_batch', ['bleu', 'avg_span', 'f1', 'em', 'depth_score']) metrics.reset() print_samples(samples, (batch.trg[0], None), n=len(batch)) gt.stamp("validation_batch") log_tensorboard(ms, step=args.logstep) log_tensorboard(vms, step=args.logstep) print(logs[-1]) print(gt.report(include_itrs=False, format_options={'itr_name_width': 30})) # -- Checkpointing if i % args.save_every == 0: print('saving checkpoint at epoch {0} batch {1}'.format(epoch, i)) print(os.path.join(args.log_directory, args.expr_name + '.checkpoint')) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'optimizer_param': args.optimizer, 'loss': loss.item() }, os.path.join(args.log_directory, args.expr_name + '.checkpoint')) model_config['longest_label'] = model.longest_label with open(os.path.join(args.log_directory, 'model_config.json'), 'w') as f: json.dump(model_config, f) print('end : epoch {0} '.format(epoch)) log_tensorboard({'lr': optimizer.param_groups[0]['lr']}, step=args.logstep)
def train_epoch(epoch): print('\nTRAINING : Epoch ' + str(epoch)) model.train() losses = [] logs = [] sample_avgs = [] update_avgs = [] last_time = time.time() metrics = Metrics(tok2i, i2tok, field=TRG) trajectory_sampler = buffer.TrajectorySampler(trainloader) n_updates = 0 oracle_samples_only = args.rollin_beta == 1.0 while n_updates < updates_per_epoch: gt.reset() if oracle_samples_only: start = time.time() trajectory = trajectory_sampler.get_oracle_trajectory(model, Oracle, oracle_flags=oracle_flags) sample_time = (time.time() - start) start = time.time() loss = trajectory_sampler.get_loss(model, trajectory, loss_flags) update_time = (time.time() - start) else: start = time.time() loss = trajectory_sampler.get_mixed_trajectory_loss(model, Oracle, oracle_flags=oracle_flags, beta=args.rollin_beta, loss_flags=loss_flags) sample_time = 0 update_time = (time.time() - start) optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), args.max_norm) losses.append(loss.item()) optimizer.step() n_updates += 1 sample_avgs.append(sample_time) update_avgs.append(update_time) gt.stamp("buffer updates") if n_updates % 20 == 0: print("%d|%d\t%.3f\tSample: %.3fs\tUpdate: %.3fs" % (epoch, n_updates, round(np.mean(losses), 3), np.mean(sample_avgs), np.mean(update_avgs))) log_tensorboard({'sample_avgs': np.mean(sample_avgs), 'update_avgs': np.mean(update_avgs)}, step=args.logstep) sample_avgs = [] update_avgs = [] # -- Report metrics every `print_every` batches. if n_updates % args.print_every == 0: gt.stamp("report") # Training report computed over the last `print_every` batches. ms = metrics.report('train') ms['train/loss'] = round(np.mean(losses), 2) logs.append('{0} ; loss {1} ; sentence/s {2} ; {3} train {4} '.format( epoch, round(np.mean(losses), 2), int(len(losses) * args.batch_size / (time.time() - last_time)), args.eval_metric, ms.get('train/%s' % args.eval_metric, 0.0), )) args.logstep += 1 last_time = time.time() losses = [] metrics.reset() # -- Validation report with a single batch. metrics.reset() model.eval() batch = next(iter(validloader)) scores, samples = predict_batch(batch) model.train() metrics.update(scores, samples, (batch.trg[0], None)) vms = metrics.report('valid_batch') logs[-1] = logs[-1] + metrics.log(vms, 'valid_batch', ['bleu', 'avg_span', 'f1', 'em', 'depth_score']) metrics.reset() print_samples(samples, (batch.trg[0], None), n=len(batch)) gt.stamp("validation_batch") log_tensorboard(ms, step=args.logstep) log_tensorboard(vms, step=args.logstep) print(logs[-1]) print(gt.report(include_itrs=False, format_options={'itr_name_width': 30})) # -- Checkpointing if n_updates % args.save_every == 0: print('saving checkpoint at epoch {0} batch {1}'.format(epoch, i)) print(os.path.join(args.log_directory, args.expr_name + '.checkpoint')) torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'optimizer_param': args.optimizer, 'loss': loss.item() }, os.path.join(args.log_directory, args.expr_name + '.checkpoint')) model_config.longest_label = model.longest_label with open(os.path.join(args.log_directory, 'model_config.pkl'), 'wb') as f: pickle.dump(model_config, f) print('end : epoch {0} '.format(epoch)) log_tensorboard({'lr': optimizer.param_groups[0]['lr']}, step=args.logstep)
def train_epoch(epoch): print('\nTRAINING : Epoch ' + str(epoch)) model.train() losses = [] logs = [] last_time = time.time() metrics = Metrics(tok2i, i2tok) for i, data in enumerate(trainloader, 0): # -- Actual Training gt.reset() xs, annots = data xs = xs.to(args.device) gt.stamp("load_data") oracle = Oracle(xs, model.n_classes, tok2i, i2tok, **oracle_flags) gt.stamp("create_oracle") max_steps = 2*xs.ne(tok2i['<p>']).sum(1).max()+1 scores, samples, p_oracle = model.forward(num_samples=args.batch_size, oracle=oracle, max_steps=max_steps, return_p_oracle=True) gt.stamp("forward") loss = loss_fn(scores, samples, p_oracle, tok2i['<end>'], **loss_flags) gt.stamp("loss") optimizer.zero_grad() loss.backward() clip_grad_norm_(model.parameters(), args.max_norm) optimizer.step() gt.stamp("backward") losses.append(loss.item()) # -- Report metrics every `print_every` batches. if i % args.print_every == 0: # Training report; loss averaged over the last `print_every` batches. metrics.update(scores, samples, data) gt.stamp("metrics.update") ms = metrics.report('train') ms['train/loss'] = round(np.mean(losses), 2) logs.append('{0} ; loss {1} ; sentence/s {2} ; f1 train {3} '.format( i+1, round(np.mean(losses), 2), int(len(losses) * args.batch_size / (time.time() - last_time)), 0, )) args.logstep += 1 last_time = time.time() losses = [] metrics.reset() scores, samples = predict_batch(data) print_samples(samples, data) gt.stamp("validation_batch") log_tensorboard(ms, step=args.logstep) print(logs[-1]) print(gt.report(include_itrs=False, format_options={'itr_name_width': 30})) # -- Checkpointing if i % args.save_every == 0: print('saving checkpoint at epoch {0} batch {1}'.format(epoch, i)) torch.save(model.state_dict(), os.path.join(args.log_directory, args.expr_name + '.checkpoint')) model_config['longest_label'] = model.longest_label with open(os.path.join(args.log_directory, 'model_config.json'), 'w') as f: json.dump(model_config, f) print('end : epoch {0} '.format(epoch)) log_tensorboard({'lr': optimizer.param_groups[0]['lr']}, step=args.logstep)