Ejemplo n.º 1
0
 def __init__(self, data_path, hparams, flags):
     super(SeqGAN, self).__init__()
     self.hparams = hparams
     device = torch.device("cuda")
     # cfg = locals().copy()
     torch.manual_seed(self.hparams.seed)
     self.training_set = self.load_training_set(data_path)
     self.total_vocabulary_size = self.generate_vocab()
     self.generator = Model(
         input_vocabulary_size=self.training_set.input_vocabulary_size,
         target_vocabulary_size=self.training_set.target_vocabulary_size,
         num_cnn_channels=self.training_set.image_channels,
         input_padding_idx=self.training_set.input_vocabulary.pad_idx,
         target_pad_idx=self.training_set.target_vocabulary.pad_idx,
         target_eos_idx=self.training_set.target_vocabulary.eos_idx,
         device=device,
         **flags)
     self.discriminator = Discriminator(
         embedding_dim=self.hparams.disc_emb_dim,
         hidden_dim=self.hparams.disc_hid_dim,
         vocab_size=self.total_vocabulary_size,
         max_seq_len=self.hparams.max_decoding_steps)
     self.rollout = Rollout(self.generator,
                            self.hparams.rollout_update_rate)
Ejemplo n.º 2
0
def main(flags):
    for argument, value in flags.items():
        logger.info("{}: {}".format(argument, value))

    if not os.path.exists(flags["output_directory"]):
        os.mkdir(os.path.join(os.getcwd(), flags["output_directory"]))

    # Some checks on the flags
    if flags["generate_vocabularies"]:
        assert flags["input_vocab_path"] and flags[
            "target_vocab_path"], "Please specify paths to vocabularies to save."

    if flags["test_batch_size"] > 1:
        raise NotImplementedError(
            "Test batch size larger than 1 not implemented.")

    data_path = os.path.join(flags["data_directory"], "dataset.txt")
    if flags["mode"] == "train":
        train(data_path=data_path, **flags)
    elif flags["mode"] == "test":
        assert os.path.exists(os.path.join(flags["data_directory"], flags["input_vocab_path"])) and os.path.exists(
            os.path.join(flags["data_directory"], flags["target_vocab_path"])), \
            "No vocabs found at {} and {}".format(flags["input_vocab_path"], flags["target_vocab_path"])
        logger.info("Loading {} dataset split...".format(flags["split"]))
        test_set = GroundedScanDataset(
            data_path,
            flags["data_directory"],
            split=flags["split"],
            input_vocabulary_file=flags["input_vocab_path"],
            target_vocabulary_file=flags["target_vocab_path"],
            generate_vocabulary=False)
        test_set.read_dataset(max_examples=flags["max_testing_examples"],
                              simple_situation_representation=flags[
                                  "simple_situation_representation"])
        logger.info("Done Loading {} dataset split.".format(flags["split"]))
        logger.info("  Loaded {} examples.".format(test_set.num_examples))
        logger.info("  Input vocabulary size: {}".format(
            test_set.input_vocabulary_size))
        logger.info("  Most common input words: {}".format(
            test_set.input_vocabulary.most_common(5)))
        logger.info("  Output vocabulary size: {}".format(
            test_set.target_vocabulary_size))
        logger.info("  Most common target words: {}".format(
            test_set.target_vocabulary.most_common(5)))

        model = Model(input_vocabulary_size=test_set.input_vocabulary_size,
                      target_vocabulary_size=test_set.target_vocabulary_size,
                      num_cnn_channels=test_set.image_channels,
                      input_padding_idx=test_set.input_vocabulary.pad_idx,
                      target_pad_idx=test_set.target_vocabulary.pad_idx,
                      target_eos_idx=test_set.target_vocabulary.eos_idx,
                      **flags)
        model = model.cuda() if use_cuda else model

        # Load model and vocabularies if resuming.
        assert os.path.isfile(
            flags["resume_from_file"]), "No checkpoint found at {}".format(
                flags["resume_from_file"])
        logger.info("Loading checkpoint from file at '{}'".format(
            flags["resume_from_file"]))
        model.load_model(flags["resume_from_file"])
        start_iteration = model.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            flags["resume_from_file"], start_iteration))
        output_file_path = os.path.join(flags["output_directory"],
                                        flags["output_file_name"])
        output_file = predict_and_save(dataset=test_set,
                                       model=model,
                                       output_file_path=output_file_path,
                                       **flags)
        logger.info("Saved predictions to {}".format(output_file))
    elif flags["mode"] == "predict":
        raise NotImplementedError()
    else:
        raise ValueError("Wrong value for parameters --mode ({}).".format(
            flags["mode"]))
