Exemplo n.º 1
0
def train(
        data_path: str,
        data_directory: str,
        generate_vocabularies: bool,
        input_vocab_path: str,
        target_vocab_path: str,
        embedding_dimension: int,
        num_encoder_layers: int,
        encoder_dropout_p: float,
        encoder_bidirectional: bool,
        training_batch_size: int,
        test_batch_size: int,
        max_decoding_steps: int,
        num_decoder_layers: int,
        decoder_dropout_p: float,
        cnn_kernel_size: int,
        cnn_dropout_p: float,
        cnn_hidden_num_channels: int,
        simple_situation_representation: bool,
        decoder_hidden_size: int,
        encoder_hidden_size: int,
        learning_rate: float,
        adam_beta_1: float,
        adam_beta_2: float,
        lr_decay: float,
        lr_decay_steps: int,
        resume_from_file: str,
        max_training_iterations: int,
        output_directory: str,
        print_every: int,
        evaluate_every: int,
        conditional_attention: bool,
        auxiliary_task: bool,
        weight_target_loss: float,
        attention_type: str,
        k: int,
        max_training_examples,
        max_testing_examples,
        # SeqGAN params begin
        pretrain_gen_path,
        pretrain_gen_epochs,
        pretrain_disc_path,
        pretrain_disc_epochs,
        rollout_trails,
        rollout_update_rate,
        disc_emb_dim,
        disc_hid_dim,
        load_tensors_from_path,
        # SeqGAN params end
        seed=42,
        **kwargs):
    device = torch.device("cpu")
    cfg = locals().copy()
    torch.manual_seed(seed)

    logger.info("Loading Training set...")

    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies,
        k=k)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation,
        load_tensors_from_path=load_tensors_from_path
    )  # set this to False if no pickle file available

    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    # logger.info("Loading Dev. set...")
    # test_set = GroundedScanDataset(data_path, data_directory, split="dev",
    #                                input_vocabulary_file=input_vocab_path,
    #                                target_vocabulary_file=target_vocab_path, generate_vocabulary=False, k=0)
    # test_set.read_dataset(max_examples=max_testing_examples,
    #                       simple_situation_representation=simple_situation_representation)
    #
    # # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    # test_set.shuffle_data()

    # logger.info("Done Loading Dev. set.")

    generator = Model(
        input_vocabulary_size=training_set.input_vocabulary_size,
        target_vocabulary_size=training_set.target_vocabulary_size,
        num_cnn_channels=training_set.image_channels,
        input_padding_idx=training_set.input_vocabulary.pad_idx,
        target_pad_idx=training_set.target_vocabulary.pad_idx,
        target_eos_idx=training_set.target_vocabulary.eos_idx,
        **cfg)
    total_vocabulary = set(
        list(training_set.input_vocabulary._word_to_idx.keys()) +
        list(training_set.target_vocabulary._word_to_idx.keys()))
    total_vocabulary_size = len(total_vocabulary)
    discriminator = Discriminator(embedding_dim=disc_emb_dim,
                                  hidden_dim=disc_hid_dim,
                                  vocab_size=total_vocabulary_size,
                                  max_seq_len=max_decoding_steps)

    generator = generator.cuda() if use_cuda else generator
    discriminator = discriminator.cuda() if use_cuda else discriminator
    rollout = Rollout(generator, rollout_update_rate)
    log_parameters(generator)
    trainable_parameters = [
        parameter for parameter in generator.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = generator.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = generator.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    if pretrain_gen_path is None:
        print('Pretraining generator with MLE...')
        pre_train_generator(training_set,
                            training_batch_size,
                            generator,
                            seed,
                            pretrain_gen_epochs,
                            name='pretrained_generator')
    else:
        print('Load pretrained generator weights')
        generator_weights = torch.load(pretrain_gen_path)
        generator.load_state_dict(generator_weights)

    if pretrain_disc_path is None:
        print('Pretraining Discriminator....')
        train_discriminator(training_set,
                            discriminator,
                            training_batch_size,
                            generator,
                            seed,
                            pretrain_disc_epochs,
                            name="pretrained_discriminator")
    else:
        print('Loading Discriminator....')
        discriminator_weights = torch.load(pretrain_disc_path)
        discriminator.load_state_dict(discriminator_weights)

    logger.info("Training starts..")
    training_iteration = start_iteration
    torch.autograd.set_detect_anomaly(True)
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()

        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions, target_positions) in \
                training_set.get_data_iterator(batch_size=training_batch_size):

            is_best = False
            generator.train()

            # Forward pass.
            samples = generator.sample(
                batch_size=training_batch_size,
                max_seq_len=max(target_lengths).astype(int),
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                sos_idx=training_set.input_vocabulary.sos_idx,
                eos_idx=training_set.input_vocabulary.eos_idx)

            rewards = rollout.get_reward(samples, rollout_trails, input_batch,
                                         input_lengths, situation_batch,
                                         target_batch,
                                         training_set.input_vocabulary.sos_idx,
                                         training_set.input_vocabulary.eos_idx,
                                         discriminator)

            assert samples.shape == rewards.shape

            # calculate rewards
            rewards = torch.exp(rewards).contiguous().view((-1, ))
            if use_cuda:
                rewards = rewards.cuda()

            # get generator scores for sequence
            target_scores = generator.get_normalized_logits(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                samples=samples,
                sample_lengths=target_lengths,
                sos_idx=training_set.input_vocabulary.sos_idx)

            del samples

            # calculate loss on the generated sequence given the rewards
            loss = generator.get_gan_loss(target_scores, target_batch, rewards)

            del rewards

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step(training_iteration)
            optimizer.zero_grad()
            generator.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                # accuracy, exact_match = generator.get_metrics(target_scores, target_batch)
                learning_rate = scheduler.get_lr()[0]
                logger.info("Iteration %08d, loss %8.4f, learning_rate %.5f," %
                            (training_iteration, loss, learning_rate))
                # logger.info("Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                #             % (training_iteration, loss, accuracy, exact_match, learning_rate))
            del target_scores, target_batch

            # # Evaluate on test set.
            # if training_iteration % evaluate_every == 0:
            #     with torch.no_grad():
            #         generator.eval()
            #         logger.info("Evaluating..")
            #         accuracy, exact_match, target_accuracy = evaluate(
            #             test_set.get_data_iterator(batch_size=1), model=generator,
            #             max_decoding_steps=max_decoding_steps, pad_idx=test_set.target_vocabulary.pad_idx,
            #             sos_idx=test_set.target_vocabulary.sos_idx,
            #             eos_idx=test_set.target_vocabulary.eos_idx,
            #             max_examples_to_evaluate=kwargs["max_testing_examples"])
            #         logger.info("  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
            #                     " Target Accuracy: %5.2f" % (accuracy, exact_match, target_accuracy))
            #         if exact_match > best_exact_match:
            #             is_best = True
            #             best_accuracy = accuracy
            #             best_exact_match = exact_match
            #             generator.update_state(accuracy=accuracy, exact_match=exact_match, is_best=is_best)
            #         file_name = "checkpoint.pth.tar".format(str(training_iteration))
            #         if is_best:
            #             generator.save_checkpoint(file_name=file_name, is_best=is_best,
            #                                       optimizer_state_dict=optimizer.state_dict())

            rollout.update_params()

            train_discriminator(training_set,
                                discriminator,
                                training_batch_size,
                                generator,
                                seed,
                                epochs=1,
                                name="training_discriminator")
            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
            del loss

        torch.save(
            generator.state_dict(),
            '{}/{}'.format(output_directory,
                           'gen_{}_{}.ckpt'.format(training_iteration, seed)))
        torch.save(
            discriminator.state_dict(),
            '{}/{}'.format(output_directory,
                           'dis_{}_{}.ckpt'.format(training_iteration, seed)))

    logger.info("Finished training.")
