def main(): # Constructing training and test graphs model = MAML() model.construct_model_train() model.construct_model_test() model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) resume_itr = 0 model_file = None if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) tf.global_variables_initializer().run() tf.train.start_queue_runners() print('Loading pretrained weights') model.load_initial_weights(sess) if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) filelist_root = '../data/MLDG/' domain_dict = { 1: 'art_painting.txt', 2: 'cartoon.txt', 3: 'photo.txt', 4: 'sketch.txt' } train_domain_list = [2, 3, 4] test_domain_list = [1] train_file_list = [ os.path.join(filelist_root, domain_dict[i]) for i in train_domain_list ] test_file_list = [ os.path.join(filelist_root, domain_dict[i]) for i in test_domain_list ] train(model, saver, sess, exp_string, train_file_list, test_file_list[0], resume_itr)
def main(): if FLAGS.train: test_num_updates = FLAGS.num_updates else: test_num_updates = 5 data_generator = DataGenerator() data_generator.generate_time_series_batch(train=FLAGS.train) model = MAML(data_generator.batch_size, test_num_updates) model.construct_model(input_tensors=None, prefix='metatrain_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver( tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() exp_string = FLAGS.train_csv_file + '.numstep' + str(test_num_updates) + '.updatelr' + str(FLAGS.meta_lr) resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint( FLAGS.logdir + '/' + exp_string) print(model_file) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index( 'model')] + 'model' + str(FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, sess, exp_string, data_generator)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 # During base-testing (and thus meta updating) 5 updates are used else: test_num_updates = 10 # During meta-testing 10 updates are used else: if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 # eval on 10 updates during testing else: test_num_updates = 10 # Omniglot gets 10 updates during training AND testing if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': # DataGenerator(num_samples_per_class, batch_size, config={}) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) else: # Dealing with a non 'sinusoid' dataset here if FLAGS.metatrain_iterations == 0 and ( FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs'): assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifarfs': # TODO - use 15 val examples for imagenet? if FLAGS.train: # TODO: why +15 and *2 --> followin Ravi: "15 examples per class were used for evaluating the post-update meta-gradient" = MAML algo 2, line 10 --> see how 5 and 15 is split up in maml.py? # DataGenerator(number_of_images_per_class, number_of_tasks_in_batch) data_generator = DataGenerator( FLAGS.update_batch_size + 15, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: # this is for omniglot data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output # number of classes, e.g. 5 for miniImagenet tasks if FLAGS.baseline == 'oracle': # NOTE - this flag is specific to sinusoid assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input # np.prod(self.img_size) for images if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'cifarfs': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed # meta train : num_total_batches = 200000 (number of tasks, not number of meta-iterations) random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1 ]) # slice(tensor, begin, slice_size) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) # The extra 15 add here?! labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } # meta val: num_total_batches = 600 (number of tasks, not number of meta-iterations) random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1 ]) # slice the training examples here inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML( dim_input, dim_output, test_num_updates=test_num_updates ) # test_num_updates = eval on at least one update for training, 10 testing if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') # Op to retrieve summaries? model.summ_op = tf.summary.merge_all() # keep last 10 copies of trainable variables saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) # remove the need to explicitly pass this Session object to run ops sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr # cls = no of classes # mbs = meta batch size # ubs = update batch size # numstep = number of INNER GRADIENT updates # updatelr = inner gradient step exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None # Initialize all variables, and tf.global_variables_initializer().run() # starts threads for all queue runners collected in the graph tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size, datasource='sinusoid') elif FLAGS.datasource == 'ball': # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size, datasource='ball') elif FLAGS.datasource == 'ball_file': # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size, datasource='ball_file') else: # 'rect_file" # ME: update_batch_size = 10 (20 samples/task); meta_batch_size = 25 (25 tasks) data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size, datasource='rect_file', rect_truncated=rect_truncated) dim_output = data_generator.dim_output dim_input = data_generator.dim_input model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) model.construct_model() model.summ_op = tf.summary.merge_all() saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size exp_string = get_exp_string(model) resume_itr = 0 tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: # ME: test_num_updates = 10; 10 gradient updates test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 else: if FLAGS.datasource == 'miniimagenet': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? if FLAGS.train: data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory else: data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory else: data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor(train=False) inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} else: tf_data_load = False input_tensors = None model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1+5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir, exist_ok=True) test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) dim_output = data_generator.dim_output dim_input = data_generator.dim_input tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor(train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'model_{}'.format(FLAGS.model_num) resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) FLAGS.train = True train(model, saver, sess, exp_string, data_generator, resume_itr) FLAGS.train = False test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu TOTAL_NUM_AU = 8 all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26'] if not FLAGS.train: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 temp_kshot = FLAGS.update_batch_size FLAGS.update_batch_size = 1 if FLAGS.model.startswith('m2'): temp_num_updates = FLAGS.num_updates FLAGS.num_updates = 1 data_generator = DataGenerator() dim_output = data_generator.num_classes dim_input = data_generator.dim_input inputa, inputb, labela, labelb = data_generator.make_data_tensor() metatrain_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} model = MAML(dim_input, dim_output) model.construct_model(input_tensors=metatrain_input_tensors, prefix='metatrain_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=20) sess = tf.InteractiveSession() if not FLAGS.train: # change to original meta batch size when loading model. FLAGS.update_batch_size = temp_kshot FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.model.startswith('m2'): FLAGS.num_updates = temp_num_updates if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr tf.global_variables_initializer().run() tf.train.start_queue_runners() print('initial weights: ', sess.run('model/b1:0')) print("========================================================================================") ################## Test ################## def _load_weight_m(trained_model_dir): all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26'] if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]] w_arr = None b_arr = None for au in all_au: model_file = None print('model file dir: ', FLAGS.logdir + '/' + au + '/' + trained_model_dir) model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + au + '/' + trained_model_dir) print("model_file from ", au, ": ", model_file) if (model_file == None): print( "############################################################################################") print("####################################################################### None for ", au) print( "############################################################################################") else: if FLAGS.test_iter > 0: files = os.listdir(model_file[:model_file.index('model')]) if 'model' + str(FLAGS.test_iter) + '.index' in files: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) print("model_file by test_iter > 0: ", model_file) else: print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) w = sess.run('model/w1:0') b = sess.run('model/b1:0') print("updated weights from ckpt: ", b) print('----------------------------------------------------------') if w_arr is None: w_arr = w b_arr = b else: w_arr = np.hstack((w_arr, w)) b_arr = np.vstack((b_arr, b)) return w_arr, b_arr def _load_weight_s(sbjt_start_idx): batch_size = 10 # 모든 au 를 이용하여 한 모델을 만든경우 그 한 모델만 로드하면됨. if FLAGS.model.startswith('s1'): three_layers = feature_layer(batch_size, TOTAL_NUM_AU) three_layers.loadWeight(FLAGS.vae_model_to_test, FLAGS.au_idx, num_au_for_rm=TOTAL_NUM_AU) # 각 au별로 다른 모델인 경우 au별 weight을 쌓아줘야함 else: three_layers = feature_layer(batch_size, 1) all_au = ['au1', 'au2', 'au4', 'au6', 'au9', 'au12', 'au25', 'au26'] if FLAGS.au_idx < TOTAL_NUM_AU: all_au = [all_au[FLAGS.au_idx]] w_arr = None b_arr = None for au in all_au: if FLAGS.model.startswith('s3'): load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str( FLAGS.update_batch_size) + '_iter100' elif FLAGS.model.startswith('s4'): load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_subject' + str( sbjt_start_idx + 1) + '_kshot' + str( FLAGS.update_batch_size) + '_iter10_maml_adad' + str(FLAGS.test_iter) else: load_model_path = FLAGS.vae_model_to_test + '/' + FLAGS.model + '_' + au + '_kshot' + str( FLAGS.update_batch_size) + '_iter200_kshot10_iter10_nobatch_adam_noinit' three_layers.loadWeight(load_model_path, au) print('=============== Model S loaded from ', load_model_path) w = three_layers.model_intensity.layers[-1].get_weights()[0] b = three_layers.model_intensity.layers[-1].get_weights()[1] print('----------------------------------------------------------') if w_arr is None: w_arr = w b_arr = b else: w_arr = np.hstack((w_arr, w)) b_arr = np.vstack((b_arr, b)) return w_arr, b_arr def _load_weight_m0(trained_model_dir): model_file = None print('--------- model file dir: ', FLAGS.logdir + trained_model_dir) model_file = tf.train.latest_checkpoint(FLAGS.logdir + trained_model_dir) print(">>>> model_file from all_aus: ", model_file) if (model_file == None): print("####################################################################### None for all_aus") else: if FLAGS.test_iter > 0: files = os.listdir(model_file[:model_file.index('model')]) if 'model' + str(FLAGS.test_iter) + '.index' in files: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) print(">>>> model_file2: ", model_file) else: print(" >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>", files) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) w = sess.run('model/w1:0') b = sess.run('model/b1:0') print("updated weights from ckpt: ", b) print('----------------------------------------------------------') return w, b print("<<<<<<<<<<<< CONCATENATE >>>>>>>>>>>>>>") save_path = "./logs/result/" y_hat = [] y_lab = [] if FLAGS.all_sub_model: # 모델이 모든 subjects를 이용해 train된 경우 print('---------------- all sub model ----------------') # weight load를 한번만 실행해도됨. subject별로 모델이 다르지 않기 때문 if FLAGS.model.startswith('m'): trained_model_dir = '/cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) if FLAGS.model.startswith('m0'): w_arr, b_arr = _load_weight_m0(trained_model_dir) else: w_arr, b_arr = _load_weight_m(trained_model_dir) # au별로 모델이 다르게됨 ### test per each subject and concatenate for i in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks): if FLAGS.model.startswith('s'): w_arr, b_arr = _load_weight_s(i) result = test_each_subject(w_arr, b_arr, i) y_hat.append(result[0]) y_lab.append(result[1]) print("y_hat shape:", result[0].shape) print("y_lab shape:", result[1].shape) print(">> y_hat_all shape:", np.vstack(y_hat).shape) print(">> y_lab_all shape:", np.vstack(y_lab).shape) print_summary(np.vstack(y_hat), np.vstack(y_lab), log_dir=save_path + "/" + "test.txt") else: # 모델이 각 subject 별로 train된 경우: vae와 MAML의 train_test두 경우에만 존재 가능 + local weight test의 경우 for subj_idx in range(FLAGS.sbjt_start_idx, FLAGS.sbjt_start_idx + FLAGS.num_test_tasks): if FLAGS.model.startswith('s'): w_arr, b_arr = _load_weight_s(subj_idx) else: trained_model_dir = '/sbjt' + str(subj_idx) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) w_arr, b_arr = _load_weight_m(trained_model_dir) result = test_each_subject(w_arr, b_arr, subj_idx) y_hat.append(result[0]) y_lab.append(result[1]) print("y_hat shape:", result[0].shape) print("y_lab shape:", result[1].shape) print(">> y_hat_all shape:", np.vstack(y_hat).shape) print(">> y_lab_all shape:", np.vstack(y_lab).shape) print_summary(np.vstack(y_hat), np.vstack(y_lab), log_dir=save_path + "/test.txt") end_time = datetime.now() elapse = end_time - start_time print("=======================================================") print(">>>>>> elapse time: " + str(elapse)) print("=======================================================")
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 else: if FLAGS.datasource == 'miniimagenet': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = TEST_NUM_UPDATES else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? if FLAGS.train: # data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) import tensorflow_hub as hub augmentation_module = hub.Module( 'https://tfhub.dev/google/image_augmentation/nas_cifar/1', name='am1') augmentation_module2 = hub.Module( 'https://tfhub.dev/google/image_augmentation/flipx_crop_rotate_color/1', name='am2') meta_batch_size = inputa.get_shape()[0] dim = inputa.get_shape()[1] inputb = tf.reshape(inputa, (meta_batch_size, dim, 84, 84, 3)) result = list() for i in range(meta_batch_size): images = augmentation_module( { 'images': inputb[i, ...], 'image_size': (84, 84), 'augmentation': True, }, signature='from_decoded_images') images = augmentation_module2( { 'images': images, 'image_size': (84, 84), 'augmentation': True, }, signature='from_decoded_images') transforms = [ 1, 0, -tf.random.uniform( shape=(), minval=-20, maxval=20, dtype=tf.int32), 0, 1, -tf.random.uniform( shape=(), minval=-20, maxval=20, dtype=tf.int32), 0, 0 ] images = tf.contrib.image.transform(images, transforms) result.append(images) inputb = tf.stack(result) inputb = tf.reshape(inputb, (meta_batch_size, dim, 84 * 84 * 3)) labelb = labela if FLAGS.train: input_tensors = { 'inputa': inputb, 'inputb': inputa, 'labela': labela, 'labelb': labelb } else: input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.train: test_num_updates = 20 elif FLAGS.from_scratch: test_num_updates = 200 else: test_num_updates = 50 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 sess = tf.InteractiveSession() if not FLAGS.dataset == 'imagenet': data_generator = DataGenerator(FLAGS.inner_update_batch_size_train + FLAGS.outer_update_batch_size, FLAGS.inner_update_batch_size_val + FLAGS.outer_update_batch_size, FLAGS.meta_batch_size) else: data_generator = DataGeneratorImageNet(FLAGS.inner_update_batch_size_train + FLAGS.outer_update_batch_size, FLAGS.inner_update_batch_size_val + FLAGS.outer_update_batch_size, FLAGS.meta_batch_size) dim_output_train = data_generator.dim_output_train dim_output_val = data_generator.dim_output_val dim_input = data_generator.dim_input tf_data_load = True num_classes_train = data_generator.num_classes_train num_classes_val = data_generator.num_classes_val if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes_train*FLAGS.inner_update_batch_size_train, -1]) inputb = tf.slice(image_tensor, [0,num_classes_train*FLAGS.inner_update_batch_size_train, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes_train*FLAGS.inner_update_batch_size_train, -1]) labelb = tf.slice(label_tensor, [0,num_classes_train*FLAGS.inner_update_batch_size_train, 0], [-1,-1,-1]) input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor(train=False) inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes_val*FLAGS.inner_update_batch_size_val, -1]) inputb = tf.slice(image_tensor, [0,num_classes_val*FLAGS.inner_update_batch_size_val, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes_val*FLAGS.inner_update_batch_size_val, -1]) labelb = tf.slice(label_tensor, [0,num_classes_val*FLAGS.inner_update_batch_size_val, 0], [-1,-1,-1]) metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} model = MAML(dim_input, dim_output_train, dim_output_val, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) if FLAGS.debug: sess = tf_debug.LocalCLIDebugWrapperSession(sess) if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.log_inner_update_batch_size_val == -1: FLAGS.log_inner_update_batch_size_val = FLAGS.inner_update_batch_size_val if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = '' exp_string += '.nu_' + str(FLAGS.num_updates) + '.ilr_' + str(FLAGS.train_update_lr) if FLAGS.meta_lr != 0.001: exp_string += '.olr_' + str(FLAGS.meta_lr) if FLAGS.mt_mode != 'gtgt': if FLAGS.partition_algorithm == 'hyperplanes': exp_string += '.m_' + str(FLAGS.margin) if FLAGS.partition_algorithm == 'kmeans' or FLAGS.partition_algorithm == 'kmodes': exp_string += '.k_' + str(FLAGS.num_clusters) exp_string += '.p_' + str(FLAGS.num_partitions) if FLAGS.scaled_encodings and FLAGS.num_partitions != 1: exp_string += '.scaled' if FLAGS.mt_mode == 'encenc': exp_string += '.ned_' + str(FLAGS.num_encoding_dims) elif FLAGS.mt_mode == 'semi': exp_string += '.pgtgt_' + str(FLAGS.p_gtgt) exp_string += '.mt_' + FLAGS.mt_mode exp_string += '.mbs_' + str(FLAGS.meta_batch_size) + \ '.nct_' + str(FLAGS.num_classes_train) + \ '.iubst_' + str(FLAGS.inner_update_batch_size_train) + \ '.iubsv_' + str(FLAGS.log_inner_update_batch_size_val) + \ '.oubs' + str(FLAGS.outer_update_batch_size) exp_string = exp_string[1:] # get rid of leading period if FLAGS.on_encodings: exp_string += '.onenc' exp_string += '.nhl_' + str(FLAGS.num_hidden_layers) if FLAGS.num_filters != 64: exp_string += '.hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += '.maxpool' if FLAGS.stop_grad: exp_string += '.stopgrad' if FLAGS.norm == 'batch_norm': exp_string += '.batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += '.layernorm' elif FLAGS.norm == 'None': exp_string += '.nonorm' else: print('Norm setting not recognized.') if FLAGS.resnet: exp_string += '.res{}parts{}'.format(FLAGS.num_res_blocks, FLAGS.num_parts_per_res_block) if FLAGS.miniimagenet_only: exp_string += '.mini' if FLAGS.suffix != '': exp_string += '.' + FLAGS.suffix resume_itr = 0 model_file = None tf.global_variables_initializer().run() print(exp_string) if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1+5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) else: print("No checkpoint found") if FLAGS.from_scratch: exp_string = '' if FLAGS.from_scratch and not os.path.isdir(logdir): os.makedirs(logdir) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): training = not args.test kshot = 1 kquery = 15 nway = 5 meta_batchsz = 4 K = 5 # kshot + kquery images per category, nway categories, meta_batchsz tasks. db = DataGenerator(nway, kshot, kquery, meta_batchsz, 200000) if training: # only construct training model if needed # get the tensor # image_tensor: [4, 80, 84*84*3] # label_tensor: [4, 80, 5] image_tensor, label_tensor = db.make_data_tensor(training=True) # NOTICE: the image order in 80 images should like this now: # [label2, label1, label3, label0, label4, and then repeat by 15 times, namely one task] # support_x : [4, 1*5, 84*84*3] # query_x : [4, 15*5, 84*84*3] # support_y : [4, 5, 5] # query_y : [4, 15*5, 5] support_x = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x') query_x = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x') support_y = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y') query_y = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y') # construct test tensors. image_tensor, label_tensor = db.make_data_tensor(training=False) support_x_test = tf.slice(image_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_x_test') query_x_test = tf.slice(image_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_x_test') support_y_test = tf.slice(label_tensor, [0, 0, 0], [-1, nway * kshot, -1], name='support_y_test') query_y_test = tf.slice(label_tensor, [0, nway * kshot, 0], [-1, -1, -1], name='query_y_test') # 1. construct MAML model model = MAML(84, 3, 5) # construct metatrain_ and metaval_ if training: model.build(support_x, support_y, query_x, query_y, K, meta_batchsz, mode='train') model.build(support_x_test, support_y_test, query_x_test, query_y_test, K, meta_batchsz, mode='eval') else: model.build(support_x_test, support_y_test, query_x_test, query_y_test, K + 5, meta_batchsz, mode='test') model.summ_op = tf.summary.merge_all() all_vars = filter(lambda x: 'meta_optim' not in x.name, tf.trainable_variables()) for p in all_vars: print(p) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=config) # tf.global_variables() to save moving_mean and moving variance of batch norm # tf.trainable_variables() NOT include moving_mean and moving_variance. saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) # initialize, under interative session tf.global_variables_initializer().run() tf.train.start_queue_runners() if os.path.exists(os.path.join('ckpt', 'checkpoint')): # alway load ckpt both train and test. model_file = tf.train.latest_checkpoint('ckpt') print("Restoring model weights from ", model_file) saver.restore(sess, model_file) if training: train(model, saver, sess) else: test(model, sess)
def main(): temp = FLAGS.update_batch_size temp2 = FLAGS.meta_batch_size FLAGS.update_batch_size = 1 FLAGS.meta_batch_size = 1 data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) dim_output = data_generator.num_classes dim_input = data_generator.dim_input if FLAGS.train: # only construct training model if needed # image_tensor, label_tensor = data_generator.make_data_tensor() # inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) #(모든 task수, NK, 모든 dim) = (meta_batch_size, NK, 2000) # #여기서 NK는 N개씩 K번 쌓은것. N개씩 쌓을때 0~N-1의 라벨을 하나씩 담되 랜덤 순서로 담음. # inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) #(모든 task수, NK, 모든 dim) = (meta_batch_size, NK, 2000) # labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) #(모든 task수, NK, 모든 label) = (meta_batch_size, NK, N) # labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) #(모든 task수, NK, 모든 label) = (meta_batch_size, NK, N) inputa, inputb, labela, labelb = data_generator.make_data_tensor() metatrain_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } inputa, inputb, labela, labelb = data_generator.make_data_tensor( train=False) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } pred_weights = data_generator.pred_weights model = MAML(dim_input, dim_output) if FLAGS.train: model.construct_model(input_tensors=metatrain_input_tensors, prefix='metatrain_') else: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=20) sess = tf.InteractiveSession() FLAGS.update_batch_size = temp FLAGS.meta_batch_size = temp2 trained_model_dir = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.update_lr) + '.metalr' + str(FLAGS.meta_lr) + '.initweight' + str(FLAGS.init_weight) + \ '/sbjt14:13.ubs_' + str(FLAGS.update_batch_size) +'.numstep5.updatelr0.005.metalr0.005' # if FLAGS.stop_grad: # trained_model_dir += 'stopgrad' # if FLAGS.baseline: # trained_model_dir += FLAGS.baseline # else: # print('Norm setting not recognized.') resume_itr = 0 tf.global_variables_initializer().run() tf.train.start_queue_runners() model_file = None model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + trained_model_dir) w = None b = None print(">>> kshot: ", FLAGS.update_batch_size) print(">>>> train_test model dir: ", model_file) model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + trained_model_dir) saver.restore(sess, model_file) w = sess.run('model/w1:0') print("global abs of w: ", np.linalg.norm(w)) b = sess.run('model/b1:0') print("global abs of b: ", np.linalg.norm(b)) model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + trained_model_dir + '/local') for i in range(13): model_file = model_file[:model_file.index('subject' )] + 'subject' + str(i) print(">>>> model_file_local: ", model_file) saver.restore(sess, model_file) w = sess.run('model/w1:0') print("subject ", i, ", abs of w: ", np.linalg.norm(w)) b = sess.run('model/b1:0') print("subject ", i, ", abs of b: ", np.linalg.norm(b))
def construct_model(self): self.sess = tf.InteractiveSession() if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource in ['sinusoid', 'mixture']: data_generator = DataGenerator( FLAGS.update_batch_size + FLAGS.update_batch_size_eval, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource in [ 'miniimagenet', 'multidataset' ]: assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource in [ 'miniimagenet', 'multidataset' ]: # TODO - use 15 val examples for imagenet? if FLAGS.train: data_generator = DataGenerator( FLAGS.update_batch_size + 15, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output dim_input = data_generator.dim_input if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset']: tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) if FLAGS.datasource in ['miniimagenet', 'omniglot']: image_tensor, label_tensor = data_generator.make_data_tensor( ) elif FLAGS.datasource == 'multidataset': image_tensor, label_tensor = data_generator.make_data_tensor_multidataset( sel_num=self.clusters, train=True) inputa = tf.slice( image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice( image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice( label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice( label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) if FLAGS.datasource in ['miniimagenet', 'omniglot']: image_tensor, label_tensor = data_generator.make_data_tensor( train=False) elif FLAGS.datasource == 'multidataset': image_tensor, label_tensor = data_generator.make_data_tensor_multidataset( sel_num=self.clusters, train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML(self.sess, dim_input, dim_output, test_num_updates=self.test_num_updates) model.cluster_layer_0 = self.clusters if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr return model, saver, data_generator
def run(d): #restore old dynamics model train_now = True restore_previous = False # old_exp_name = 'MAML_roach/terrain_types_turf_model_on_turf' # old_model_num = 0 previous_dynamics_model = "/home/anagabandi/rllab-private/data/local/experiment/MAML_roach_copy/Tuesday_optimization/carpet_on_carpet/model_epoch10" num_steps_per_rollout= 140 desired_shape_for_rollout = "left" #straight, left, right, circle_left, zigzag, figure8 save_rollout_run_num = 0 rollout_save_filename= desired_shape_for_rollout + str(save_rollout_run_num) #settings cheaty_training = False use_one_hot = False #True use_camera = False #True playback_mode = False state_representation = "exclude_x_y" #["exclude_x_y", "all"] #don't change much default_addrs= [b'\x00\x01'] use_pid_mode = True slow_pid_mode = True visualize_rviz=True #turning this off can make things go faster visualize_True = True visualize_False = False noise_True = True noise_False = False make_aggregated_dataset_noisy = True make_training_dataset_noisy = True perform_forwardsim_for_vis= True print_minimal=False noiseToSignal = 0 if(make_training_dataset_noisy): noiseToSignal = 0.01 #datatypes tf_datatype= tf.float32 ############################# CHANGE BACK! np_datatype= np.float32 #motor limits left_min = 1200 right_min = 1200 left_max = 2000 right_max = 2000 if(use_pid_mode): if(slow_pid_mode): left_min = 2*math.pow(2,16)*0.001 right_min = 2*math.pow(2,16)*0.001 left_max = 9*math.pow(2,16)*0.001 right_max = 9*math.pow(2,16)*0.001 else: #this hasnt been tested yet left_min = 4*math.pow(2,16)*0.001 right_min = 4*math.pow(2,16)*0.001 left_max = 12*math.pow(2,16)*0.001 right_max = 12*math.pow(2,16)*0.001 #vars from config config = d['config'] curr_agg_iter = d['curr_agg_iter'] save_dir = '/media/anagabandi/f1e71f04-dc4b-4434-ae4c-fcb16447d5b3/' + d['exp_name'] ############################################################################################ CHANGE BACK! #save_dir = '/media/anagabandi/f1e71f04-dc4b-4434-ae4c-fcb16447d5b3/' + d['exp_name'] print("\n\nSAVING EVERYTHING TO: ", save_dir) #make directories if not os.path.exists(save_dir + '/saved_rollouts'): os.makedirs(save_dir + '/saved_rollouts') if not os.path.exists(save_dir + '/saved_rollouts/'+rollout_save_filename+ '_aggIter' +str(curr_agg_iter)): os.makedirs(save_dir + '/saved_rollouts/'+rollout_save_filename+ '_aggIter' +str(curr_agg_iter)) ###################################### ######## GET TRAINING DATA ########### ###################################### print("\n\nCURR AGGREGATION ITER: ", curr_agg_iter) # Training data dataX=[] dataX_full=[] #this is just for your personal use for forwardsim (for debugging) dataY=[] dataZ=[] # Validation data dataX_val = [] dataX_full_val=[] dataY_val=[] dataZ_val=[] agg_itr = 0 training_ratio = config['training']['training_ratio'] for agg_itr in range(curr_agg_iter+1): #getDataFromDisk should give (tasks, rollouts from that task, each rollout has its points) dataX_curr, dataY_curr, dataZ_curr, dataX_curr_full = getDataFromDisk(agg_itr, config['experiment_type'], use_one_hot, use_camera, cheaty_training, state_representation, config['training']) if(agg_itr==0): for i in range(len(dataX_curr)): taski_num_rollout = len(dataX_curr[i]) print("taski_num_rollout: ", taski_num_rollout) dataX.append(dataX_curr[0][:int(taski_num_rollout*training_ratio)]) dataX_full.append(dataX_curr_full[0][:int(taski_num_rollout*training_ratio)]) dataY.append(dataY_curr[0][:int(taski_num_rollout*training_ratio)]) dataZ.append(dataZ_curr[0][:int(taski_num_rollout*training_ratio)]) dataX_val.append(dataX_curr[0][int(taski_num_rollout*training_ratio):]) dataX_full_val.append(dataX_curr_full[0][int(taski_num_rollout*training_ratio):]) dataY_val.append(dataY_curr[0][int(taski_num_rollout*training_ratio):]) dataZ_val.append(dataZ_curr[0][int(taski_num_rollout*training_ratio):]) #IPython.embed() else: #combine these rollouts w previous rollouts, so everything is still organized by task for task_num in range(len(dataX)): for rollout_num in range(len(dataX_curr[task_num])): dataX[task_num].append(dataX_curr[task_num][rollout_num]) dataY[task_num].append(dataY_curr[task_num][rollout_num]) dataZ[task_num].append(dataZ_curr[task_num][rollout_num]) dataX_full[task_num].append(dataX_curr_full[task_num][rollout_num]) # Do validation for this too! total_num_data = len(dataX)*len(dataX[0])*len(dataX[0][0]) # numSteps = tasks * rollouts * steps print("\n\nTotal number of data points: ", total_num_data) #return ## concatenate state and action --> inputs outputs = copy.deepcopy(dataZ) inputs = copy.deepcopy(dataX) #IPython.embed() inputs_val = np.append(np.array(dataX_val), np.array(dataY_val), axis = 3) outputs_val = np.array(dataZ_val) #IPython.embed() # check shapes for task_num in range(len(dataX)): for rollout_num in range (len(dataX[task_num])): #dataX[task_num][rollout_num] (steps x s_dim) #dataY[task_num][rollout_num] (steps x a_dim) inputs[task_num][rollout_num] = np.concatenate([dataX[task_num][rollout_num], dataY[task_num][rollout_num]], axis=1) #inputs should now be (tasks, rollouts from that task, [s,a]) #outputs should now be (tasks, rollouts from that task, [ds]) inputSize = inputs[0][0].shape[1] outputSize = outputs[0][0].shape[1] print("\n\nDimensions:") print("states: ", dataX[0][0].shape[1]) print("actions: ", dataY[0][0].shape[1]) print("inputs to NN: ", inputSize) print("outputs of NN: ", outputSize) #calc mean/std on full dataset if config["model"]["nonlinearity"] == "tanh": # Do you scale inputs to [-1, 1] and then standardize outputs? #IPython.embed() inputs_array = np.array(inputs) mean_inp = (inputs_array.max() + inputs_array.min())/2.0 std_inp = inputs_array.max() - mean_inp mean_inp = mean_inp*np.ones((1, inputs_array.shape[3])) std_inp = std_inp*np.ones((1, inputs_array.shape[3])) #IPython.embed() mean_outp = np.expand_dims(np.mean(outputs,axis=(0,1,2)), axis=0) std_outp = np.expand_dims(np.std(outputs,axis=(0,1,2)), axis=0) #IPython.embed() # HOw should I expand_dims? # check that after the operation, all inputs do lie in this range elif config["model"]["nonlinearity"] == "sigmoid": # Do you scale inputs to [0, 1] and then standardize outputs? #IPython.embed() inputs_array = np.array(inputs) mean_inp = inputs_array.min() std_inp = inputs_array.max() - mean_inp mean_inp = mean_inp*np.ones((1, inputs_array.shape[3])) std_inp = std_inp*np.ones((1, inputs_array.shape[3])) #IPython.embed() mean_outp = np.expand_dims(np.mean(outputs,axis=(0,1,2)), axis=0) std_outp = np.expand_dims(np.std(outputs,axis=(0,1,2)), axis=0) #IPython.embed() # HOw should I expand_dims? # check that after the operation, all inputs do lie in this range else: # for all the relu variants mean_inp = np.expand_dims(np.mean(inputs,axis=(0,1,2)), axis=0) std_inp = np.expand_dims(np.std(inputs,axis=(0,1,2)), axis=0) mean_outp = np.expand_dims(np.mean(outputs,axis=(0,1,2)), axis=0) std_outp = np.expand_dims(np.std(outputs,axis=(0,1,2)), axis=0) print("\n\nCalulated means and stds... ", mean_inp.shape, std_inp.shape, mean_outp.shape, std_outp.shape, "\n\n") ########################################################### ## CREATE regressor, policy, data generator, maml model ########################################################### # create regressor (NN dynamics model) regressor = DeterministicMLPRegressor(inputSize, outputSize, dim_obs=outputSize, tf_datatype=tf_datatype, seed=config['seed'],weight_initializer=config['training']['weight_initializer'], **config['model']) # create policy (MPC controller) policy = Policy(regressor, inputSize, outputSize, left_min, right_min, left_max, right_max, state_representation=state_representation, visualize_rviz=config['roach']['visualize_rviz'], x_index=config['roach']['x_index'], y_index=config['roach']['y_index'], yaw_cos_index=config['roach']['yaw_cos_index'], yaw_sin_index=config['roach']['yaw_sin_index'], **config['policy']) # create MAML model # note: this also constructs the actual regressor network/weights model = MAML(regressor, inputSize, outputSize, config=config['training']) model.construct_model(input_tensors=None, prefix='metatrain_') model.summ_op = tf.summary.merge_all() # GPU config proto gpu_device = 0 gpu_frac = 0.3 #0.3 os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_frac) config_2 = tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True, inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) # saving saver = tf.train.Saver(max_to_keep=10) sess = tf.InteractiveSession(config=config_2) # initialize tensorflow vars tf.global_variables_initializer().run() tf.train.start_queue_runners() # set the mean/std of regressor according to mean/std of the data we have so far regressor.update_params_data_dist(mean_inp, std_inp, mean_outp, std_outp, len(inputs[0])*len(inputs[0][0])*4) ########################################################### ## TRAIN THE DYNAMICS MODEL ########################################################### #train on the given full dataset, for max_epochs if train_now: if(restore_previous): print("\n\nRESTORING PREVIOUS DYNAMICS MODEL FROM ", previous_dynamics_model, " AND CONTINUING TRAINING...\n\n") saver.restore(sess, previous_dynamics_model) """trainable_vars = tf.trainable_variables() weights = sess.run(trainable_vars) with open(osp.join(osp.dirname(previous_dynamics_model), "weights.pickle"), "wb") as output_file: pickle.dump(weights, output_file)""" #IPython.embed() # np.save(save_dir + "/inputs.npy", inputs) # np.save(save_dir + "/outputs.npy", outputs) # # mean_inp.shape, std_inp.shape, mean_outp.shape, std_outp.shape # np.save(save_dir + "/mean_inp.npy", mean_inp) # np.save(save_dir + "/std_inp.npy", std_inp) # np.save(save_dir + "/mean_outp.npy", mean_outp) # np.save(save_dir + "/std_outp.npy", std_outp) train(inputs, outputs, curr_agg_iter, model, saver, sess, config, inputs_val, outputs_val) else: print("\n\nRESTORING A DYNAMICS MODEL FROM ", previous_dynamics_model) saver.restore(sess, previous_dynamics_model) #IPython.embed() return #IPython.embed() predicted_traj = regressor.do_forward_sim(dataX_full[0][0][27:45], dataY[0][0][27:45], state_representation) #np.save(save_dir + '/forwardsim_true.npy', dataX_full[0][7][27:45]) #np.save(save_dir + '/forwardsim_pred.npy', predicted_traj) ########################################################### ## RUN THE MPC CONTROLLER ########################################################### #create controller node controller_node = GBAC_Controller(sess=sess, policy=policy, model=model, state_representation=state_representation, use_pid_mode=use_pid_mode, default_addrs=default_addrs, update_batch_size=config['training']['update_batch_size'], **config['roach']) #do 1 rollout print("\n\n\nPAUSING... right before a controller run... RESET THE ROBOT TO A GOOD LOCATION BEFORE CONTINUING...") #IPython.embed() resulting_x, selected_u, desired_seq, list_robot_info, list_mocap_info, old_saving_format_dict = controller_node.run(num_steps_per_rollout, desired_shape_for_rollout) #where to save this rollout pathStartName = save_dir + '/saved_rollouts/'+rollout_save_filename+ '_aggIter' +str(curr_agg_iter) print("\n\n************** TRYING TO SAVE EVERYTHING TO: ", pathStartName) #save the result of the run np.save(pathStartName + '/oldFormat_actions.npy', old_saving_format_dict['actions_taken']) np.save(pathStartName + '/oldFormat_desired.npy', old_saving_format_dict['desired_states']) np.save(pathStartName + '/oldFormat_executed.npy', old_saving_format_dict['traj_taken']) np.save(pathStartName + '/oldFormat_perp.npy', old_saving_format_dict['save_perp_dist']) np.save(pathStartName + '/oldFormat_forward.npy', old_saving_format_dict['save_forward_dist']) np.save(pathStartName + '/oldFormat_oldforward.npy', old_saving_format_dict['saved_old_forward_dist']) np.save(pathStartName + '/oldFormat_movedtonext.npy', old_saving_format_dict['save_moved_to_next']) np.save(pathStartName + '/oldFormat_desheading.npy', old_saving_format_dict['save_desired_heading']) np.save(pathStartName + '/oldFormat_currheading.npy', old_saving_format_dict['save_curr_heading']) yaml.dump(config, open(osp.join(pathStartName, 'saved_config.yaml'), 'w')) #save the result of the run np.save(pathStartName + '/actions.npy', selected_u) np.save(pathStartName + '/states.npy', resulting_x) np.save(pathStartName + '/desired.npy', desired_seq) pickle.dump(list_robot_info,open(pathStartName + '/robotInfo.obj','w')) pickle.dump(list_mocap_info,open(pathStartName + '/mocapInfo.obj','w')) #stop roach print("killing robot") controller_node.kill_robot() return
def main(): # test_num_updates ########################################################## if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 else: if FLAGS.datasource == 'miniimagenet': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 else: test_num_updates = 10 ########################################################## if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': "如果为正弦拟合任务," "则构造FLAGS.meta_batch_size个正弦函数,每个正弦函数采FLAGS.update_batch_size * 2个样本" "默认值:FLAGS.update_batch_size=5, FLAGS.meta_batch_size=25" "则默认设置的正弦任务数据生成器每次产生数据的尺寸为:[25, 10, 1]" data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) else: "如果不是正弦拟合任务," if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': "如果meta训练的迭代轮数为0,且为miniimagenet任务" "断言FLAGS.meta_batch_size=1, 即判断类别数是否为1" assert FLAGS.meta_batch_size == 1 "断言FLAGS.update_batch_size=1, 即判断类别下的采样数是否为1" assert FLAGS.update_batch_size == 1 "构造一个类别采样一个数据的数据生成器" data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? "如果任务是miniimagenet," if FLAGS.train: "如果是训练任务," "构造一个类别数为FLAGS.meta_batch_size, 每个类别下采样FLAGS.update_batch_size + 15个样本" "默认值:FLAGS.update_batch_size=5, FLAGS.meta_batch_size=25" "则默认设置的miniimagenet任务数据生成器每次产生数据的尺寸为: [25, 5+15, 84x84x3]" data_generator = DataGenerator( FLAGS.update_batch_size + 15, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: "如果不是训练任务," "则默认设置的miniimagenet任务数据生成器每次产生数据的尺寸为: [25, 5*2, 84x84x3]" data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: "如果任务是omniglot," "则默认设置的omniglot任务数据生成器每次产生数据的尺寸为: [25, 5*2, 28x28]" data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory "任务数据的输出维度, 对于:" "正弦拟合:dim_output=1" "omniglot: dim_output=data_generator.num_classes, data_generator.num_classes默认值为1" "miniimagenet: dim_output=data_generator.num_classes, data_generator.num_classes默认值为1" # dim_output ############################################## "数据的输出维度" dim_output = data_generator.dim_output # dim_input ############################################## if FLAGS.baseline == 'oracle': "如果FLAGS.baseline==oracle, " "则断言检查FLAGS.datasource是否为正弦拟合任务,否则报错" assert FLAGS.datasource == 'sinusoid' "将输入维度修改为3" dim_input = 3 "将meta训练的迭代轮并入到预训练迭代轮数" FLAGS.pretrain_iterations += FLAGS.metatrain_iterations "将meta训练的迭代轮数置为0" FLAGS.metatrain_iterations = 0 else: "正弦拟合:dim_input=1" "omniglot: dim_input=28x28" "miniimagenet: dim_input=84x84x3" dim_input = data_generator.dim_input if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': "如果任务为miniimagenet或者omniglot," "则需要构造tensorflow的数据记载相关操作对应的计算图" tf_data_load = True "类别数" num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed "如果是训练阶段," "初始化随机种子,保证实验可重复" random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): training = not args.test main_dir = './data/' dataset_name = 'flickr' kshot = 5 meta_batchsz = 4 k = 5 batch_num = 50000 if dataset_name == 'flickr': batch_num = 50000 elif dataset_name == 'wiki': batch_num = 10000 elif dataset_name == 'email': batch_num = 5000 else: batch_num = 10000 db = DataGenerator(main_dir, dataset_name, kshot, meta_batchsz, 50000) if training: node_tensor, label_tensor, data_tensor = db.make_data_tensor( training=True) support_n = tf.slice(node_tensor, [0, 0, 0], [-1, kshot, -1], name='support_n') query_n = tf.slice(node_tensor, [0, kshot, 0], [-1, -1, -1], name='query_n') support_x = tf.slice(data_tensor, [0, 0, 0], [-1, kshot, -1], name='support_x') query_x = tf.slice(data_tensor, [0, kshot, 0], [-1, -1, -1], name='query_x') support_y = tf.slice(label_tensor, [0, 0, 0], [-1, kshot, -1], name='support_y') query_y = tf.slice(label_tensor, [0, kshot, 0], [-1, -1, -1], name='query_y') node_tensor, label_tensor, data_tensor = db.make_data_tensor( training=False) support_n_test = tf.slice(node_tensor, [0, 0, 0], [-1, kshot, -1], name='support_n_test') query_n_test = tf.slice(node_tensor, [0, kshot, 0], [-1, -1, -1], name='query_n_test') support_x_test = tf.slice(data_tensor, [0, 0, 0], [-1, kshot, -1], name='support_x_test') query_x_test = tf.slice(data_tensor, [0, kshot, 0], [-1, -1, -1], name='query_x_test') support_y_test = tf.slice(label_tensor, [0, 0, 0], [-1, kshot, -1], name='support_y_test') query_y_test = tf.slice(label_tensor, [0, kshot, 0], [-1, -1, -1], name='query_y_test') model = MAML(128) model.build(support_n, support_x, support_y, query_n, query_x, query_y, k, meta_batchsz, mode='train') model.build(support_n_test, support_x_test, support_y_test, query_n_test, query_x_test, query_y_test, k, meta_batchsz, mode='test') model.summ_op = tf.summary.merge_all() all_vars = filter(lambda x: 'meta_optim' not in x.name, tf.trainable_variables()) for p in all_vars: print(p) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.InteractiveSession(config=config) saver = tf.train.Saver(tf.global_variables(), max_to_keep=5) tf.global_variables_initializer().run() tf.train.start_queue_runners() if os.path.exists(os.path.join('ckpt', 'checkpoint')): model_file = tf.train.latest_checkpoint('ckpt') print("Restoring model weights from ", model_file) saver.restore(sess, model_file) train(model, sess, batch_num) test(model, sess, dataset_name)
def main(): if FLAGS.datasource == 'sinusoid': #数据源为正弦波 if FLAGS.train: test_num_updates = 5 #训练期间至少更新5次 else: test_num_updates = 10 #测试期间至少更新10次 else: if FLAGS.datasource == 'miniimagenet': #数据源为'miniimagenet' if FLAGS.train == True: test_num_updates = 1 # 训练期间至少更新一次 else: test_num_updates = 10 else: test_num_updates = 10 if FLAGS.train == False: #测试时 orig_meta_batch_size = FLAGS.meta_batch_size # 测试时,始终使用元批量大小为1。 FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator(1, FLAGS.meta_batch_size) # 只使用一个数据点, else: if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? if FLAGS.train: data_generator = DataGenerator( FLAGS.update_batch_size + 15, FLAGS.meta_batch_size) # 仅使用一个数据点进行测试以保存内存 else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # 仅使用一个数据点进行测试以保存内存 else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # 仅使用一个数据点进行测试以保存内存 dim_output = data_generator.dim_output if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': #数据源为'miniimagenet'或'omniglot'时 tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # 只有在需要时才能建立训练模型 random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor( ) #读取数据 inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1 ]) #tf.slice(inputs, begin, size, name) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # 加载模型时更改为原始元批次大小 FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': # load data data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # k=update_batch_size*2=10 dim_output = data_generator.dim_output dim_input = data_generator.dim_input tf_data_load = False input_tensors = None # construct meta learning model model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) model.construct_model(input_tensors=input_tensors, prefix='metatrain_') model.summ_op = tf.summary.merge_all() saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() # writer = tf.summary.FileWriter("../../../logs", sess.graph) if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(1) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() print("exp_string is: ", exp_string) if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates) sin_test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size FLAGS.meta_batch_size = 1 if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.poison_lr' + str( FLAGS.poison_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') num_images_per_class = FLAGS.update_batch_size * 3 data_generator = DataGenerator( num_images_per_class, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output dim_input = data_generator.dim_input if FLAGS.mode == 'train_with_poison': print('Loading poison examples from %s' % FLAGS.poison_path) poison_example = np.load(FLAGS.poison_dir) # poison_example=np.load(FLAGS.logdir + '/' + exp_string+'/poisonx_%d.npy'%FLAGS.poison_itr) else: poison_example = None model = MAML(dim_input=dim_input, dim_output=dim_output, num_images_per_class=num_images_per_class, num_classes=FLAGS.num_classes, poison_example=poison_example) sess = tf.InteractiveSession() print('Session created') if FLAGS.datasource == 'omniglot': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor( train=True, poison=(model.poisonx, model.poisony), sess=sess) if FLAGS.reptile: inputa = image_tensor labela = label_tensor else: inputa = tf.slice( image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labela = tf.slice( label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) if FLAGS.mode == 'train_poison': inputa_test = tf.slice( image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb_test = tf.slice( image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela_test = tf.slice( label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb_test = tf.slice( label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb, 'inputa_test': inputa_test, 'inputb_test': inputb_test, 'labela_test': labela_test, 'labelb_test': labelb_test } else: input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix=FLAGS.mode) if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') print('Model built') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) resume_itr = 0 model_file = None tf.train.start_queue_runners() tf.global_variables_initializer().run() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) test_params = [ model, saver, sess, exp_string, data_generator, test_num_updates ] test(model, saver, sess, exp_string, data_generator, test_num_updates) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr, test_params=test_params) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.datasource == 'multidataset_leave_one_out': assert FLAGS.leave_one_out_id > -1 sess = tf.InteractiveSession() if FLAGS.datasource in ['sinusoid', 'mixture']: if FLAGS.train: test_num_updates = 1 else: test_num_updates = 10 else: if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource in ['sinusoid', 'mixture']: data_generator = DataGenerator(FLAGS.update_batch_size + FLAGS.update_batch_size_eval, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: if FLAGS.train: data_generator = DataGenerator(FLAGS.update_batch_size + 15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory else: data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory else: data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output dim_input = data_generator.dim_input if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset', 'multidataset_leave_one_out']: tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) if FLAGS.datasource in ['miniimagenet', 'omniglot']: image_tensor, label_tensor = data_generator.make_data_tensor() elif FLAGS.datasource == 'multidataset': image_tensor, label_tensor = data_generator.make_data_tensor_multidataset() elif FLAGS.datasource == 'multidataset_leave_one_out': image_tensor, label_tensor = data_generator.make_data_tensor_multidataset_leave_one_out() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} random.seed(6) if FLAGS.datasource in ['miniimagenet', 'omniglot']: image_tensor, label_tensor = data_generator.make_data_tensor(train=False) elif FLAGS.datasource == 'multidataset': image_tensor, label_tensor = data_generator.make_data_tensor_multidataset(train=False) elif FLAGS.datasource == 'multidataset_leave_one_out': image_tensor, label_tensor = data_generator.make_data_tensor_multidataset_leave_one_out(train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} else: tf_data_load = False input_tensors = None model = MAML(sess, dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) + '.emb_loss_weight' + str( FLAGS.emb_loss_weight) + '.num_groups' + str(FLAGS.num_groups) + '.emb_type' + str( FLAGS.emb_type) + '.hidden_dim' + str(FLAGS.hidden_dim) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() print(exp_string) if FLAGS.resume or not FLAGS.train: if FLAGS.train == True: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) else: print(FLAGS.test_epoch) model_file = '{0}/{2}/model{1}'.format(FLAGS.logdir, FLAGS.test_epoch, exp_string) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def run(d): #IPython.embed() config = d['config'] ################################################################ ###################### Load parameters ######################### ################################################################ previous_dynamics_model = config["previous_dynamics_model"] train_now = d['train_bool'] ################################################################ desired_shape_for_rollout = config["testing"]["desired_shape_for_rollout"] save_rollout_run_num = config["testing"]["save_rollout_run_num"] rollout_save_filename = desired_shape_for_rollout + str( save_rollout_run_num) num_steps_per_rollout = config["testing"]["num_steps_per_rollout"] if (desired_shape_for_rollout == "figure8"): num_steps_per_rollout = 400 elif (desired_shape_for_rollout == "zigzag"): num_steps_per_rollout = 150 ############################################################## #settings cheaty_training = False use_one_hot = False #True use_camera = False #True playback_mode = False state_representation = "exclude_x_y" #["exclude_x_y", "all"] # Settings (generally, keep these to default) default_addrs = [b'\x00\x01'] use_pid_mode = True slow_pid_mode = True visualize_rviz = True #turning this off can make things go faster visualize_True = True visualize_False = False noise_True = True noise_False = False make_aggregated_dataset_noisy = True make_training_dataset_noisy = True perform_forwardsim_for_vis = True print_minimal = False noiseToSignal = 0 if (make_training_dataset_noisy): noiseToSignal = 0.01 # Defining datatypes tf_datatype = tf.float32 np_datatype = np.float32 # Setting motor limits left_min = 1200 right_min = 1200 left_max = 2000 right_max = 2000 if (use_pid_mode): if (slow_pid_mode): left_min = 2 * math.pow(2, 16) * 0.001 right_min = 2 * math.pow(2, 16) * 0.001 left_max = 9 * math.pow(2, 16) * 0.001 right_max = 9 * math.pow(2, 16) * 0.001 else: #this hasnt been tested yet left_min = 4 * math.pow(2, 16) * 0.001 right_min = 4 * math.pow(2, 16) * 0.001 left_max = 12 * math.pow(2, 16) * 0.001 right_max = 12 * math.pow(2, 16) * 0.001 #vars from config curr_agg_iter = config['aggregation']['curr_agg_iter'] save_dir = d['exp_name'] print("\n\nSAVING EVERYTHING TO: ", save_dir) #make directories if not os.path.exists(save_dir + '/saved_rollouts'): os.makedirs(save_dir + '/saved_rollouts') if not os.path.exists(save_dir + '/saved_rollouts/' + rollout_save_filename + '_aggIter' + str(curr_agg_iter)): os.makedirs(save_dir + '/saved_rollouts/' + rollout_save_filename + '_aggIter' + str(curr_agg_iter)) ###################################### ######## GET TRAINING DATA ########### ###################################### print("\n\nCURR AGGREGATION ITER: ", curr_agg_iter) # Training data # Random dataX = [] dataX_full = [ ] #this is just for your personal use for forwardsim (for debugging) dataY = [] dataZ = [] # Training data # MPC dataX_onPol = [] dataX_full_onPol = [] dataY_onPol = [] dataZ_onPol = [] # Validation data # Random dataX_val = [] dataX_full_val = [] dataY_val = [] dataZ_val = [] # Validation data # MPC dataX_val_onPol = [] dataX_full_val_onPol = [] dataY_val_onPol = [] dataZ_val_onPol = [] training_ratio = config['training']['training_ratio'] for agg_itr_counter in range(curr_agg_iter + 1): #getDataFromDisk should give (tasks, rollouts from that task, each rollout has its points) dataX_curr, dataY_curr, dataZ_curr, dataX_curr_full = getDataFromDisk( config['experiment_type'], use_one_hot, use_camera, cheaty_training, state_representation, agg_itr_counter, config_training=config['training']) if (agg_itr_counter == 1): print("*********TRYING TO FIND THE WEIRD ROLLOUT...") for rollout in range(len(dataX_curr[2])): val = dataX_curr[2][rollout][:, 4] if (np.any(val < 0)): dataX_curr[2][rollout] = dataX_curr[2][rollout + 1] dataY_curr[2][rollout] = dataY_curr[2][rollout + 1] dataZ_curr[2][rollout] = dataZ_curr[2][rollout + 1] print("FOUND IT!!!!!!! rollout number ", rollout) #random data #go from dataX_curr (tasks, rollouts, steps) --> to dataX (tasks, some rollouts, steps) and dataX_val (tasks, some rollouts, steps) if (agg_itr_counter == 0): for task_num in range(len(dataX_curr)): taski_num_rollout = len(dataX_curr[task_num]) print("task" + str(task_num) + "_num_rollouts: ", taski_num_rollout) #for each task, append something like (356, 48, 22) (numrollouts per task, num steps in that rollout, dim) dataX.append(dataX_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_full.append( dataX_curr_full[task_num][:int(taski_num_rollout * training_ratio)]) dataY.append(dataY_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataZ.append(dataZ_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_val.append(dataX_curr[task_num][int(taski_num_rollout * training_ratio):]) dataX_full_val.append( dataX_curr_full[task_num][int(taski_num_rollout * training_ratio):]) dataY_val.append(dataY_curr[task_num][int(taski_num_rollout * training_ratio):]) dataZ_val.append(dataZ_curr[task_num][int(taski_num_rollout * training_ratio):]) #on-policy data #go from dataX_curr (tasks, rollouts, steps) --> to dataX_onPol (tasks, some rollouts, steps) and dataX_val_onPol (tasks, some rollouts, steps) elif (agg_itr_counter == 1): for task_num in range(len(dataX_curr)): taski_num_rollout = len(dataX_curr[task_num]) print("task" + str(task_num) + "_num_rollouts for onpolicy: ", taski_num_rollout) dataX_onPol.append( dataX_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_full_onPol.append( dataX_curr_full[task_num][:int(taski_num_rollout * training_ratio)]) dataY_onPol.append( dataY_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataZ_onPol.append( dataZ_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_val_onPol.append( dataX_curr[task_num][int(taski_num_rollout * training_ratio):]) dataX_full_val_onPol.append( dataX_curr_full[task_num][int(taski_num_rollout * training_ratio):]) dataY_val_onPol.append( dataY_curr[task_num][int(taski_num_rollout * training_ratio):]) dataZ_val_onPol.append( dataZ_curr[task_num][int(taski_num_rollout * training_ratio):]) #on-policy data #go from dataX_curr (tasks, rollouts, steps) --> to ADDING ONTO dataX_onPol (tasks, some more rollouts than before, steps) and dataX_val_onPol (tasks, some more rollouts than before, steps) else: for task_num in range(len(dataX_curr)): taski_num_rollout = len(dataX_curr[task_num]) print("task" + str(task_num) + "_num_rollouts for onpolicy: ", taski_num_rollout) dataX_onPol[task_num].extend( dataX_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_full_onPol[task_num].extend( dataX_curr_full[task_num][:int(taski_num_rollout * training_ratio)]) dataY_onPol[task_num].extend( dataY_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataZ_onPol[task_num].extend( dataZ_curr[task_num][:int(taski_num_rollout * training_ratio)]) dataX_val_onPol[task_num].extend( dataX_curr[task_num][int(taski_num_rollout * training_ratio):]) dataX_full_val_onPol[task_num].extend( dataX_curr_full[task_num][int(taski_num_rollout * training_ratio):]) dataY_val_onPol[task_num].extend( dataY_curr[task_num][int(taski_num_rollout * training_ratio):]) dataZ_val_onPol[task_num].extend( dataZ_curr[task_num][int(taski_num_rollout * training_ratio):]) ############################################################# #count number of random and onpol data points total_random_data = len(dataX) * len(dataX[1]) * len( dataX[1][0]) # numSteps = tasks * rollouts * steps if (len(dataX_onPol) == 0): total_onPol_data = 0 else: total_onPol_data = len(dataX_onPol) * len(dataX_onPol[0]) * len( dataX_onPol[0][0] ) #this is approximate because each task doesn't have the same num rollouts or the same num steps total_num_data = total_random_data + total_onPol_data print() print() print("Number of random data points: ", total_random_data) print("Number of on-policy data points: ", total_onPol_data) print("TOTAL number of data points: ", total_num_data) ############################################################# #combine random and onpol data into a single dataset for training ratio_new = config["aggregation"]["ratio_new"] num_new_pts = ratio_new * (total_random_data) / (1.0 - ratio_new) if (len(dataX_onPol) == 0): num_times_to_copy_onPol = 0 else: num_times_to_copy_onPol = int(num_new_pts / total_onPol_data) #copy all rollouts from each task of onpol data, and do this copying this many times for i in range(num_times_to_copy_onPol): for task_num in range(len(dataX_onPol)): for rollout_num in range(len(dataX_onPol[task_num])): dataX[task_num].append(dataX_onPol[task_num][rollout_num]) dataX_full[task_num].append( dataX_full_onPol[task_num][rollout_num]) dataY[task_num].append(dataY_onPol[task_num][rollout_num]) dataZ[task_num].append(dataZ_onPol[task_num][rollout_num]) #print("num_times_to_copy_onPol: ", num_times_to_copy_onPol) # make a list of all X,Y,Z so can take mean of them # concatenate state and action --> inputs (for training) all_points_inp = [] all_points_outp = [] outputs = copy.deepcopy(dataZ) inputs = copy.deepcopy(dataX) for task_num in range(len(dataX)): for rollout_num in range(len(dataX[task_num])): #this will just be a big list of everything, so can take the mean input_pts = np.concatenate( (dataX[task_num][rollout_num], dataY[task_num][rollout_num]), axis=1) output_pts = dataZ[task_num][rollout_num] #this will the concatenate thing for later inputs[task_num][rollout_num] = np.concatenate( [dataX[task_num][rollout_num], dataY[task_num][rollout_num]], axis=1) all_points_inp.append(input_pts) all_points_outp.append(output_pts) all_points_inp = np.concatenate(all_points_inp) all_points_outp = np.concatenate(all_points_outp) ## concatenate state and action --> inputs (for validation) outputs_val = copy.deepcopy(dataZ_val) inputs_val = copy.deepcopy(dataX_val) for task_num in range(len(dataX_val)): for rollout_num in range(len(dataX_val[task_num])): #dataX[task_num][rollout_num] (steps x s_dim) #dataY[task_num][rollout_num] (steps x a_dim) inputs_val[task_num][rollout_num] = np.concatenate([ dataX_val[task_num][rollout_num], dataY_val[task_num][rollout_num] ], axis=1) ## concatenate state and action --> inputs (for validation onpol) outputs_val_onPol = copy.deepcopy(dataZ_val_onPol) inputs_val_onPol = copy.deepcopy(dataX_val_onPol) for task_num in range(len(dataX_val_onPol)): for rollout_num in range(len(dataX_val_onPol[task_num])): #dataX[task_num][rollout_num] (steps x s_dim) #dataY[task_num][rollout_num] (steps x a_dim) inputs_val_onPol[task_num][rollout_num] = np.concatenate([ dataX_val_onPol[task_num][rollout_num], dataY_val_onPol[task_num][rollout_num] ], axis=1) ############################################################# #inputs should now be (tasks, rollouts from that task, [s,a]) #outputs should now be (tasks, rollouts from that task, [ds]) #IPython.embed() inputSize = inputs[0][0].shape[1] outputSize = outputs[1][0].shape[1] print("\n\nDimensions:") print("states: ", dataX[1][0].shape[1]) print("actions: ", dataY[1][0].shape[1]) print("inputs to NN: ", inputSize) print("outputs of NN: ", outputSize) mean_inp = np.expand_dims(np.mean(all_points_inp, axis=0), axis=0) std_inp = np.expand_dims(np.std(all_points_inp, axis=0), axis=0) mean_outp = np.expand_dims(np.mean(all_points_outp, axis=0), axis=0) std_outp = np.expand_dims(np.std(all_points_outp, axis=0), axis=0) print("\n\nCalulated means and stds... ", mean_inp.shape, std_inp.shape, mean_outp.shape, std_outp.shape, "\n\n") ########################################################### ## CREATE regressor, policy, data generator, maml model ########################################################### # create regressor (NN dynamics model) regressor = DeterministicMLPRegressor( inputSize, outputSize, outputSize, tf_datatype, config['seed'], config['training']['weight_initializer'], config['model']) # create policy (MPC controller) policy = Policy(regressor, inputSize, outputSize, left_min, right_min, left_max, right_max, state_representation=state_representation, visualize_rviz=visualize_rviz, x_index=config['roach']['x_index'], y_index=config['roach']['y_index'], yaw_cos_index=config['roach']['yaw_cos_index'], yaw_sin_index=config['roach']['yaw_sin_index'], **config['policy']) # create MAML model # note: this also constructs the actual regressor network/weights model = MAML(regressor, inputSize, outputSize, config) model.construct_model(input_tensors=None, prefix='metatrain_') model.summ_op = tf.summary.merge_all() # GPU config proto gpu_device = 0 gpu_frac = 0.4 #0.4 #0.8 #0.3 os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_device) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_frac) config_2 = tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False, allow_soft_placement=True, inter_op_parallelism_threads=1, intra_op_parallelism_threads=1) # saving saver = tf.train.Saver(max_to_keep=10) sess = tf.InteractiveSession(config=config_2) # initialize tensorflow vars tf.global_variables_initializer().run() tf.train.start_queue_runners() # set the mean/std of regressor according to mean/std of the data we have so far regressor.update_params_data_dist(mean_inp, std_inp, mean_outp, std_outp, total_num_data) ########################################################### ## TRAIN THE DYNAMICS MODEL ########################################################### #train on the given full dataset, for max_epochs if train_now: if config["training"]["restore_previous_dynamics_model"]: print("\n\nRESTORING PREVIOUS DYNAMICS MODEL FROM ", previous_dynamics_model, " AND CONTINUING TRAINING...\n\n") saver.restore(sess, previous_dynamics_model) np.save(save_dir + "/inputs.npy", inputs) np.save(save_dir + "/outputs.npy", outputs) np.save(save_dir + "/inputs_val.npy", inputs_val) np.save(save_dir + "/outputs_val.npy", outputs_val) train(inputs, outputs, curr_agg_iter, model, saver, sess, config, inputs_val, outputs_val, inputs_val_onPol, outputs_val_onPol) else: print("\n\nRESTORING A DYNAMICS MODEL FROM ", previous_dynamics_model) saver.restore(sess, previous_dynamics_model) ########################################################### ## RUN THE MPC CONTROLLER ########################################################### #create controller node controller_node = GBAC_Controller( sess, policy, model, use_pid_mode=use_pid_mode, state_representation=state_representation, default_addrs=default_addrs, update_batch_size=config['testing']['update_batch_size'], num_updates=config['testing']['num_updates'], de=config['testing']['dynamic_evaluation'], roach_config=config['roach']) #do 1 rollout print( "\n\n\nPAUSING... right before a controller run... RESET THE ROBOT TO A GOOD LOCATION BEFORE CONTINUING..." ) #IPython.embed() resulting_x, selected_u, desired_seq, list_robot_info, list_mocap_info, old_saving_format_dict, list_best_action_sequences = controller_node.run( num_steps_per_rollout, desired_shape_for_rollout) #where to save this rollout pathStartName = save_dir + '/saved_rollouts/' + rollout_save_filename + '_aggIter' + str( curr_agg_iter) print("\n\n************** TRYING TO SAVE EVERYTHING TO: ", pathStartName) #save the result of the run np.save(pathStartName + '/oldFormat_actions.npy', old_saving_format_dict['actions_taken']) np.save(pathStartName + '/oldFormat_desired.npy', old_saving_format_dict['desired_states']) np.save(pathStartName + '/oldFormat_executed.npy', old_saving_format_dict['traj_taken']) np.save(pathStartName + '/oldFormat_perp.npy', old_saving_format_dict['save_perp_dist']) np.save(pathStartName + '/oldFormat_forward.npy', old_saving_format_dict['save_forward_dist']) np.save(pathStartName + '/oldFormat_oldforward.npy', old_saving_format_dict['saved_old_forward_dist']) np.save(pathStartName + '/oldFormat_movedtonext.npy', old_saving_format_dict['save_moved_to_next']) np.save(pathStartName + '/oldFormat_desheading.npy', old_saving_format_dict['save_desired_heading']) np.save(pathStartName + '/oldFormat_currheading.npy', old_saving_format_dict['save_curr_heading']) np.save(pathStartName + '/list_best_action_sequences.npy', list_best_action_sequences) yaml.dump(config, open(osp.join(pathStartName, 'saved_config.yaml'), 'w')) #save the result of the run np.save(pathStartName + '/actions.npy', selected_u) np.save(pathStartName + '/states.npy', resulting_x) np.save(pathStartName + '/desired.npy', desired_seq) pickle.dump(list_robot_info, open(pathStartName + '/robotInfo.obj', 'w')) pickle.dump(list_mocap_info, open(pathStartName + '/mocapInfo.obj', 'w')) #stop roach print("killing robot") controller_node.kill_robot() return
def main(): os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 data_generator = DataGenerator() dim_output = data_generator.num_classes dim_input = data_generator.dim_input inputa, inputb, labela, labelb = data_generator.make_data_tensor() metatrain_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} # pred_weights = data_generator.pred_weights model = MAML(dim_input, dim_output) model.construct_model(input_tensors=metatrain_input_tensors, prefix='metatrain_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=20) sess = tf.InteractiveSession() if not FLAGS.train: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr trained_model_dir = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) print(">>>>> trained_model_dir: ", FLAGS.logdir + '/' + trained_model_dir) resume_itr = 0 tf.global_variables_initializer().run() tf.train.start_queue_runners() print("================================================================================") print('initial weights norm: ', np.linalg.norm(sess.run('model/w1:0'))) print('initial last weights: ', sess.run('model/w1:0')[-1]) print('initial bias: ', sess.run('model/b1:0')) print("================================================================================") ################## Train ################## if FLAGS.resume: model_file = None if FLAGS.model.startswith('m2'): trained_model_dir = 'sbjt' + str(FLAGS.sbjt_start_idx) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + trained_model_dir) print(">>>>> trained_model_dir: ", FLAGS.logdir + '/' + trained_model_dir) w = None b = None print(">>>> model_file1: ", model_file) if model_file: if FLAGS.test_iter > 0: files = os.listdir(model_file[:model_file.index('model')]) if 'model' + str(FLAGS.test_iter) + '.index' in files: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) print(">>>> model_file2: ", model_file) print("1. Restoring model weights from " + model_file) saver.restore(sess, model_file) b = sess.run('model/b1:0').tolist() print("updated weights from ckpt: ", np.array(b)) ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) elif FLAGS.keep_train_dir: # when the model needs to be initialized from another model. resume_itr = 0 print('resume_itr: ', resume_itr) model_file = tf.train.latest_checkpoint(FLAGS.keep_train_dir) print(">>>>> base_model_dir: ", FLAGS.keep_train_dir) if FLAGS.test_iter > 0: files = os.listdir(model_file[:model_file.index('model')]) if 'model' + str(FLAGS.test_iter) + '.index' in files: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) print(">>>> model_file2: ", model_file) print("2. Restoring model weights from " + model_file) saver.restore(sess, model_file) print("updated weights from ckpt: ", sess.run('model/b1:0')) elif FLAGS.model.startswith('s4'): from feature_layers import feature_layer three_layers = feature_layer(10, 1) print('FLAGS.base_vae_model: ', FLAGS.base_vae_model) three_layers.model_intensity.load_weights(FLAGS.base_vae_model + '.h5') w = three_layers.model_intensity.layers[-1].get_weights()[0] b = three_layers.model_intensity.layers[-1].get_weights()[1] print('s2 b: ', b) print('s2 w: ', w) print('-----------------------------------------------------------------') with tf.variable_scope("model", reuse=True) as scope: scope.reuse_variables() b1 = tf.get_variable("b1", [1, 2]).assign(np.array(b)) w1 = tf.get_variable("w1", [300, 1, 2]).assign(np.array(w)) sess.run(b1) sess.run(w1) print("after: ", sess.run('model/b1:0')) print("after: ", sess.run('model/w1:0')) if not FLAGS.all_sub_model: trained_model_dir = 'sbjt' + str(FLAGS.sbjt_start_idx) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) print("================================================================================") train(model, saver, sess, trained_model_dir, metatrain_input_tensors, resume_itr) end_time = datetime.now() elapse = end_time - start_time print("================================================================================") print(">>>>>> elapse time: " + str(elapse)) print("================================================================================")
def main(): data_generator = DataGenerator(FLAGS.update_batch_size, FLAGS.meta_batch_size, k_shot=FLAGS.k_shot) dim_output = data_generator.dim_output dim_input = data_generator.dim_input if FLAGS.datasource == 'ml': input_tensors = { 'inputa': tf.placeholder(tf.int32, shape=[None, None, 2]), 'inputb': tf.placeholder(tf.int32, shape=[None, None, 2]), 'labela': tf.placeholder(tf.float32, shape=[None, None, 1]), 'labelb': tf.placeholder(tf.float32, shape=[None, None, 1]) } elif FLAGS.datasource == 'bpr' or FLAGS.datasource == 'bpr_time': input_tensors = { 'inputa': tf.placeholder(tf.int32, shape=[None, None, 3]), 'inputb': tf.placeholder(tf.int32, shape=[None, None, 3]), } else: raise Exception('non-supported data source: {}'.format( FLAGS.datasource)) model = MAML(dim_input, dim_output) if FLAGS.train or FLAGS.test_existing_user: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') else: model.construct_model(input_tensors=input_tensors, prefix='META_TEST') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() exp_string = 'mtype_{}.mbs_{}.ubs_{}.meta_lr_{}.' \ 'update_step_{}.update_lr_{}.' \ 'lambda_lr_{}.avg_f_{}' \ '.time_{}'.format(FLAGS.datasource, FLAGS.meta_batch_size, FLAGS.update_batch_size, FLAGS.meta_lr, FLAGS.num_updates, FLAGS.update_lr, FLAGS.lambda_lr, FLAGS.use_avg_init, str(datetime.now())) resume_itr = 0 tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume: model_path = '{}/mlRRS/model/{}/model_{}'.format( FLAGS.logdir, FLAGS.load_dir, FLAGS.resume_iter) if os.path.exists(model_path + '.meta'): loader.restore(sess=sess, save_path=model_path) resume_itr = FLAGS.resume_iter else: raise Exception('No model saved at path {}'.format(model_path)) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) if FLAGS.test_existing_user: test_existing_user(model, saver, sess, exp_string, data_generator, resume_itr) if FLAGS.test: test(model, saver, sess, exp_string, data_generator, resume_itr)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 else: if FLAGS.datasource == 'miniimagenet': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 # if FLAGS.datasource == 'sinusoid': # data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # else: # if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': # assert FLAGS.meta_batch_size == 1 # assert FLAGS.update_batch_size == 1 # data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint, # else: # if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? # if FLAGS.train: # data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory # else: # data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory # else: # data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory dim_output = FLAGS.num_classes if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = 84 * 84 * 3 model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train: model.construct_model(input_tensors=None, prefix='metatrain_') else: model.construct_model(input_tensors=None, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=40) for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): print(var.name) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') if FLAGS.lr_mode > 0: exp_string += 'lrmode' + str(FLAGS.lr_mode) print(exp_string) resume_itr = 0 model_file = None tf.global_variables_initializer().run() #tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, resume_itr) else: test(model, saver, sess, exp_string, test_num_updates)
def main(): print('Train(0) or Test(1)?') train_ = input() train_count = 100 if train_ == '0': FLAGS.train = True print('训练模式下的训练次数') train_count = input() FLAGS.metatrain_iterations = int(train_count) else: FLAGS.train = False print('选择GPU:') gpu_index = input() os.environ['CUDA_VISIBLE_DEVICES'] = gpu_index config_gpu = tf.ConfigProto() config_gpu.gpu_options.allow_growth = True if FLAGS.train is True: test_num_updates = 1 else: test_num_updates = 10 # 源代码在测试时候是10次内部梯度下降 if FLAGS.train is False: # 测试 orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 print('main.py: 生成data_generator') if FLAGS.train: data_generator = DataGeneratorOneInstance(FLAGS.update_batch_size + 15, FLAGS.meta_batch_size) # data_generator = DataGenerator_embedding(FLAGS.update_batch_size + 15, FLAGS.meta_batch_size) else: data_generator = DataGeneratorOneInstance(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # data_generator = DataGenerator_embedding(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) # 输出维度 dim_output = data_generator.dim_output dim_input = data_generator.dim_input print('dim_input in main is {}'.format(dim_input)) tf_data_load = True num_classes = data_generator.num_classes sess = tf.InteractiveSession(config=config_gpu) # sess = tf.InteractiveSession() if FLAGS.train: # only construct training model if needed random.seed(5) ''' 关于image_tensor和label_tensor的说明 return all_image_batches, all_label_batches all_images_batches: [batch1:[pic1, pic2, ...], batch2:[]...],其中pic:[0.1,0.08,...共84*84*3长] all_label_batches: [batch1:[ [[0,1,0..], [1,0,0..], []..] ], batch2:[]...],其中[0,1,..]长为num_classes个 ''' # make_data_tensor print( 'main.py: train: data_generator.make_data_tensor(),得到inputa等并进行切分') image_tensor, label_tensor = data_generator.make_data_tensor( train=True) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } # 用于生成验证数据集实时打印准确率 random.seed(6) print('main.py: val: data_generator.make_data_tensor()') image_tensor, label_tensor = data_generator.make_data_tensor( train=False) # train=False仅影响文件夹以及batch_count inputa = tf.slice( image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) # 0到5*4为input_a inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } print('model = MAML()') # test_num_updates: train:1, test:5,内部梯度下降数 model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: # 初始化结束后必须调用 construct_model函数 print('model.construct_model(\'metatrain_\')') model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: print('model.construct_model(\'metaval_\')') model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection( tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) # 训练阶段 if FLAGS.train is False: # 测试阶段使用原始的batch_size FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 # 断点继续训练 model_file = None # 初始化变量 tf.global_variables_initializer().run() tf.local_variables_initializer().run() tf.train.start_queue_runners() # cls_5.mbs_4.ubs_5.numstep5.updatelr0.01hidden32maxpoolbatchnorm if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("读取已有训练数据Restoring model weights from " + model_file) saver.restore(sess, model_file) # if FLAGS.train: if FLAGS.train: print('main.py: 跳转到 train(model, saver, sess, exp_string...)...') # my(model, sess) train(model, saver, sess, exp_string, data_generator, resume_itr) else: print('main.py: 跳转到 _test(model, saver, sess, exp_string...)...') _test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 10 else: if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'cifar100': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0: #and FLAGS.datasource == 'miniimagenet': assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet': #or FLAGS.datasource == 'cifar100': # TODO - use 15 val examples for imagenet? if FLAGS.train: data_generator = DataGenerator(FLAGS.update_batch_size+15, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory else: data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory else: data_generator = DataGenerator(FLAGS.update_batch_size*2, FLAGS.meta_batch_size) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input random.seed(7) X = data_generator.make_autoencoder_data_tensor(train=True) Y = data_generator.make_autoencoder_data_tensor(train=False) autoencoder_input_tensors = {'X': X, 'Y': Y} dim_s = 32 autoencoder = Autoencoder(dim_input, dim_s) if FLAGS.train: autoencoder.construct_autoencoder(input_tensors=autoencoder_input_tensors, prefix='autoencoder_train') else: autoencoder.construct_autoencoder(input_tensors=autoencoder_input_tensors, prefix='autoencoder_test') if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot' or FLAGS.datasource == 'cifar100': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) # s_tensor = tf.map_fn(lambda x: autoencoder.encode(x), inputa[0]) # s_tensor = tf.reshape(s_tensor, [s_tensor.get_shape()[0], -1]) # s_tensor = tf.reduce_sum(s_tensor, 0) s_tensor = tf.map_fn(lambda x: make_s(x, autoencoder), inputa) input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb, 's_tensor':s_tensor} random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor(train=False) inputa = tf.slice(image_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) labela = tf.slice(label_tensor, [0,0,0], [-1,num_classes*FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0,num_classes*FLAGS.update_batch_size, 0], [-1,-1,-1]) s_tensor = tf.map_fn(lambda x: make_s(x, autoencoder), inputa) metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb, 's_tensor':s_tensor} else: tf_data_load = False input_tensors = None #autoencoder_for_maml = autoencoder.encode(input_tensors = input_tensors['inputa']) model = MAML(autoencoder.train_phase, dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') # else: if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() autoencoder.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_'+str(FLAGS.num_classes)+'.mbs_'+str(FLAGS.meta_batch_size) + '.ubs_' + str(FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') # exp_string_model = exp_string # exp_string_autoencoder = exp_string resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) autoencoder_file = tf.train.latest_checkpoint(FLAGS.logdir_autoencoder + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model')] + 'model' + str(FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1+5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) # ind1 = autoencoder_file.index('autoencoder') # resume_itr = int() print("Restoring autoencoder weights from " + autoencoder_file) w1 = sess.run(autoencoder.weights) saver.restore(sess, autoencoder_file) w2 = sess.run(autoencoder.weights) if FLAGS.train: print('training now') train(model, autoencoder, saver, sess, exp_string, data_generator, resume_itr) else: test(model, autoencoder, saver, sess, exp_string, data_generator, test_num_updates)
def main(): if FLAGS.datasource == 'sinusoid': if FLAGS.train: test_num_updates = 5 else: test_num_updates = 2 else: if FLAGS.datasource == 'miniimagenet': if FLAGS.train == True: test_num_updates = 1 # eval on at least one update during training else: test_num_updates = 10 else: test_num_updates = 10 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 1 when testing. FLAGS.meta_batch_size = 1 if FLAGS.datasource == 'sinusoid': data_generator = DataGenerator(FLAGS.update_batch_size * 2, FLAGS.meta_batch_size) else: if FLAGS.metatrain_iterations == 0 and FLAGS.datasource == 'miniimagenet': assert FLAGS.meta_batch_size == 1 assert FLAGS.update_batch_size == 1 data_generator = DataGenerator( 1, FLAGS.meta_batch_size) # only use one datapoint, else: if FLAGS.datasource == 'miniimagenet': # TODO - use 15 val examples for imagenet? if FLAGS.train: data_generator = DataGenerator( FLAGS.update_batch_size + 15, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory else: data_generator = DataGenerator( FLAGS.update_batch_size * 2, FLAGS.meta_batch_size ) # only use one datapoint for testing to save memory dim_output = data_generator.dim_output if FLAGS.baseline == 'oracle': assert FLAGS.datasource == 'sinusoid' dim_input = 3 FLAGS.pretrain_iterations += FLAGS.metatrain_iterations FLAGS.metatrain_iterations = 0 else: dim_input = data_generator.dim_input if FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'omniglot': tf_data_load = True num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor( train=False) inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) metaval_input_tensors = { 'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb } else: tf_data_load = False input_tensors = None model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.train or not tf_data_load: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') if tf_data_load: model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(max_to_keep=10) #saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size if FLAGS.train_update_batch_size == -1: FLAGS.train_update_batch_size = FLAGS.update_batch_size if FLAGS.train_update_lr == -1: FLAGS.train_update_lr = FLAGS.update_lr exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str( FLAGS.meta_batch_size) + '.ubs_' + str( FLAGS.train_update_batch_size) + '.numstep' + str( FLAGS.num_updates) + '.updatelr' + str(FLAGS.train_update_lr) if FLAGS.num_filters != 64: exp_string += 'hidden' + str(FLAGS.num_filters) if FLAGS.max_pool: exp_string += 'maxpool' if FLAGS.stop_grad: exp_string += 'stopgrad' if FLAGS.baseline: exp_string += FLAGS.baseline if FLAGS.norm == 'batch_norm': exp_string += 'batchnorm' elif FLAGS.norm == 'layer_norm': exp_string += 'layernorm' elif FLAGS.norm == 'None': exp_string += 'nonorm' else: print('Norm setting not recognized.') resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() if FLAGS.resume or not FLAGS.train: print("Seeing if resume....") print("File string: ", FLAGS.logdir + '/' + exp_string) model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) print("model file name: ", model_file) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) saver.restore(sess, model_file) if FLAGS.train: train(model, saver, sess, exp_string, data_generator, resume_itr) else: test(model, saver, sess, exp_string, data_generator, test_num_updates)
def main(): test_num_updates = 1 if FLAGS.train == False: orig_meta_batch_size = FLAGS.meta_batch_size # always use meta batch size of 100 when testing. FLAGS.meta_batch_size = 100 data_generator = DataGenerator(batch_size=FLAGS.meta_batch_size) dim_output = data_generator.dim_output dim_input = data_generator.dim_input num_classes = data_generator.num_classes if FLAGS.train: # only construct training model if needed random.seed(5) image_tensor, label_tensor = data_generator.make_data_tensor() input_tensors = {'input': image_tensor, 'label': label_tensor} random.seed(6) image_tensor, label_tensor = data_generator.make_data_tensor(train=False) metaval_input_tensors = {'input': image_tensor, 'label': label_tensor} model = MAML(dim_input, dim_output, test_num_updates=test_num_updates) if FLAGS.vanilla: if FLAGS.train: model.construct_vanilla_model(input_tensors=input_tensors, prefix='metatrain_') model.construct_vanilla_model(input_tensors=metaval_input_tensors, prefix='metaval_') else: if FLAGS.train: model.construct_model(input_tensors=input_tensors, prefix='metatrain_') model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') model.summ_op = tf.summary.merge_all() saver = loader = tf.train.Saver(tf.global_variables(), max_to_keep=10) sess = tf.InteractiveSession() if FLAGS.train == False: # change to original meta batch size when loading model. FLAGS.meta_batch_size = orig_meta_batch_size exp_string = FLAGS.dataset \ + '_backbone_' + FLAGS.backbone \ + '_scalar_lr_' + str(FLAGS.scalar_lr) \ + '_mbs_'+str(FLAGS.meta_batch_size) \ + '.dict_' + str(FLAGS.dict_size) + '.numstep' \ + str(FLAGS.num_updates) + '.updatelr' + str(FLAGS.update_lr) \ + '.vanilla_' + str(FLAGS.vanilla) \ + '.fix_v_' + str(FLAGS.fix_v) \ + '.alpha_' + str(FLAGS.alpha) \ if FLAGS.dropout_ratio != 0.5: exp_string += '_dropout_' + str(FLAGS.dropout_ratio) if FLAGS.vanilla and FLAGS.optimizer != 'sgd': exp_string += FLAGS.optimizer exp_string += '_weight_decay_' + str(FLAGS.weight_decay) if FLAGS.dot: exp_string += '_dot' if FLAGS.modulate in ['all', 'last', 'before_fc']: exp_string += '_modulate_' + FLAGS.modulate + '_size_' + str( FLAGS.film_dict_size) print(exp_string) resume_itr = 0 model_file = None tf.global_variables_initializer().run() tf.train.start_queue_runners() prev_best_accu = 0 if FLAGS.resume or not FLAGS.train: model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) if FLAGS.test_iter > 0: model_file = model_file[:model_file.index('model' )] + 'model' + str( FLAGS.test_iter) if model_file: ind1 = model_file.index('model') resume_itr = int(model_file[ind1 + 5:]) print("Restoring model weights from " + model_file) loader.restore(sess, model_file) orig_train = FLAGS.train FLAGS.train = False if FLAGS.vanilla: prev_best_accu = test_vanilla(model, saver, sess, exp_string, data_generator) else: prev_best_accu = test(model, saver, sess, exp_string, data_generator) FLAGS.train = orig_train if FLAGS.vanilla: if FLAGS.train: train_vanilla(model, saver, sess, exp_string, data_generator, prev_best_accu, resume_itr) else: test_vanilla(model, saver, sess, exp_string, data_generator) else: if FLAGS.train: train(model, saver, sess, exp_string, data_generator, prev_best_accu, resume_itr) else: test(model, saver, sess, exp_string, data_generator)