Ejemplo n.º 3
0
def main():
    config = Config()
    seq2seq = Model(config)
    seq2seq.eval()
    seq2seq.print_parameters()
    if os.path.isfile(args.seq2seq_path):
        _, _ = seq2seq.load_model(args.seq2seq_path)
        print('载入seq2seq模型完成!')
    else:
        print('请载入一个seq2seq模型!')
        return
    if args.gpu:
        seq2seq.to('cuda')

    file_readers = []  # 读取所有结果文件
    if os.path.isdir(args.samples_path):
        for root, dirs, files in os.walk(args.samples_path):
            for idx, file in enumerate(files):
                print(f'打开第{idx}个采样文件:{file}')
                file_readers.append(open(os.path.join(args.samples_path, file), 'r', encoding='utf8'))
        print(f'所有采样文件打开完毕,共打开{len(file_readers)}个文件!')
    else:
        print(f'{os.path.abspath(args.samples_path)}路径错误!')
        return

    results = []  # 将所有文件结果合并
    for fid, fr in enumerate(file_readers):
        for lid, line in enumerate(fr):
            data = json.loads(line.strip())
            if fid == 0:
                result = {'post': data['post'], 'response': data['response'], 'result': [data['result']]}
                results.append(result)
            else:
                results[lid]['result'].append(data['result'])
    print(f'共读取{len(results)}条数据!')
    for fr in file_readers:
        fr.close()

    vocab, vads = [], []  # 读取vad字典
    with open(args.vad_path, 'r', encoding='utf8') as fr:
        for line in fr:
            line = line.strip()
            word = line[: line.find(' ')]
            vad = line[line.find(' ') + 1:].split()
            vad = [float(item) for item in vad]
            vocab.append(word)
            vads.append(vad)
    print(f'载入词汇表: {len(vocab)}个')
    print(f'载入vad字典: {len(vads)}个')

    sp = SentenceProcessor(vocab, vads, config.pad_id, config.start_id, config.end_id, config.unk_id)

    if not os.path.exists(args.result_path):  # 创建结果文件夹
        os.makedirs(args.result_path)
    fw = open(os.path.join(args.result_path, 'result.txt'), 'w', encoding='utf8')
    fwd = open(os.path.join(args.result_path, 'detail.txt'), 'w', encoding='utf8')

    for result in tqdm(results):  # 对每个post的回复进行重排序
        str_post = result['post']  # [len]
        str_response = result['response']  # [len]
        str_results = result['result']  # [sample, len]
        sample_times = len(str_results)

        # 1. seq2seq给出语法流利的分数
        id_post, len_post = sp.word2index(str_post)
        id_post = [sp.start_id] + id_post + [sp.end_id]
        id_posts = [id_post for _ in range(sample_times)]  # [sample, len]
        len_posts = [len_post for _ in range(sample_times)]  # [sample]

        id_results, len_results = [], []
        for str_result in str_results:
            id_result, len_result = sp.word2index(str_result)
            id_results.append(id_result)
            len_results.append(len_result)

        len_posts = [l+2 for l in len_posts]  # 加上start和end
        len_results = [l+2 for l in len_results]

        max_len_results = max(len_results)
        id_results = [sp.pad_sentence(id_result, max_len_results) for id_result in id_results]  # 补齐

        feed_data = {'posts': id_posts, 'responses': id_results, 'len_posts': len_posts, 'len_responses': len_results}
        feed_data = prepare_feed_data(feed_data)
        output_vocab = seq2seq(feed_data, gpu=args.gpu)  # [sample, len_decoder, num_vocab]

        masks = feed_data['masks']  # [sample, len_decoder]
        labels = feed_data['responses'][:, 1:]  # [sample, len_decoder]
        token_per_batch = masks.sum(1)
        nll_loss = F.nll_loss(output_vocab.reshape(-1, config.num_vocab).clamp_min(1e-12).log(),
                              labels.reshape(-1), reduction='none') * masks.reshape(-1)
        nll_loss = nll_loss.reshape(sample_times, -1).sum(1)  # [sample]
        ppl = (nll_loss / token_per_batch.clamp_min(1e-12)).exp().cpu().detach().numpy()  # [sample]
        score_ppl = (ppl - ppl.min()) / (ppl.max() - ppl.min())  # [sample]

        # 语义
        embed_posts = torch.cat([seq2seq.embedding(feed_data['posts']), seq2seq.affect_embedding(feed_data['posts'])], 2)
        embed_responses = torch.cat([seq2seq.embedding(feed_data['responses']), seq2seq.affect_embedding(feed_data['responses'])], 2)
        embed_posts = embed_posts.sum(1) / feed_data['len_posts'].float().unsqueeze(1).clamp_min(1e-12)  # [sample, 303]
        embed_responses = embed_responses.sum(1) / feed_data['len_responses'].float().unsqueeze(1).clamp_min(1e-12)
        score_cos = torch.cosine_similarity(embed_posts, embed_responses, 1).cpu().detach().numpy()  # [sample]
        score_cos = (score_cos - score_cos.min()) / (score_cos.max() - score_cos.min())

        # 2. vad奖励分数
        vad_posts = np.array([sp.index2vad(id_post) for id_post in id_posts])[:, 1:]  # [sample, len, 3]
        vad_results = np.array([sp.index2vad(id_result) for id_result in id_results])[:, 1:]

        neutral_posts = np.tile(np.array([0.5, 0.5, 0.5]).reshape(1, 1, -1), (sample_times, vad_posts.shape[1], 1))
        neutral_results = np.tile(np.array([0.5, 0.5, 0.5]).reshape(1, 1, -1), (sample_times, vad_results.shape[1], 1))

        posts_mask = 1 - (vad_posts == neutral_posts).astype(np.float).prod(2)  # [sample, len]
        affect_posts = (vad_posts * np.expand_dims(posts_mask, 2)).sum(1) / posts_mask.sum(1).clip(1e-12).reshape(sample_times, 1)
        results_mask = 1 - (vad_results == neutral_results).astype(np.float).prod(2)  # [sample, len]
        affect_results = (vad_results * np.expand_dims(results_mask, 2)).sum(1) / results_mask.sum(1).clip(1e-12).reshape(sample_times, 1)

        post_v = affect_posts[:, 0]  # batch
        post_a = affect_posts[:, 1]
        post_d = affect_posts[:, 2]
        result_v = affect_results[:, 0]
        result_a = affect_results[:, 1]
        result_d = affect_results[:, 2]

        score_v = 1 - np.abs(post_v - result_v)  # [0, 1]
        score_a = np.abs(post_a - result_a)
        score_d = np.abs(post_d - result_d)
        score_vad = score_v + score_a + score_d
        baseline_score_vad = score_vad.mean()
        score_vad = score_vad - baseline_score_vad
        score_vad = (score_vad - score_vad.min()) / (score_vad.max() - score_vad.min())

        # 3. 情感分数
        # score_af = ((vad_results - neutral_results) ** 2).sum(2) ** 0.5  # [sample, len]
        # token_per_batch = token_per_batch.cpu().detach().numpy() - 1  # [sample]
        # score_af = score_af.sum(1) / token_per_batch.clip(1e-12)
        # score_af = (score_af - score_af.min()) / (score_af.max() - score_af.min())

        # 4. 句子长度
        # score_len = np.array([len(str_result) for str_result in str_results])  # [sample]
        # score_len = (score_len - score_len.min()) / (score_len.max() - score_len.min())

        score = 0.1*score_ppl + 0.4*score_vad + 0.5*score_cos
        output_id = score.argmax()

        output = {'post': str_post, 'response': str_response, 'result': str_results[output_id]}
        fw.write(json.dumps(output, ensure_ascii=False) + '\n')

        fwd.write('post: {}\n'.format(' '.join(str_post)))
        fwd.write('chosen response: {}\n'.format(' '.join(str_results[output_id])))
        fwd.write('response: {}\n'.format(' '.join(str_response)))
        for idx, str_result in enumerate(str_results):
            fwd.write('candidate{}: {} (t:{:.2f} p:{:.2f} v:{:.2f} c:{:.2f})\n'
                      .format(idx, ' '.join(str_result), score[idx], 0.1*score_ppl[idx], 0.4*score_vad[idx],
                              0.5*score_cos[idx]))
        fwd.write('\n')
    fw.close()
    fwd.close()
