Example #1
0
def main():
	argparser = argparse.ArgumentParser()
	argparser.add_argument('-n', help='n way', default=5)
	argparser.add_argument('-k', help='k shot', default=1)
	argparser.add_argument('-b', help='batch size', default=4)
	argparser.add_argument('-l', help='learning rate', default=1e-3)
	args = argparser.parse_args()
	n_way = int(args.n)
	k_shot = int(args.k)
	meta_batchsz = int(args.b)
	lr = float(args.l)

	k_query = 1
	imgsz = 84
	threhold = 0.699 if k_shot==5 else 0.584 # threshold for when to test full version of episode
	mdl_file = 'ckpt/maml%d%d.mdl'%(n_way, k_shot)
	print('mini-imagnet: %d-way %d-shot lr:%f, threshold:%f' % (n_way, k_shot, lr, threhold))



	device = torch.device('cuda')
	net = MAML(n_way, k_shot, k_query, meta_batchsz=meta_batchsz, K=5, device=device)
	print(net)

	if os.path.exists(mdl_file):
		print('load from checkpoint ...', mdl_file)
		net.load_state_dict(torch.load(mdl_file))
	else:
		print('training from scratch.')

	# whole parameters number
	model_parameters = filter(lambda p: p.requires_grad, net.parameters())
	params = sum([np.prod(p.size()) for p in model_parameters])
	print('Total params:', params)


	for epoch in range(1000):
		# batchsz here means total episode number
		mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query,
		                    batchsz=10000, resize=imgsz)
		# fetch meta_batchsz num of episode each time
		db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=8, pin_memory=True)

		for step, batch in enumerate(db):

			# 2. train
			support_x = batch[0].to(device)
			support_y = batch[1].to(device)
			query_x = batch[2].to(device)
			query_y = batch[3].to(device)

			accs = net(support_x, support_y, query_x, query_y, training = True)

			if step % 10 == 0:
				print(accs)
def main():

    # Constructing training and test graphs
    model = MAML()
    model.construct_model_train()
    model.construct_model_test()

    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)
    sess = tf.InteractiveSession()

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    resume_itr = 0
    model_file = None

    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir)

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    print('Loading pretrained weights')
    model.load_initial_weights(sess)

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    filelist_root = '../data/MLDG/'
    domain_dict = {
        1: 'art_painting.txt',
        2: 'cartoon.txt',
        3: 'photo.txt',
        4: 'sketch.txt'
    }
    train_domain_list = [2, 3, 4]
    test_domain_list = [1]

    train_file_list = [
        os.path.join(filelist_root, domain_dict[i]) for i in train_domain_list
    ]
    test_file_list = [
        os.path.join(filelist_root, domain_dict[i]) for i in test_domain_list
    ]
    train(model, saver, sess, exp_string, train_file_list, test_file_list[0],
          resume_itr)
Example #3
0
def main():
    if FLAGS.train:
        test_num_updates = 5
    else:
        test_num_updates = 10

        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    data_generator = SinusoidDataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    input_tensors = None

    model = MAML(
        stop_grad=FLAGS.stop_grad,
        meta_lr=FLAGS.meta_lr,
        num_updates=FLAGS.num_updates,
        update_lr=FLAGS.update_lr,
        dim_input=dim_input,
        dim_output=dim_output,
        test_num_updates=test_num_updates,
        meta_batch_size=FLAGS.meta_batch_size,
        metatrain_iterations=FLAGS.metatrain_iterations,
        norm=FLAGS.norm,
    )
    model.build(input_tensors=input_tensors, prefix="metatrain")

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    trainer = Trainer(
        model,
        data_generator,
        Path(FLAGS.logdir),
        FLAGS.pretrain_iterations,
        FLAGS.metatrain_iterations,
        FLAGS.meta_batch_size,
        FLAGS.update_batch_size,
        FLAGS.num_updates,
        FLAGS.update_lr,
        stop_grad=FLAGS.stop_grad,
        baseline=FLAGS.baseline,
        is_training=True
    )

    trainer.train()
    trainer.test()
Example #4
0
 def __init__(self,
              module,
              task_map,
              finetune=1,
              fine_optim=None,
              optim=None,
              second_order=False,
              distributed=False,
              world_size=1,
              rank=-1):
     super(MetaTrainWrapper, self).__init__()
     self.module = module
     self.task_map = task_map
     self.finetune = finetune
     self.fine_optim = fine_optim
     self.optim = optim
     self.distributed = distributed
     self.init_distributed(world_size, rank)
     self.meta_module = MAML(self.module,
                             self.finetune,
                             self.fine_optim,
                             self.task_map,
                             second_order=second_order)
     self.train_history = None
     self.train_meter = None
     self.val_history = None
     self.val_meter = None
Example #5
0
def test_with_maml(dataset, learner, checkpoint, steps, loss_fn):
    print("[*] Testing...")
    model = MAML(learner, steps=steps, loss_function=loss_fn)
    model.to(device)
    if checkpoint:
        model.restore(checkpoint, resume_training=False)
    else:
        print("[!] You are running inference on a randomly initialized model!")
    model.eval(dataset, compute_accuracy=(type(dataset) is OmniglotDataset))
    print("[*] Done!")
Example #6
0
def main():
    if FLAGS.train:
        test_num_updates = FLAGS.num_updates
    else:
        test_num_updates = 5
    data_generator = DataGenerator()
    data_generator.generate_time_series_batch(train=FLAGS.train)
    model = MAML(data_generator.batch_size, test_num_updates)
    model.construct_model(input_tensors=None, prefix='metatrain_')
    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(
        tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES),
        max_to_keep=10)
    sess = tf.InteractiveSession()

    exp_string = FLAGS.train_csv_file + '.numstep' + str(test_num_updates) + '.updatelr' + str(FLAGS.meta_lr)


    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(
            FLAGS.logdir + '/' + exp_string)
        print(model_file)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index(
                'model')] + 'model' + str(FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, sess, exp_string, data_generator)
Example #7
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('-n', help='n way', default=5)
    argparser.add_argument('-k', help='k shot', default=1)
    argparser.add_argument('-b', help='batch size', default=32)
    argparser.add_argument('-l', help='meta learning rate', default=1e-3)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    meta_batchsz = int(args.b)
    meta_lr = float(args.l)
    train_lr = 0.4

    k_query = 15
    imgsz = 84
    mdl_file = 'ckpt/omniglot%d%d.mdl' % (n_way, k_shot)
    print('omniglot: %d-way %d-shot meta-lr:%f, train-lr:%f' %
          (n_way, k_shot, meta_lr, train_lr))

    device = torch.device('cuda:0')
    net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr,
               device)
    print(net)

    # batchsz here means total episode number
    db = OmniglotNShot('omniglot',
                       batchsz=meta_batchsz,
                       n_way=n_way,
                       k_shot=k_shot,
                       k_query=k_query,
                       imgsz=imgsz)

    for step in range(10000000):

        # train
        support_x, support_y, query_x, query_y = db.get_batch('train')
        support_x = torch.from_numpy(support_x).float().transpose(
            2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1).to(device)
        query_x = torch.from_numpy(query_x).float().transpose(2, 4).transpose(
            3, 4).repeat(1, 1, 3, 1, 1).to(device)
        support_y = torch.from_numpy(support_y).long().to(device)
        query_y = torch.from_numpy(query_y).long().to(device)

        accs = net(support_x, support_y, query_x, query_y, training=True)

        if step % 20 == 0:
            print(step, '\t', accs)

        if step % 1000 == 0:
            # test
            pass
Example #8
0
def train_ganabi():
    config_file = './config/ganabi.config.gin'
    gin.parse_config_file(config_file)
    config_obj = TrainConfig()
    # config = config_obj.get_config()

    data_generator = DataGenerator(config_obj)
    maml = MAML(config_obj)
    maml.save_gin_config(config_file)
    maml.train_manager(data_generator)
Example #9
0
def train_with_maml(dataset,
                    learner,
                    save_path: str,
                    steps: int,
                    meta_batch_size: int,
                    iterations: int,
                    checkpoint=None,
                    loss_fn=None):
    print("[*] Training...")
    model = MAML(learner, steps=steps, loss_function=loss_fn)
    model.to(device)
    epoch = 0
    if checkpoint:
        model.restore(checkpoint)
        epoch = checkpoint['epoch']
    model.fit(dataset, iterations, save_path, epoch, 100)
    print("[*] Done!")
    return model
def main(args):
    np.random.seed(args.seed)
    dataset = get_dataset(args.dataset, args.K)
    model = MAML(dataset, args.model_type, args.loss_type, dataset.dim_input,
                 dataset.dim_output, args.alpha, args.beta, args.K,
                 args.batch_size, args.is_train, args.num_updates, args.norm)
    if args.is_train:
        model.learn(args.batch_size, dataset, args.max_steps)
    else:
        model.evaluate(dataset,
                       args.test_sample,
                       args.draw,
                       restore_checkpoint=args.restore_checkpoint,
                       restore_dir=args.restore_dir)
Example #11
0
def main():
    if os.path.exists(JOB_NAME):
        raise AssertionError("Job name already exists")
    else:
        os.mkdir(JOB_NAME)
        f = open(os.path.join(JOB_NAME, "train_params.txt"), 'w')
        f.write("META_LEARNER " + str(META_LEARNER) + '\n')
        f.write("FUNCTION " + str(FUNCTION_TRAIN) + '\n')
        f.write("K_TRAIN " + str(K_TRAIN) + '\n')
        f.write("SGD_STEPS_TRAIN " + str(SGD_STEPS_TRAIN) + '\n')
        f.write("NOISE_PERCENT_TRAIN " + str(NOISE_PERCENT_TRAIN) + '\n')
        f.write("ITERATIONS_TRAIN " + str(ITERATIONS_TRAIN) + '\n')
        f.write("OUTER_LR_TRAIN " + str(OUTER_LR_TRAIN) + '\n')
        f.write("INNER_LR_TRAIN " + str(INNER_LR_TRAIN) + '\n')
        f.write("AVERAGER_SIZE_TRAIN " + str(AVERAGER_SIZE_TRAIN) + '\n')
        f.close()

    model = Net()
    if META_LEARNER == "reptile":
        learning_alg = Reptile(lr_inner=INNER_LR_TRAIN,
                               lr_outer=OUTER_LR_TRAIN,
                               sgd_steps_inner=SGD_STEPS_TRAIN)
    elif META_LEARNER == "maml":
        learning_alg = MAML(lr_inner=INNER_LR_TRAIN,
                            lr_outer=OUTER_LR_TRAIN,
                            sgd_steps_inner=SGD_STEPS_TRAIN)
    else:
        learning_alg = Insect(lr_inner=INNER_LR_TRAIN,
                              lr_outer=OUTER_LR_TRAIN,
                              sgd_steps_inner=SGD_STEPS_TRAIN,
                              averager=AVERAGER_SIZE_TRAIN)
    meta_train_data = DataGenerator(function=FUNCTION_TRAIN,
                                    size=ITERATIONS_TRAIN,
                                    K=K_TRAIN,
                                    noise_percent=NOISE_PERCENT_TRAIN)
    learning_alg.train(model, meta_train_data)

    torch.save(model, os.path.join(JOB_NAME, "trained_model.pth"))
    test(model)
