n_support_examples=config['num_shots']['ner'],
        n_query_examples=config['num_test_samples']['ner'],
        task='ner',
        meta_train=False)
    train_episodes.extend(ner_train_episodes)
    val_episodes.extend(ner_val_episodes)
    test_episodes.extend(ner_test_episodes)
    logger.info('Finished generating episodes for NER')

    # Initialize meta learner
    if config['meta_learner'] == 'maml':
        meta_learner = MAML(config)
    elif config['meta_learner'] == 'proto_net':
        meta_learner = PrototypicalNetwork(config)
    elif config['meta_learner'] == 'baseline':
        meta_learner = Baseline(config)
    elif config['meta_learner'] == 'majority':
        meta_learner = MajorityClassifier()
    elif config['meta_learner'] == 'nearest_neighbor':
        meta_learner = NearestNeighborClassifier(config)
    else:
        raise NotImplementedError

    # Meta-training
    meta_learner.training(train_episodes, val_episodes)
    logger.info('Meta-learning completed')

    # Meta-testing
    meta_learner.testing(test_episodes)
    logger.info('Meta-testing completed')
示例#2
0
    # Initialize meta learner
    if config['meta_learner'] == 'maml':
        meta_learner = MAML(config)
    elif config['meta_learner'] == 'proto_net':
        meta_learner = PrototypicalNetwork(config)
    elif config['meta_learner'] == 'baseline':
        meta_learner = Baseline(config)
    elif config['meta_learner'] == 'majority':
        meta_learner = MajorityClassifier()
    elif config['meta_learner'] == 'nearest_neighbor':
        meta_learner = NearestNeighborClassifier(config)
    else:
        raise NotImplementedError

    # Meta-training
    meta_learner.training(train_episodes, val_episodes)
    logger.info('Meta-learning completed')

    # Meta-testing
    for _ in trange(5):
        test_episodes, label_map = utils.generate_ner_episodes(
            dir=ner_test_path,
            labels_file=labels_test,
            n_episodes=config['num_test_episodes']['ner'],
            n_support_examples=config['num_shots']['ner'],
            n_query_examples=config['num_test_samples']['ner'],
            task='ner',
            meta_train=False)
        meta_learner.testing(test_episodes, label_map)
    logger.info('Meta-testing completed')
示例#3
0
        if config['meta_learner'] == 'maml':
            meta_learner = MAML(config)
        elif config['meta_learner'] == 'proto_net':
            meta_learner = PrototypicalNetwork(config)
        elif config['meta_learner'] == 'baseline':
            meta_learner = Baseline(config)
        elif config['meta_learner'] == 'majority':
            meta_learner = MajorityClassifier()
        elif config['meta_learner'] == 'nearest_neighbor':
            meta_learner = NearestNeighborClassifier(config)
        else:
            raise NotImplementedError

        logger.info('Run {}'.format(i + 1))
        val_f1 = meta_learner.training(wsd_train_episodes, wsd_val_episodes)
        test_f1 = meta_learner.testing(wsd_test_episodes)
        run_dict['val_' + str(i+1) + '_f1'] = val_f1
        run_dict['test_' + str(i+1) + '_f1'] = test_f1
        val_f1s.append(val_f1)
        test_f1s.append(test_f1)
    avg_val_f1 = np.mean(val_f1s)
    avg_test_f1 = np.mean(test_f1s)
    std_test_f1 = np.std(test_f1s)
    run_dict['avg_val_f1'] = avg_val_f1
    run_dict['avg_test_f1'] = avg_test_f1
    run_dict['std_test_f1'] = std_test_f1
    logger.info('Got average validation F1: {}'.format(avg_val_f1))
    logger.info('Got average test F1: {}'.format(avg_test_f1))

    results_columns = ['model_name', 'output_lr', 'learner_lr', 'meta_lr', 'hidden_size', 'num_updates', 'dropout_ratio', 'meta_weight_decay'] \
                      + ['val_' + str(k) + '_f1' for k in range(1, args.n_runs + 1)] \