예제 #1
0
def main():
	if FLAGS.meta_train == False:
		orig_meta_batch_size = FLAGS.meta_batch_size
		# always use meta batch size of 1 when testing.
		FLAGS.meta_batch_size = 1

    # call data_generator and get data with FLAGS.k_shot*2 samples per class
	data_generator = DataGenerator(FLAGS.n_way, FLAGS.k_shot*2, FLAGS.n_way, FLAGS.k_shot*2, config={'data_folder': FLAGS.data_path})

    # set up MAML model
	dim_output = data_generator.dim_output
	dim_input = data_generator.dim_input
	meta_test_num_inner_updates = FLAGS.meta_test_num_inner_updates
	model = MAML(dim_input, dim_output,
		meta_test_num_inner_updates=meta_test_num_inner_updates,
		learn_inner_lr=FLAGS.learn_inner_update_lr)
	model.construct_model(prefix='maml')
	model.summ_op = tf.summary.merge_all()

	saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

	tf_config = tf.ConfigProto()
	tf_config.gpu_options.allow_growth=True
	sess = tf.InteractiveSession(config=tf_config)

	if FLAGS.meta_train == False:
		# change to original meta batch size when loading model.
		FLAGS.meta_batch_size = orig_meta_batch_size

	if FLAGS.meta_train_k_shot == -1:
		FLAGS.meta_train_k_shot = FLAGS.k_shot
	if FLAGS.meta_train_inner_update_lr == -1:
		FLAGS.meta_train_inner_update_lr = FLAGS.inner_update_lr

	exp_string = 'cls_'+str(FLAGS.n_way)+'.mbs_'+str(FLAGS.meta_batch_size) + '.k_shot_' + str(FLAGS.meta_train_k_shot) + '.inner_numstep' + str(FLAGS.num_inner_updates) + '.inner_updatelr' + str(FLAGS.meta_train_inner_update_lr)
	if FLAGS.learn_inner_update_lr:
		exp_string += ".learn_inner_lr"

	resume_itr = 0
	model_file = None

	tf.global_variables_initializer().run()

	if FLAGS.resume or not FLAGS.meta_train:
		model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
		if FLAGS.meta_test_iter > 0:
			model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.meta_test_iter)
		if model_file:
			ind1 = model_file.index('model')
			resume_itr = int(model_file[ind1+5:])
			print("Restoring model weights from " + model_file)
			saver.restore(sess, model_file)

	if FLAGS.meta_train:
		meta_train(model, saver, sess, exp_string, data_generator, resume_itr)
	else:
		FLAGS.meta_batch_size = 1
		meta_test(model, saver, sess, exp_string, data_generator, meta_test_num_inner_updates)
예제 #2
0
    ner_val_episodes, _ = utils.generate_ner_episodes(
        dir=ner_val_path,
        labels_file=labels_test,
        n_episodes=config['num_val_episodes']['ner'],
        n_support_examples=config['num_shots']['ner'],
        n_query_examples=config['num_test_samples']['ner'],
        task='ner',
        meta_train=False)

    train_episodes.extend(ner_train_episodes)
    val_episodes.extend(ner_val_episodes)
    logger.info('Finished generating episodes for NER')

    # Initialize meta learner
    if config['meta_learner'] == 'maml':
        meta_learner = MAML(config)
    elif config['meta_learner'] == 'proto_net':
        meta_learner = PrototypicalNetwork(config)
    elif config['meta_learner'] == 'baseline':
        meta_learner = Baseline(config)
    elif config['meta_learner'] == 'majority':
        meta_learner = MajorityClassifier()
    elif config['meta_learner'] == 'nearest_neighbor':
        meta_learner = NearestNeighborClassifier(config)
    else:
        raise NotImplementedError

    # Meta-training
    meta_learner.training(train_episodes, val_episodes)
    logger.info('Meta-learning completed')
예제 #3
0
        elif not config['fomaml'] and not config['proto_maml']:
            model_name = 'MAML'
    else:
        model_name = config['meta_learner']
    model_name += '_' + config['vectors'] + '_' + str(config['num_shots']['wsd'])

    run_dict = {'model_name': model_name, 'output_lr': config['output_lr'], 'learner_lr': config['learner_lr'],
                'meta_lr': config['meta_lr'], 'hidden_size': config['learner_params']['hidden_size'],
                'num_updates': config['num_updates'], 'dropout_ratio': config['learner_params']['dropout_ratio'],
                'meta_weight_decay': config['meta_weight_decay']}
    for i in range(args.n_runs):
        torch.manual_seed(42 + i)

        # Initialize meta learner
        if config['meta_learner'] == 'maml':
            meta_learner = MAML(config)
        elif config['meta_learner'] == 'proto_net':
            meta_learner = PrototypicalNetwork(config)
        elif config['meta_learner'] == 'baseline':
            meta_learner = Baseline(config)
        elif config['meta_learner'] == 'majority':
            meta_learner = MajorityClassifier()
        elif config['meta_learner'] == 'nearest_neighbor':
            meta_learner = NearestNeighborClassifier(config)
        else:
            raise NotImplementedError

        logger.info('Run {}'.format(i + 1))
        val_f1 = meta_learner.training(wsd_train_episodes, wsd_val_episodes)
        test_f1 = meta_learner.testing(wsd_test_episodes)
        run_dict['val_' + str(i+1) + '_f1'] = val_f1