Example #12
0
def main():
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu
    TOTAL_NUM_AU = 8
    all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']

    if not FLAGS.train:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1
        temp_kshot = FLAGS.update_batch_size
        FLAGS.update_batch_size = 1
    if FLAGS.model.startswith('m2'):
        temp_num_updates = FLAGS.num_updates
        FLAGS.num_updates = 1



    data_generator = DataGenerator()

    dim_output = data_generator.num_classes
    dim_input = data_generator.dim_input

    inputa, inputb, labela, labelb = data_generator.make_data_tensor()
    metatrain_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

    model = MAML(dim_input, dim_output)
    model.construct_model(input_tensors=metatrain_input_tensors, prefix='metatrain_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=20)

    sess = tf.InteractiveSession()


    if not FLAGS.train:
        # change to original meta batch size when loading model.
        FLAGS.update_batch_size = temp_kshot
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.model.startswith('m2'):
        FLAGS.num_updates = temp_num_updates

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    print('initial weights: ', sess.run('model/b1:0'))
    print("========================================================================================")

    ################## Test ##################
    def _load_weight_m(trained_model_dir):
        all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']
        if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]]
        w_arr = None
        b_arr = None
        for au in all_au:
            model_file = None
            print('model file dir: ', FLAGS.logdir + '/' + au + '/' + trained_model_dir)
            model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + au + '/' + trained_model_dir)
            print("model_file from ", au, ": ", model_file)
            if (model_file == None):
                print(
                    "############################################################################################")
                print("####################################################################### None for ", au)
                print(
                    "############################################################################################")
            else:
                if FLAGS.test_iter > 0:
                    files = os.listdir(model_file[:model_file.index('model')])
                    if 'model' + str(FLAGS.test_iter) + '.index' in files:
                        model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
                        print("model_file by test_iter > 0: ", model_file)
                    else:
                        print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files)
                print("Restoring model weights from " + model_file)

                saver.restore(sess, model_file)
                w = sess.run('model/w1:0')
                b = sess.run('model/b1:0')
                print("updated weights from ckpt: ", b)
                print('----------------------------------------------------------')
                if w_arr is None:
                    w_arr = w
                    b_arr = b
                else:
                    w_arr = np.hstack((w_arr, w))
                    b_arr = np.vstack((b_arr, b))

        return w_arr, b_arr

    def _load_weight_s(sbjt_start_idx):
        batch_size = 10
        # 모든 au 를 이용하여 한 모델을 만든경우 그 한 모델만 로드하면됨.
        if FLAGS.model.startswith('s1'):
            three_layers = feature_layer(batch_size, TOTAL_NUM_AU)
            three_layers.loadWeight(FLAGS.vae_model_to_test, FLAGS.au_idx, num_au_for_rm=TOTAL_NUM_AU)
        # 각 au별로 다른 모델인 경우 au별 weight을 쌓아줘야함
        else:
            three_layers = feature_layer(batch_size, 1)
            all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26']
            if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]]
            w_arr = None
            b_arr = None
            for au in all_au:
                if FLAGS.model.startswith('s3'):
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter100'
                elif FLAGS.model.startswith('s4'):
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_subject' + str(
                        sbjt_start_idx + 1) + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter10_maml_adad' + str(FLAGS.test_iter)
                else:
                    load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str(
                        FLAGS.update_batch_size) + '_iter200_kshot10_iter10_nobatch_adam_noinit'
                three_layers.loadWeight(load_model_path, au)
                print('=============== Model S loaded from ', load_model_path)
                w = three_layers.model_intensity.layers[-1].get_weights()[0]
                b = three_layers.model_intensity.layers[-1].get_weights()[1]
                print('----------------------------------------------------------')
                if w_arr is None:
                    w_arr = w
                    b_arr = b
                else:
                    w_arr = np.hstack((w_arr, w))
                    b_arr = np.vstack((b_arr, b))

        return w_arr, b_arr



    def _load_weight_m0(trained_model_dir):
        model_file = None
        print('--------- model file dir: ', FLAGS.logdir + trained_model_dir)
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + trained_model_dir)
        print(">>>> model_file from all_aus: ", model_file)
        if (model_file == None):
            print("####################################################################### None for all_aus")
        else:
            if FLAGS.test_iter > 0:
                files = os.listdir(model_file[:model_file.index('model')])
                if 'model' + str(FLAGS.test_iter) + '.index' in files:
                    model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
                    print(">>>> model_file2: ", model_file)
                else:
                    print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files)
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
            w = sess.run('model/w1:0')
            b = sess.run('model/b1:0')
            print("updated weights from ckpt: ", b)
            print('----------------------------------------------------------')
        return w, b

    print("<<<<<<<<<<<< CONCATENATE >>>>>>>>>>>>>>")
    save_path = "./logs/result/"
    y_hat = []
    y_lab = []
    if FLAGS.all_sub_model:  # 모델이 모든 subjects를 이용해 train된 경우
        print('---------------- all sub model ----------------')
        # weight load를 한번만 실행해도됨. subject별로 모델이 다르지 않기 때문
        if FLAGS.model.startswith('m'):
            trained_model_dir = '/cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
                FLAGS.meta_batch_size) + '.ubs_' + str(
                FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
                FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)
            if FLAGS.model.startswith('m0'):
                w_arr, b_arr = _load_weight_m0(trained_model_dir)
            else:
                w_arr, b_arr = _load_weight_m(trained_model_dir)  # au별로 모델이 다르게됨

        ### test per each subject and concatenate
        for i in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks):
            if FLAGS.model.startswith('s'):
                w_arr, b_arr = _load_weight_s(i)

            result = test_each_subject(w_arr, b_arr, i)
            y_hat.append(result[0])
            y_lab.append(result[1])
            print("y_hat shape:", result[0].shape)
            print("y_lab shape:", result[1].shape)
            print(">> y_hat_all shape:", np.vstack(y_hat).shape)
            print(">> y_lab_all shape:", np.vstack(y_lab).shape)
        print_summary(np.vstack(y_hat), np.vstack(y_lab), log_dir=save_path + "/" + "test.txt")
    else:  # 모델이 각 subject 별로 train된 경우: vae와 MAML의 train_test두 경우에만 존재 가능 + local weight test의 경우
        for subj_idx in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks):
            if FLAGS.model.startswith('s'):
                w_arr, b_arr = _load_weight_s(subj_idx)
            else:
                trained_model_dir = '/sbjt' + str(subj_idx) + '.ubs_' + str(
                    FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
                    FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr)
                w_arr, b_arr = _load_weight_m(trained_model_dir)
            result = test_each_subject(w_arr, b_arr, subj_idx)
            y_hat.append(result[0])
            y_lab.append(result[1])
            print("y_hat shape:", result[0].shape)
            print("y_lab shape:", result[1].shape)
            print(">> y_hat_all shape:", np.vstack(y_hat).shape)
            print(">> y_lab_all shape:", np.vstack(y_lab).shape)
        print_summary(np.vstack(y_hat), np.vstack(y_lab),
                      log_dir=save_path + "/test.txt")

    end_time = datetime.now()
    elapse = end_time - start_time
    print("=======================================================")
    print(">>>>>> elapse time: " + str(elapse))
    print("=======================================================")
Example #13
0
def train_omniglot():
    config = get_Omniglot_config()
    data_generator = DataGenerator(config)
    maml = MAML(config)
    maml.train_manager(data_generator)
Example #14
0
def main():
    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir, exist_ok=True)

    test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                   FLAGS.meta_batch_size)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    tf_data_load = True
    num_classes = data_generator.num_classes

    if FLAGS.train:  # only construct training model if needed
        random.seed(5)
        image_tensor, label_tensor = data_generator.make_data_tensor()
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }

    random.seed(6)
    image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
    inputa = tf.slice(image_tensor, [0, 0, 0],
                      [-1, num_classes * FLAGS.update_batch_size, -1])
    inputb = tf.slice(image_tensor,
                      [0, num_classes * FLAGS.update_batch_size, 0],
                      [-1, -1, -1])
    labela = tf.slice(label_tensor, [0, 0, 0],
                      [-1, num_classes * FLAGS.update_batch_size, -1])
    labelb = tf.slice(label_tensor,
                      [0, num_classes * FLAGS.update_batch_size, 0],
                      [-1, -1, -1])
    metaval_input_tensors = {
        'inputa': inputa,
        'inputb': inputb,
        'labela': labela,
        'labelb': labelb
    }

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'model_{}'.format(FLAGS.model_num)

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    FLAGS.train = True
    train(model, saver, sess, exp_string, data_generator, resume_itr)
    FLAGS.train = False
    test(model, saver, sess, exp_string, data_generator, test_num_updates)
