def train(self):
        print('\n:: training started\n')
        epochs = self.config['epochs']
        best_dev_accuracy = 0.0
        epochs_without_improvement = 0
        for j in range(epochs):
            losses = []
            batch_gen = batch_generator(self.data_train, self.batch_size)
            for batch in batch_gen:  # batch_x, batch_context_features, batch_action_masks, batch_y in batch_gen:
                batch_copy = [np.copy(elem) for elem in batch]
                dropped_out_batch = self.drop_out_batch(batch_copy)
                batch_loss_dict, lr = self.net.train_step(*dropped_out_batch)

            # evaluate every epoch
            train_accuracy, train_loss_dict = evaluate(self.net, self.data_train)
            train_loss_report = ' '.join(['{}: {:.3f}'.format(key, value) for key, value in train_loss_dict.items()])
            dev_accuracy, dev_loss_dict = evaluate(self.net, self.data_dev, runs_number=3)
            dev_loss_report = ' '.join(['{}: {:.3f}'.format(key, value) for key, value in dev_loss_dict.items()])
            print(':: {}@lr={:.5f} || trn accuracy {:.3f} {} || dev accuracy {:.3f} {}'.format(j + 1, lr, train_accuracy, train_loss_report, dev_accuracy, dev_loss_report))

            eval_stats_noisy = evaluate_advanced(self.net,
                                                 self.data_test,
                                                 self.action_templates,
                                                 BABI_CONFIG['backoff_utterance'].lower(),
                                                 post_ood_turns=self.post_ood_turns_noisy,
                                                 runs_number=1)
            print('\n\n')
            print('Noisy dataset: {} turns overall, {} turns after the first OOD'.format(eval_stats_noisy['total_turns'],
                                                                                         eval_stats_noisy['total_turns_after_ood']))
            print('Accuracy:')
            accuracy = eval_stats_noisy['correct_turns'] / eval_stats_noisy['total_turns']
            accuracy_after_ood = eval_stats_noisy['correct_turns_after_ood'] / eval_stats_noisy['total_turns_after_ood'] \
                if eval_stats_noisy['total_turns_after_ood'] != 0 \
                else 0
            accuracy_post_ood = eval_stats_noisy['correct_post_ood_turns'] / eval_stats_noisy['total_post_ood_turns'] \
                if eval_stats_noisy['total_post_ood_turns'] != 0 \
                else 0
            accuracy_ood = eval_stats_noisy['correct_ood_turns'] / eval_stats_noisy['total_ood_turns'] \
                if eval_stats_noisy['total_ood_turns'] != 0 \
                else 0
            print('overall: {:.3f}; after first OOD: {:.3f}, directly post-OOD: {:.3f}; OOD: {:.3f}'.format(accuracy,
                                                                                                            accuracy_after_ood,
                                                                                                            accuracy_post_ood,
                                                                                                            accuracy_ood))

            if best_dev_accuracy < dev_accuracy:
                print('New best dev accuracy. Saving checkpoint')
                self.net.save(self.model_folder)
                best_dev_accuracy = dev_accuracy
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1
            if self.config['early_stopping_threshold'] < epochs_without_improvement:
                print('Finished after {} epochs due to early stopping'.format(j))
                break
