Ejemplo n.º 1
0
def run(model_dir, max_len, source_train_path, target_train_path,
        source_val_path, target_val_path, enc_max_vocab, dec_max_vocab,
        encoder_emb_size, decoder_emb_size, encoder_units, decoder_units,
        batch_size, epochs, learning_rate, decay_step, decay_percent,
        log_interval, save_interval, compare_interval):

    train_iter, val_iter, source_vocab, target_vocab = create_dataset(
        batch_size, enc_max_vocab, dec_max_vocab, source_train_path,
        target_train_path, source_val_path, target_val_path)
    transformer = Transformer(max_length=max_len,
                              enc_vocab=source_vocab,
                              dec_vocab=target_vocab,
                              enc_emb_size=encoder_emb_size,
                              dec_emb_size=decoder_emb_size,
                              enc_units=encoder_units,
                              dec_units=decoder_units)
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(transformer.parameters(), lr=learning_rate)
    lr_decay = StepLR(opt, step_size=decay_step, gamma=decay_percent)

    if torch.cuda.is_available():
        transformer.cuda()
        loss_fn.cuda()

    def training_update_function(batch):
        transformer.train()
        lr_decay.step()
        opt.zero_grad()

        softmaxed_predictions, predictions = transformer(batch.src, batch.trg)

        flattened_predictions = predictions.view(-1, len(target_vocab.itos))
        flattened_target = batch.trg.view(-1)

        loss = loss_fn(flattened_predictions, flattened_target)

        loss.backward()
        opt.step()

        return softmaxed_predictions.data, loss.data[0], batch.trg.data

    def validation_inference_function(batch):
        transformer.eval()
        softmaxed_predictions, predictions = transformer(batch.src, batch.trg)

        flattened_predictions = predictions.view(-1, len(target_vocab.itos))
        flattened_target = batch.trg.view(-1)

        loss = loss_fn(flattened_predictions, flattened_target)

        return loss.data[0]

    trainer = Trainer(train_iter, training_update_function, val_iter,
                      validation_inference_function)
    trainer.add_event_handler(TrainingEvents.TRAINING_STARTED,
                              restore_checkpoint_hook(transformer, model_dir))
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              log_training_simple_moving_average,
                              window_size=10,
                              metric_name="CrossEntropy",
                              should_log=lambda trainer: trainer.
                              current_iteration % log_interval == 0,
                              history_transform=lambda history: history[1])
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              save_checkpoint_hook(transformer, model_dir),
                              should_save=lambda trainer: trainer.
                              current_iteration % save_interval == 0)
    trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
                              print_current_prediction_hook(target_vocab),
                              should_print=lambda trainer: trainer.
                              current_iteration % compare_interval == 0)
    trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED,
                              log_validation_simple_moving_average,
                              window_size=10,
                              metric_name="CrossEntropy")
    trainer.add_event_handler(TrainingEvents.TRAINING_COMPLETED,
                              save_checkpoint_hook(transformer, model_dir),
                              should_save=lambda trainer: True)
    trainer.run(max_epochs=epochs, validate_every_epoch=True)
Ejemplo n.º 2
0
def run(model_dir, max_len, source_train_path, target_train_path,
        source_val_path, target_val_path, enc_max_vocab, dec_max_vocab,
        encoder_emb_size, decoder_emb_size, encoder_units, decoder_units,
        batch_size, epochs, learning_rate, decay_step, decay_percent,
        val_interval, save_interval, compare_interval):

    logging.basicConfig(filename="validation.log",
                        filemode="w",
                        level=logging.INFO)

    train_iter, val_iter, source_vocab, target_vocab = create_dataset(
        batch_size, enc_max_vocab, dec_max_vocab, source_train_path,
        target_train_path, source_val_path, target_val_path)
    transformer = Transformer(max_length=max_len,
                              enc_vocab=source_vocab,
                              dec_vocab=target_vocab,
                              enc_emb_size=encoder_emb_size,
                              dec_emb_size=decoder_emb_size,
                              enc_units=encoder_units,
                              dec_units=decoder_units)
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(transformer.parameters(), lr=learning_rate)
    lr_decay = StepLR(opt, step_size=decay_step, gamma=decay_percent)

    if torch.cuda.is_available():
        transformer.cuda()
        loss_fn.cuda()

    def training_step(engine, batch):
        transformer.train()
        lr_decay.step()
        opt.zero_grad()

        _, predictions = transformer(batch.src, batch.trg)

        flattened_predictions = predictions.view(-1, len(target_vocab.itos))
        flattened_target = batch.trg.view(-1)

        loss = loss_fn(flattened_predictions, flattened_target)

        loss.backward()
        opt.step()

        return loss.cpu().item()

    def validation_step(engine, batch):
        transformer.eval()
        with torch.no_grad():
            softmaxed_predictions, predictions = transformer(
                batch.src, batch.trg)

            flattened_predictions = predictions.view(-1,
                                                     len(target_vocab.itos))
            flattened_target = batch.trg.view(-1)

            loss = loss_fn(flattened_predictions, flattened_target)

            if not engine.state.output:
                predictions = softmaxed_predictions.argmax(
                    -1).cpu().numpy().tolist()
                targets = batch.trg.cpu().numpy().tolist()
            else:
                predictions = engine.state.output[
                    "predictions"] + softmaxed_predictions.argmax(
                        -1).cpu().numpy().tolist()
                targets = engine.state.output["targets"] + batch.trg.cpu(
                ).numpy().tolist()

            return {
                "loss": loss.cpu().item(),
                "predictions": predictions,
                "targets": targets
            }

    trainer = Engine(training_step)
    evaluator = Engine(validation_step)
    checkpoint_handler = ModelCheckpoint(model_dir,
                                         "Transformer",
                                         save_interval=save_interval,
                                         n_saved=10,
                                         require_empty=False)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # Attach training metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "train_loss")
    # Attach validation metrics
    RunningAverage(output_transform=lambda x: x["loss"]).attach(
        evaluator, "val_loss")

    pbar = ProgressBar()
    pbar.attach(trainer, ["train_loss"])

    # trainer.add_event_handler(Events.TRAINING_STARTED,
    #                           restore_checkpoint_hook(transformer, model_dir))
    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              handler=validation_result_hook(
                                  evaluator,
                                  val_iter,
                                  target_vocab,
                                  val_interval,
                                  logger=logging.info))

    trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  "nmt": {
                                      "transformer": transformer,
                                      "opt": opt,
                                      "lr_decay": lr_decay
                                  }
                              })

    # Run the prediction
    trainer.run(train_iter, max_epochs=epochs)