Ejemplo n.º 4
0
def train(
        data_path: str,
        data_directory: str,
        generate_vocabularies: bool,
        input_vocab_path: str,
        target_vocab_path: str,
        embedding_dimension: int,
        num_encoder_layers: int,
        encoder_dropout_p: float,
        encoder_bidirectional: bool,
        training_batch_size: int,
        test_batch_size: int,
        max_decoding_steps: int,
        num_decoder_layers: int,
        decoder_dropout_p: float,
        cnn_kernel_size: int,
        cnn_dropout_p: float,
        cnn_hidden_num_channels: int,
        simple_situation_representation: bool,
        decoder_hidden_size: int,
        encoder_hidden_size: int,
        learning_rate: float,
        adam_beta_1: float,
        adam_beta_2: float,
        lr_decay: float,
        lr_decay_steps: int,
        resume_from_file: str,
        max_training_iterations: int,
        output_directory: str,
        print_every: int,
        evaluate_every: int,
        conditional_attention: bool,
        auxiliary_task: bool,
        weight_target_loss: float,
        attention_type: str,
        k: int,
        max_training_examples,
        max_testing_examples,
        # SeqGAN params begin
        pretrain_gen_path,
        pretrain_gen_epochs,
        pretrain_disc_path,
        pretrain_disc_epochs,
        rollout_trails,
        rollout_update_rate,
        disc_emb_dim,
        disc_hid_dim,
        load_tensors_from_path,
        # SeqGAN params end
        seed=42,
        **kwargs):
    device = torch.device("cpu")
    cfg = locals().copy()
    torch.manual_seed(seed)

    logger.info("Loading Training set...")

    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies,
        k=k)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation,
        load_tensors_from_path=load_tensors_from_path
    )  # set this to False if no pickle file available

    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    # logger.info("Loading Dev. set...")
    # test_set = GroundedScanDataset(data_path, data_directory, split="dev",
    #                                input_vocabulary_file=input_vocab_path,
    #                                target_vocabulary_file=target_vocab_path, generate_vocabulary=False, k=0)
    # test_set.read_dataset(max_examples=max_testing_examples,
    #                       simple_situation_representation=simple_situation_representation)
    #
    # # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    # test_set.shuffle_data()

    # logger.info("Done Loading Dev. set.")

    generator = Model(
        input_vocabulary_size=training_set.input_vocabulary_size,
        target_vocabulary_size=training_set.target_vocabulary_size,
        num_cnn_channels=training_set.image_channels,
        input_padding_idx=training_set.input_vocabulary.pad_idx,
        target_pad_idx=training_set.target_vocabulary.pad_idx,
        target_eos_idx=training_set.target_vocabulary.eos_idx,
        **cfg)
    total_vocabulary = set(
        list(training_set.input_vocabulary._word_to_idx.keys()) +
        list(training_set.target_vocabulary._word_to_idx.keys()))
    total_vocabulary_size = len(total_vocabulary)
    discriminator = Discriminator(embedding_dim=disc_emb_dim,
                                  hidden_dim=disc_hid_dim,
                                  vocab_size=total_vocabulary_size,
                                  max_seq_len=max_decoding_steps)

    generator = generator.cuda() if use_cuda else generator
    discriminator = discriminator.cuda() if use_cuda else discriminator
    rollout = Rollout(generator, rollout_update_rate)
    log_parameters(generator)
    trainable_parameters = [
        parameter for parameter in generator.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = generator.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = generator.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    if pretrain_gen_path is None:
        print('Pretraining generator with MLE...')
        pre_train_generator(training_set,
                            training_batch_size,
                            generator,
                            seed,
                            pretrain_gen_epochs,
                            name='pretrained_generator')
    else:
        print('Load pretrained generator weights')
        generator_weights = torch.load(pretrain_gen_path)
        generator.load_state_dict(generator_weights)

    if pretrain_disc_path is None:
        print('Pretraining Discriminator....')
        train_discriminator(training_set,
                            discriminator,
                            training_batch_size,
                            generator,
                            seed,
                            pretrain_disc_epochs,
                            name="pretrained_discriminator")
    else:
        print('Loading Discriminator....')
        discriminator_weights = torch.load(pretrain_disc_path)
        discriminator.load_state_dict(discriminator_weights)

    logger.info("Training starts..")
    training_iteration = start_iteration
    torch.autograd.set_detect_anomaly(True)
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()

        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions, target_positions) in \
                training_set.get_data_iterator(batch_size=training_batch_size):

            is_best = False
            generator.train()

            # Forward pass.
            samples = generator.sample(
                batch_size=training_batch_size,
                max_seq_len=max(target_lengths).astype(int),
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                sos_idx=training_set.input_vocabulary.sos_idx,
                eos_idx=training_set.input_vocabulary.eos_idx)

            rewards = rollout.get_reward(samples, rollout_trails, input_batch,
                                         input_lengths, situation_batch,
                                         target_batch,
                                         training_set.input_vocabulary.sos_idx,
                                         training_set.input_vocabulary.eos_idx,
                                         discriminator)

            assert samples.shape == rewards.shape

            # calculate rewards
            rewards = torch.exp(rewards).contiguous().view((-1, ))
            if use_cuda:
                rewards = rewards.cuda()

            # get generator scores for sequence
            target_scores = generator.get_normalized_logits(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                samples=samples,
                sample_lengths=target_lengths,
                sos_idx=training_set.input_vocabulary.sos_idx)

            del samples

            # calculate loss on the generated sequence given the rewards
            loss = generator.get_gan_loss(target_scores, target_batch, rewards)

            del rewards

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step(training_iteration)
            optimizer.zero_grad()
            generator.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                # accuracy, exact_match = generator.get_metrics(target_scores, target_batch)
                learning_rate = scheduler.get_lr()[0]
                logger.info("Iteration %08d, loss %8.4f, learning_rate %.5f," %
                            (training_iteration, loss, learning_rate))
                # logger.info("Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                #             % (training_iteration, loss, accuracy, exact_match, learning_rate))
            del target_scores, target_batch

            # # Evaluate on test set.
            # if training_iteration % evaluate_every == 0:
            #     with torch.no_grad():
            #         generator.eval()
            #         logger.info("Evaluating..")
            #         accuracy, exact_match, target_accuracy = evaluate(
            #             test_set.get_data_iterator(batch_size=1), model=generator,
            #             max_decoding_steps=max_decoding_steps, pad_idx=test_set.target_vocabulary.pad_idx,
            #             sos_idx=test_set.target_vocabulary.sos_idx,
            #             eos_idx=test_set.target_vocabulary.eos_idx,
            #             max_examples_to_evaluate=kwargs["max_testing_examples"])
            #         logger.info("  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
            #                     " Target Accuracy: %5.2f" % (accuracy, exact_match, target_accuracy))
            #         if exact_match > best_exact_match:
            #             is_best = True
            #             best_accuracy = accuracy
            #             best_exact_match = exact_match
            #             generator.update_state(accuracy=accuracy, exact_match=exact_match, is_best=is_best)
            #         file_name = "checkpoint.pth.tar".format(str(training_iteration))
            #         if is_best:
            #             generator.save_checkpoint(file_name=file_name, is_best=is_best,
            #                                       optimizer_state_dict=optimizer.state_dict())

            rollout.update_params()

            train_discriminator(training_set,
                                discriminator,
                                training_batch_size,
                                generator,
                                seed,
                                epochs=1,
                                name="training_discriminator")
            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
            del loss

        torch.save(
            generator.state_dict(),
            '{}/{}'.format(output_directory,
                           'gen_{}_{}.ckpt'.format(training_iteration, seed)))
        torch.save(
            discriminator.state_dict(),
            '{}/{}'.format(output_directory,
                           'dis_{}_{}.ckpt'.format(training_iteration, seed)))

    logger.info("Finished training.")