Example #15
0
def main():
    if FLAGS.train:
        test_num_updates = 20
    elif FLAGS.from_scratch:
        test_num_updates = 200
    else:
        test_num_updates = 50

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    sess = tf.InteractiveSession()

    if not FLAGS.dataset == 'imagenet':
        data_generator = DataGenerator(FLAGS.inner_update_batch_size_train + FLAGS.outer_update_batch_size,
                                       FLAGS.inner_update_batch_size_val + FLAGS.outer_update_batch_size,
                                       FLAGS.meta_batch_size)
    else:
        data_generator = DataGeneratorImageNet(FLAGS.inner_update_batch_size_train + FLAGS.outer_update_batch_size,
                                               FLAGS.inner_update_batch_size_val + FLAGS.outer_update_batch_size,
                                               FLAGS.meta_batch_size)

    dim_output_train = data_generator.dim_output_train
    dim_output_val = data_generator.dim_output_val
    dim_input = data_generator.dim_input


    tf_data_load = True
    num_classes_train = data_generator.num_classes_train
    num_classes_val = data_generator.num_classes_val

    if FLAGS.train: # only construct training model if needed
        random.seed(5)
        image_tensor, label_tensor = data_generator.make_data_tensor()
        inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes_train*FLAGS.inner_update_batch_size_train, -1])
        inputb = tf.slice(image_tensor, [0,num_classes_train*FLAGS.inner_update_batch_size_train, 0], [-1,-1,-1])
        labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes_train*FLAGS.inner_update_batch_size_train, -1])
        labelb = tf.slice(label_tensor, [0,num_classes_train*FLAGS.inner_update_batch_size_train, 0], [-1,-1,-1])
        input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

    random.seed(6)
    image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
    inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes_val*FLAGS.inner_update_batch_size_val, -1])
    inputb = tf.slice(image_tensor, [0,num_classes_val*FLAGS.inner_update_batch_size_val, 0], [-1,-1,-1])
    labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes_val*FLAGS.inner_update_batch_size_val, -1])
    labelb = tf.slice(label_tensor, [0,num_classes_val*FLAGS.inner_update_batch_size_val, 0], [-1,-1,-1])
    metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

    model = MAML(dim_input, dim_output_train, dim_output_val, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    if FLAGS.debug:
        sess = tf_debug.LocalCLIDebugWrapperSession(sess)

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.log_inner_update_batch_size_val == -1:
        FLAGS.log_inner_update_batch_size_val = FLAGS.inner_update_batch_size_val
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = ''
    exp_string += '.nu_' + str(FLAGS.num_updates) + '.ilr_' + str(FLAGS.train_update_lr)
    if FLAGS.meta_lr != 0.001:
        exp_string += '.olr_' + str(FLAGS.meta_lr)
    if FLAGS.mt_mode != 'gtgt':
        if FLAGS.partition_algorithm == 'hyperplanes':
            exp_string += '.m_' + str(FLAGS.margin)
        if FLAGS.partition_algorithm == 'kmeans' or FLAGS.partition_algorithm == 'kmodes':
            exp_string += '.k_' + str(FLAGS.num_clusters)
            exp_string += '.p_' + str(FLAGS.num_partitions)
            if FLAGS.scaled_encodings and FLAGS.num_partitions != 1:
                exp_string += '.scaled'
        if FLAGS.mt_mode == 'encenc':
            exp_string += '.ned_' + str(FLAGS.num_encoding_dims)
        elif FLAGS.mt_mode == 'semi':
            exp_string += '.pgtgt_' + str(FLAGS.p_gtgt)
    exp_string += '.mt_' + FLAGS.mt_mode
    exp_string += '.mbs_' + str(FLAGS.meta_batch_size) + \
                  '.nct_' + str(FLAGS.num_classes_train) + \
                  '.iubst_' + str(FLAGS.inner_update_batch_size_train) + \
                    '.iubsv_' + str(FLAGS.log_inner_update_batch_size_val) + \
                    '.oubs' + str(FLAGS.outer_update_batch_size)
    exp_string = exp_string[1:]     # get rid of leading period

    if FLAGS.on_encodings:
        exp_string += '.onenc'
        exp_string += '.nhl_' + str(FLAGS.num_hidden_layers)
    if FLAGS.num_filters != 64:
        exp_string += '.hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += '.maxpool'
    if FLAGS.stop_grad:
        exp_string += '.stopgrad'
    if FLAGS.norm == 'batch_norm':
        exp_string += '.batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += '.layernorm'
    elif FLAGS.norm == 'None':
        exp_string += '.nonorm'
    else:
        print('Norm setting not recognized.')
    if FLAGS.resnet:
        exp_string += '.res{}parts{}'.format(FLAGS.num_res_blocks, FLAGS.num_parts_per_res_block)
    if FLAGS.miniimagenet_only:
        exp_string += '.mini'
    if FLAGS.suffix != '':
        exp_string += '.' + FLAGS.suffix

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()

    print(exp_string)

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(logdir + '/' + exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1+5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
        else:
            print("No checkpoint found")

    if FLAGS.from_scratch:
        exp_string = ''

    if FLAGS.from_scratch and not os.path.isdir(logdir):
        os.makedirs(logdir)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
Example #16
0
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5  # During base-testing (and thus meta updating) 5 updates are used
        else:
            test_num_updates = 10  # During meta-testing 10 updates are used
    else:
        if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10  # eval on 10 updates during testing
        else:
            test_num_updates = 10  # Omniglot gets 10 updates during training AND testing

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        # DataGenerator(num_samples_per_class, batch_size, config={})
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size)
    else:  # Dealing with a non 'sinusoid' dataset here
        if FLAGS.metatrain_iterations == 0 and (
                FLAGS.datasource == 'miniimagenet'
                or FLAGS.datasource == 'cifarfs'):
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(
                1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs':  # TODO - use 15 val examples for imagenet?
                if FLAGS.train:  # TODO: why +15 and *2 --> followin Ravi: "15 examples per class were used for evaluating the post-update meta-gradient" = MAML algo 2, line 10 --> see how 5 and 15 is split up in maml.py?
                    # DataGenerator(number_of_images_per_class, number_of_tasks_in_batch)
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size + 15, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
            else:  # this is for omniglot
                data_generator = DataGenerator(
                    FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                )  # only use one datapoint for testing to save memory

    dim_output = data_generator.dim_output  # number of classes, e.g. 5 for miniImagenet tasks
    if FLAGS.baseline == 'oracle':  # NOTE - this flag is specific to sinusoid
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input  # np.prod(self.img_size) for images

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'cifarfs':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            # meta train : num_total_batches = 200000 (number of tasks, not number of meta-iterations)
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1
                               ])  # slice(tensor, begin, slice_size)
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])  # The extra 15 add here?!
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }

        # meta val: num_total_batches = 600 (number of tasks, not number of meta-iterations)
        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1
                           ])  # slice the training examples here
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(
        dim_input, dim_output, test_num_updates=test_num_updates
    )  # test_num_updates = eval on at least one update for training, 10 testing
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')

    # Op to retrieve summaries?
    model.summ_op = tf.summary.merge_all()

    # keep last 10 copies of trainable variables
    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    # remove the need to explicitly pass this Session object to run ops
    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    # cls = no of classes
    # mbs = meta batch size
    # ubs = update batch size
    # numstep = number of INNER GRADIENT updates
    # updatelr = inner gradient step
    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    # Initialize all variables, and
    tf.global_variables_initializer().run()
    # starts threads for all queue runners collected in the graph
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
Example #17
0
def main():
    training = not args.test
    kshot = 1
    kquery = 15
    nway = 5
    meta_batchsz = 4
    K = 5

    # kshot + kquery images per category, nway categories, meta_batchsz tasks.
    db = DataGenerator(nway, kshot, kquery, meta_batchsz, 200000)

    if training:  # only construct training model if needed
        # get the tensor
        # image_tensor: [4, 80, 84*84*3]
        # label_tensor: [4, 80, 5]
        image_tensor, label_tensor = db.make_data_tensor(training=True)

        # NOTICE: the image order in 80 images should like this now:
        # [label2, label1, label3, label0, label4, and then repeat by 15 times, namely one task]
        # support_x : [4, 1*5, 84*84*3]
        # query_x   : [4, 15*5, 84*84*3]
        # support_y : [4, 5, 5]
        # query_y   : [4, 15*5, 5]
        support_x = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1],
                             name='support_x')
        query_x = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1],
                           name='query_x')
        support_y = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1],
                             name='support_y')
        query_y = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1],
                           name='query_y')

    # construct test tensors.
    image_tensor, label_tensor = db.make_data_tensor(training=False)
    support_x_test = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1],
                              name='support_x_test')
    query_x_test = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1],
                            name='query_x_test')
    support_y_test = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1],
                              name='support_y_test')
    query_y_test = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1],
                            name='query_y_test')

    # 1. construct MAML model
    model = MAML(84, 3, 5)

    # construct metatrain_ and metaval_
    if training:
        model.build(support_x,
                    support_y,
                    query_x,
                    query_y,
                    K,
                    meta_batchsz,
                    mode='train')
        model.build(support_x_test,
                    support_y_test,
                    query_x_test,
                    query_y_test,
                    K,
                    meta_batchsz,
                    mode='eval')
    else:
        model.build(support_x_test,
                    support_y_test,
                    query_x_test,
                    query_y_test,
                    K + 5,
                    meta_batchsz,
                    mode='test')
    model.summ_op = tf.summary.merge_all()

    all_vars = filter(lambda x: 'meta_optim' not in x.name,
                      tf.trainable_variables())
    for p in all_vars:
        print(p)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)
    # tf.global_variables() to save moving_mean and moving variance of batch norm
    # tf.trainable_variables()  NOT include moving_mean and moving_variance.
    saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)

    # initialize, under interative session
    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if os.path.exists(os.path.join('ckpt', 'checkpoint')):
        # alway load ckpt both train and test.
        model_file = tf.train.latest_checkpoint('ckpt')
        print("Restoring model weights from ", model_file)
        saver.restore(sess, model_file)

    if training:
        train(model, saver, sess)
    else:
        test(model, sess)
Example #18
0
testset = miniimagenet("data",
                       ways=5,
                       shots=5,
                       test_shots=15,
                       meta_test=True,
                       download=True)
testloader = BatchMetaDataLoader(testset,
                                 batch_size=2,
                                 num_workers=4,
                                 shuffle=True)

# training

epochs = 6000  # batch sizeが2だと7751が上限(dataloaderの制限)
model = MAML().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss().to(device)
model_path = "./model/"
result_path = "./log/train"

trainiter = iter(trainloader)
evaliter = iter(testloader)

train_loss_log = []
train_acc_log = []
test_loss_log = []
test_acc_log = []

for epoch in range(epochs):
    # train
