def train(train_generator, vocab: Vocab, model: Seq2Seq, params: Params, valid_generator=None, saved_state: dict = None, losses=None): # variables for plotting plot_points_per_epoch = max(math.log(params.n_batches, 1.6), 1.) plot_every = round(params.n_batches / plot_points_per_epoch) if losses is None: plot_losses, cached_losses, plot_val_losses, plot_val_metrics = [], [], [], [] else: plot_losses, cached_losses, plot_val_losses, plot_val_metrics = losses # total_parameters = sum(parameter.numel() for parameter in model.parameters() # if parameter.requires_grad) # print("Training %d trainable parameters..." % total_parameters) model.to(DEVICE) if saved_state is None: if params.optimizer == 'adagrad': optimizer = optim.Adagrad( model.parameters(), lr=params.lr, initial_accumulator_value=params.adagrad_accumulator) else: optimizer = optim.Adam(model.parameters(), lr=params.lr) past_epochs = 0 total_batch_count = 0 else: optimizer = saved_state['optimizer'] past_epochs = saved_state['epoch'] total_batch_count = saved_state['total_batch_count'] if params.lr_decay: lr_scheduler = optim.lr_scheduler.StepLR(optimizer, params.lr_decay_step, params.lr_decay, past_epochs - 1) criterion = nn.NLLLoss(ignore_index=vocab.PAD) best_avg_loss, best_epoch_id = float("inf"), None for epoch_count in range(1 + past_epochs, params.n_epochs + 1): if params.lr_decay: lr_scheduler.step() rl_ratio = params.rl_ratio if epoch_count >= params.rl_start_epoch else 0 epoch_loss, epoch_metric = 0, 0 epoch_avg_loss, valid_avg_loss, valid_avg_metric = None, None, None prog_bar = tqdm(range(1, params.n_batches + 1), desc='Epoch %d' % epoch_count) model.train() for batch_count in prog_bar: # training batches if params.forcing_decay_type: if params.forcing_decay_type == 'linear': forcing_ratio = max( 0, params.forcing_ratio - params.forcing_decay * total_batch_count) elif params.forcing_decay_type == 'exp': forcing_ratio = params.forcing_ratio * ( params.forcing_decay**total_batch_count) elif params.forcing_decay_type == 'sigmoid': forcing_ratio = params.forcing_ratio * params.forcing_decay / ( params.forcing_decay + math.exp(total_batch_count / params.forcing_decay)) else: raise ValueError('Unrecognized forcing_decay_type: ' + params.forcing_decay_type) else: forcing_ratio = params.forcing_ratio batch = next(train_generator) loss, metric = train_batch(batch, model, criterion, optimizer, pack_seq=params.pack_seq, forcing_ratio=forcing_ratio, partial_forcing=params.partial_forcing, sample=params.sample, rl_ratio=rl_ratio, vocab=vocab, grad_norm=params.grad_norm, show_cover_loss=params.show_cover_loss) epoch_loss += float(loss) epoch_avg_loss = epoch_loss / batch_count if metric is not None: # print ROUGE as well if reinforcement learning is enabled epoch_metric += metric epoch_avg_metric = epoch_metric / batch_count prog_bar.set_postfix(loss='%g' % epoch_avg_loss, rouge='%.4g' % (epoch_avg_metric * 100)) else: prog_bar.set_postfix(loss='%g' % epoch_avg_loss) cached_losses.append(loss) total_batch_count += 1 if total_batch_count % plot_every == 0: period_avg_loss = sum(cached_losses) / len(cached_losses) plot_losses.append(period_avg_loss) cached_losses = [] if valid_generator is not None: # validation batches valid_loss, valid_metric = 0, 0 prog_bar = tqdm(range(1, params.n_val_batches + 1), desc='Valid %d' % epoch_count) model.eval() for batch_count in prog_bar: batch = next(valid_generator) loss, metric = eval_batch( batch, model, vocab, criterion, pack_seq=params.pack_seq, show_cover_loss=params.show_cover_loss) valid_loss += loss valid_metric += metric valid_avg_loss = valid_loss / batch_count valid_avg_metric = valid_metric / batch_count prog_bar.set_postfix(loss='%g' % valid_avg_loss, rouge='%.4g' % (valid_avg_metric * 100)) plot_val_losses.append(valid_avg_loss) plot_val_metrics.append(valid_avg_metric) metric_loss = -valid_avg_metric # choose the best model by ROUGE instead of loss if metric_loss < best_avg_loss: best_epoch_id = epoch_count best_avg_loss = metric_loss else: # no validation, "best" is defined by training loss if epoch_avg_loss < best_avg_loss: best_epoch_id = epoch_count best_avg_loss = epoch_avg_loss if params.model_path_prefix: # save model filename = '%s.%02d.pt' % (params.model_path_prefix, epoch_count) torch.save(model, filename) if not params.keep_every_epoch: # clear previously saved models for epoch_id in range(1 + past_epochs, epoch_count): if epoch_id != best_epoch_id: try: prev_filename = '%s.%02d.pt' % ( params.model_path_prefix, epoch_id) os.remove(prev_filename) except FileNotFoundError: pass # save training status torch.save( { 'epoch': epoch_count, 'total_batch_count': total_batch_count, 'train_avg_loss': epoch_avg_loss, 'valid_avg_loss': valid_avg_loss, 'valid_avg_metric': valid_avg_metric, 'best_epoch_so_far': best_epoch_id, 'params': params, 'optimizer': optimizer, 'train_loss': plot_losses, 'cached_losses': cached_losses, 'val_loss': plot_val_losses, 'plot_val_metrics': plot_val_metrics }, '%s.train.pt' % params.model_path_prefix) if rl_ratio > 0: params.rl_ratio **= params.rl_ratio_power show_plot(plot_losses, plot_every, plot_val_losses, plot_val_metrics, params.n_batches, params.model_path_prefix)
def train(train_generator, vocab: Vocab, model: Seq2Seq, params: Params, valid_generator=None): # variables for plotting plot_points_per_epoch = max(math.log(params.n_batches, 1.6), 1.) plot_every = round(params.n_batches / plot_points_per_epoch) plot_losses, cached_losses = [], [] total_batch_count = 0 plot_val_losses, plot_val_metrics = [], [] total_parameters = sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad) print("Training %d trainable parameters..." % total_parameters) model.to(DEVICE) optimizer = optim.Adam(model.parameters(), lr=params.lr) criterion = nn.NLLLoss(ignore_index=vocab.PAD) best_avg_loss, best_epoch_id = float("inf"), None for epoch_count in range(1, params.n_epochs + 1): rl_ratio = params.rl_ratio if epoch_count >= params.rl_start_epoch else 0 epoch_loss, epoch_metric = 0, 0 epoch_avg_loss, valid_avg_loss, valid_avg_metric = None, None, None prog_bar = tqdm(range(1, params.n_batches + 1), desc='Epoch %d' % epoch_count) model.train() for batch_count in prog_bar: # training batches batch = next(train_generator) loss, metric = train_batch(batch, model, criterion, optimizer, pack_seq=params.pack_seq, forcing_ratio=params.forcing_ratio, partial_forcing=params.partial_forcing, rl_ratio=rl_ratio, vocab=vocab) epoch_loss += float(loss) epoch_avg_loss = epoch_loss / batch_count if metric is not None: # print ROUGE as well if reinforcement learning is enabled epoch_metric += metric epoch_avg_metric = epoch_metric / batch_count prog_bar.set_postfix(loss='%g' % epoch_avg_loss, rouge='%.4g' % (epoch_avg_metric * 100)) else: prog_bar.set_postfix(loss='%g' % epoch_avg_loss) cached_losses.append(loss) if (total_batch_count + batch_count) % plot_every == 0: period_avg_loss = sum(cached_losses) / len(cached_losses) plot_losses.append(period_avg_loss) cached_losses = [] if valid_generator is not None: # validation batches valid_loss, valid_metric = 0, 0 prog_bar = tqdm(range(1, params.n_val_batches + 1), desc='Valid %d' % epoch_count) model.eval() for batch_count in prog_bar: batch = next(valid_generator) loss, metric = eval_batch(batch, model, vocab, criterion, pack_seq=params.pack_seq) valid_loss += loss valid_metric += metric valid_avg_loss = valid_loss / batch_count valid_avg_metric = valid_metric / batch_count prog_bar.set_postfix(loss='%g' % valid_avg_loss, rouge='%.4g' % (valid_avg_metric * 100)) plot_val_losses.append(valid_avg_loss) plot_val_metrics.append(valid_avg_metric) metric_loss = -valid_avg_metric # choose the best model by ROUGE instead of loss if metric_loss < best_avg_loss: best_epoch_id = epoch_count best_avg_loss = metric_loss else: # no validation, "best" is defined by training loss if epoch_avg_loss < best_avg_loss: best_epoch_id = epoch_count best_avg_loss = epoch_avg_loss if params.model_path_prefix: # save model filename = '%s.%02d.pt' % (params.model_path_prefix, epoch_count) torch.save(model, filename) if not params.keep_every_epoch: # clear previously saved models for epoch_id in range(1, epoch_count): if epoch_id != best_epoch_id: try: prev_filename = '%s.%02d.pt' % ( params.model_path_prefix, epoch_id) os.remove(prev_filename) except FileNotFoundError: pass # save training status torch.save( { 'epoch': epoch_count, 'train_avg_loss': epoch_avg_loss, 'valid_avg_loss': valid_avg_loss, 'valid_avg_metric': valid_avg_metric, 'best_epoch_so_far': best_epoch_id, 'params': params, 'optimizer': optimizer }, '%s.train.pt' % params.model_path_prefix) if rl_ratio > 0: params.rl_ratio **= params.rl_ratio_power total_batch_count += params.n_batches show_plot(plot_losses, plot_every, plot_val_losses, plot_val_metrics, params.n_batches, params.model_path_prefix)