Ejemplo n.º 2
0
def main(in_clean_dataset_folder, in_noisy_dataset_folder, in_model_folder, in_mode, in_runs_number):
    rev_vocab, kb, action_templates, config = load_model(in_model_folder)
    clean_dialogs, clean_indices = read_dialogs(os.path.join(in_clean_dataset_folder, 'dialog-babi-task6-dstc2-tst.txt'),
                                                with_indices=True)
    noisy_dialogs, noisy_indices = read_dialogs(os.path.join(in_noisy_dataset_folder, 'dialog-babi-task6-dstc2-tst.txt'),
                                                with_indices=True)

    max_noisy_dialog_length = max([item['end'] - item['start'] + 1 for item in noisy_indices])
    config['max_input_length'] = max_noisy_dialog_length
    post_ood_turns_clean, post_ood_turns_noisy = mark_post_ood_turns(noisy_dialogs, BABI_CONFIG['backoff_utterance'].lower())

    et = EntityTracker(kb)
    at = ActionTracker(None, et)
    at.set_action_templates(action_templates)

    vocab = {word: idx for idx, word in enumerate(rev_vocab)}
    data_clean = make_dataset_for_vhcn_v2(clean_dialogs, clean_indices, vocab, et, at, **config)
    data_noisy = make_dataset_for_vhcn_v2(noisy_dialogs, noisy_indices, vocab, et, at, **config)

    context_features_clean, action_masks_clean = data_clean[2:4]
    net = VariationalHierarchicalLSTMv3(rev_vocab, config, context_features_clean.shape[-1], action_masks_clean.shape[-1])
    net.restore(in_model_folder)

    if in_mode == 'clean':
        eval_stats_clean = evaluate_advanced(net,
                                             data_clean,
                                             at.action_templates,
                                             BABI_CONFIG['backoff_utterance'].lower(),
                                             post_ood_turns=post_ood_turns_clean,
                                             runs_number=in_runs_number)
        print('Clean dataset: {} turns overall'.format(eval_stats_clean['total_turns']))
        print('Accuracy:')
        accuracy = eval_stats_clean['correct_turns'] / eval_stats_clean['total_turns']
        accuracy_continuous = eval_stats_clean['correct_continuous_turns'] / eval_stats_clean['total_turns']
        accuracy_post_ood = eval_stats_clean['correct_post_ood_turns'] / eval_stats_clean['total_post_ood_turns'] \
            if eval_stats_clean['total_post_ood_turns'] != 0 \
            else 0
        print('overall: {:.3f}; continuous: {:.3f}; directly post-OOD: {:.3f}'.format(accuracy, accuracy_continuous, accuracy_post_ood))
        print('Loss : {:.3f}'.format(eval_stats_clean['avg_loss']))
    elif in_mode == 'noisy':
        eval_stats_noisy = evaluate_advanced(net,
                                             data_noisy, 
                                             at.action_templates,
                                             BABI_CONFIG['backoff_utterance'].lower(),
                                             post_ood_turns=post_ood_turns_noisy,
                                             runs_number=in_runs_number)
        print('\n\n')
        print('Noisy dataset: {} turns overall'.format(eval_stats_noisy['total_turns']))
        print('Accuracy:')
        accuracy = eval_stats_noisy['correct_turns'] / eval_stats_noisy['total_turns']
        accuracy_continuous = eval_stats_noisy['correct_continuous_turns'] / eval_stats_noisy['total_turns']
        accuracy_post_ood = eval_stats_noisy['correct_post_ood_turns'] / eval_stats_noisy['total_post_ood_turns'] \
            if eval_stats_noisy['total_post_ood_turns'] != 0 \
            else 0
        accuracy_ood = eval_stats_noisy['correct_ood_turns'] / eval_stats_noisy['total_ood_turns'] \
            if eval_stats_noisy['total_ood_turns'] != 0 \
            else 0
        print('overall: {:.3f}; continuous: {:.3f}; directly post-OOD: {:.3f}; OOD: {:.3f}'.format(accuracy,
                                                                                                   accuracy_continuous,
                                                                                                   accuracy_post_ood,
                                                                                                   accuracy_ood))
        print('Loss : {:.3f}'.format(eval_stats_noisy['avg_loss']))
    elif in_mode == 'noisy_ignore_ood':
        eval_stats_no_ood = evaluate_advanced(net,
                                              data_noisy,
                                              at.action_templates,
                                              BABI_CONFIG['backoff_utterance'].lower(),
                                              post_ood_turns=post_ood_turns_noisy,
                                              ignore_ood_accuracy=True,
                                              runs_number=in_runs_number)
        print('Accuracy (OOD turns ignored):')
        accuracy = eval_stats_no_ood['correct_turns'] / eval_stats_no_ood['total_turns']
        accuracy_after_ood = eval_stats_no_ood['correct_turns_after_ood'] / eval_stats_no_ood['total_turns_after_ood'] \
            if eval_stats_no_ood['total_turns_after_ood'] != 0 \
            else 0
        accuracy_post_ood = eval_stats_no_ood['correct_post_ood_turns'] / eval_stats_no_ood['total_post_ood_turns'] \
            if eval_stats_no_ood['total_post_ood_turns'] != 0 \
            else 0
        print('overall: {:.3f}; after first OOD: {:.3f}, directly post-OOD: {:.3f}'.format(accuracy, accuracy_after_ood, accuracy_post_ood))
    def train(self):
        print('\n:: training started\n')
        epochs = self.config['epochs']
        best_dev_accuracy = 0.0
        epochs_without_improvement = 0
        random_input_prob = self.config.get('random_input_prob', 0.0)
        unk_action_id = self.action_templates.index(UNK)
        for j in range(epochs):
            losses = []
            batch_gen = batch_generator([self.X_train, self.context_features_train, self.action_masks_train, self.prev_action_train, self.y_train], self.batch_size)
            for batch in batch_gen:  # batch_x, batch_context_features, batch_action_masks, batch_y in batch_gen:
                batch_copy = [np.copy(elem) for elem in batch]
                X, context_features, action_masks, prev_action, y = batch_copy
                num_turns = np.sum(np.vectorize(lambda x: x!= 0)(y))
                for idx in range(num_turns):
                    if np.random.random() < random_input_prob:
                        random_input_idx = np.random.choice(range(self.random_input[0].shape[0]))
                        random_input = [random_input_i[random_input_idx] for random_input_i in self.random_input]
                        X[0][idx] = random_input[0]
                        y[0][idx] = unk_action_id
                        if idx + 1 < num_turns:
                            prev_action[0][idx + 1] = unk_action_id
                batch_loss_dict, lr = self.net.train_step(X, context_features, action_masks, prev_action, y)

            # evaluate every epoch
            train_accuracy, train_loss_dict = evaluate(self.net, (self.X_train, self.context_features_train, self.action_masks_train, self.prev_action_train, self.y_train))
            train_loss_report = ' '.join(['{}: {:.3f}'.format(key, value) for key, value in train_loss_dict.items()])
            dev_accuracy, dev_loss_dict = evaluate(self.net, (self.X_dev, self.context_features_dev, self.action_masks_dev, self.prev_action_dev, self.y_dev))
            dev_loss_report = ' '.join(['{}: {:.3f}'.format(key, value) for key, value in dev_loss_dict.items()])
            print(':: {}@lr={:.5f} || trn accuracy {:.3f} {} || dev accuracy {:.3f} {}'.format(j + 1, lr, train_accuracy, train_loss_report, dev_accuracy, dev_loss_report))

            eval_stats_noisy = evaluate_advanced(self.net,
                                                 (self.X_test, self.context_features_test, self.action_masks_test, self.prev_action_test, self.y_test),
                                                 self.action_templates,
                                                 BABI_CONFIG['backoff_utterance'].lower(),
                                                 post_ood_turns=self.post_ood_turns_noisy,
                                                 runs_number=1)
            print('\n\n')
            print('Noisy dataset: {} turns overall, {} turns after the first OOD'.format(eval_stats_noisy['total_turns'],
                                                                                         eval_stats_noisy['total_turns_after_ood']))
            print('Accuracy:')
            accuracy = eval_stats_noisy['correct_turns'] / eval_stats_noisy['total_turns']
            accuracy_after_ood = eval_stats_noisy['correct_turns_after_ood'] / eval_stats_noisy['total_turns_after_ood'] \
                if eval_stats_noisy['total_turns_after_ood'] != 0 \
                else 0
            accuracy_post_ood = eval_stats_noisy['correct_post_ood_turns'] / eval_stats_noisy['total_post_ood_turns'] \
                if eval_stats_noisy['total_post_ood_turns'] != 0 \
                else 0
            accuracy_ood = eval_stats_noisy['correct_ood_turns'] / eval_stats_noisy['total_ood_turns'] \
                if eval_stats_noisy['total_ood_turns'] != 0 \
                else 0
            print('overall: {:.3f}; after first OOD: {:.3f}, directly post-OOD: {:.3f}; OOD: {:.3f}'.format(accuracy,
                                                                                                            accuracy_after_ood,
                                                                                                            accuracy_post_ood,
                                                                                                            accuracy_ood))

            if best_dev_accuracy < dev_accuracy:
                print('New best dev loss. Saving checkpoint')
                self.net.save(self.model_folder)
                best_dev_accuracy = dev_accuracy
                epochs_without_improvement = 0
            else:
                epochs_without_improvement += 1
            if self.config['early_stopping_threshold'] < epochs_without_improvement:
                print('Finished after {} epochs due to early stopping'.format(j))
                break