Example #19
0
    def construct_model(self):
        self.sess = tf.InteractiveSession()
        if FLAGS.train == False:
            orig_meta_batch_size = FLAGS.meta_batch_size
            # always use meta batch size of 1 when testing.
            FLAGS.meta_batch_size = 1

        if FLAGS.datasource in ['sinusoid', 'mixture']:
            data_generator = DataGenerator(
                FLAGS.update_batch_size + FLAGS.update_batch_size_eval,
                FLAGS.meta_batch_size)
        else:
            if FLAGS.metatrain_iterations == 0 and FLAGS.datasource in [
                    'miniimagenet', 'multidataset'
            ]:
                assert FLAGS.meta_batch_size == 1
                assert FLAGS.update_batch_size == 1
                data_generator = DataGenerator(
                    1, FLAGS.meta_batch_size)  # only use one datapoint,
            else:
                if FLAGS.datasource in [
                        'miniimagenet', 'multidataset'
                ]:  # TODO - use 15 val examples for imagenet?
                    if FLAGS.train:
                        data_generator = DataGenerator(
                            FLAGS.update_batch_size + 15, FLAGS.meta_batch_size
                        )  # only use one datapoint for testing to save memory
                    else:
                        data_generator = DataGenerator(
                            FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                        )  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory

        dim_output = data_generator.dim_output

        dim_input = data_generator.dim_input

        if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset']:
            tf_data_load = True
            num_classes = data_generator.num_classes

            if FLAGS.train:  # only construct training model if needed
                random.seed(5)
                if FLAGS.datasource in ['miniimagenet', 'omniglot']:
                    image_tensor, label_tensor = data_generator.make_data_tensor(
                    )
                elif FLAGS.datasource == 'multidataset':
                    image_tensor, label_tensor = data_generator.make_data_tensor_multidataset(
                        sel_num=self.clusters, train=True)
                inputa = tf.slice(
                    image_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                inputb = tf.slice(
                    image_tensor,
                    [0, num_classes * FLAGS.update_batch_size, 0],
                    [-1, -1, -1])
                labela = tf.slice(
                    label_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                labelb = tf.slice(
                    label_tensor,
                    [0, num_classes * FLAGS.update_batch_size, 0],
                    [-1, -1, -1])
                input_tensors = {
                    'inputa': inputa,
                    'inputb': inputb,
                    'labela': labela,
                    'labelb': labelb
                }

            random.seed(6)
            if FLAGS.datasource in ['miniimagenet', 'omniglot']:
                image_tensor, label_tensor = data_generator.make_data_tensor(
                    train=False)
            elif FLAGS.datasource == 'multidataset':
                image_tensor, label_tensor = data_generator.make_data_tensor_multidataset(
                    sel_num=self.clusters, train=False)
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            metaval_input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }
        else:
            tf_data_load = False
            input_tensors = None

        model = MAML(self.sess,
                     dim_input,
                     dim_output,
                     test_num_updates=self.test_num_updates)

        model.cluster_layer_0 = self.clusters

        if FLAGS.train or not tf_data_load:
            model.construct_model(input_tensors=input_tensors,
                                  prefix='metatrain_')
        if tf_data_load:
            model.construct_model(input_tensors=metaval_input_tensors,
                                  prefix='metaval_')
        model.summ_op = tf.summary.merge_all()
        saver = loader = tf.train.Saver(tf.get_collection(
            tf.GraphKeys.TRAINABLE_VARIABLES),
                                        max_to_keep=10)

        if FLAGS.train == False:
            # change to original meta batch size when loading model.
            FLAGS.meta_batch_size = orig_meta_batch_size

        if FLAGS.train_update_batch_size == -1:
            FLAGS.train_update_batch_size = FLAGS.update_batch_size
        if FLAGS.train_update_lr == -1:
            FLAGS.train_update_lr = FLAGS.update_lr

        return model, saver, data_generator
Example #20
0
torch.backends.cudnn.benchmark = True

testset = miniimagenet("data",
                       ways=5,
                       shots=5,
                       test_shots=15,
                       meta_test=True,
                       download=True)
testloader = BatchMetaDataLoader(testset,
                                 batch_size=2,
                                 num_workers=4,
                                 shuffle=True)
evaliter = iter(testloader)

model_path = './model/model.pth'
model = MAML().to(device)
model.load_state_dict(torch.load(model_path))
loss_fn = torch.nn.CrossEntropyLoss().to(device)

test_loss_log = []
test_acc_log = []

for i in range(1000):
    evalbatch = evaliter.next()
    model.eval()
    testloss, testacc = test(model,
                             evalbatch,
                             loss_fn,
                             lr=0.01,
                             train_step=10,
                             device=device)
Example #21
0
def main():

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks)
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size,
                                       datasource='sinusoid')
    elif FLAGS.datasource == 'ball':
        # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks)
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size,
                                       datasource='ball')
    elif FLAGS.datasource == 'ball_file':
        # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks)
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size,
                                       datasource='ball_file')
    else:  # 'rect_file"
        # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks)
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size,
                                       datasource='rect_file',
                                       rect_truncated=rect_truncated)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    model.construct_model()

    model.summ_op = tf.summary.merge_all()

    saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES),
                           max_to_keep=10)
    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    exp_string = get_exp_string(model)
    resume_itr = 0

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        # ME: test_num_updates = 10; 10 gradient updates
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
Example #22
0
def run(d):
    #IPython.embed()
    config = d['config']

    ################################################################
    ###################### Load parameters #########################
    ################################################################
    previous_dynamics_model = config["previous_dynamics_model"]
    train_now = d['train_bool']

    ################################################################
    desired_shape_for_rollout = config["testing"]["desired_shape_for_rollout"]
    save_rollout_run_num = config["testing"]["save_rollout_run_num"]
    rollout_save_filename = desired_shape_for_rollout + str(
        save_rollout_run_num)

    num_steps_per_rollout = config["testing"]["num_steps_per_rollout"]
    if (desired_shape_for_rollout == "figure8"):
        num_steps_per_rollout = 400
    elif (desired_shape_for_rollout == "zigzag"):
        num_steps_per_rollout = 150
    ##############################################################

    #settings
    cheaty_training = False
    use_one_hot = False  #True
    use_camera = False  #True
    playback_mode = False

    state_representation = "exclude_x_y"  #["exclude_x_y", "all"]

    # Settings (generally, keep these to default)
    default_addrs = [b'\x00\x01']
    use_pid_mode = True
    slow_pid_mode = True
    visualize_rviz = True  #turning this off can make things go faster
    visualize_True = True
    visualize_False = False
    noise_True = True
    noise_False = False
    make_aggregated_dataset_noisy = True
    make_training_dataset_noisy = True
    perform_forwardsim_for_vis = True
    print_minimal = False
    noiseToSignal = 0
    if (make_training_dataset_noisy):
        noiseToSignal = 0.01

    # Defining datatypes
    tf_datatype = tf.float32
    np_datatype = np.float32

    # Setting motor limits
    left_min = 1200
    right_min = 1200
    left_max = 2000
    right_max = 2000
    if (use_pid_mode):
        if (slow_pid_mode):
            left_min = 2 * math.pow(2, 16) * 0.001
            right_min = 2 * math.pow(2, 16) * 0.001
            left_max = 9 * math.pow(2, 16) * 0.001
            right_max = 9 * math.pow(2, 16) * 0.001
        else:  #this hasnt been tested yet
            left_min = 4 * math.pow(2, 16) * 0.001
            right_min = 4 * math.pow(2, 16) * 0.001
            left_max = 12 * math.pow(2, 16) * 0.001
            right_max = 12 * math.pow(2, 16) * 0.001

    #vars from config

    curr_agg_iter = config['aggregation']['curr_agg_iter']
    save_dir = d['exp_name']
    print("\n\nSAVING EVERYTHING TO: ", save_dir)

    #make directories
    if not os.path.exists(save_dir + '/saved_rollouts'):
        os.makedirs(save_dir + '/saved_rollouts')
    if not os.path.exists(save_dir + '/saved_rollouts/' +
                          rollout_save_filename + '_aggIter' +
                          str(curr_agg_iter)):
        os.makedirs(save_dir + '/saved_rollouts/' + rollout_save_filename +
                    '_aggIter' + str(curr_agg_iter))

    ######################################
    ######## GET TRAINING DATA ###########
    ######################################

    print("\n\nCURR AGGREGATION ITER: ", curr_agg_iter)

    # Training data
    # Random
    dataX = []
    dataX_full = [
    ]  #this is just for your personal use for forwardsim (for debugging)
    dataY = []
    dataZ = []

    # Training data
    # MPC
    dataX_onPol = []
    dataX_full_onPol = []
    dataY_onPol = []
    dataZ_onPol = []

    # Validation data
    # Random
    dataX_val = []
    dataX_full_val = []
    dataY_val = []
    dataZ_val = []

    # Validation data
    # MPC
    dataX_val_onPol = []
    dataX_full_val_onPol = []
    dataY_val_onPol = []
    dataZ_val_onPol = []

    training_ratio = config['training']['training_ratio']
    for agg_itr_counter in range(curr_agg_iter + 1):

        #getDataFromDisk should give (tasks, rollouts from that task, each rollout has its points)
        dataX_curr, dataY_curr, dataZ_curr, dataX_curr_full = getDataFromDisk(
            config['experiment_type'],
            use_one_hot,
            use_camera,
            cheaty_training,
            state_representation,
            agg_itr_counter,
            config_training=config['training'])

        if (agg_itr_counter == 1):
            print("*********TRYING TO FIND THE WEIRD ROLLOUT...")
            for rollout in range(len(dataX_curr[2])):
                val = dataX_curr[2][rollout][:, 4]
                if (np.any(val < 0)):
                    dataX_curr[2][rollout] = dataX_curr[2][rollout + 1]
                    dataY_curr[2][rollout] = dataY_curr[2][rollout + 1]
                    dataZ_curr[2][rollout] = dataZ_curr[2][rollout + 1]
                    print("FOUND IT!!!!!!! rollout number ", rollout)

        #random data
        #go from dataX_curr (tasks, rollouts, steps) --> to dataX (tasks, some rollouts, steps) and dataX_val (tasks, some rollouts, steps)
        if (agg_itr_counter == 0):
            for task_num in range(len(dataX_curr)):
                taski_num_rollout = len(dataX_curr[task_num])
                print("task" + str(task_num) + "_num_rollouts: ",
                      taski_num_rollout)

                #for each task, append something like (356, 48, 22) (numrollouts per task, num steps in that rollout, dim)
                dataX.append(dataX_curr[task_num][:int(taski_num_rollout *
                                                       training_ratio)])
                dataX_full.append(
                    dataX_curr_full[task_num][:int(taski_num_rollout *
                                                   training_ratio)])
                dataY.append(dataY_curr[task_num][:int(taski_num_rollout *
                                                       training_ratio)])
                dataZ.append(dataZ_curr[task_num][:int(taski_num_rollout *
                                                       training_ratio)])

                dataX_val.append(dataX_curr[task_num][int(taski_num_rollout *
                                                          training_ratio):])
                dataX_full_val.append(
                    dataX_curr_full[task_num][int(taski_num_rollout *
                                                  training_ratio):])
                dataY_val.append(dataY_curr[task_num][int(taski_num_rollout *
                                                          training_ratio):])
                dataZ_val.append(dataZ_curr[task_num][int(taski_num_rollout *
                                                          training_ratio):])

        #on-policy data
        #go from dataX_curr (tasks, rollouts, steps) --> to dataX_onPol (tasks, some rollouts, steps) and dataX_val_onPol (tasks, some rollouts, steps)
        elif (agg_itr_counter == 1):

            for task_num in range(len(dataX_curr)):
                taski_num_rollout = len(dataX_curr[task_num])
                print("task" + str(task_num) + "_num_rollouts for onpolicy: ",
                      taski_num_rollout)

                dataX_onPol.append(
                    dataX_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])
                dataX_full_onPol.append(
                    dataX_curr_full[task_num][:int(taski_num_rollout *
                                                   training_ratio)])
                dataY_onPol.append(
                    dataY_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])
                dataZ_onPol.append(
                    dataZ_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])

                dataX_val_onPol.append(
                    dataX_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])
                dataX_full_val_onPol.append(
                    dataX_curr_full[task_num][int(taski_num_rollout *
                                                  training_ratio):])
                dataY_val_onPol.append(
                    dataY_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])
                dataZ_val_onPol.append(
                    dataZ_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])

        #on-policy data
        #go from dataX_curr (tasks, rollouts, steps) --> to ADDING ONTO dataX_onPol (tasks, some more rollouts than before, steps) and dataX_val_onPol (tasks, some more rollouts than before, steps)
        else:
            for task_num in range(len(dataX_curr)):

                taski_num_rollout = len(dataX_curr[task_num])
                print("task" + str(task_num) + "_num_rollouts for onpolicy: ",
                      taski_num_rollout)

                dataX_onPol[task_num].extend(
                    dataX_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])
                dataX_full_onPol[task_num].extend(
                    dataX_curr_full[task_num][:int(taski_num_rollout *
                                                   training_ratio)])
                dataY_onPol[task_num].extend(
                    dataY_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])
                dataZ_onPol[task_num].extend(
                    dataZ_curr[task_num][:int(taski_num_rollout *
                                              training_ratio)])

                dataX_val_onPol[task_num].extend(
                    dataX_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])
                dataX_full_val_onPol[task_num].extend(
                    dataX_curr_full[task_num][int(taski_num_rollout *
                                                  training_ratio):])
                dataY_val_onPol[task_num].extend(
                    dataY_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])
                dataZ_val_onPol[task_num].extend(
                    dataZ_curr[task_num][int(taski_num_rollout *
                                             training_ratio):])

    #############################################################

    #count number of random and onpol data points
    total_random_data = len(dataX) * len(dataX[1]) * len(
        dataX[1][0])  # numSteps = tasks * rollouts * steps
    if (len(dataX_onPol) == 0):
        total_onPol_data = 0
    else:
        total_onPol_data = len(dataX_onPol) * len(dataX_onPol[0]) * len(
            dataX_onPol[0][0]
        )  #this is approximate because each task doesn't have the same num rollouts or the same num steps
    total_num_data = total_random_data + total_onPol_data
    print()
    print()
    print("Number of random data points: ", total_random_data)
    print("Number of on-policy data points: ", total_onPol_data)
    print("TOTAL number of data points: ", total_num_data)

    #############################################################

    #combine random and onpol data into a single dataset for training
    ratio_new = config["aggregation"]["ratio_new"]
    num_new_pts = ratio_new * (total_random_data) / (1.0 - ratio_new)
    if (len(dataX_onPol) == 0):
        num_times_to_copy_onPol = 0
    else:
        num_times_to_copy_onPol = int(num_new_pts / total_onPol_data)

    #copy all rollouts from each task of onpol data, and do this copying this many times
    for i in range(num_times_to_copy_onPol):
        for task_num in range(len(dataX_onPol)):
            for rollout_num in range(len(dataX_onPol[task_num])):
                dataX[task_num].append(dataX_onPol[task_num][rollout_num])
                dataX_full[task_num].append(
                    dataX_full_onPol[task_num][rollout_num])
                dataY[task_num].append(dataY_onPol[task_num][rollout_num])
                dataZ[task_num].append(dataZ_onPol[task_num][rollout_num])
    #print("num_times_to_copy_onPol: ", num_times_to_copy_onPol)

    # make a list of all X,Y,Z so can take mean of them
    # concatenate state and action --> inputs (for training)
    all_points_inp = []
    all_points_outp = []
    outputs = copy.deepcopy(dataZ)
    inputs = copy.deepcopy(dataX)
    for task_num in range(len(dataX)):
        for rollout_num in range(len(dataX[task_num])):

            #this will just be a big list of everything, so can take the mean
            input_pts = np.concatenate(
                (dataX[task_num][rollout_num], dataY[task_num][rollout_num]),
                axis=1)
            output_pts = dataZ[task_num][rollout_num]

            #this will the concatenate thing for later
            inputs[task_num][rollout_num] = np.concatenate(
                [dataX[task_num][rollout_num], dataY[task_num][rollout_num]],
                axis=1)

            all_points_inp.append(input_pts)
            all_points_outp.append(output_pts)
    all_points_inp = np.concatenate(all_points_inp)
    all_points_outp = np.concatenate(all_points_outp)

    ## concatenate state and action --> inputs (for validation)
    outputs_val = copy.deepcopy(dataZ_val)
    inputs_val = copy.deepcopy(dataX_val)
    for task_num in range(len(dataX_val)):
        for rollout_num in range(len(dataX_val[task_num])):
            #dataX[task_num][rollout_num] (steps x s_dim)
            #dataY[task_num][rollout_num] (steps x a_dim)
            inputs_val[task_num][rollout_num] = np.concatenate([
                dataX_val[task_num][rollout_num],
                dataY_val[task_num][rollout_num]
            ],
                                                               axis=1)

    ## concatenate state and action --> inputs (for validation onpol)
    outputs_val_onPol = copy.deepcopy(dataZ_val_onPol)
    inputs_val_onPol = copy.deepcopy(dataX_val_onPol)
    for task_num in range(len(dataX_val_onPol)):
        for rollout_num in range(len(dataX_val_onPol[task_num])):
            #dataX[task_num][rollout_num] (steps x s_dim)
            #dataY[task_num][rollout_num] (steps x a_dim)
            inputs_val_onPol[task_num][rollout_num] = np.concatenate([
                dataX_val_onPol[task_num][rollout_num],
                dataY_val_onPol[task_num][rollout_num]
            ],
                                                                     axis=1)

    #############################################################

    #inputs should now be (tasks, rollouts from that task, [s,a])
    #outputs should now be (tasks, rollouts from that task, [ds])
    #IPython.embed()

    inputSize = inputs[0][0].shape[1]
    outputSize = outputs[1][0].shape[1]
    print("\n\nDimensions:")
    print("states: ", dataX[1][0].shape[1])
    print("actions: ", dataY[1][0].shape[1])
    print("inputs to NN: ", inputSize)
    print("outputs of NN: ", outputSize)

    mean_inp = np.expand_dims(np.mean(all_points_inp, axis=0), axis=0)
    std_inp = np.expand_dims(np.std(all_points_inp, axis=0), axis=0)
    mean_outp = np.expand_dims(np.mean(all_points_outp, axis=0), axis=0)
    std_outp = np.expand_dims(np.std(all_points_outp, axis=0), axis=0)
    print("\n\nCalulated means and stds... ", mean_inp.shape, std_inp.shape,
          mean_outp.shape, std_outp.shape, "\n\n")

    ###########################################################
    ## CREATE regressor, policy, data generator, maml model
    ###########################################################

    # create regressor (NN dynamics model)
    regressor = DeterministicMLPRegressor(
        inputSize, outputSize, outputSize, tf_datatype, config['seed'],
        config['training']['weight_initializer'], config['model'])

    # create policy (MPC controller)
    policy = Policy(regressor,
                    inputSize,
                    outputSize,
                    left_min,
                    right_min,
                    left_max,
                    right_max,
                    state_representation=state_representation,
                    visualize_rviz=visualize_rviz,
                    x_index=config['roach']['x_index'],
                    y_index=config['roach']['y_index'],
                    yaw_cos_index=config['roach']['yaw_cos_index'],
                    yaw_sin_index=config['roach']['yaw_sin_index'],
                    **config['policy'])

    # create MAML model
    # note: this also constructs the actual regressor network/weights
    model = MAML(regressor, inputSize, outputSize, config)
    model.construct_model(input_tensors=None, prefix='metatrain_')
    model.summ_op = tf.summary.merge_all()

    # GPU config proto
    gpu_device = 0
    gpu_frac = 0.4  #0.4 #0.8 #0.3
    os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device)
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_frac)
    config_2 = tf.ConfigProto(gpu_options=gpu_options,
                              log_device_placement=False,
                              allow_soft_placement=True,
                              inter_op_parallelism_threads=1,
                              intra_op_parallelism_threads=1)
    # saving
    saver = tf.train.Saver(max_to_keep=10)
    sess = tf.InteractiveSession(config=config_2)

    # initialize tensorflow vars
    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    # set the mean/std of regressor according to mean/std of the data we have so far
    regressor.update_params_data_dist(mean_inp, std_inp, mean_outp, std_outp,
                                      total_num_data)

    ###########################################################
    ## TRAIN THE DYNAMICS MODEL
    ###########################################################

    #train on the given full dataset, for max_epochs
    if train_now:
        if config["training"]["restore_previous_dynamics_model"]:
            print("\n\nRESTORING PREVIOUS DYNAMICS MODEL FROM ",
                  previous_dynamics_model, " AND CONTINUING TRAINING...\n\n")
            saver.restore(sess, previous_dynamics_model)

        np.save(save_dir + "/inputs.npy", inputs)
        np.save(save_dir + "/outputs.npy", outputs)
        np.save(save_dir + "/inputs_val.npy", inputs_val)
        np.save(save_dir + "/outputs_val.npy", outputs_val)

        train(inputs, outputs, curr_agg_iter, model, saver, sess, config,
              inputs_val, outputs_val, inputs_val_onPol, outputs_val_onPol)
    else:
        print("\n\nRESTORING A DYNAMICS MODEL FROM ", previous_dynamics_model)
        saver.restore(sess, previous_dynamics_model)

    ###########################################################
    ## RUN THE MPC CONTROLLER
    ###########################################################

    #create controller node
    controller_node = GBAC_Controller(
        sess,
        policy,
        model,
        use_pid_mode=use_pid_mode,
        state_representation=state_representation,
        default_addrs=default_addrs,
        update_batch_size=config['testing']['update_batch_size'],
        num_updates=config['testing']['num_updates'],
        de=config['testing']['dynamic_evaluation'],
        roach_config=config['roach'])

    #do 1 rollout
    print(
        "\n\n\nPAUSING... right before a controller run... RESET THE ROBOT TO A GOOD LOCATION BEFORE CONTINUING..."
    )
    #IPython.embed()
    resulting_x, selected_u, desired_seq, list_robot_info, list_mocap_info, old_saving_format_dict, list_best_action_sequences = controller_node.run(
        num_steps_per_rollout, desired_shape_for_rollout)

    #where to save this rollout
    pathStartName = save_dir + '/saved_rollouts/' + rollout_save_filename + '_aggIter' + str(
        curr_agg_iter)
    print("\n\n************** TRYING TO SAVE EVERYTHING TO: ", pathStartName)

    #save the result of the run
    np.save(pathStartName + '/oldFormat_actions.npy',
            old_saving_format_dict['actions_taken'])
    np.save(pathStartName + '/oldFormat_desired.npy',
            old_saving_format_dict['desired_states'])
    np.save(pathStartName + '/oldFormat_executed.npy',
            old_saving_format_dict['traj_taken'])
    np.save(pathStartName + '/oldFormat_perp.npy',
            old_saving_format_dict['save_perp_dist'])
    np.save(pathStartName + '/oldFormat_forward.npy',
            old_saving_format_dict['save_forward_dist'])
    np.save(pathStartName + '/oldFormat_oldforward.npy',
            old_saving_format_dict['saved_old_forward_dist'])
    np.save(pathStartName + '/oldFormat_movedtonext.npy',
            old_saving_format_dict['save_moved_to_next'])
    np.save(pathStartName + '/oldFormat_desheading.npy',
            old_saving_format_dict['save_desired_heading'])
    np.save(pathStartName + '/oldFormat_currheading.npy',
            old_saving_format_dict['save_curr_heading'])
    np.save(pathStartName + '/list_best_action_sequences.npy',
            list_best_action_sequences)

    yaml.dump(config, open(osp.join(pathStartName, 'saved_config.yaml'), 'w'))

    #save the result of the run
    np.save(pathStartName + '/actions.npy', selected_u)
    np.save(pathStartName + '/states.npy', resulting_x)
    np.save(pathStartName + '/desired.npy', desired_seq)
    pickle.dump(list_robot_info, open(pathStartName + '/robotInfo.obj', 'w'))
    pickle.dump(list_mocap_info, open(pathStartName + '/mocapInfo.obj', 'w'))

    #stop roach
    print("killing robot")
    controller_node.kill_robot()

    return
