Esempio n. 1
0
def train(train_generator,
          vocab: Vocab,
          model: Seq2Seq,
          params: Params,
          valid_generator=None,
          saved_state: dict = None,
          losses=None):
    # variables for plotting
    plot_points_per_epoch = max(math.log(params.n_batches, 1.6), 1.)
    plot_every = round(params.n_batches / plot_points_per_epoch)
    if losses is None:
        plot_losses, cached_losses, plot_val_losses, plot_val_metrics = [], [], [], []
    else:
        plot_losses, cached_losses, plot_val_losses, plot_val_metrics = losses

    # total_parameters = sum(parameter.numel() for parameter in model.parameters()
    #                        if parameter.requires_grad)
    # print("Training %d trainable parameters..." % total_parameters)
    model.to(DEVICE)
    if saved_state is None:
        if params.optimizer == 'adagrad':
            optimizer = optim.Adagrad(
                model.parameters(),
                lr=params.lr,
                initial_accumulator_value=params.adagrad_accumulator)
        else:
            optimizer = optim.Adam(model.parameters(), lr=params.lr)
        past_epochs = 0
        total_batch_count = 0
    else:
        optimizer = saved_state['optimizer']
        past_epochs = saved_state['epoch']
        total_batch_count = saved_state['total_batch_count']
    if params.lr_decay:
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                 params.lr_decay_step,
                                                 params.lr_decay,
                                                 past_epochs - 1)
    criterion = nn.NLLLoss(ignore_index=vocab.PAD)
    best_avg_loss, best_epoch_id = float("inf"), None

    for epoch_count in range(1 + past_epochs, params.n_epochs + 1):
        if params.lr_decay:
            lr_scheduler.step()
        rl_ratio = params.rl_ratio if epoch_count >= params.rl_start_epoch else 0
        epoch_loss, epoch_metric = 0, 0
        epoch_avg_loss, valid_avg_loss, valid_avg_metric = None, None, None
        prog_bar = tqdm(range(1, params.n_batches + 1),
                        desc='Epoch %d' % epoch_count)
        model.train()

        for batch_count in prog_bar:  # training batches
            if params.forcing_decay_type:
                if params.forcing_decay_type == 'linear':
                    forcing_ratio = max(
                        0, params.forcing_ratio -
                        params.forcing_decay * total_batch_count)
                elif params.forcing_decay_type == 'exp':
                    forcing_ratio = params.forcing_ratio * (
                        params.forcing_decay**total_batch_count)
                elif params.forcing_decay_type == 'sigmoid':
                    forcing_ratio = params.forcing_ratio * params.forcing_decay / (
                        params.forcing_decay +
                        math.exp(total_batch_count / params.forcing_decay))
                else:
                    raise ValueError('Unrecognized forcing_decay_type: ' +
                                     params.forcing_decay_type)
            else:
                forcing_ratio = params.forcing_ratio

            batch = next(train_generator)
            loss, metric = train_batch(batch,
                                       model,
                                       criterion,
                                       optimizer,
                                       pack_seq=params.pack_seq,
                                       forcing_ratio=forcing_ratio,
                                       partial_forcing=params.partial_forcing,
                                       sample=params.sample,
                                       rl_ratio=rl_ratio,
                                       vocab=vocab,
                                       grad_norm=params.grad_norm,
                                       show_cover_loss=params.show_cover_loss)

            epoch_loss += float(loss)
            epoch_avg_loss = epoch_loss / batch_count
            if metric is not None:  # print ROUGE as well if reinforcement learning is enabled
                epoch_metric += metric
                epoch_avg_metric = epoch_metric / batch_count
                prog_bar.set_postfix(loss='%g' % epoch_avg_loss,
                                     rouge='%.4g' % (epoch_avg_metric * 100))
            else:
                prog_bar.set_postfix(loss='%g' % epoch_avg_loss)

            cached_losses.append(loss)
            total_batch_count += 1
            if total_batch_count % plot_every == 0:
                period_avg_loss = sum(cached_losses) / len(cached_losses)
                plot_losses.append(period_avg_loss)
                cached_losses = []

        if valid_generator is not None:  # validation batches
            valid_loss, valid_metric = 0, 0
            prog_bar = tqdm(range(1, params.n_val_batches + 1),
                            desc='Valid %d' % epoch_count)
            model.eval()

            for batch_count in prog_bar:
                batch = next(valid_generator)
                loss, metric = eval_batch(
                    batch,
                    model,
                    vocab,
                    criterion,
                    pack_seq=params.pack_seq,
                    show_cover_loss=params.show_cover_loss)
                valid_loss += loss
                valid_metric += metric
                valid_avg_loss = valid_loss / batch_count
                valid_avg_metric = valid_metric / batch_count
                prog_bar.set_postfix(loss='%g' % valid_avg_loss,
                                     rouge='%.4g' % (valid_avg_metric * 100))

            plot_val_losses.append(valid_avg_loss)
            plot_val_metrics.append(valid_avg_metric)

            metric_loss = -valid_avg_metric  # choose the best model by ROUGE instead of loss
            if metric_loss < best_avg_loss:
                best_epoch_id = epoch_count
                best_avg_loss = metric_loss

        else:  # no validation, "best" is defined by training loss
            if epoch_avg_loss < best_avg_loss:
                best_epoch_id = epoch_count
                best_avg_loss = epoch_avg_loss

        if params.model_path_prefix:
            # save model
            filename = '%s.%02d.pt' % (params.model_path_prefix, epoch_count)
            torch.save(model, filename)
            if not params.keep_every_epoch:  # clear previously saved models
                for epoch_id in range(1 + past_epochs, epoch_count):
                    if epoch_id != best_epoch_id:
                        try:
                            prev_filename = '%s.%02d.pt' % (
                                params.model_path_prefix, epoch_id)
                            os.remove(prev_filename)
                        except FileNotFoundError:
                            pass
            # save training status
            torch.save(
                {
                    'epoch': epoch_count,
                    'total_batch_count': total_batch_count,
                    'train_avg_loss': epoch_avg_loss,
                    'valid_avg_loss': valid_avg_loss,
                    'valid_avg_metric': valid_avg_metric,
                    'best_epoch_so_far': best_epoch_id,
                    'params': params,
                    'optimizer': optimizer,
                    'train_loss': plot_losses,
                    'cached_losses': cached_losses,
                    'val_loss': plot_val_losses,
                    'plot_val_metrics': plot_val_metrics
                }, '%s.train.pt' % params.model_path_prefix)

        if rl_ratio > 0:
            params.rl_ratio **= params.rl_ratio_power

        show_plot(plot_losses, plot_every, plot_val_losses, plot_val_metrics,
                  params.n_batches, params.model_path_prefix)