Ejemplo n.º 4
0
def main(in_clean_dataset_folder, in_noisy_dataset_folder, in_model_folder,
         in_mode, in_runs_number):
    rev_vocab, kb, action_templates, config = load_model(in_model_folder)
    train_json = load_hcn_json(
        os.path.join(in_clean_dataset_folder, 'train.json'))
    test_json = load_hcn_json(
        os.path.join(in_clean_dataset_folder, 'test.json'))
    test_ood_json = load_hcn_json(
        os.path.join(in_noisy_dataset_folder, 'test_ood.json'))

    max_noisy_dialog_length = max(
        [len(dialog['turns']) for dialog in test_ood_json['dialogs']])
    config['max_input_length'] = max_noisy_dialog_length
    post_ood_turns_clean, post_ood_turns_noisy = mark_post_ood_turns(
        test_ood_json)

    et = EntityTracker(kb)

    action_weights = defaultdict(lambda: 1.0)
    action_weights[0] = 0.0
    action_weighting = np.vectorize(action_weights.__getitem__)
    vocab = {word: idx for idx, word in enumerate(rev_vocab)}

    ctx_features = []
    for dialog in train_json['dialogs']:
        for utterance in dialog['turns']:
            if 'context_features' in utterance:
                ctx_features.append(utterance['context_features'])
    ctx_features_vocab, ctx_features_rev_vocab = make_vocabulary(
        ctx_features, config['max_vocabulary_size'], special_tokens=[])

    data_preparation_function = getattr(utils.preprocessing,
                                        config['data_preparation_function'])
    data_clean = data_preparation_function(test_json, vocab,
                                           ctx_features_vocab, et, **config)
    data_noisy = data_preparation_function(test_ood_json, vocab,
                                           ctx_features_vocab, et, **config)

    data_clean_with_weights = *data_clean, action_weighting(data_clean[-1])
    data_noisy_with_weights = *data_noisy, action_weighting(data_noisy[-1])

    net = getattr(modules, config['model_name'])(vocab, config,
                                                 len(ctx_features_vocab),
                                                 len(action_templates))
    net.restore(in_model_folder)

    if in_mode == 'clean':
        eval_stats_clean = evaluate_advanced(
            net,
            data_clean_with_weights,
            action_templates,
            BABI_CONFIG['backoff_utterance'],
            post_ood_turns=post_ood_turns_clean,
            runs_number=in_runs_number)
        print('Clean dataset: {} turns overall'.format(
            eval_stats_clean['total_turns']))
        print('Accuracy:')
        accuracy = eval_stats_clean['correct_turns'] / eval_stats_clean[
            'total_turns']
        accuracy_continuous = eval_stats_clean[
            'correct_continuous_turns'] / eval_stats_clean['total_turns']
        accuracy_post_ood = eval_stats_clean['correct_post_ood_turns'] / eval_stats_clean['total_post_ood_turns'] \
            if eval_stats_clean['total_post_ood_turns'] != 0 \
            else 0
        ood_f1 = eval_stats_clean['ood_f1']
        print('overall acc: {:.3f}; continuous acc: {:.3f}; '
              'directly post-OOD acc: {:.3f}; OOD F1: {:.3f}'.format(
                  accuracy, accuracy_continuous, accuracy_post_ood, ood_f1))
        print('Loss : {:.3f}'.format(eval_stats_clean['avg_loss']))
    elif in_mode == 'noisy':
        eval_stats_noisy = evaluate_advanced(
            net,
            data_noisy_with_weights,
            action_templates,
            BABI_CONFIG['backoff_utterance'],
            post_ood_turns=post_ood_turns_noisy,
            runs_number=in_runs_number)
        print('\n\n')
        print('Noisy dataset: {} turns overall'.format(
            eval_stats_noisy['total_turns']))
        print('Accuracy:')
        accuracy = eval_stats_noisy['correct_turns'] / eval_stats_noisy[
            'total_turns']
        accuracy_continuous = eval_stats_noisy[
            'correct_continuous_turns'] / eval_stats_noisy['total_turns']
        accuracy_post_ood = eval_stats_noisy['correct_post_ood_turns'] / eval_stats_noisy['total_post_ood_turns'] \
            if eval_stats_noisy['total_post_ood_turns'] != 0 \
            else 0
        accuracy_ood = eval_stats_noisy['correct_ood_turns'] / eval_stats_noisy['total_ood_turns'] \
            if eval_stats_noisy['total_ood_turns'] != 0 \
            else 0
        ood_f1 = eval_stats_noisy['ood_f1']
        print('overall acc: {:.3f}; continuous acc: {:.3f}; '
              'directly post-OOD acc: {:.3f}; OOD acc: {:.3f}; OOD F1: {:.3f}'.
              format(accuracy, accuracy_continuous, accuracy_post_ood,
                     accuracy_ood, ood_f1))
        print('Loss : {:.3f}'.format(eval_stats_noisy['avg_loss']))
    elif in_mode == 'noisy_ignore_ood':
        eval_stats_no_ood = evaluate_advanced(
            net,
            data_noisy_with_weights,
            action_templates,
            BABI_CONFIG['backoff_utterance'],
            post_ood_turns=post_ood_turns_noisy,
            ignore_ood_accuracy=True,
            runs_number=in_runs_number)
        print('Accuracy (OOD turns ignored):')
        accuracy = eval_stats_no_ood['correct_turns'] / eval_stats_no_ood[
            'total_turns']
        accuracy_after_ood = eval_stats_no_ood['correct_turns_after_ood'] / eval_stats_no_ood['total_turns_after_ood'] \
            if eval_stats_no_ood['total_turns_after_ood'] != 0 \
            else 0
        accuracy_post_ood = eval_stats_no_ood['correct_post_ood_turns'] / eval_stats_no_ood['total_post_ood_turns'] \
            if eval_stats_no_ood['total_post_ood_turns'] != 0 \
            else 0
        print('overall: {:.3f}; '
              'after first OOD: {:.3f}; '
              'directly post-OOD: {:.3f}'.format(accuracy, accuracy_after_ood,
                                                 accuracy_post_ood))