Example #23
0
def main():
    argparser = argparse.ArgumentParser()
    argparser.add_argument('-n', help='n way', default=5)
    argparser.add_argument('-k', help='k shot', default=1)
    argparser.add_argument('-b', help='batch size', default=4)
    argparser.add_argument('-l', help='meta learning rate', default=1e-3)
    args = argparser.parse_args()
    n_way = int(args.n)
    k_shot = int(args.k)
    meta_batchsz = int(args.b)
    meta_lr = float(args.l)
    train_lr = 1e-2

    k_query = 15
    imgsz = 84
    mdl_file = 'ckpt/miniimagenet%d%d.mdl' % (n_way, k_shot)
    print('mini-imagnet: %d-way %d-shot meta-lr:%f, train-lr:%f' %
          (n_way, k_shot, meta_lr, train_lr))

    device = torch.device('cuda:0')
    net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr,
               device)
    print(net)

    for epoch in range(1000):
        # batchsz here means total episode number
        mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/',
                            mode='train',
                            n_way=n_way,
                            k_shot=k_shot,
                            k_query=k_query,
                            batchsz=10000,
                            resize=imgsz)
        # fetch meta_batchsz num of episode each time
        db = DataLoader(mini,
                        meta_batchsz,
                        shuffle=True,
                        num_workers=4,
                        pin_memory=True)

        for step, batch in enumerate(db):

            # 2. train
            support_x = batch[0].to(device)
            support_y = batch[1].to(device)
            query_x = batch[2].to(device)
            query_y = batch[3].to(device)

            accs = net(support_x, support_y, query_x, query_y, training=True)

            if step % 50 == 0:
                print(epoch, step, '\t', accs)

            if step % 1000 == 0 and step != 0:  # batchsz here means total episode number
                mini_test = MiniImagenet(
                    '/hdd1/liangqu/datasets/miniimagenet/',
                    mode='test',
                    n_way=n_way,
                    k_shot=k_shot,
                    k_query=k_query,
                    batchsz=600,
                    resize=imgsz)
                # fetch meta_batchsz num of episode each time
                db_test = DataLoader(mini_test,
                                     meta_batchsz,
                                     shuffle=True,
                                     num_workers=4,
                                     pin_memory=True)
                accs_all_test = []
                for batch in db_test:
                    support_x = batch[0].to(device)
                    support_y = batch[1].to(device)
                    query_x = batch[2].to(device)
                    query_y = batch[3].to(device)

                    accs = net(support_x,
                               support_y,
                               query_x,
                               query_y,
                               training=True)
                    accs_all_test.append(accs)
                # [600, K+1]
                accs_all_test = np.array(accs_all_test)
                # [600, K+1] => [K+1]
                accs_all_test = accs_all_test.mean(axis=0)
                print('>>Test:\t', accs_all_test, '<<')