Exemplo n.º 2
0
def train(data_path: str,
          data_directory: str,
          generate_vocabularies: bool,
          input_vocab_path: str,
          target_vocab_path: str,
          embedding_dimension: int,
          num_encoder_layers: int,
          encoder_dropout_p: float,
          encoder_bidirectional: bool,
          training_batch_size: int,
          test_batch_size: int,
          max_decoding_steps: int,
          num_decoder_layers: int,
          decoder_dropout_p: float,
          cnn_kernel_size: int,
          cnn_dropout_p: float,
          cnn_hidden_num_channels: int,
          simple_situation_representation: bool,
          decoder_hidden_size: int,
          encoder_hidden_size: int,
          learning_rate: float,
          adam_beta_1: float,
          adam_beta_2: float,
          lr_decay: float,
          lr_decay_steps: int,
          resume_from_file: str,
          max_training_iterations: int,
          output_directory: str,
          print_every: int,
          evaluate_every: int,
          conditional_attention: bool,
          auxiliary_task: bool,
          weight_target_loss: float,
          attention_type: str,
          max_training_examples=None,
          seed=42,
          **kwargs):
    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')
    cfg = locals().copy()

    torch.manual_seed(seed)

    logger.info("Loading Training set...")
    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation)
    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    logger.info("Loading Test set...")
    test_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="test",  # TODO: use dev set here
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=False)
    test_set.read_dataset(
        max_examples=None,
        simple_situation_representation=simple_situation_representation)

    # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    test_set.shuffle_data()
    logger.info("Done Loading Test set.")

    model = Model(input_vocabulary_size=training_set.input_vocabulary_size,
                  target_vocabulary_size=training_set.target_vocabulary_size,
                  num_cnn_channels=training_set.image_channels,
                  input_padding_idx=training_set.input_vocabulary.pad_idx,
                  target_pad_idx=training_set.target_vocabulary.pad_idx,
                  target_eos_idx=training_set.target_vocabulary.eos_idx,
                  **cfg)
    model = model.cuda() if use_cuda else model
    log_parameters(model)
    trainable_parameters = [
        parameter for parameter in model.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = model.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = model.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    logger.info("Training starts..")
    training_iteration = start_iteration
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()
        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions,
             target_positions) in training_set.get_data_iterator(
                 batch_size=training_batch_size):
            is_best = False
            model.train()

            # Forward pass.
            target_scores, target_position_scores = model(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                target_lengths=target_lengths)
            loss = model.get_loss(target_scores, target_batch)
            if auxiliary_task:
                target_loss = model.get_auxiliary_loss(target_position_scores,
                                                       target_positions)
            else:
                target_loss = 0
            loss += weight_target_loss * target_loss

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                accuracy, exact_match = model.get_metrics(
                    target_scores, target_batch)
                if auxiliary_task:
                    auxiliary_accuracy_target = model.get_auxiliary_accuracy(
                        target_position_scores, target_positions)
                else:
                    auxiliary_accuracy_target = 0.
                learning_rate = scheduler.get_lr()[0]
                logger.info(
                    "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                    " aux. accuracy target pos %5.2f" %
                    (training_iteration, loss, accuracy, exact_match,
                     learning_rate, auxiliary_accuracy_target))

            # Evaluate on test set.
            if training_iteration % evaluate_every == 0:
                with torch.no_grad():
                    model.eval()
                    logger.info("Evaluating..")
                    accuracy, exact_match, target_accuracy = evaluate(
                        test_set.get_data_iterator(batch_size=1),
                        model=model,
                        max_decoding_steps=max_decoding_steps,
                        pad_idx=test_set.target_vocabulary.pad_idx,
                        sos_idx=test_set.target_vocabulary.sos_idx,
                        eos_idx=test_set.target_vocabulary.eos_idx,
                        max_examples_to_evaluate=kwargs["max_testing_examples"]
                    )
                    logger.info(
                        "  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
                        " Target Accuracy: %5.2f" %
                        (accuracy, exact_match, target_accuracy))
                    if exact_match > best_exact_match:
                        is_best = True
                        best_accuracy = accuracy
                        best_exact_match = exact_match
                        model.update_state(accuracy=accuracy,
                                           exact_match=exact_match,
                                           is_best=is_best)
                    file_name = "checkpoint.pth.tar".format(
                        str(training_iteration))
                    if is_best:
                        model.save_checkpoint(
                            file_name=file_name,
                            is_best=is_best,
                            optimizer_state_dict=optimizer.state_dict())

            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
    logger.info("Finished training.")