Ejemplo n.º 5
0
class SeqGAN(pl.LightningModule):
    def __init__(self, data_path, hparams, flags):
        super(SeqGAN, self).__init__()
        self.hparams = hparams
        device = torch.device("cuda")
        # cfg = locals().copy()
        torch.manual_seed(self.hparams.seed)
        self.training_set = self.load_training_set(data_path)
        self.total_vocabulary_size = self.generate_vocab()
        self.generator = Model(
            input_vocabulary_size=self.training_set.input_vocabulary_size,
            target_vocabulary_size=self.training_set.target_vocabulary_size,
            num_cnn_channels=self.training_set.image_channels,
            input_padding_idx=self.training_set.input_vocabulary.pad_idx,
            target_pad_idx=self.training_set.target_vocabulary.pad_idx,
            target_eos_idx=self.training_set.target_vocabulary.eos_idx,
            device=device,
            **flags)
        self.discriminator = Discriminator(
            embedding_dim=self.hparams.disc_emb_dim,
            hidden_dim=self.hparams.disc_hid_dim,
            vocab_size=self.total_vocabulary_size,
            max_seq_len=self.hparams.max_decoding_steps)
        self.rollout = Rollout(self.generator,
                               self.hparams.rollout_update_rate)

        # if pretrain_gen_path is None:
        #     print('Pretraining generator with MLE...')
        #     pre_train_generator(training_set, training_batch_size, generator, seed, pretrain_gen_epochs,
        #                         name='pretrained_generator')
        # else:
        #     print('Load pretrained generator weights')
        #     generator_weights = torch.load(pretrain_gen_path)
        #     generator.load_state_dict(generator_weights)
        #
        # if pretrain_disc_path is None:
        #     print('Pretraining Discriminator....')
        #     train_discriminator(training_set, discriminator, training_batch_size, generator, seed, pretrain_disc_epochs,
        #                         name="pretrained_discriminator")
        # else:
        #     print('Loading Discriminator....')
        #     discriminator_weights = torch.load(pretrain_disc_path)
        #     discriminator.load_state_dict(discriminator_weights)

    def load_training_set(self, data_path):
        logger.info("Loading Training set...")
        training_set = GroundedScanDataset(
            data_path,
            self.hparams.data_directory,
            split="train",
            input_vocabulary_file=self.hparams.input_vocab_path,
            target_vocabulary_file=self.hparams.target_vocab_path,
            generate_vocabulary=self.hparams.generate_vocabularies,
            k=self.hparams.k)
        training_set.read_dataset(
            max_examples=self.hparams.max_training_examples,
            simple_situation_representation=self.hparams.
            simple_situation_representation,
            load_tensors_from_path=self.hparams.load_tensors_from_path)
        logger.info("Done Loading Training set.")
        return training_set

    def generate_vocab(self):
        if bool(self.hparams.generate_vocabularies):
            self.training_set.save_vocabularies(self.hparams.input_vocab_path,
                                                self.hparams.target_vocab_path)
        total_vocab_size = len(
            set(
                list(self.training_set.input_vocabulary._word_to_idx.keys()) +
                list(self.training_set.target_vocabulary._word_to_idx.keys())))
        return total_vocab_size

    def forward(self, input_batch, input_lengths, situation_batch,
                target_batch, target_lengths):
        samples = self.generator.sample(
            batch_size=self.hparams.training_batch_size,
            max_seq_len=max(target_lengths).astype(int),
            commands_input=input_batch,
            commands_lengths=input_lengths,
            situations_input=situation_batch,
            target_batch=target_batch,
            sos_idx=self.training_set.input_vocabulary.sos_idx,
            eos_idx=self.training_set.input_vocabulary.eos_idx)

        target_scores = self.generator.get_normalized_logits(
            commands_input=input_batch,
            commands_lengths=input_lengths,
            situations_input=situation_batch,
            samples=samples,
            sample_lengths=target_lengths,
            sos_idx=self.training_set.input_vocabulary.sos_idx)

        rewards = self.rollout.get_reward(
            samples, self.hparams.rollout_trails, input_batch, input_lengths,
            situation_batch, target_batch,
            self.training_set.input_vocabulary.sos_idx,
            self.training_set.input_vocabulary.eos_idx, self.discriminator)
        return target_scores, rewards

    def configure_optimizers(self):
        lr = self.hparams.learning_rate
        lr_decay = self.hparams.lr_decay
        lr_decay_steps = self.hparams.lr_decay_steps
        b1 = self.hparams.adam_beta_1
        b2 = self.hparams.adam_beta_2

        trainable_parameters = [
            parameter for parameter in self.generator.parameters()
            if parameter.requires_grad
        ]
        opt_g = torch.optim.Adam(trainable_parameters, lr=lr, betas=(b1, b2))
        trainable_parameters = [
            parameter for parameter in self.discriminator.parameters()
            if parameter.requires_grad
        ]
        opt_d = torch.optim.Adam(trainable_parameters, lr=lr, betas=(b1, b2))

        scheduler_g = LambdaLR(
            opt_g, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))
        scheduler_d = LambdaLR(
            opt_g, lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

        return [opt_g, opt_d], [scheduler_g, scheduler_d]

    def training_step(self, batch, batch_idx, optimizer_idx):
        input_batch, input_lengths, _, situation_batch, _, target_batch, target_lengths, agent_positions, target_positions = batch
        # input_batch, input_lengths, situation_batch, target_batch, target_lengths = batch

        if optimizer_idx == 0:
            pred, rewards = self(input_batch, input_lengths, situation_batch,
                                 target_batch, target_lengths)
            pred = pred.cuda()
            rewards = rewards.cuda()
            g_loss = self.generator.get_gan_loss(pred, target_batch, rewards)
            del rewards, pred
            # print("Iteration %08d, loss %8.4f" % (batch_idx, g_loss))
            self.rollout.update_params()
            tqdm_dict = {'d_loss': g_loss}
            output = OrderedDict({
                'loss': g_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

        if optimizer_idx == 1:
            neg_samples = self.generator.sample(
                batch_size=self.hparams.training_batch_size,
                max_seq_len=max(target_lengths).astype(int),
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                sos_idx=self.training_set.input_vocabulary.sos_idx,
                eos_idx=self.training_set.input_vocabulary.eos_idx)
            fake = torch.zeros(neg_samples.size(0), 1)
            fake = fake.type_as(fake)
            neg_out = self.discriminator.batchClassify(target_batch.long())
            fake_loss = F.binary_cross_entropy(neg_out, fake)

            valid = torch.ones(target_batch.size(0), 1)
            valid = valid.type_as(target_batch)
            pos_out = self.discriminator.batchClassify(target_batch.long())
            real_loss = F.binary_cross_entropy(pos_out, valid)

            d_loss = (real_loss + fake_loss) / 2
            tqdm_dict = {'d_loss': d_loss}
            output = OrderedDict({
                'loss': d_loss,
                'progress_bar': tqdm_dict,
                'log': tqdm_dict
            })
            return output

    def train_dataloader(self):
        return self.training_set.get_data_iterator(
            batch_size=self.hparams.training_batch_size)

    def on_epoch_end(self):
        torch.save(
            self.generator.state_dict(),
            os.path.join(self.hparams.output_directory,
                         'gen_{}.ckpt'.format(self.hparams.seed)))
        torch.save(
            self.discriminator.state_dict(),
            os.path.join(self.hparams.output_directory,
                         'dis_{}.ckpt'.format(self.hparams.seed)))
Ejemplo n.º 6
0
def train(data_path: str,
          data_directory: str,
          generate_vocabularies: bool,
          input_vocab_path: str,
          target_vocab_path: str,
          embedding_dimension: int,
          num_encoder_layers: int,
          encoder_dropout_p: float,
          encoder_bidirectional: bool,
          training_batch_size: int,
          test_batch_size: int,
          max_decoding_steps: int,
          num_decoder_layers: int,
          decoder_dropout_p: float,
          cnn_kernel_size: int,
          cnn_dropout_p: float,
          cnn_hidden_num_channels: int,
          simple_situation_representation: bool,
          decoder_hidden_size: int,
          encoder_hidden_size: int,
          learning_rate: float,
          adam_beta_1: float,
          adam_beta_2: float,
          lr_decay: float,
          lr_decay_steps: int,
          resume_from_file: str,
          max_training_iterations: int,
          output_directory: str,
          print_every: int,
          evaluate_every: int,
          conditional_attention: bool,
          auxiliary_task: bool,
          weight_target_loss: float,
          attention_type: str,
          max_training_examples=None,
          seed=42,
          **kwargs):
    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')
    cfg = locals().copy()

    torch.manual_seed(seed)

    logger.info("Loading Training set...")
    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation)
    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    logger.info("Loading Test set...")
    test_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="test",  # TODO: use dev set here
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=False)
    test_set.read_dataset(
        max_examples=None,
        simple_situation_representation=simple_situation_representation)

    # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    test_set.shuffle_data()
    logger.info("Done Loading Test set.")

    model = Model(input_vocabulary_size=training_set.input_vocabulary_size,
                  target_vocabulary_size=training_set.target_vocabulary_size,
                  num_cnn_channels=training_set.image_channels,
                  input_padding_idx=training_set.input_vocabulary.pad_idx,
                  target_pad_idx=training_set.target_vocabulary.pad_idx,
                  target_eos_idx=training_set.target_vocabulary.eos_idx,
                  **cfg)
    model = model.cuda() if use_cuda else model
    log_parameters(model)
    trainable_parameters = [
        parameter for parameter in model.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = model.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = model.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    logger.info("Training starts..")
    training_iteration = start_iteration
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()
        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions,
             target_positions) in training_set.get_data_iterator(
                 batch_size=training_batch_size):
            is_best = False
            model.train()

            # Forward pass.
            target_scores, target_position_scores = model(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                target_lengths=target_lengths)
            loss = model.get_loss(target_scores, target_batch)
            if auxiliary_task:
                target_loss = model.get_auxiliary_loss(target_position_scores,
                                                       target_positions)
            else:
                target_loss = 0
            loss += weight_target_loss * target_loss

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                accuracy, exact_match = model.get_metrics(
                    target_scores, target_batch)
                if auxiliary_task:
                    auxiliary_accuracy_target = model.get_auxiliary_accuracy(
                        target_position_scores, target_positions)
                else:
                    auxiliary_accuracy_target = 0.
                learning_rate = scheduler.get_lr()[0]
                logger.info(
                    "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                    " aux. accuracy target pos %5.2f" %
                    (training_iteration, loss, accuracy, exact_match,
                     learning_rate, auxiliary_accuracy_target))

            # Evaluate on test set.
            if training_iteration % evaluate_every == 0:
                with torch.no_grad():
                    model.eval()
                    logger.info("Evaluating..")
                    accuracy, exact_match, target_accuracy = evaluate(
                        test_set.get_data_iterator(batch_size=1),
                        model=model,
                        max_decoding_steps=max_decoding_steps,
                        pad_idx=test_set.target_vocabulary.pad_idx,
                        sos_idx=test_set.target_vocabulary.sos_idx,
                        eos_idx=test_set.target_vocabulary.eos_idx,
                        max_examples_to_evaluate=kwargs["max_testing_examples"]
                    )
                    logger.info(
                        "  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
                        " Target Accuracy: %5.2f" %
                        (accuracy, exact_match, target_accuracy))
                    if exact_match > best_exact_match:
                        is_best = True
                        best_accuracy = accuracy
                        best_exact_match = exact_match
                        model.update_state(accuracy=accuracy,
                                           exact_match=exact_match,
                                           is_best=is_best)
                    file_name = "checkpoint.pth.tar".format(
                        str(training_iteration))
                    if is_best:
                        model.save_checkpoint(
                            file_name=file_name,
                            is_best=is_best,
                            optimizer_state_dict=optimizer.state_dict())

            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
    logger.info("Finished training.")
Ejemplo n.º 7
0
# TODO: test gSCAN_dataset.py (SOS and EOS and padding and unk)
# TODO: test model.py (masking and stuff)
import unittest
import torch

from seq2seq.model import Model

test_model = Model(input_vocabulary_size=5,
                   embedding_dimension=10,
                   encoder_hidden_size=15,
                   num_encoder_layers=1,
                   target_vocabulary_size=4,
                   encoder_dropout_p=0.,
                   encoder_bidirectional=False,
                   num_decoder_layers=1,
                   decoder_dropout_p=0.,
                   image_dimensions=3,
                   num_cnn_channels=3,
                   cnn_kernel_size=1,
                   cnn_dropout_p=0.,
                   cnn_hidden_num_channels=5,
                   input_padding_idx=0,
                   target_pad_idx=0,
                   target_eos_idx=3,
                   output_directory="test_dir")


class TestGroundedScanDataset(unittest.TestCase):
    def test_situation_encoder(self):
        input_image = torch.zeros()
        self.assertEqual('foo'.upper(), 'FOO')