Example #24
0
def main():
    data_generator = DataGenerator(FLAGS.update_batch_size,
                                   FLAGS.meta_batch_size,
                                   k_shot=FLAGS.k_shot)

    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input

    if FLAGS.datasource == 'ml':
        input_tensors = {
            'inputa': tf.placeholder(tf.int32, shape=[None, None, 2]),
            'inputb': tf.placeholder(tf.int32, shape=[None, None, 2]),
            'labela': tf.placeholder(tf.float32, shape=[None, None, 1]),
            'labelb': tf.placeholder(tf.float32, shape=[None, None, 1])
        }
    elif FLAGS.datasource == 'bpr' or FLAGS.datasource == 'bpr_time':
        input_tensors = {
            'inputa': tf.placeholder(tf.int32, shape=[None, None, 3]),
            'inputb': tf.placeholder(tf.int32, shape=[None, None, 3]),
        }
    else:
        raise Exception('non-supported data source: {}'.format(
            FLAGS.datasource))

    model = MAML(dim_input, dim_output)
    if FLAGS.train or FLAGS.test_existing_user:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    else:
        model.construct_model(input_tensors=input_tensors, prefix='META_TEST')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    sess = tf.InteractiveSession()

    exp_string = 'mtype_{}.mbs_{}.ubs_{}.meta_lr_{}.' \
                 'update_step_{}.update_lr_{}.' \
                 'lambda_lr_{}.avg_f_{}' \
                 '.time_{}'.format(FLAGS.datasource,
                                   FLAGS.meta_batch_size,
                                   FLAGS.update_batch_size,
                                   FLAGS.meta_lr, FLAGS.num_updates,
                                   FLAGS.update_lr,
                                   FLAGS.lambda_lr,
                                   FLAGS.use_avg_init,
                                   str(datetime.now()))

    resume_itr = 0
    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()
    if FLAGS.resume:
        model_path = '{}/mlRRS/model/{}/model_{}'.format(
            FLAGS.logdir, FLAGS.load_dir, FLAGS.resume_iter)
        if os.path.exists(model_path + '.meta'):
            loader.restore(sess=sess, save_path=model_path)
            resume_itr = FLAGS.resume_iter
        else:
            raise Exception('No model saved at path {}'.format(model_path))
    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    if FLAGS.test_existing_user:
        test_existing_user(model, saver, sess, exp_string, data_generator,
                           resume_itr)
    if FLAGS.test:
        test(model, saver, sess, exp_string, data_generator, resume_itr)
