示例#1
0
def main():
    config = Config()
    seq2seq = Model(config)
    seq2seq.eval()
    seq2seq.print_parameters()
    if os.path.isfile(args.seq2seq_path):
        _, _ = seq2seq.load_model(args.seq2seq_path)
        print('载入seq2seq模型完成!')
    else:
        print('请载入一个seq2seq模型!')
        return
    if args.gpu:
        seq2seq.to('cuda')

    file_readers = []  # 读取所有结果文件
    if os.path.isdir(args.samples_path):
        for root, dirs, files in os.walk(args.samples_path):
            for idx, file in enumerate(files):
                print(f'打开第{idx}个采样文件:{file}')
                file_readers.append(open(os.path.join(args.samples_path, file), 'r', encoding='utf8'))
        print(f'所有采样文件打开完毕,共打开{len(file_readers)}个文件!')
    else:
        print(f'{os.path.abspath(args.samples_path)}路径错误!')
        return

    results = []  # 将所有文件结果合并
    for fid, fr in enumerate(file_readers):
        for lid, line in enumerate(fr):
            data = json.loads(line.strip())
            if fid == 0:
                result = {'post': data['post'], 'response': data['response'], 'result': [data['result']]}
                results.append(result)
            else:
                results[lid]['result'].append(data['result'])
    print(f'共读取{len(results)}条数据!')
    for fr in file_readers:
        fr.close()

    vocab, vads = [], []  # 读取vad字典
    with open(args.vad_path, 'r', encoding='utf8') as fr:
        for line in fr:
            line = line.strip()
            word = line[: line.find(' ')]
            vad = line[line.find(' ') + 1:].split()
            vad = [float(item) for item in vad]
            vocab.append(word)
            vads.append(vad)
    print(f'载入词汇表: {len(vocab)}个')
    print(f'载入vad字典: {len(vads)}个')

    sp = SentenceProcessor(vocab, vads, config.pad_id, config.start_id, config.end_id, config.unk_id)

    if not os.path.exists(args.result_path):  # 创建结果文件夹
        os.makedirs(args.result_path)
    fw = open(os.path.join(args.result_path, 'result.txt'), 'w', encoding='utf8')
    fwd = open(os.path.join(args.result_path, 'detail.txt'), 'w', encoding='utf8')

    for result in tqdm(results):  # 对每个post的回复进行重排序
        str_post = result['post']  # [len]
        str_response = result['response']  # [len]
        str_results = result['result']  # [sample, len]
        sample_times = len(str_results)

        # 1. seq2seq给出语法流利的分数
        id_post, len_post = sp.word2index(str_post)
        id_post = [sp.start_id] + id_post + [sp.end_id]
        id_posts = [id_post for _ in range(sample_times)]  # [sample, len]
        len_posts = [len_post for _ in range(sample_times)]  # [sample]

        id_results, len_results = [], []
        for str_result in str_results:
            id_result, len_result = sp.word2index(str_result)
            id_results.append(id_result)
            len_results.append(len_result)

        len_posts = [l+2 for l in len_posts]  # 加上start和end
        len_results = [l+2 for l in len_results]

        max_len_results = max(len_results)
        id_results = [sp.pad_sentence(id_result, max_len_results) for id_result in id_results]  # 补齐

        feed_data = {'posts': id_posts, 'responses': id_results, 'len_posts': len_posts, 'len_responses': len_results}
        feed_data = prepare_feed_data(feed_data)
        output_vocab = seq2seq(feed_data, gpu=args.gpu)  # [sample, len_decoder, num_vocab]

        masks = feed_data['masks']  # [sample, len_decoder]
        labels = feed_data['responses'][:, 1:]  # [sample, len_decoder]
        token_per_batch = masks.sum(1)
        nll_loss = F.nll_loss(output_vocab.reshape(-1, config.num_vocab).clamp_min(1e-12).log(),
                              labels.reshape(-1), reduction='none') * masks.reshape(-1)
        nll_loss = nll_loss.reshape(sample_times, -1).sum(1)  # [sample]
        ppl = (nll_loss / token_per_batch.clamp_min(1e-12)).exp().cpu().detach().numpy()  # [sample]
        score_ppl = (ppl - ppl.min()) / (ppl.max() - ppl.min())  # [sample]

        # 语义
        embed_posts = torch.cat([seq2seq.embedding(feed_data['posts']), seq2seq.affect_embedding(feed_data['posts'])], 2)
        embed_responses = torch.cat([seq2seq.embedding(feed_data['responses']), seq2seq.affect_embedding(feed_data['responses'])], 2)
        embed_posts = embed_posts.sum(1) / feed_data['len_posts'].float().unsqueeze(1).clamp_min(1e-12)  # [sample, 303]
        embed_responses = embed_responses.sum(1) / feed_data['len_responses'].float().unsqueeze(1).clamp_min(1e-12)
        score_cos = torch.cosine_similarity(embed_posts, embed_responses, 1).cpu().detach().numpy()  # [sample]
        score_cos = (score_cos - score_cos.min()) / (score_cos.max() - score_cos.min())

        # 2. vad奖励分数
        vad_posts = np.array([sp.index2vad(id_post) for id_post in id_posts])[:, 1:]  # [sample, len, 3]
        vad_results = np.array([sp.index2vad(id_result) for id_result in id_results])[:, 1:]

        neutral_posts = np.tile(np.array([0.5, 0.5, 0.5]).reshape(1, 1, -1), (sample_times, vad_posts.shape[1], 1))
        neutral_results = np.tile(np.array([0.5, 0.5, 0.5]).reshape(1, 1, -1), (sample_times, vad_results.shape[1], 1))

        posts_mask = 1 - (vad_posts == neutral_posts).astype(np.float).prod(2)  # [sample, len]
        affect_posts = (vad_posts * np.expand_dims(posts_mask, 2)).sum(1) / posts_mask.sum(1).clip(1e-12).reshape(sample_times, 1)
        results_mask = 1 - (vad_results == neutral_results).astype(np.float).prod(2)  # [sample, len]
        affect_results = (vad_results * np.expand_dims(results_mask, 2)).sum(1) / results_mask.sum(1).clip(1e-12).reshape(sample_times, 1)

        post_v = affect_posts[:, 0]  # batch
        post_a = affect_posts[:, 1]
        post_d = affect_posts[:, 2]
        result_v = affect_results[:, 0]
        result_a = affect_results[:, 1]
        result_d = affect_results[:, 2]

        score_v = 1 - np.abs(post_v - result_v)  # [0, 1]
        score_a = np.abs(post_a - result_a)
        score_d = np.abs(post_d - result_d)
        score_vad = score_v + score_a + score_d
        baseline_score_vad = score_vad.mean()
        score_vad = score_vad - baseline_score_vad
        score_vad = (score_vad - score_vad.min()) / (score_vad.max() - score_vad.min())

        # 3. 情感分数
        # score_af = ((vad_results - neutral_results) ** 2).sum(2) ** 0.5  # [sample, len]
        # token_per_batch = token_per_batch.cpu().detach().numpy() - 1  # [sample]
        # score_af = score_af.sum(1) / token_per_batch.clip(1e-12)
        # score_af = (score_af - score_af.min()) / (score_af.max() - score_af.min())

        # 4. 句子长度
        # score_len = np.array([len(str_result) for str_result in str_results])  # [sample]
        # score_len = (score_len - score_len.min()) / (score_len.max() - score_len.min())

        score = 0.1*score_ppl + 0.4*score_vad + 0.5*score_cos
        output_id = score.argmax()

        output = {'post': str_post, 'response': str_response, 'result': str_results[output_id]}
        fw.write(json.dumps(output, ensure_ascii=False) + '\n')

        fwd.write('post: {}\n'.format(' '.join(str_post)))
        fwd.write('chosen response: {}\n'.format(' '.join(str_results[output_id])))
        fwd.write('response: {}\n'.format(' '.join(str_response)))
        for idx, str_result in enumerate(str_results):
            fwd.write('candidate{}: {} (t:{:.2f} p:{:.2f} v:{:.2f} c:{:.2f})\n'
                      .format(idx, ' '.join(str_result), score[idx], 0.1*score_ppl[idx], 0.4*score_vad[idx],
                              0.5*score_cos[idx]))
        fwd.write('\n')
    fw.close()
    fwd.close()
