def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way', default=5) argparser.add_argument('-k', help='k shot', default=1) argparser.add_argument('-b', help='batch size', default=4) argparser.add_argument('-l', help='learning rate', default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) meta_batchsz = int(args.b) lr = float(args.l) k_query = 1 imgsz = 84 threhold = 0.699 if k_shot==5 else 0.584 # threshold for when to test full version of episode mdl_file = 'ckpt/maml%d%d.mdl'%(n_way, k_shot) print('mini-imagnet: %d-way %d-shot lr:%f, threshold:%f' % (n_way, k_shot, lr, threhold)) device = torch.device('cuda') net = MAML(n_way, k_shot, k_query, meta_batchsz=meta_batchsz, K=5, device=device) print(net) if os.path.exists(mdl_file): print('load from checkpoint ...', mdl_file) net.load_state_dict(torch.load(mdl_file)) else: print('training from scratch.') # whole parameters number model_parameters = filter(lambda p: p.requires_grad, net.parameters()) params = sum([np.prod(p.size()) for p in model_parameters]) print('Total params:', params) for epoch in range(1000): # batchsz here means total episode number mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) # fetch meta_batchsz num of episode each time db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=8, pin_memory=True) for step, batch in enumerate(db): # 2. train support_x = batch[0].to(device) support_y = batch[1].to(device) query_x = batch[2].to(device) query_y = batch[3].to(device) accs = net(support_x, support_y, query_x, query_y, training = True) if step % 10 == 0: print(accs)
def main(): # Constructing training and test graphs model = MAML() model.construct_model_train() model.construct_model_test() model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) resume_itr = 0 model_file = None if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) tf.global_variables_initializer().run() tf.train.start_queue_runners() print('Loading pretrained weights') model.load_initial_weights(sess) if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) filelist_root = '../data/MLDG/' domain_dict = { 1: 'art_painting.txt', 2: 'cartoon.txt', 3: 'photo.txt', 4: 'sketch.txt' } train_domain_list = [2, 3, 4] test_domain_list = [1] train_file_list = [ os.path.join(filelist_root, domain_dict[i]) for i in train_domain_list ] test_file_list = [ os.path.join(filelist_root, domain_dict[i]) for i in test_domain_list ] train(model, saver, sess, exp_string, train_file_list, test_file_list[0], resume_itr)
def main(): if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 data_generator = SinusoidDataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) dim_output = data_generator.dim_output dim_input = data_generator.dim_input input_tensors = None model = MAML( stop_grad=FLAGS.stop_grad, meta_lr=FLAGS.meta_lr, num_updates=FLAGS.num_updates, update_lr=FLAGS.update_lr, dim_input=dim_input, dim_output=dim_output, test_num_updates=test_num_updates, meta_batch_size=FLAGS.meta_batch_size, metatrain_iterations=FLAGS.metatrain_iterations, norm=FLAGS.norm, ) model.build(input_tensors=input_tensors, prefix="metatrain") if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size trainer = Trainer( model, data_generator, Path(FLAGS.logdir), FLAGS.pretrain_iterations, FLAGS.metatrain_iterations, FLAGS.meta_batch_size, FLAGS.update_batch_size, FLAGS.num_updates, FLAGS.update_lr, stop_grad=FLAGS.stop_grad, baseline=FLAGS.baseline, is_training=True ) trainer.train() trainer.test()
def __init__(self, module, task_map, finetune=1, fine_optim=None, optim=None, second_order=False, distributed=False, world_size=1, rank=-1): super(MetaTrainWrapper, self).__init__() self.module = module self.task_map = task_map self.finetune = finetune self.fine_optim = fine_optim self.optim = optim self.distributed = distributed self.init_distributed(world_size, rank) self.meta_module = MAML(self.module, self.finetune, self.fine_optim, self.task_map, second_order=second_order) self.train_history = None self.train_meter = None self.val_history = None self.val_meter = None
def test_with_maml(dataset, learner, checkpoint, steps, loss_fn): print("[*] Testing...") model = MAML(learner, steps=steps, loss_function=loss_fn) model.to(device) if checkpoint: model.restore(checkpoint, resume_training=False) else: print("[!] You are running inference on a randomly initialized model!") model.eval(dataset, compute_accuracy=(type(dataset) is OmniglotDataset)) print("[*] Done!")
def main(): if FLAGS.train: test_num_updates = FLAGS.num_updates else: test_num_updates = 5 data_generator = DataGenerator() data_generator.generate_time_series_batch(train=FLAGS.train) model = MAML(data_generator.batch_size, test_num_updates) model.construct_model(input_tensors=None, prefix='metatrain_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver( tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() exp_string = FLAGS.train_csv_file + '.numstep' + str(test_num_updates) + '.updatelr' + str(FLAGS.meta_lr) resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint( FLAGS.logdir + '/' + exp_string) print(model_file) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index( 'model')] + 'model' + str(FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, sess, exp_string, data_generator)
def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way', default=5) argparser.add_argument('-k', help='k shot', default=1) argparser.add_argument('-b', help='batch size', default=32) argparser.add_argument('-l', help='meta learning rate', default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) meta_batchsz = int(args.b) meta_lr = float(args.l) train_lr = 0.4 k_query = 15 imgsz = 84 mdl_file = 'ckpt/omniglot%d%d.mdl' % (n_way, k_shot) print('omniglot: %d-way %d-shot meta-lr:%f, train-lr:%f' % (n_way, k_shot, meta_lr, train_lr)) device = torch.device('cuda:0') net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr, device) print(net) # batchsz here means total episode number db = OmniglotNShot('omniglot', batchsz=meta_batchsz, n_way=n_way, k_shot=k_shot, k_query=k_query, imgsz=imgsz) for step in range(10000000): # train support_x, support_y, query_x, query_y = db.get_batch('train') support_x = torch.from_numpy(support_x).float().transpose( 2, 4).transpose(3, 4).repeat(1, 1, 3, 1, 1).to(device) query_x = torch.from_numpy(query_x).float().transpose(2, 4).transpose( 3, 4).repeat(1, 1, 3, 1, 1).to(device) support_y = torch.from_numpy(support_y).long().to(device) query_y = torch.from_numpy(query_y).long().to(device) accs = net(support_x, support_y, query_x, query_y, training=True) if step % 20 == 0: print(step, '\t', accs) if step % 1000 == 0: # test pass
def train_ganabi(): config_file = './config/ganabi.config.gin' gin.parse_config_file(config_file) config_obj = TrainConfig() # config = config_obj.get_config() data_generator = DataGenerator(config_obj) maml = MAML(config_obj) maml.save_gin_config(config_file) maml.train_manager(data_generator)
def train_with_maml(dataset, learner, save_path: str, steps: int, meta_batch_size: int, iterations: int, checkpoint=None, loss_fn=None): print("[*] Training...") model = MAML(learner, steps=steps, loss_function=loss_fn) model.to(device) epoch = 0 if checkpoint: model.restore(checkpoint) epoch = checkpoint['epoch'] model.fit(dataset, iterations, save_path, epoch, 100) print("[*] Done!") return model
def main(args): np.random.seed(args.seed) dataset = get_dataset(args.dataset, args.K) model = MAML(dataset, args.model_type, args.loss_type, dataset.dim_input, dataset.dim_output, args.alpha, args.beta, args.K, args.batch_size, args.is_train, args.num_updates, args.norm) if args.is_train: model.learn(args.batch_size, dataset, args.max_steps) else: model.evaluate(dataset, args.test_sample, args.draw, restore_checkpoint=args.restore_checkpoint, restore_dir=args.restore_dir)
def main(): if os.path.exists(JOB_NAME): raise AssertionError("Job name already exists") else: os.mkdir(JOB_NAME) f = open(os.path.join(JOB_NAME, "train_params.txt"), 'w') f.write("META_LEARNER " + str(META_LEARNER) + '\n') f.write("FUNCTION " + str(FUNCTION_TRAIN) + '\n') f.write("K_TRAIN " + str(K_TRAIN) + '\n') f.write("SGD_STEPS_TRAIN " + str(SGD_STEPS_TRAIN) + '\n') f.write("NOISE_PERCENT_TRAIN " + str(NOISE_PERCENT_TRAIN) + '\n') f.write("ITERATIONS_TRAIN " + str(ITERATIONS_TRAIN) + '\n') f.write("OUTER_LR_TRAIN " + str(OUTER_LR_TRAIN) + '\n') f.write("INNER_LR_TRAIN " + str(INNER_LR_TRAIN) + '\n') f.write("AVERAGER_SIZE_TRAIN " + str(AVERAGER_SIZE_TRAIN) + '\n') f.close() model = Net() if META_LEARNER == "reptile": learning_alg = Reptile(lr_inner=INNER_LR_TRAIN, lr_outer=OUTER_LR_TRAIN, sgd_steps_inner=SGD_STEPS_TRAIN) elif META_LEARNER == "maml": learning_alg = MAML(lr_inner=INNER_LR_TRAIN, lr_outer=OUTER_LR_TRAIN, sgd_steps_inner=SGD_STEPS_TRAIN) else: learning_alg = Insect(lr_inner=INNER_LR_TRAIN, lr_outer=OUTER_LR_TRAIN, sgd_steps_inner=SGD_STEPS_TRAIN, averager=AVERAGER_SIZE_TRAIN) meta_train_data = DataGenerator(function=FUNCTION_TRAIN, size=ITERATIONS_TRAIN, K=K_TRAIN, noise_percent=NOISE_PERCENT_TRAIN) learning_alg.train(model, meta_train_data) torch.save(model, os.path.join(JOB_NAME, "trained_model.pth")) test(model)
def main(): os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu TOTAL_NUM_AU = 8 all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26'] if not FLAGS.train: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 temp_kshot = FLAGS.update_batch_size FLAGS.update_batch_size = 1 if FLAGS.model.startswith('m2'): temp_num_updates = FLAGS.num_updates FLAGS.num_updates = 1 data_generator = DataGenerator() dim_output = data_generator.num_classes dim_input = data_generator.dim_input inputa, inputb, labela, labelb = data_generator.make_data_tensor() metatrain_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} model = MAML(dim_input, dim_output) model.construct_model(input_tensors=metatrain_input_tensors, prefix='metatrain_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=20) sess = tf.InteractiveSession() if not FLAGS.train: # change to original meta batch size when loading model. FLAGS.update_batch_size = temp_kshot FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.model.startswith('m2'): FLAGS.num_updates = temp_num_updates if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr tf.global_variables_initializer().run() tf.train.start_queue_runners() print('initial weights: ', sess.run('model/b1:0')) print("========================================================================================") ################## Test ################## def _load_weight_m(trained_model_dir): all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26'] if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]] w_arr = None b_arr = None for au in all_au: model_file = None print('model file dir: ', FLAGS.logdir + '/' + au + '/' + trained_model_dir) model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + au + '/' + trained_model_dir) print("model_file from ", au, ": ", model_file) if (model_file == None): print( "############################################################################################") print("####################################################################### None for ", au) print( "############################################################################################") else: if FLAGS.test_iter > 0: files = os.listdir(model_file[:model_file.index('model')]) if 'model' + str(FLAGS.test_iter) + '.index' in files: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) print("model_file by test_iter > 0: ", model_file) else: print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) w = sess.run('model/w1:0') b = sess.run('model/b1:0') print("updated weights from ckpt: ", b) print('----------------------------------------------------------') if w_arr is None: w_arr = w b_arr = b else: w_arr = np.hstack((w_arr, w)) b_arr = np.vstack((b_arr, b)) return w_arr, b_arr def _load_weight_s(sbjt_start_idx): batch_size = 10 # 모든 au 를 이용하여 한 모델을 만든경우 그 한 모델만 로드하면됨. if FLAGS.model.startswith('s1'): three_layers = feature_layer(batch_size, TOTAL_NUM_AU) three_layers.loadWeight(FLAGS.vae_model_to_test, FLAGS.au_idx, num_au_for_rm=TOTAL_NUM_AU) # 각 au별로 다른 모델인 경우 au별 weight을 쌓아줘야함 else: three_layers = feature_layer(batch_size, 1) all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26'] if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]] w_arr = None b_arr = None for au in all_au: if FLAGS.model.startswith('s3'): load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str( FLAGS.update_batch_size) + '_iter100' elif FLAGS.model.startswith('s4'): load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_subject' + str( sbjt_start_idx + 1) + '_kshot' + str( FLAGS.update_batch_size) + '_iter10_maml_adad' + str(FLAGS.test_iter) else: load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str( FLAGS.update_batch_size) + '_iter200_kshot10_iter10_nobatch_adam_noinit' three_layers.loadWeight(load_model_path, au) print('=============== Model S loaded from ', load_model_path) w = three_layers.model_intensity.layers[-1].get_weights()[0] b = three_layers.model_intensity.layers[-1].get_weights()[1] print('----------------------------------------------------------') if w_arr is None: w_arr = w b_arr = b else: w_arr = np.hstack((w_arr, w)) b_arr = np.vstack((b_arr, b)) return w_arr, b_arr def _load_weight_m0(trained_model_dir): model_file = None print('--------- model file dir: ', FLAGS.logdir + trained_model_dir) model_file = tf.train.latest_checkpoint(FLAGS.logdir + trained_model_dir) print(">>>> model_file from all_aus: ", model_file) if (model_file == None): print("####################################################################### None for all_aus") else: if FLAGS.test_iter > 0: files = os.listdir(model_file[:model_file.index('model')]) if 'model' + str(FLAGS.test_iter) + '.index' in files: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) print(">>>> model_file2: ", model_file) else: print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) w = sess.run('model/w1:0') b = sess.run('model/b1:0') print("updated weights from ckpt: ", b) print('----------------------------------------------------------') return w, b print("<<<<<<<<<<<< CONCATENATE >>>>>>>>>>>>>>") save_path = "./logs/result/" y_hat = [] y_lab = [] if FLAGS.all_sub_model: # 모델이 모든 subjects를 이용해 train된 경우 print('---------------- all sub model ----------------') # weight load를 한번만 실행해도됨. subject별로 모델이 다르지 않기 때문 if FLAGS.model.startswith('m'): trained_model_dir = '/cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) if FLAGS.model.startswith('m0'): w_arr, b_arr = _load_weight_m0(trained_model_dir) else: w_arr, b_arr = _load_weight_m(trained_model_dir) # au별로 모델이 다르게됨 ### test per each subject and concatenate for i in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks): if FLAGS.model.startswith('s'): w_arr, b_arr = _load_weight_s(i) result = test_each_subject(w_arr, b_arr, i) y_hat.append(result[0]) y_lab.append(result[1]) print("y_hat shape:", result[0].shape) print("y_lab shape:", result[1].shape) print(">> y_hat_all shape:", np.vstack(y_hat).shape) print(">> y_lab_all shape:", np.vstack(y_lab).shape) print_summary(np.vstack(y_hat), np.vstack(y_lab), log_dir=save_path + "/" + "test.txt") else: # 모델이 각 subject 별로 train된 경우: vae와 MAML의 train_test두 경우에만 존재 가능 + local weight test의 경우 for subj_idx in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks): if FLAGS.model.startswith('s'): w_arr, b_arr = _load_weight_s(subj_idx) else: trained_model_dir = '/sbjt' + str(subj_idx) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) w_arr, b_arr = _load_weight_m(trained_model_dir) result = test_each_subject(w_arr, b_arr, subj_idx) y_hat.append(result[0]) y_lab.append(result[1]) print("y_hat shape:", result[0].shape) print("y_lab shape:", result[1].shape) print(">> y_hat_all shape:", np.vstack(y_hat).shape) print(">> y_lab_all shape:", np.vstack(y_lab).shape) print_summary(np.vstack(y_hat), np.vstack(y_lab), log_dir=save_path + "/test.txt") end_time = datetime.now() elapse = end_time - start_time print("=======================================================") print(">>>>>> elapse time: " + str(elapse)) print("=======================================================")
def train_omniglot(): config = get_Omniglot_config() data_generator = DataGenerator(config) maml = MAML(config) maml.train_manager(data_generator)
def main(): if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir, exist_ok=True) test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) dim_output = data_generator.dim_output dim_input = data_generator.dim_input tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor(train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'model_{}'.format(FLAGS.model_num) resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) FLAGS.train = True train(model, saver, sess, exp_string, data_generator, resume_itr) FLAGS.train = False test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.train: test_num_updates = 20 elif FLAGS.from_scratch: test_num_updates = 200 else: test_num_updates = 50 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 sess = tf.InteractiveSession() if not FLAGS.dataset == 'imagenet': data_generator = DataGenerator(FLAGS.inner_update_batch_size_train + FLAGS.outer_update_batch_size, FLAGS.inner_update_batch_size_val + FLAGS.outer_update_batch_size, FLAGS.meta_batch_size) else: data_generator = DataGeneratorImageNet(FLAGS.inner_update_batch_size_train + FLAGS.outer_update_batch_size, FLAGS.inner_update_batch_size_val + FLAGS.outer_update_batch_size, FLAGS.meta_batch_size) dim_output_train = data_generator.dim_output_train dim_output_val = data_generator.dim_output_val dim_input = data_generator.dim_input tf_data_load = True num_classes_train = data_generator.num_classes_train num_classes_val = data_generator.num_classes_val if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes_train*FLAGS.inner_update_batch_size_train, -1]) inputb = tf.slice(image_tensor, [0,num_classes_train*FLAGS.inner_update_batch_size_train, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes_train*FLAGS.inner_update_batch_size_train, -1]) labelb = tf.slice(label_tensor, [0,num_classes_train*FLAGS.inner_update_batch_size_train, 0], [-1,-1,-1]) input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor(train=False) inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes_val*FLAGS.inner_update_batch_size_val, -1]) inputb = tf.slice(image_tensor, [0,num_classes_val*FLAGS.inner_update_batch_size_val, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes_val*FLAGS.inner_update_batch_size_val, -1]) labelb = tf.slice(label_tensor, [0,num_classes_val*FLAGS.inner_update_batch_size_val, 0], [-1,-1,-1]) metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} model = MAML(dim_input, dim_output_train, dim_output_val, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) if FLAGS.debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess) if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.log_inner_update_batch_size_val == -1: FLAGS.log_inner_update_batch_size_val = FLAGS.inner_update_batch_size_val if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = '' exp_string += '.nu_' + str(FLAGS.num_updates) + '.ilr_' + str(FLAGS.train_update_lr) if FLAGS.meta_lr != 0.001: exp_string += '.olr_' + str(FLAGS.meta_lr) if FLAGS.mt_mode != 'gtgt': if FLAGS.partition_algorithm == 'hyperplanes': exp_string += '.m_' + str(FLAGS.margin) if FLAGS.partition_algorithm == 'kmeans' or FLAGS.partition_algorithm == 'kmodes': exp_string += '.k_' + str(FLAGS.num_clusters) exp_string += '.p_' + str(FLAGS.num_partitions) if FLAGS.scaled_encodings and FLAGS.num_partitions != 1: exp_string += '.scaled' if FLAGS.mt_mode == 'encenc': exp_string += '.ned_' + str(FLAGS.num_encoding_dims) elif FLAGS.mt_mode == 'semi': exp_string += '.pgtgt_' + str(FLAGS.p_gtgt) exp_string += '.mt_' + FLAGS.mt_mode exp_string += '.mbs_' + str(FLAGS.meta_batch_size) + \ '.nct_' + str(FLAGS.num_classes_train) + \ '.iubst_' + str(FLAGS.inner_update_batch_size_train) + \ '.iubsv_' + str(FLAGS.log_inner_update_batch_size_val) + \ '.oubs' + str(FLAGS.outer_update_batch_size) exp_string = exp_string[1:] # get rid of leading period if FLAGS.on_encodings: exp_string += '.onenc' exp_string += '.nhl_' + str(FLAGS.num_hidden_layers) if FLAGS.num_filters != 64: exp_string += '.hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += '.maxpool' if FLAGS.stop_grad: exp_string += '.stopgrad' if FLAGS.norm == 'batch_norm': exp_string += '.batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += '.layernorm' elif FLAGS.norm == 'None': exp_string += '.nonorm' else: print('Norm setting not recognized.') if FLAGS.resnet: exp_string += '.res{}parts{}'.format(FLAGS.num_res_blocks, FLAGS.num_parts_per_res_block) if FLAGS.miniimagenet_only: exp_string += '.mini' if FLAGS.suffix != '': exp_string += '.' + FLAGS.suffix resume_itr = 0 model_file = None tf.global_variables_initializer().run() print(exp_string) if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1+5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) else: print("No checkpoint found") if FLAGS.from_scratch: exp_string = '' if FLAGS.from_scratch and not os.path.isdir(logdir): os.makedirs(logdir) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 # During base-testing (and thus meta updating) 5 updates are used else: test_num_updates = 10 # During meta-testing 10 updates are used else: if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 # eval on 10 updates during testing else: test_num_updates = 10 # Omniglot gets 10 updates during training AND testing if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': # DataGenerator(num_samples_per_class, batch_size, config={}) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) else: # Dealing with a non 'sinusoid' dataset here if FLAGS.metatrain_iterations == 0 and ( FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs'): assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs': # TODO - use 15 val examples for imagenet? if FLAGS.train: # TODO: why +15 and *2 --> followin Ravi: "15 examples per class were used for evaluating the post-update meta-gradient" = MAML algo 2, line 10 --> see how 5 and 15 is split up in maml.py? # DataGenerator(number_of_images_per_class, number_of_tasks_in_batch) data_generator = DataGenerator( FLAGS.update_batch_size + 15, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: # this is for omniglot data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output # number of classes, e.g. 5 for miniImagenet tasks if FLAGS.baseline == 'oracle': # NOTE - this flag is specific to sinusoid assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input # np.prod(self.img_size) for images if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'cifarfs': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed # meta train : num_total_batches = 200000 (number of tasks, not number of meta-iterations) random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1 ]) # slice(tensor, begin, slice_size) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) # The extra 15 add here?! labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } # meta val: num_total_batches = 600 (number of tasks, not number of meta-iterations) random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1 ]) # slice the training examples here inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML( dim_input, dim_output, test_num_updates=test_num_updates ) # test_num_updates = eval on at least one update for training, 10 testing if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') # Op to retrieve summaries? model.summ_op = tf.summary.merge_all() # keep last 10 copies of trainable variables saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) # remove the need to explicitly pass this Session object to run ops sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr # cls = no of classes # mbs = meta batch size # ubs = update batch size # numstep = number of INNER GRADIENT updates # updatelr = inner gradient step exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None # Initialize all variables, and tf.global_variables_initializer().run() # starts threads for all queue runners collected in the graph tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): training = not args.test kshot = 1 kquery = 15 nway = 5 meta_batchsz = 4 K = 5 # kshot + kquery images per category, nway categories, meta_batchsz tasks. db = DataGenerator(nway, kshot, kquery, meta_batchsz, 200000) if training: # only construct training model if needed # get the tensor # image_tensor: [4, 80, 84*84*3] # label_tensor: [4, 80, 5] image_tensor, label_tensor = db.make_data_tensor(training=True) # NOTICE: the image order in 80 images should like this now: # [label2, label1, label3, label0, label4, and then repeat by 15 times, namely one task] # support_x : [4, 1*5, 84*84*3] # query_x : [4, 15*5, 84*84*3] # support_y : [4, 5, 5] # query_y : [4, 15*5, 5] support_x = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x') query_x = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x') support_y = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y') query_y = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y') # construct test tensors. image_tensor, label_tensor = db.make_data_tensor(training=False) support_x_test = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x_test') query_x_test = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x_test') support_y_test = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y_test') query_y_test = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y_test') # 1. construct MAML model model = MAML(84, 3, 5) # construct metatrain_ and metaval_ if training: model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train') model.build(support_x_test, support_y_test, query_x_test, query_y_test, K, meta_batchsz, mode='eval') else: model.build(support_x_test, support_y_test, query_x_test, query_y_test, K + 5, meta_batchsz, mode='test') model.summ_op = tf.summary.merge_all() all_vars = filter(lambda x: 'meta_optim' not in x.name, tf.trainable_variables()) for p in all_vars: print(p) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=config) # tf.global_variables() to save moving_mean and moving variance of batch norm # tf.trainable_variables() NOT include moving_mean and moving_variance. saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) # initialize, under interative session tf.global_variables_initializer().run() tf.train.start_queue_runners() if os.path.exists(os.path.join('ckpt', 'checkpoint')): # alway load ckpt both train and test. model_file = tf.train.latest_checkpoint('ckpt') print("Restoring model weights from ", model_file) saver.restore(sess, model_file) if training: train(model, saver, sess) else: test(model, sess)
testset = miniimagenet("data", ways=5, shots=5, test_shots=15, meta_test=True, download=True) testloader = BatchMetaDataLoader(testset, batch_size=2, num_workers=4, shuffle=True) # training epochs = 6000 # batch sizeが2だと7751が上限(dataloaderの制限) model = MAML().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = torch.nn.CrossEntropyLoss().to(device) model_path = "./model/" result_path = "./log/train" trainiter = iter(trainloader) evaliter = iter(testloader) train_loss_log = [] train_acc_log = [] test_loss_log = [] test_acc_log = [] for epoch in range(epochs): # train
def construct_model(self): self.sess = tf.InteractiveSession() if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource in ['sinusoid', 'mixture']: data_generator = DataGenerator( FLAGS.update_batch_size + FLAGS.update_batch_size_eval, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource in [ 'miniimagenet', 'multidataset' ]: assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource in [ 'miniimagenet', 'multidataset' ]: # TODO - use 15 val examples for imagenet? if FLAGS.train: data_generator = DataGenerator( FLAGS.update_batch_size + 15, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output dim_input = data_generator.dim_input if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset']: tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) if FLAGS.datasource in ['miniimagenet', 'omniglot']: image_tensor, label_tensor = data_generator.make_data_tensor( ) elif FLAGS.datasource == 'multidataset': image_tensor, label_tensor = data_generator.make_data_tensor_multidataset( sel_num=self.clusters, train=True) inputa = tf.slice( image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice( image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice( label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice( label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) if FLAGS.datasource in ['miniimagenet', 'omniglot']: image_tensor, label_tensor = data_generator.make_data_tensor( train=False) elif FLAGS.datasource == 'multidataset': image_tensor, label_tensor = data_generator.make_data_tensor_multidataset( sel_num=self.clusters, train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML(self.sess, dim_input, dim_output, test_num_updates=self.test_num_updates) model.cluster_layer_0 = self.clusters if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr return model, saver, data_generator
torch.backends.cudnn.benchmark = True testset = miniimagenet("data", ways=5, shots=5, test_shots=15, meta_test=True, download=True) testloader = BatchMetaDataLoader(testset, batch_size=2, num_workers=4, shuffle=True) evaliter = iter(testloader) model_path = './model/model.pth' model = MAML().to(device) model.load_state_dict(torch.load(model_path)) loss_fn = torch.nn.CrossEntropyLoss().to(device) test_loss_log = [] test_acc_log = [] for i in range(1000): evalbatch = evaliter.next() model.eval() testloss, testacc = test(model, evalbatch, loss_fn, lr=0.01, train_step=10, device=device)
def main(): if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size, datasource='sinusoid') elif FLAGS.datasource == 'ball': # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size, datasource='ball') elif FLAGS.datasource == 'ball_file': # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size, datasource='ball_file') else: # 'rect_file" # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size, datasource='rect_file', rect_truncated=rect_truncated) dim_output = data_generator.dim_output dim_input = data_generator.dim_input model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) model.construct_model() model.summ_op = tf.summary.merge_all() saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size exp_string = get_exp_string(model) resume_itr = 0 tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: # ME: test_num_updates = 10; 10 gradient updates test(model, saver, sess, exp_string, data_generator, test_num_updates)
def run(d): #IPython.embed() config = d['config'] ################################################################ ###################### Load parameters ######################### ################################################################ previous_dynamics_model = config["previous_dynamics_model"] train_now = d['train_bool'] ################################################################ desired_shape_for_rollout = config["testing"]["desired_shape_for_rollout"] save_rollout_run_num = config["testing"]["save_rollout_run_num"] rollout_save_filename = desired_shape_for_rollout + str( save_rollout_run_num) num_steps_per_rollout = config["testing"]["num_steps_per_rollout"] if (desired_shape_for_rollout == "figure8"): num_steps_per_rollout = 400 elif (desired_shape_for_rollout == "zigzag"): num_steps_per_rollout = 150 ############################################################## #settings cheaty_training = False use_one_hot = False #True use_camera = False #True playback_mode = False state_representation = "exclude_x_y" #["exclude_x_y", "all"] # Settings (generally, keep these to default) default_addrs = [b'\x00\x01'] use_pid_mode = True slow_pid_mode = True visualize_rviz = True #turning this off can make things go faster visualize_True = True visualize_False = False noise_True = True noise_False = False make_aggregated_dataset_noisy = True make_training_dataset_noisy = True perform_forwardsim_for_vis = True print_minimal = False noiseToSignal = 0 if (make_training_dataset_noisy): noiseToSignal = 0.01 # Defining datatypes tf_datatype = tf.float32 np_datatype = np.float32 # Setting motor limits left_min = 1200 right_min = 1200 left_max = 2000 right_max = 2000 if (use_pid_mode): if (slow_pid_mode): left_min = 2 * math.pow(2, 16) * 0.001 right_min = 2 * math.pow(2, 16) * 0.001 left_max = 9 * math.pow(2, 16) * 0.001 right_max = 9 * math.pow(2, 16) * 0.001 else: #this hasnt been tested yet left_min = 4 * math.pow(2, 16) * 0.001 right_min = 4 * math.pow(2, 16) * 0.001 left_max = 12 * math.pow(2, 16) * 0.001 right_max = 12 * math.pow(2, 16) * 0.001 #vars from config curr_agg_iter = config['aggregation']['curr_agg_iter'] save_dir = d['exp_name'] print("\n\nSAVING EVERYTHING TO: ", save_dir) #make directories if not os.path.exists(save_dir + '/saved_rollouts'): os.makedirs(save_dir + '/saved_rollouts') if not os.path.exists(save_dir + '/saved_rollouts/' + rollout_save_filename + '_aggIter' + str(curr_agg_iter)): os.makedirs(save_dir + '/saved_rollouts/' + rollout_save_filename + '_aggIter' + str(curr_agg_iter)) ###################################### ######## GET TRAINING DATA ########### ###################################### print("\n\nCURR AGGREGATION ITER: ", curr_agg_iter) # Training data # Random dataX = [] dataX_full = [ ] #this is just for your personal use for forwardsim (for debugging) dataY = [] dataZ = [] # Training data # MPC dataX_onPol = [] dataX_full_onPol = [] dataY_onPol = [] dataZ_onPol = [] # Validation data # Random dataX_val = [] dataX_full_val = [] dataY_val = [] dataZ_val = [] # Validation data # MPC dataX_val_onPol = [] dataX_full_val_onPol = [] dataY_val_onPol = [] dataZ_val_onPol = [] training_ratio = config['training']['training_ratio'] for agg_itr_counter in range(curr_agg_iter + 1): #getDataFromDisk should give (tasks, rollouts from that task, each rollout has its points) dataX_curr, dataY_curr, dataZ_curr, dataX_curr_full = getDataFromDisk( config['experiment_type'], use_one_hot, use_camera, cheaty_training, state_representation, agg_itr_counter, config_training=config['training']) if (agg_itr_counter == 1): print("*********TRYING TO FIND THE WEIRD ROLLOUT...") for rollout in range(len(dataX_curr[2])): val = dataX_curr[2][rollout][:, 4] if (np.any(val < 0)): dataX_curr[2][rollout] = dataX_curr[2][rollout + 1] dataY_curr[2][rollout] = dataY_curr[2][rollout + 1] dataZ_curr[2][rollout] = dataZ_curr[2][rollout + 1] print("FOUND IT!!!!!!! rollout number ", rollout) #random data #go from dataX_curr (tasks, rollouts, steps) --> to dataX (tasks, some rollouts, steps) and dataX_val (tasks, some rollouts, steps) if (agg_itr_counter == 0): for task_num in range(len(dataX_curr)): taski_num_rollout = len(dataX_curr[task_num]) print("task" + str(task_num) + "_num_rollouts: ", taski_num_rollout) #for each task, append something like (356, 48, 22) (numrollouts per task, num steps in that rollout, dim) dataX.append(dataX_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_full.append( dataX_curr_full[task_num][:int(taski_num_rollout * training_ratio)]) dataY.append(dataY_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataZ.append(dataZ_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_val.append(dataX_curr[task_num][int(taski_num_rollout * training_ratio):]) dataX_full_val.append( dataX_curr_full[task_num][int(taski_num_rollout * training_ratio):]) dataY_val.append(dataY_curr[task_num][int(taski_num_rollout * training_ratio):]) dataZ_val.append(dataZ_curr[task_num][int(taski_num_rollout * training_ratio):]) #on-policy data #go from dataX_curr (tasks, rollouts, steps) --> to dataX_onPol (tasks, some rollouts, steps) and dataX_val_onPol (tasks, some rollouts, steps) elif (agg_itr_counter == 1): for task_num in range(len(dataX_curr)): taski_num_rollout = len(dataX_curr[task_num]) print("task" + str(task_num) + "_num_rollouts for onpolicy: ", taski_num_rollout) dataX_onPol.append( dataX_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_full_onPol.append( dataX_curr_full[task_num][:int(taski_num_rollout * training_ratio)]) dataY_onPol.append( dataY_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataZ_onPol.append( dataZ_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_val_onPol.append( dataX_curr[task_num][int(taski_num_rollout * training_ratio):]) dataX_full_val_onPol.append( dataX_curr_full[task_num][int(taski_num_rollout * training_ratio):]) dataY_val_onPol.append( dataY_curr[task_num][int(taski_num_rollout * training_ratio):]) dataZ_val_onPol.append( dataZ_curr[task_num][int(taski_num_rollout * training_ratio):]) #on-policy data #go from dataX_curr (tasks, rollouts, steps) --> to ADDING ONTO dataX_onPol (tasks, some more rollouts than before, steps) and dataX_val_onPol (tasks, some more rollouts than before, steps) else: for task_num in range(len(dataX_curr)): taski_num_rollout = len(dataX_curr[task_num]) print("task" + str(task_num) + "_num_rollouts for onpolicy: ", taski_num_rollout) dataX_onPol[task_num].extend( dataX_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_full_onPol[task_num].extend( dataX_curr_full[task_num][:int(taski_num_rollout * training_ratio)]) dataY_onPol[task_num].extend( dataY_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataZ_onPol[task_num].extend( dataZ_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_val_onPol[task_num].extend( dataX_curr[task_num][int(taski_num_rollout * training_ratio):]) dataX_full_val_onPol[task_num].extend( dataX_curr_full[task_num][int(taski_num_rollout * training_ratio):]) dataY_val_onPol[task_num].extend( dataY_curr[task_num][int(taski_num_rollout * training_ratio):]) dataZ_val_onPol[task_num].extend( dataZ_curr[task_num][int(taski_num_rollout * training_ratio):]) ############################################################# #count number of random and onpol data points total_random_data = len(dataX) * len(dataX[1]) * len( dataX[1][0]) # numSteps = tasks * rollouts * steps if (len(dataX_onPol) == 0): total_onPol_data = 0 else: total_onPol_data = len(dataX_onPol) * len(dataX_onPol[0]) * len( dataX_onPol[0][0] ) #this is approximate because each task doesn't have the same num rollouts or the same num steps total_num_data = total_random_data + total_onPol_data print() print() print("Number of random data points: ", total_random_data) print("Number of on-policy data points: ", total_onPol_data) print("TOTAL number of data points: ", total_num_data) ############################################################# #combine random and onpol data into a single dataset for training ratio_new = config["aggregation"]["ratio_new"] num_new_pts = ratio_new * (total_random_data) / (1.0 - ratio_new) if (len(dataX_onPol) == 0): num_times_to_copy_onPol = 0 else: num_times_to_copy_onPol = int(num_new_pts / total_onPol_data) #copy all rollouts from each task of onpol data, and do this copying this many times for i in range(num_times_to_copy_onPol): for task_num in range(len(dataX_onPol)): for rollout_num in range(len(dataX_onPol[task_num])): dataX[task_num].append(dataX_onPol[task_num][rollout_num]) dataX_full[task_num].append( dataX_full_onPol[task_num][rollout_num]) dataY[task_num].append(dataY_onPol[task_num][rollout_num]) dataZ[task_num].append(dataZ_onPol[task_num][rollout_num]) #print("num_times_to_copy_onPol: ", num_times_to_copy_onPol) # make a list of all X,Y,Z so can take mean of them # concatenate state and action --> inputs (for training) all_points_inp = [] all_points_outp = [] outputs = copy.deepcopy(dataZ) inputs = copy.deepcopy(dataX) for task_num in range(len(dataX)): for rollout_num in range(len(dataX[task_num])): #this will just be a big list of everything, so can take the mean input_pts = np.concatenate( (dataX[task_num][rollout_num], dataY[task_num][rollout_num]), axis=1) output_pts = dataZ[task_num][rollout_num] #this will the concatenate thing for later inputs[task_num][rollout_num] = np.concatenate( [dataX[task_num][rollout_num], dataY[task_num][rollout_num]], axis=1) all_points_inp.append(input_pts) all_points_outp.append(output_pts) all_points_inp = np.concatenate(all_points_inp) all_points_outp = np.concatenate(all_points_outp) ## concatenate state and action --> inputs (for validation) outputs_val = copy.deepcopy(dataZ_val) inputs_val = copy.deepcopy(dataX_val) for task_num in range(len(dataX_val)): for rollout_num in range(len(dataX_val[task_num])): #dataX[task_num][rollout_num] (steps x s_dim) #dataY[task_num][rollout_num] (steps x a_dim) inputs_val[task_num][rollout_num] = np.concatenate([ dataX_val[task_num][rollout_num], dataY_val[task_num][rollout_num] ], axis=1) ## concatenate state and action --> inputs (for validation onpol) outputs_val_onPol = copy.deepcopy(dataZ_val_onPol) inputs_val_onPol = copy.deepcopy(dataX_val_onPol) for task_num in range(len(dataX_val_onPol)): for rollout_num in range(len(dataX_val_onPol[task_num])): #dataX[task_num][rollout_num] (steps x s_dim) #dataY[task_num][rollout_num] (steps x a_dim) inputs_val_onPol[task_num][rollout_num] = np.concatenate([ dataX_val_onPol[task_num][rollout_num], dataY_val_onPol[task_num][rollout_num] ], axis=1) ############################################################# #inputs should now be (tasks, rollouts from that task, [s,a]) #outputs should now be (tasks, rollouts from that task, [ds]) #IPython.embed() inputSize = inputs[0][0].shape[1] outputSize = outputs[1][0].shape[1] print("\n\nDimensions:") print("states: ", dataX[1][0].shape[1]) print("actions: ", dataY[1][0].shape[1]) print("inputs to NN: ", inputSize) print("outputs of NN: ", outputSize) mean_inp = np.expand_dims(np.mean(all_points_inp, axis=0), axis=0) std_inp = np.expand_dims(np.std(all_points_inp, axis=0), axis=0) mean_outp = np.expand_dims(np.mean(all_points_outp, axis=0), axis=0) std_outp = np.expand_dims(np.std(all_points_outp, axis=0), axis=0) print("\n\nCalulated means and stds... ", mean_inp.shape, std_inp.shape, mean_outp.shape, std_outp.shape, "\n\n") ########################################################### ## CREATE regressor, policy, data generator, maml model ########################################################### # create regressor (NN dynamics model) regressor = DeterministicMLPRegressor( inputSize, outputSize, outputSize, tf_datatype, config['seed'], config['training']['weight_initializer'], config['model']) # create policy (MPC controller) policy = Policy(regressor, inputSize, outputSize, left_min, right_min, left_max, right_max, state_representation=state_representation, visualize_rviz=visualize_rviz, x_index=config['roach']['x_index'], y_index=config['roach']['y_index'], yaw_cos_index=config['roach']['yaw_cos_index'], yaw_sin_index=config['roach']['yaw_sin_index'], **config['policy']) # create MAML model # note: this also constructs the actual regressor network/weights model = MAML(regressor, inputSize, outputSize, config) model.construct_model(input_tensors=None, prefix='metatrain_') model.summ_op = tf.summary.merge_all() # GPU config proto gpu_device = 0 gpu_frac = 0.4 #0.4 #0.8 #0.3 os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_frac) config_2 = tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True, inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) # saving saver = tf.train.Saver(max_to_keep=10) sess = tf.InteractiveSession(config=config_2) # initialize tensorflow vars tf.global_variables_initializer().run() tf.train.start_queue_runners() # set the mean/std of regressor according to mean/std of the data we have so far regressor.update_params_data_dist(mean_inp, std_inp, mean_outp, std_outp, total_num_data) ########################################################### ## TRAIN THE DYNAMICS MODEL ########################################################### #train on the given full dataset, for max_epochs if train_now: if config["training"]["restore_previous_dynamics_model"]: print("\n\nRESTORING PREVIOUS DYNAMICS MODEL FROM ", previous_dynamics_model, " AND CONTINUING TRAINING...\n\n") saver.restore(sess, previous_dynamics_model) np.save(save_dir + "/inputs.npy", inputs) np.save(save_dir + "/outputs.npy", outputs) np.save(save_dir + "/inputs_val.npy", inputs_val) np.save(save_dir + "/outputs_val.npy", outputs_val) train(inputs, outputs, curr_agg_iter, model, saver, sess, config, inputs_val, outputs_val, inputs_val_onPol, outputs_val_onPol) else: print("\n\nRESTORING A DYNAMICS MODEL FROM ", previous_dynamics_model) saver.restore(sess, previous_dynamics_model) ########################################################### ## RUN THE MPC CONTROLLER ########################################################### #create controller node controller_node = GBAC_Controller( sess, policy, model, use_pid_mode=use_pid_mode, state_representation=state_representation, default_addrs=default_addrs, update_batch_size=config['testing']['update_batch_size'], num_updates=config['testing']['num_updates'], de=config['testing']['dynamic_evaluation'], roach_config=config['roach']) #do 1 rollout print( "\n\n\nPAUSING... right before a controller run... RESET THE ROBOT TO A GOOD LOCATION BEFORE CONTINUING..." ) #IPython.embed() resulting_x, selected_u, desired_seq, list_robot_info, list_mocap_info, old_saving_format_dict, list_best_action_sequences = controller_node.run( num_steps_per_rollout, desired_shape_for_rollout) #where to save this rollout pathStartName = save_dir + '/saved_rollouts/' + rollout_save_filename + '_aggIter' + str( curr_agg_iter) print("\n\n************** TRYING TO SAVE EVERYTHING TO: ", pathStartName) #save the result of the run np.save(pathStartName + '/oldFormat_actions.npy', old_saving_format_dict['actions_taken']) np.save(pathStartName + '/oldFormat_desired.npy', old_saving_format_dict['desired_states']) np.save(pathStartName + '/oldFormat_executed.npy', old_saving_format_dict['traj_taken']) np.save(pathStartName + '/oldFormat_perp.npy', old_saving_format_dict['save_perp_dist']) np.save(pathStartName + '/oldFormat_forward.npy', old_saving_format_dict['save_forward_dist']) np.save(pathStartName + '/oldFormat_oldforward.npy', old_saving_format_dict['saved_old_forward_dist']) np.save(pathStartName + '/oldFormat_movedtonext.npy', old_saving_format_dict['save_moved_to_next']) np.save(pathStartName + '/oldFormat_desheading.npy', old_saving_format_dict['save_desired_heading']) np.save(pathStartName + '/oldFormat_currheading.npy', old_saving_format_dict['save_curr_heading']) np.save(pathStartName + '/list_best_action_sequences.npy', list_best_action_sequences) yaml.dump(config, open(osp.join(pathStartName, 'saved_config.yaml'), 'w')) #save the result of the run np.save(pathStartName + '/actions.npy', selected_u) np.save(pathStartName + '/states.npy', resulting_x) np.save(pathStartName + '/desired.npy', desired_seq) pickle.dump(list_robot_info, open(pathStartName + '/robotInfo.obj', 'w')) pickle.dump(list_mocap_info, open(pathStartName + '/mocapInfo.obj', 'w')) #stop roach print("killing robot") controller_node.kill_robot() return
def main(): argparser = argparse.ArgumentParser() argparser.add_argument('-n', help='n way', default=5) argparser.add_argument('-k', help='k shot', default=1) argparser.add_argument('-b', help='batch size', default=4) argparser.add_argument('-l', help='meta learning rate', default=1e-3) args = argparser.parse_args() n_way = int(args.n) k_shot = int(args.k) meta_batchsz = int(args.b) meta_lr = float(args.l) train_lr = 1e-2 k_query = 15 imgsz = 84 mdl_file = 'ckpt/miniimagenet%d%d.mdl' % (n_way, k_shot) print('mini-imagnet: %d-way %d-shot meta-lr:%f, train-lr:%f' % (n_way, k_shot, meta_lr, train_lr)) device = torch.device('cuda:0') net = MAML(n_way, k_shot, k_query, meta_batchsz, 5, meta_lr, train_lr, device) print(net) for epoch in range(1000): # batchsz here means total episode number mini = MiniImagenet('/hdd1/liangqu/datasets/miniimagenet/', mode='train', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=10000, resize=imgsz) # fetch meta_batchsz num of episode each time db = DataLoader(mini, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) for step, batch in enumerate(db): # 2. train support_x = batch[0].to(device) support_y = batch[1].to(device) query_x = batch[2].to(device) query_y = batch[3].to(device) accs = net(support_x, support_y, query_x, query_y, training=True) if step % 50 == 0: print(epoch, step, '\t', accs) if step % 1000 == 0 and step != 0: # batchsz here means total episode number mini_test = MiniImagenet( '/hdd1/liangqu/datasets/miniimagenet/', mode='test', n_way=n_way, k_shot=k_shot, k_query=k_query, batchsz=600, resize=imgsz) # fetch meta_batchsz num of episode each time db_test = DataLoader(mini_test, meta_batchsz, shuffle=True, num_workers=4, pin_memory=True) accs_all_test = [] for batch in db_test: support_x = batch[0].to(device) support_y = batch[1].to(device) query_x = batch[2].to(device) query_y = batch[3].to(device) accs = net(support_x, support_y, query_x, query_y, training=True) accs_all_test.append(accs) # [600, K+1] accs_all_test = np.array(accs_all_test) # [600, K+1] => [K+1] accs_all_test = accs_all_test.mean(axis=0) print('>>Test:\t', accs_all_test, '<<')
def main(): data_generator = DataGenerator(FLAGS.update_batch_size, FLAGS.meta_batch_size, k_shot=FLAGS.k_shot) dim_output = data_generator.dim_output dim_input = data_generator.dim_input if FLAGS.datasource == 'ml': input_tensors = { 'inputa': tf.placeholder(tf.int32, shape=[None, None, 2]), 'inputb': tf.placeholder(tf.int32, shape=[None, None, 2]), 'labela': tf.placeholder(tf.float32, shape=[None, None, 1]), 'labelb': tf.placeholder(tf.float32, shape=[None, None, 1]) } elif FLAGS.datasource == 'bpr' or FLAGS.datasource == 'bpr_time': input_tensors = { 'inputa': tf.placeholder(tf.int32, shape=[None, None, 3]), 'inputb': tf.placeholder(tf.int32, shape=[None, None, 3]), } else: raise Exception('non-supported data source: {}'.format( FLAGS.datasource)) model = MAML(dim_input, dim_output) if FLAGS.train or FLAGS.test_existing_user: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') else: model.construct_model(input_tensors=input_tensors, prefix='META_TEST') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() exp_string = 'mtype_{}.mbs_{}.ubs_{}.meta_lr_{}.' \ 'update_step_{}.update_lr_{}.' \ 'lambda_lr_{}.avg_f_{}' \ '.time_{}'.format(FLAGS.datasource, FLAGS.meta_batch_size, FLAGS.update_batch_size, FLAGS.meta_lr, FLAGS.num_updates, FLAGS.update_lr, FLAGS.lambda_lr, FLAGS.use_avg_init, str(datetime.now())) resume_itr = 0 tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume: model_path = '{}/mlRRS/model/{}/model_{}'.format( FLAGS.logdir, FLAGS.load_dir, FLAGS.resume_iter) if os.path.exists(model_path + '.meta'): loader.restore(sess=sess, save_path=model_path) resume_itr = FLAGS.resume_iter else: raise Exception('No model saved at path {}'.format(model_path)) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) if FLAGS.test_existing_user: test_existing_user(model, saver, sess, exp_string, data_generator, resume_itr) if FLAGS.test: test(model, saver, sess, exp_string, data_generator, resume_itr)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 2 else: if FLAGS.datasource == 'miniimagenet': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? if FLAGS.train: data_generator = DataGenerator( FLAGS.update_batch_size + 15, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(max_to_keep=10) #saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: print("Seeing if resume....") print("File string: ", FLAGS.logdir + '/' + exp_string) model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) print("model file name: ", model_file) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): print('Train(0) or Test(1)?') train_ = input() train_count = 100 if train_ == '0': FLAGS.train = True print('训练模式下的训练次数') train_count = input() FLAGS.metatrain_iterations = int(train_count) else: FLAGS.train = False print('选择GPU:') gpu_index = input() os.environ['CUDA_VISIBLE_DEVICES'] = gpu_index config_gpu = tf.ConfigProto() config_gpu.gpu_options.allow_growth = True if FLAGS.train is True: test_num_updates = 1 else: test_num_updates = 10 # 源代码在测试时候是10次内部梯度下降 if FLAGS.train is False: # 测试 orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 print('main.py: 生成data_generator') if FLAGS.train: data_generator = DataGeneratorOneInstance(FLAGS.update_batch_size + 15, FLAGS.meta_batch_size) # data_generator = DataGenerator_embedding(FLAGS.update_batch_size + 15, FLAGS.meta_batch_size) else: data_generator = DataGeneratorOneInstance(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # data_generator = DataGenerator_embedding(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # 输出维度 dim_output = data_generator.dim_output dim_input = data_generator.dim_input print('dim_input in main is {}'.format(dim_input)) tf_data_load = True num_classes = data_generator.num_classes sess = tf.InteractiveSession(config=config_gpu) # sess = tf.InteractiveSession() if FLAGS.train: # only construct training model if needed random.seed(5) ''' 关于image_tensor和label_tensor的说明 return all_image_batches, all_label_batches all_images_batches: [batch1:[pic1, pic2, ...], batch2:[]...],其中pic:[0.1,0.08,...共84*84*3长] all_label_batches: [batch1:[ [[0,1,0..], [1,0,0..], []..] ], batch2:[]...],其中[0,1,..]长为num_classes个 ''' # make_data_tensor print( 'main.py: train: data_generator.make_data_tensor(),得到inputa等并进行切分') image_tensor, label_tensor = data_generator.make_data_tensor( train=True) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } # 用于生成验证数据集实时打印准确率 random.seed(6) print('main.py: val: data_generator.make_data_tensor()') image_tensor, label_tensor = data_generator.make_data_tensor( train=False) # train=False仅影响文件夹以及batch_count inputa = tf.slice( image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) # 0到5*4为input_a inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } print('model = MAML()') # test_num_updates: train:1, test:5,内部梯度下降数 model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: # 初始化结束后必须调用 construct_model函数 print('model.construct_model(\'metatrain_\')') model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: print('model.construct_model(\'metaval_\')') model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) # 训练阶段 if FLAGS.train is False: # 测试阶段使用原始的batch_size FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 # 断点继续训练 model_file = None # 初始化变量 tf.global_variables_initializer().run() tf.local_variables_initializer().run() tf.train.start_queue_runners() # cls_5.mbs_4.ubs_5.numstep5.updatelr0.01hidden32maxpoolbatchnorm if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("读取已有训练数据Restoring model weights from " + model_file) saver.restore(sess, model_file) # if FLAGS.train: if FLAGS.train: print('main.py: 跳转到 train(model, saver, sess, exp_string...)...') # my(model, sess) train(model, saver, sess, exp_string, data_generator, resume_itr) else: print('main.py: 跳转到 _test(model, saver, sess, exp_string...)...') _test(model, saver, sess, exp_string, data_generator, test_num_updates)
def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = x.view(x.size(0), -1) return self.logits(x) if __name__ == "__main__": trans = transforms.Compose( [transforms.Resize((28, 28)), transforms.ToTensor()]) tasks = Omniglot_Task_Distribution( datasets.Omniglot('./Omniglot/', transform=trans), 20) N, K = 5, 5 task = tasks.sample_task(N, K, 15) meta_model = Classifier(N) maml = MAML(meta_model.cuda(), tasks, inner_lr=0.01, meta_lr=0.001, K=10, inner_steps=1, tasks_per_meta_batch=32, criterion=nn.CrossEntropyLoss()) maml.main_loop(num_iterations=100)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 else: if FLAGS.datasource == 'miniimagenet': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = TEST_NUM_UPDATES else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? if FLAGS.train: # data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) import tensorflow_hub as hub augmentation_module = hub.Module( 'https://tfhub.dev/google/image_augmentation/nas_cifar/1', name='am1') augmentation_module2 = hub.Module( 'https://tfhub.dev/google/image_augmentation/flipx_crop_rotate_color/1', name='am2') meta_batch_size = inputa.get_shape()[0] dim = inputa.get_shape()[1] inputb = tf.reshape(inputa, (meta_batch_size, dim, 84, 84, 3)) result = list() for i in range(meta_batch_size): images = augmentation_module( { 'images': inputb[i, ...], 'image_size': (84, 84), 'augmentation': True, }, signature='from_decoded_images') images = augmentation_module2( { 'images': images, 'image_size': (84, 84), 'augmentation': True, }, signature='from_decoded_images') transforms = [ 1, 0, -tf.random.uniform( shape=(), minval=-20, maxval=20, dtype=tf.int32), 0, 1, -tf.random.uniform( shape=(), minval=-20, maxval=20, dtype=tf.int32), 0, 0 ] images = tf.contrib.image.transform(images, transforms) result.append(images) inputb = tf.stack(result) inputb = tf.reshape(inputb, (meta_batch_size, dim, 84 * 84 * 3)) labelb = labela if FLAGS.train: input_tensors = { 'inputa': inputb, 'inputb': inputa, 'labela': labela, 'labelb': labelb } else: input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): mode = args.mode short_term_seq_len=7 kshot = 1 kquery = 4 nway = 5 meta_batchsz = 32 K = 5 iterations_pre = 2000 iterations = 1000 ################################ #SOM_MAML without attention db_with_attention = DataGenerator_SOM_MAML_with_attention(nway, kshot, kquery, meta_batchsz) data_tensor, label_tensor = db_with_attention.make_data_tensor(mode='pretrain-NYtaxi', total_batch_num=meta_batchsz*iterations_pre) support_x_pretrain = tf.slice(data_tensor, [0, 0, 0, 0, 0, 0], [-1, nway * kshot, -1, -1, -1, -1], name='support_x_pretrain') query_x_pretrain = tf.slice(data_tensor, [0, nway * kshot, 0, 0, 0, 0], [-1, -1, -1, -1, -1, -1], name='query_x_pretrain') support_y_pretrain = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y_pretrain') query_y_pretrain = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y_pretrain') #x_fine_tune_NYbike, y_fine_tune_NYbike = db_with_attention.make_data_tensor(mode='fine-tune_NYbike') #x_test_NYbike, y_test_NYbike = db_with_attention.make_data_tensor(mode='NYbike-test') x_fine_tune_SZtaxi, y_fine_tune_SZtaxi = db_with_attention.make_data_tensor(mode='fine-tune_SZtaxi') x_test_SZtaxi, y_test_SZtaxi = shenzhen_SZ_test() #print('-------qvdiao------') #print(np.array(x_test_SZtaxi).shape) #x_test_SZtaxi, y_test_SZtaxi = db_with_attention.make_data_tensor(mode='SZtaxi-test') #print('--------buqv-------') #print(np.array(x_test_SZtaxi).shape) # 1. construct MAML model #modelNYbike_MAML = MAML(short_term_seq_len, 3, 2, nway) modelSZtaxi_MAML = MAML(short_term_seq_len, 3, 2, nway) #modelNYbike = NO_MAML(short_term_seq_len, 3, 2) modelSZtaxi = NO_MAML(short_term_seq_len, 3, 2) # construct metatrain_ and metaval # NYbike + SOM_MAML #modelNYbike_MAML.pretrain(support_x_pretrain, support_y_pretrain, query_x_pretrain, query_y_pretrain, K, meta_batchsz) #modelNYbike_MAML.fine_tune(x_fine_tune_NYbike, y_fine_tune_NYbike, x_test_NYbike, y_test_NYbike) #config = tf.ConfigProto() #config.gpu_options.allow_growth = True #sessNYbikeSOM_MAML = tf.InteractiveSession(config=config) # tf.global_variables() to save moving_mean and moving variance of batch norm # tf.trainable_variables() NOT include moving_mean and moving_variance. #saverNYbikeSOM_MAML = tf.train.Saver(tf.global_variables(), max_to_keep=5) # initialize, under interative session #tf.global_variables_initializer().run() # tf.train.start_queue_runners() #if os.path.exists(os.path.join('ckpt/SOM_NYbike', 'checkpoint')): # model_file = tf.train.latest_checkpoint('ckpt/SOM_NYbike') # print("Restoring model weights from ", model_file) # saverNYbikeSOM_MAML.restore(sessNYbikeSOM_MAML, model_file) #pretrain(modelNYbike_MAML, saverNYbikeSOM_MAML, sessNYbikeSOM_MAML, iterations_pre) #fine_tune(modelNYbike_MAML, saverNYbikeSOM_MAML, sessNYbikeSOM_MAML, iterations) #sessNYbikeSOM_MAML.close() # SZtaxi + SOM_MAML modelSZtaxi_MAML.pretrain(support_x_pretrain, support_y_pretrain, query_x_pretrain, query_y_pretrain, K, meta_batchsz) modelSZtaxi_MAML.fine_tune(x_fine_tune_SZtaxi, y_fine_tune_SZtaxi, x_test_SZtaxi, y_test_SZtaxi) config = tf.ConfigProto() config.gpu_options.allow_growth = True sessSZtaxiSOM_MAML = tf.InteractiveSession(config=config) # tf.global_variables() to save moving_mean and moving variance of batch norm # tf.trainable_variables() NOT include moving_mean and moving_variance. saverSZtaxiSOM_MAML = tf.train.Saver(tf.global_variables(), max_to_keep=5) # initialize, under interative session tf.global_variables_initializer().run() # tf.train.start_queue_runners() if os.path.exists(os.path.join('ckpt/SOM_NYbike', 'checkpoint')): model_file = tf.train.latest_checkpoint('ckpt/SOM_SZtaxi') print("Restoring model weights from ", model_file) saverSZtaxiSOM_MAML.restore(sessSZtaxiSOM_MAML, model_file) pretrain(modelSZtaxi_MAML, saverSZtaxiSOM_MAML, sessSZtaxiSOM_MAML, iterations_pre) fine_tune(modelSZtaxi_MAML, saverSZtaxiSOM_MAML, sessSZtaxiSOM_MAML, iterations) sessSZtaxiSOM_MAML.close() # NYbike #modelNYbike.train(x_fine_tune_NYbike, y_fine_tune_NYbike, x_test_NYbike, y_test_NYbike) #config = tf.ConfigProto() #config.gpu_options.allow_growth = True #sessNYbike = tf.InteractiveSession(config=config) # tf.global_variables() to save moving_mean and moving variance of batch norm # tf.trainable_variables() NOT include moving_mean and moving_variance. #saverNYbike = tf.train.Saver(tf.global_variables(), max_to_keep=5) # initialize, under interative session #tf.global_variables_initializer().run() # tf.train.start_queue_runners() #model_file = tf.train.latest_checkpoint('ckpt/NYbike') #print("Restoring model weights from ", model_file) #saverNYbike.restore(sessNYbike, model_file) #train_without_pretrain(modelNYbike, saverNYbike, sessNYbike, iterations) # SZtaxi modelSZtaxi.train(x_fine_tune_SZtaxi, y_fine_tune_SZtaxi, x_test_SZtaxi, y_test_SZtaxi) config = tf.ConfigProto() config.gpu_options.allow_growth = True sessSZtaxi = tf.InteractiveSession(config=config) # tf.global_variables() to save moving_mean and moving variance of batch norm # tf.trainable_variables() NOT include moving_mean and moving_variance. saverSZtaxi = tf.train.Saver(tf.global_variables(), max_to_keep=5) # initialize, under interative session tf.global_variables_initializer().run() # tf.train.start_queue_runners() #model_file = tf.train.latest_checkpoint('ckpt/SZtaxi') #print("Restoring model weights from ", model_file) #saverSZtaxi.restore(sessSZtaxi, model_file) train_without_pretrain(modelSZtaxi, saverSZtaxi, sessSZtaxi, iterations)
def main(): test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size FLAGS.meta_batch_size = 1 if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.poison_lr' + str( FLAGS.poison_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') num_images_per_class = FLAGS.update_batch_size * 3 data_generator = DataGenerator( num_images_per_class, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output dim_input = data_generator.dim_input if FLAGS.mode == 'train_with_poison': print('Loading poison examples from %s' % FLAGS.poison_path) poison_example = np.load(FLAGS.poison_dir) # poison_example=np.load(FLAGS.logdir + '/' + exp_string+'/poisonx_%d.npy'%FLAGS.poison_itr) else: poison_example = None model = MAML(dim_input=dim_input, dim_output=dim_output, num_images_per_class=num_images_per_class, num_classes=FLAGS.num_classes, poison_example=poison_example) sess = tf.InteractiveSession() print('Session created') if FLAGS.datasource == 'omniglot': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor( train=True, poison=(model.poisonx, model.poisony), sess=sess) if FLAGS.reptile: inputa = image_tensor labela = label_tensor else: inputa = tf.slice( image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labela = tf.slice( label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) if FLAGS.mode == 'train_poison': inputa_test = tf.slice( image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb_test = tf.slice( image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela_test = tf.slice( label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb_test = tf.slice( label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb, 'inputa_test': inputa_test, 'inputb_test': inputb_test, 'labela_test': labela_test, 'labelb_test': labelb_test } else: input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix=FLAGS.mode) if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') print('Model built') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) resume_itr = 0 model_file = None tf.train.start_queue_runners() tf.global_variables_initializer().run() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) test_params = [ model, saver, sess, exp_string, data_generator, test_num_updates ] test(model, saver, sess, exp_string, data_generator, test_num_updates) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr, test_params=test_params) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 else: if FLAGS.datasource == 'miniimagenet': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? if FLAGS.train: data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory else: data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory else: data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor(train=False) inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} else: tf_data_load = False input_tensors = None model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1+5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)