Example #25
0
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 2
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(
                1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet':  # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size + 15, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(
                    FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                )  # only use one datapoint for testing to save memory

    dim_output = data_generator.dim_output
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            input_tensors = {
                'inputa': inputa,
                'inputb': inputb,
                'labela': labela,
                'labelb': labelb
            }

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(max_to_keep=10)

    #saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        print("Seeing if resume....")
        print("File string: ", FLAGS.logdir + '/' + exp_string)
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        print("model file name: ", model_file)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
Example #26
0
def main():
    print('Train(0) or Test(1)?')
    train_ = input()
    train_count = 100
    if train_ == '0':
        FLAGS.train = True
        print('训练模式下的训练次数')
        train_count = input()
        FLAGS.metatrain_iterations = int(train_count)
    else:
        FLAGS.train = False

    print('选择GPU:')
    gpu_index = input()

    os.environ['CUDA_VISIBLE_DEVICES'] = gpu_index
    config_gpu = tf.ConfigProto()
    config_gpu.gpu_options.allow_growth = True

    if FLAGS.train is True:
        test_num_updates = 1
    else:
        test_num_updates = 10  # 源代码在测试时候是10次内部梯度下降

    if FLAGS.train is False:  # 测试
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1
    print('main.py: 生成data_generator')
    if FLAGS.train:
        data_generator = DataGeneratorOneInstance(FLAGS.update_batch_size + 15,
                                                  FLAGS.meta_batch_size)
        # data_generator = DataGenerator_embedding(FLAGS.update_batch_size + 15, FLAGS.meta_batch_size)
    else:
        data_generator = DataGeneratorOneInstance(FLAGS.update_batch_size * 2,
                                                  FLAGS.meta_batch_size)
        # data_generator = DataGenerator_embedding(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size)

    # 输出维度
    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input
    print('dim_input in main is {}'.format(dim_input))

    tf_data_load = True
    num_classes = data_generator.num_classes
    sess = tf.InteractiveSession(config=config_gpu)
    # sess = tf.InteractiveSession()

    if FLAGS.train:  # only construct training model if needed
        random.seed(5)
        '''
        关于image_tensor和label_tensor的说明
        return all_image_batches, all_label_batches
        all_images_batches:
        [batch1:[pic1, pic2, ...], batch2:[]...],其中pic:[0.1,0.08,...共84*84*3长]
        all_label_batches:
        [batch1:[  [[0,1,0..], [1,0,0..], []..]  ], batch2:[]...],其中[0,1,..]长为num_classes个
        '''
        # make_data_tensor
        print(
            'main.py: train: data_generator.make_data_tensor(),得到inputa等并进行切分')
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=True)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    # 用于生成验证数据集实时打印准确率
    random.seed(6)
    print('main.py: val: data_generator.make_data_tensor()')
    image_tensor, label_tensor = data_generator.make_data_tensor(
        train=False)  # train=False仅影响文件夹以及batch_count
    inputa = tf.slice(
        image_tensor, [0, 0, 0],
        [-1, num_classes * FLAGS.update_batch_size, -1])  # 0到5*4为input_a
    inputb = tf.slice(image_tensor,
                      [0, num_classes * FLAGS.update_batch_size, 0],
                      [-1, -1, -1])
    labela = tf.slice(label_tensor, [0, 0, 0],
                      [-1, num_classes * FLAGS.update_batch_size, -1])
    labelb = tf.slice(label_tensor,
                      [0, num_classes * FLAGS.update_batch_size, 0],
                      [-1, -1, -1])
    metaval_input_tensors = {
        'inputa': inputa,
        'inputb': inputb,
        'labela': labela,
        'labelb': labelb
    }

    print('model = MAML()')
    # test_num_updates: train:1, test:5,内部梯度下降数
    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        # 初始化结束后必须调用 construct_model函数
        print('model.construct_model(\'metatrain_\')')
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        print('model.construct_model(\'metaval_\')')
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    # 训练阶段
    if FLAGS.train is False:
        # 测试阶段使用原始的batch_size
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0  # 断点继续训练
    model_file = None
    # 初始化变量
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()
    tf.train.start_queue_runners()

    # cls_5.mbs_4.ubs_5.numstep5.updatelr0.01hidden32maxpoolbatchnorm
    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("读取已有训练数据Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    # if FLAGS.train:
    if FLAGS.train:
        print('main.py: 跳转到 train(model, saver, sess, exp_string...)...')
        # my(model, sess)
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        print('main.py: 跳转到 _test(model, saver, sess, exp_string...)...')
        _test(model, saver, sess, exp_string, data_generator, test_num_updates)
Example #27
0
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)

        x = x.view(x.size(0), -1)

        return self.logits(x)


if __name__ == "__main__":

    trans = transforms.Compose(
        [transforms.Resize((28, 28)),
         transforms.ToTensor()])
    tasks = Omniglot_Task_Distribution(
        datasets.Omniglot('./Omniglot/', transform=trans), 20)
    N, K = 5, 5
    task = tasks.sample_task(N, K, 15)
    meta_model = Classifier(N)
    maml = MAML(meta_model.cuda(),
                tasks,
                inner_lr=0.01,
                meta_lr=0.001,
                K=10,
                inner_steps=1,
                tasks_per_meta_batch=32,
                criterion=nn.CrossEntropyLoss())
    maml.main_loop(num_iterations=100)
Example #28
0
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = TEST_NUM_UPDATES
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size * 2,
                                       FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(
                1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet':  # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    # data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(
                        FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                    )  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(
                    FLAGS.update_batch_size * 2, FLAGS.meta_batch_size
                )  # only use one datapoint for testing to save memory

    dim_output = data_generator.dim_output
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labela = tf.slice(label_tensor, [0, 0, 0],
                              [-1, num_classes * FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])

            import tensorflow_hub as hub
            augmentation_module = hub.Module(
                'https://tfhub.dev/google/image_augmentation/nas_cifar/1',
                name='am1')

            augmentation_module2 = hub.Module(
                'https://tfhub.dev/google/image_augmentation/flipx_crop_rotate_color/1',
                name='am2')

            meta_batch_size = inputa.get_shape()[0]
            dim = inputa.get_shape()[1]

            inputb = tf.reshape(inputa, (meta_batch_size, dim, 84, 84, 3))
            result = list()
            for i in range(meta_batch_size):
                images = augmentation_module(
                    {
                        'images': inputb[i, ...],
                        'image_size': (84, 84),
                        'augmentation': True,
                    },
                    signature='from_decoded_images')

                images = augmentation_module2(
                    {
                        'images': images,
                        'image_size': (84, 84),
                        'augmentation': True,
                    },
                    signature='from_decoded_images')

                transforms = [
                    1, 0, -tf.random.uniform(
                        shape=(), minval=-20, maxval=20, dtype=tf.int32), 0, 1,
                    -tf.random.uniform(
                        shape=(), minval=-20, maxval=20, dtype=tf.int32), 0, 0
                ]
                images = tf.contrib.image.transform(images, transforms)
                result.append(images)

            inputb = tf.stack(result)

            inputb = tf.reshape(inputb, (meta_batch_size, dim, 84 * 84 * 3))
            labelb = labela

            if FLAGS.train:
                input_tensors = {
                    'inputa': inputb,
                    'inputb': inputa,
                    'labela': labela,
                    'labelb': labelb
                }
            else:
                input_tensors = {
                    'inputa': inputa,
                    'inputb': inputb,
                    'labela': labela,
                    'labelb': labelb
                }

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
Example #29
0
def main():
    mode = args.mode
    short_term_seq_len=7
    kshot = 1
    kquery = 4
    nway = 5
    meta_batchsz = 32
    K = 5
    iterations_pre = 2000
    iterations = 1000



################################
    #SOM_MAML without attention
    db_with_attention = DataGenerator_SOM_MAML_with_attention(nway, kshot, kquery, meta_batchsz)

    data_tensor, label_tensor = db_with_attention.make_data_tensor(mode='pretrain-NYtaxi', total_batch_num=meta_batchsz*iterations_pre)
    support_x_pretrain = tf.slice(data_tensor, [0, 0, 0, 0, 0, 0], [-1,  nway * kshot, -1, -1, -1, -1], name='support_x_pretrain')
    query_x_pretrain = tf.slice(data_tensor, [0,  nway * kshot, 0, 0, 0, 0], [-1, -1, -1, -1, -1, -1], name='query_x_pretrain')
    support_y_pretrain = tf.slice(label_tensor, [0, 0, 0], [-1,  nway * kshot, -1], name='support_y_pretrain')
    query_y_pretrain = tf.slice(label_tensor, [0,  nway * kshot, 0], [-1, -1, -1], name='query_y_pretrain')

    #x_fine_tune_NYbike, y_fine_tune_NYbike = db_with_attention.make_data_tensor(mode='fine-tune_NYbike')
    #x_test_NYbike, y_test_NYbike = db_with_attention.make_data_tensor(mode='NYbike-test')
    x_fine_tune_SZtaxi, y_fine_tune_SZtaxi = db_with_attention.make_data_tensor(mode='fine-tune_SZtaxi')
    x_test_SZtaxi, y_test_SZtaxi = shenzhen_SZ_test()
    #print('-------qvdiao------')
    #print(np.array(x_test_SZtaxi).shape)
    #x_test_SZtaxi, y_test_SZtaxi = db_with_attention.make_data_tensor(mode='SZtaxi-test')
    #print('--------buqv-------')
    #print(np.array(x_test_SZtaxi).shape)





    # 1. construct MAML model
    #modelNYbike_MAML = MAML(short_term_seq_len, 3, 2, nway)
    modelSZtaxi_MAML = MAML(short_term_seq_len, 3, 2, nway)
    #modelNYbike = NO_MAML(short_term_seq_len, 3, 2)
    modelSZtaxi = NO_MAML(short_term_seq_len, 3, 2)



    # construct metatrain_ and metaval
    # NYbike + SOM_MAML
    #modelNYbike_MAML.pretrain(support_x_pretrain, support_y_pretrain, query_x_pretrain, query_y_pretrain, K, meta_batchsz)
    #modelNYbike_MAML.fine_tune(x_fine_tune_NYbike, y_fine_tune_NYbike, x_test_NYbike, y_test_NYbike)

    #config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True
    #sessNYbikeSOM_MAML = tf.InteractiveSession(config=config)
    # tf.global_variables() to save moving_mean and moving variance of batch norm
    # tf.trainable_variables()  NOT include moving_mean and moving_variance.
    #saverNYbikeSOM_MAML = tf.train.Saver(tf.global_variables(), max_to_keep=5)

    # initialize, under interative session
    #tf.global_variables_initializer().run()
    # tf.train.start_queue_runners()

    #if os.path.exists(os.path.join('ckpt/SOM_NYbike', 'checkpoint')):
    #    model_file = tf.train.latest_checkpoint('ckpt/SOM_NYbike')
    #    print("Restoring model weights from ", model_file)
    #    saverNYbikeSOM_MAML.restore(sessNYbikeSOM_MAML, model_file)

    #pretrain(modelNYbike_MAML, saverNYbikeSOM_MAML, sessNYbikeSOM_MAML, iterations_pre)
    #fine_tune(modelNYbike_MAML, saverNYbikeSOM_MAML, sessNYbikeSOM_MAML, iterations)
    #sessNYbikeSOM_MAML.close()



    # SZtaxi + SOM_MAML
    modelSZtaxi_MAML.pretrain(support_x_pretrain, support_y_pretrain, query_x_pretrain, query_y_pretrain, K, meta_batchsz)
    modelSZtaxi_MAML.fine_tune(x_fine_tune_SZtaxi, y_fine_tune_SZtaxi, x_test_SZtaxi, y_test_SZtaxi)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sessSZtaxiSOM_MAML = tf.InteractiveSession(config=config)
    # tf.global_variables() to save moving_mean and moving variance of batch norm
    # tf.trainable_variables()  NOT include moving_mean and moving_variance.
    saverSZtaxiSOM_MAML = tf.train.Saver(tf.global_variables(), max_to_keep=5)

    # initialize, under interative session
    tf.global_variables_initializer().run()
    # tf.train.start_queue_runners()

    if os.path.exists(os.path.join('ckpt/SOM_NYbike', 'checkpoint')):
        model_file = tf.train.latest_checkpoint('ckpt/SOM_SZtaxi')
        print("Restoring model weights from ", model_file)
        saverSZtaxiSOM_MAML.restore(sessSZtaxiSOM_MAML, model_file)

    pretrain(modelSZtaxi_MAML, saverSZtaxiSOM_MAML, sessSZtaxiSOM_MAML, iterations_pre)
    fine_tune(modelSZtaxi_MAML, saverSZtaxiSOM_MAML, sessSZtaxiSOM_MAML, iterations)
    sessSZtaxiSOM_MAML.close()

    # NYbike
    #modelNYbike.train(x_fine_tune_NYbike, y_fine_tune_NYbike, x_test_NYbike, y_test_NYbike)

    #config = tf.ConfigProto()
    #config.gpu_options.allow_growth = True
    #sessNYbike = tf.InteractiveSession(config=config)
    # tf.global_variables() to save moving_mean and moving variance of batch norm
    # tf.trainable_variables()  NOT include moving_mean and moving_variance.
    #saverNYbike = tf.train.Saver(tf.global_variables(), max_to_keep=5)

    # initialize, under interative session
    #tf.global_variables_initializer().run()
    # tf.train.start_queue_runners()

    #model_file = tf.train.latest_checkpoint('ckpt/NYbike')
    #print("Restoring model weights from ", model_file)
    #saverNYbike.restore(sessNYbike, model_file)

    #train_without_pretrain(modelNYbike, saverNYbike, sessNYbike, iterations)




    # SZtaxi
    modelSZtaxi.train(x_fine_tune_SZtaxi, y_fine_tune_SZtaxi, x_test_SZtaxi, y_test_SZtaxi)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sessSZtaxi = tf.InteractiveSession(config=config)
    # tf.global_variables() to save moving_mean and moving variance of batch norm
    # tf.trainable_variables()  NOT include moving_mean and moving_variance.
    saverSZtaxi = tf.train.Saver(tf.global_variables(), max_to_keep=5)

    # initialize, under interative session
    tf.global_variables_initializer().run()
    # tf.train.start_queue_runners()

    #model_file = tf.train.latest_checkpoint('ckpt/SZtaxi')
    #print("Restoring model weights from ", model_file)
    #saverSZtaxi.restore(sessSZtaxi, model_file)

    train_without_pretrain(modelSZtaxi, saverSZtaxi, sessSZtaxi, iterations)