示例#2
0
def train(data_path: str,
          data_directory: str,
          generate_vocabularies: bool,
          input_vocab_path: str,
          target_vocab_path: str,
          embedding_dimension: int,
          num_encoder_layers: int,
          encoder_dropout_p: float,
          encoder_bidirectional: bool,
          training_batch_size: int,
          test_batch_size: int,
          max_decoding_steps: int,
          num_decoder_layers: int,
          decoder_dropout_p: float,
          cnn_kernel_size: int,
          cnn_dropout_p: float,
          cnn_hidden_num_channels: int,
          simple_situation_representation: bool,
          decoder_hidden_size: int,
          encoder_hidden_size: int,
          learning_rate: float,
          adam_beta_1: float,
          adam_beta_2: float,
          lr_decay: float,
          lr_decay_steps: int,
          resume_from_file: str,
          max_training_iterations: int,
          output_directory: str,
          print_every: int,
          evaluate_every: int,
          conditional_attention: bool,
          auxiliary_task: bool,
          weight_target_loss: float,
          attention_type: str,
          max_training_examples=None,
          seed=42,
          **kwargs):
    device = torch.device(type='cuda') if use_cuda else torch.device(
        type='cpu')
    cfg = locals().copy()

    torch.manual_seed(seed)

    logger.info("Loading Training set...")
    training_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="train",
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=generate_vocabularies)
    training_set.read_dataset(
        max_examples=max_training_examples,
        simple_situation_representation=simple_situation_representation)
    logger.info("Done Loading Training set.")
    logger.info("  Loaded {} training examples.".format(
        training_set.num_examples))
    logger.info("  Input vocabulary size training set: {}".format(
        training_set.input_vocabulary_size))
    logger.info("  Most common input words: {}".format(
        training_set.input_vocabulary.most_common(5)))
    logger.info("  Output vocabulary size training set: {}".format(
        training_set.target_vocabulary_size))
    logger.info("  Most common target words: {}".format(
        training_set.target_vocabulary.most_common(5)))

    if generate_vocabularies:
        training_set.save_vocabularies(input_vocab_path, target_vocab_path)
        logger.info(
            "Saved vocabularies to {} for input and {} for target.".format(
                input_vocab_path, target_vocab_path))

    logger.info("Loading Test set...")
    test_set = GroundedScanDataset(
        data_path,
        data_directory,
        split="test",  # TODO: use dev set here
        input_vocabulary_file=input_vocab_path,
        target_vocabulary_file=target_vocab_path,
        generate_vocabulary=False)
    test_set.read_dataset(
        max_examples=None,
        simple_situation_representation=simple_situation_representation)

    # Shuffle the test set to make sure that if we only evaluate max_testing_examples we get a random part of the set.
    test_set.shuffle_data()
    logger.info("Done Loading Test set.")

    model = Model(input_vocabulary_size=training_set.input_vocabulary_size,
                  target_vocabulary_size=training_set.target_vocabulary_size,
                  num_cnn_channels=training_set.image_channels,
                  input_padding_idx=training_set.input_vocabulary.pad_idx,
                  target_pad_idx=training_set.target_vocabulary.pad_idx,
                  target_eos_idx=training_set.target_vocabulary.eos_idx,
                  **cfg)
    model = model.cuda() if use_cuda else model
    log_parameters(model)
    trainable_parameters = [
        parameter for parameter in model.parameters()
        if parameter.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=learning_rate,
                                 betas=(adam_beta_1, adam_beta_2))
    scheduler = LambdaLR(optimizer,
                         lr_lambda=lambda t: lr_decay**(t / lr_decay_steps))

    # Load model and vocabularies if resuming.
    start_iteration = 1
    best_iteration = 1
    best_accuracy = 0
    best_exact_match = 0
    best_loss = float('inf')
    if resume_from_file:
        assert os.path.isfile(
            resume_from_file), "No checkpoint found at {}".format(
                resume_from_file)
        logger.info(
            "Loading checkpoint from file at '{}'".format(resume_from_file))
        optimizer_state_dict = model.load_model(resume_from_file)
        optimizer.load_state_dict(optimizer_state_dict)
        start_iteration = model.trained_iterations
        logger.info("Loaded checkpoint '{}' (iter {})".format(
            resume_from_file, start_iteration))

    logger.info("Training starts..")
    training_iteration = start_iteration
    while training_iteration < max_training_iterations:

        # Shuffle the dataset and loop over it.
        training_set.shuffle_data()
        for (input_batch, input_lengths, _, situation_batch, _, target_batch,
             target_lengths, agent_positions,
             target_positions) in training_set.get_data_iterator(
                 batch_size=training_batch_size):
            is_best = False
            model.train()

            # Forward pass.
            target_scores, target_position_scores = model(
                commands_input=input_batch,
                commands_lengths=input_lengths,
                situations_input=situation_batch,
                target_batch=target_batch,
                target_lengths=target_lengths)
            loss = model.get_loss(target_scores, target_batch)
            if auxiliary_task:
                target_loss = model.get_auxiliary_loss(target_position_scores,
                                                       target_positions)
            else:
                target_loss = 0
            loss += weight_target_loss * target_loss

            # Backward pass and update model parameters.
            loss.backward()
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.update_state(is_best=is_best)

            # Print current metrics.
            if training_iteration % print_every == 0:
                accuracy, exact_match = model.get_metrics(
                    target_scores, target_batch)
                if auxiliary_task:
                    auxiliary_accuracy_target = model.get_auxiliary_accuracy(
                        target_position_scores, target_positions)
                else:
                    auxiliary_accuracy_target = 0.
                learning_rate = scheduler.get_lr()[0]
                logger.info(
                    "Iteration %08d, loss %8.4f, accuracy %5.2f, exact match %5.2f, learning_rate %.5f,"
                    " aux. accuracy target pos %5.2f" %
                    (training_iteration, loss, accuracy, exact_match,
                     learning_rate, auxiliary_accuracy_target))

            # Evaluate on test set.
            if training_iteration % evaluate_every == 0:
                with torch.no_grad():
                    model.eval()
                    logger.info("Evaluating..")
                    accuracy, exact_match, target_accuracy = evaluate(
                        test_set.get_data_iterator(batch_size=1),
                        model=model,
                        max_decoding_steps=max_decoding_steps,
                        pad_idx=test_set.target_vocabulary.pad_idx,
                        sos_idx=test_set.target_vocabulary.sos_idx,
                        eos_idx=test_set.target_vocabulary.eos_idx,
                        max_examples_to_evaluate=kwargs["max_testing_examples"]
                    )
                    logger.info(
                        "  Evaluation Accuracy: %5.2f Exact Match: %5.2f "
                        " Target Accuracy: %5.2f" %
                        (accuracy, exact_match, target_accuracy))
                    if exact_match > best_exact_match:
                        is_best = True
                        best_accuracy = accuracy
                        best_exact_match = exact_match
                        model.update_state(accuracy=accuracy,
                                           exact_match=exact_match,
                                           is_best=is_best)
                    file_name = "checkpoint.pth.tar".format(
                        str(training_iteration))
                    if is_best:
                        model.save_checkpoint(
                            file_name=file_name,
                            is_best=is_best,
                            optimizer_state_dict=optimizer.state_dict())

            training_iteration += 1
            if training_iteration > max_training_iterations:
                break
    logger.info("Finished training.")