Esempio n. 2
0
def train(train_generator,
          vocab: Vocab,
          model: Seq2Seq,
          params: Params,
          valid_generator=None):
    # variables for plotting
    plot_points_per_epoch = max(math.log(params.n_batches, 1.6), 1.)
    plot_every = round(params.n_batches / plot_points_per_epoch)
    plot_losses, cached_losses = [], []
    total_batch_count = 0
    plot_val_losses, plot_val_metrics = [], []

    total_parameters = sum(parameter.numel()
                           for parameter in model.parameters()
                           if parameter.requires_grad)
    print("Training %d trainable parameters..." % total_parameters)
    model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=params.lr)
    criterion = nn.NLLLoss(ignore_index=vocab.PAD)
    best_avg_loss, best_epoch_id = float("inf"), None

    for epoch_count in range(1, params.n_epochs + 1):
        rl_ratio = params.rl_ratio if epoch_count >= params.rl_start_epoch else 0
        epoch_loss, epoch_metric = 0, 0
        epoch_avg_loss, valid_avg_loss, valid_avg_metric = None, None, None
        prog_bar = tqdm(range(1, params.n_batches + 1),
                        desc='Epoch %d' % epoch_count)
        model.train()

        for batch_count in prog_bar:  # training batches
            batch = next(train_generator)
            loss, metric = train_batch(batch,
                                       model,
                                       criterion,
                                       optimizer,
                                       pack_seq=params.pack_seq,
                                       forcing_ratio=params.forcing_ratio,
                                       partial_forcing=params.partial_forcing,
                                       rl_ratio=rl_ratio,
                                       vocab=vocab)

            epoch_loss += float(loss)
            epoch_avg_loss = epoch_loss / batch_count
            if metric is not None:  # print ROUGE as well if reinforcement learning is enabled
                epoch_metric += metric
                epoch_avg_metric = epoch_metric / batch_count
                prog_bar.set_postfix(loss='%g' % epoch_avg_loss,
                                     rouge='%.4g' % (epoch_avg_metric * 100))
            else:
                prog_bar.set_postfix(loss='%g' % epoch_avg_loss)

            cached_losses.append(loss)
            if (total_batch_count + batch_count) % plot_every == 0:
                period_avg_loss = sum(cached_losses) / len(cached_losses)
                plot_losses.append(period_avg_loss)
                cached_losses = []

        if valid_generator is not None:  # validation batches
            valid_loss, valid_metric = 0, 0
            prog_bar = tqdm(range(1, params.n_val_batches + 1),
                            desc='Valid %d' % epoch_count)
            model.eval()

            for batch_count in prog_bar:
                batch = next(valid_generator)
                loss, metric = eval_batch(batch,
                                          model,
                                          vocab,
                                          criterion,
                                          pack_seq=params.pack_seq)
                valid_loss += loss
                valid_metric += metric
                valid_avg_loss = valid_loss / batch_count
                valid_avg_metric = valid_metric / batch_count
                prog_bar.set_postfix(loss='%g' % valid_avg_loss,
                                     rouge='%.4g' % (valid_avg_metric * 100))

            plot_val_losses.append(valid_avg_loss)
            plot_val_metrics.append(valid_avg_metric)

            metric_loss = -valid_avg_metric  # choose the best model by ROUGE instead of loss
            if metric_loss < best_avg_loss:
                best_epoch_id = epoch_count
                best_avg_loss = metric_loss

        else:  # no validation, "best" is defined by training loss
            if epoch_avg_loss < best_avg_loss:
                best_epoch_id = epoch_count
                best_avg_loss = epoch_avg_loss

        if params.model_path_prefix:
            # save model
            filename = '%s.%02d.pt' % (params.model_path_prefix, epoch_count)
            torch.save(model, filename)
            if not params.keep_every_epoch:  # clear previously saved models
                for epoch_id in range(1, epoch_count):
                    if epoch_id != best_epoch_id:
                        try:
                            prev_filename = '%s.%02d.pt' % (
                                params.model_path_prefix, epoch_id)
                            os.remove(prev_filename)
                        except FileNotFoundError:
                            pass
            # save training status
            torch.save(
                {
                    'epoch': epoch_count,
                    'train_avg_loss': epoch_avg_loss,
                    'valid_avg_loss': valid_avg_loss,
                    'valid_avg_metric': valid_avg_metric,
                    'best_epoch_so_far': best_epoch_id,
                    'params': params,
                    'optimizer': optimizer
                }, '%s.train.pt' % params.model_path_prefix)

        if rl_ratio > 0:
            params.rl_ratio **= params.rl_ratio_power

        total_batch_count += params.n_batches
        show_plot(plot_losses, plot_every, plot_val_losses, plot_val_metrics,
                  params.n_batches, params.model_path_prefix)