Example #30
0
def main():

    test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        FLAGS.meta_batch_size = 1
    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr
    exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(
        FLAGS.meta_batch_size) + '.ubs_' + str(
            FLAGS.train_update_batch_size) + '.numstep' + str(
                FLAGS.num_updates) + '.updatelr' + str(
                    FLAGS.train_update_lr) + '.poison_lr' + str(
                        FLAGS.poison_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    num_images_per_class = FLAGS.update_batch_size * 3

    data_generator = DataGenerator(
        num_images_per_class, FLAGS.meta_batch_size
    )  # only use one datapoint for testing to save memory
    dim_output = data_generator.dim_output
    dim_input = data_generator.dim_input
    if FLAGS.mode == 'train_with_poison':
        print('Loading poison examples from %s' % FLAGS.poison_path)
        poison_example = np.load(FLAGS.poison_dir)
        # poison_example=np.load(FLAGS.logdir + '/' + exp_string+'/poisonx_%d.npy'%FLAGS.poison_itr)
    else:
        poison_example = None
    model = MAML(dim_input=dim_input,
                 dim_output=dim_output,
                 num_images_per_class=num_images_per_class,
                 num_classes=FLAGS.num_classes,
                 poison_example=poison_example)
    sess = tf.InteractiveSession()
    print('Session created')
    if FLAGS.datasource == 'omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train:  # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor(
                train=True, poison=(model.poisonx, model.poisony), sess=sess)
            if FLAGS.reptile:
                inputa = image_tensor
                labela = label_tensor
            else:
                inputa = tf.slice(
                    image_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                labela = tf.slice(
                    label_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            labelb = tf.slice(label_tensor,
                              [0, num_classes * FLAGS.update_batch_size, 0],
                              [-1, -1, -1])
            image_tensor, label_tensor = data_generator.make_data_tensor(
                train=False)
            if FLAGS.mode == 'train_poison':
                inputa_test = tf.slice(
                    image_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                inputb_test = tf.slice(
                    image_tensor,
                    [0, num_classes * FLAGS.update_batch_size, 0],
                    [-1, -1, -1])
                labela_test = tf.slice(
                    label_tensor, [0, 0, 0],
                    [-1, num_classes * FLAGS.update_batch_size, -1])
                labelb_test = tf.slice(
                    label_tensor,
                    [0, num_classes * FLAGS.update_batch_size, 0],
                    [-1, -1, -1])
                input_tensors = {
                    'inputa': inputa,
                    'inputb': inputb,
                    'labela': labela,
                    'labelb': labelb,
                    'inputa_test': inputa_test,
                    'inputb_test': inputb_test,
                    'labela_test': labela_test,
                    'labelb_test': labelb_test
                }
            else:
                input_tensors = {
                    'inputa': inputa,
                    'inputb': inputb,
                    'labela': labela,
                    'labelb': labelb
                }

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(
            train=False)
        inputa = tf.slice(image_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        labela = tf.slice(label_tensor, [0, 0, 0],
                          [-1, num_classes * FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor,
                          [0, num_classes * FLAGS.update_batch_size, 0],
                          [-1, -1, -1])
        metaval_input_tensors = {
            'inputa': inputa,
            'inputb': inputb,
            'labela': labela,
            'labelb': labelb
        }
    else:
        tf_data_load = False
        input_tensors = None

    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix=FLAGS.mode)

    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors,
                              prefix='metaval_')
    print('Model built')
    model.summ_op = tf.summary.merge_all()
    saver = loader = tf.train.Saver(tf.get_collection(
        tf.GraphKeys.TRAINABLE_VARIABLES),
                                    max_to_keep=10)

    resume_itr = 0
    model_file = None
    tf.train.start_queue_runners()
    tf.global_variables_initializer().run()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' +
                                                exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model'
                                                      )] + 'model' + str(
                                                          FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1 + 5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)
    test_params = [
        model, saver, sess, exp_string, data_generator, test_num_updates
    ]
    test(model, saver, sess, exp_string, data_generator, test_num_updates)
    if FLAGS.train:
        train(model,
              saver,
              sess,
              exp_string,
              data_generator,
              resume_itr,
              test_params=test_params)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)
Example #31
0
def main():
    if FLAGS.datasource == 'sinusoid':
        if FLAGS.train:
            test_num_updates = 5
        else:
            test_num_updates = 10
    else:
        if FLAGS.datasource == 'miniimagenet':
            if FLAGS.train == True:
                test_num_updates = 1  # eval on at least one update during training
            else:
                test_num_updates = 10
        else:
            test_num_updates = 10

    if FLAGS.train == False:
        orig_meta_batch_size = FLAGS.meta_batch_size
        # always use meta batch size of 1 when testing.
        FLAGS.meta_batch_size = 1

    if FLAGS.datasource == 'sinusoid':
        data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)
    else:
        if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet':
            assert FLAGS.meta_batch_size == 1
            assert FLAGS.update_batch_size == 1
            data_generator = DataGenerator(1, FLAGS.meta_batch_size)  # only use one datapoint,
        else:
            if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet?
                if FLAGS.train:
                    data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
                else:
                    data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory
            else:
                data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size)  # only use one datapoint for testing to save memory


    dim_output = data_generator.dim_output
    if FLAGS.baseline == 'oracle':
        assert FLAGS.datasource == 'sinusoid'
        dim_input = 3
        FLAGS.pretrain_iterations += FLAGS.metatrain_iterations
        FLAGS.metatrain_iterations = 0
    else:
        dim_input = data_generator.dim_input

    if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot':
        tf_data_load = True
        num_classes = data_generator.num_classes

        if FLAGS.train: # only construct training model if needed
            random.seed(5)
            image_tensor, label_tensor = data_generator.make_data_tensor()
            inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
            inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
            labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
            labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
            input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}

        random.seed(6)
        image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
        inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
        inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
        labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1])
        labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1])
        metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
    else:
        tf_data_load = False
        input_tensors = None

    model = MAML(dim_input, dim_output, test_num_updates=test_num_updates)
    if FLAGS.train or not tf_data_load:
        model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
    if tf_data_load:
        model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
    model.summ_op = tf.summary.merge_all()

    saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)

    sess = tf.InteractiveSession()

    if FLAGS.train == False:
        # change to original meta batch size when loading model.
        FLAGS.meta_batch_size = orig_meta_batch_size

    if FLAGS.train_update_batch_size == -1:
        FLAGS.train_update_batch_size = FLAGS.update_batch_size
    if FLAGS.train_update_lr == -1:
        FLAGS.train_update_lr = FLAGS.update_lr

    exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr)

    if FLAGS.num_filters != 64:
        exp_string += 'hidden' + str(FLAGS.num_filters)
    if FLAGS.max_pool:
        exp_string += 'maxpool'
    if FLAGS.stop_grad:
        exp_string += 'stopgrad'
    if FLAGS.baseline:
        exp_string += FLAGS.baseline
    if FLAGS.norm == 'batch_norm':
        exp_string += 'batchnorm'
    elif FLAGS.norm == 'layer_norm':
        exp_string += 'layernorm'
    elif FLAGS.norm == 'None':
        exp_string += 'nonorm'
    else:
        print('Norm setting not recognized.')

    resume_itr = 0
    model_file = None

    tf.global_variables_initializer().run()
    tf.train.start_queue_runners()

    if FLAGS.resume or not FLAGS.train:
        model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
        if FLAGS.test_iter > 0:
            model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter)
        if model_file:
            ind1 = model_file.index('model')
            resume_itr = int(model_file[ind1+5:])
            print("Restoring model weights from " + model_file)
            saver.restore(sess, model_file)

    if FLAGS.train:
        train(model, saver, sess, exp_string, data_generator, resume_itr)
    else:
        test(model, saver, sess, exp_string, data_generator, test_num_updates)