def load_episode(self, index, data_type='train'): """The function to load the episodes. Args: index: the index for the episodes. data_type: the phase for meta-learning. """ if data_type == 'train': data_list = self.train_data epite_sample_num = FLAGS.metatrain_epite_sample_num elif data_type == 'test': data_list = self.test_data if FLAGS.metatest_epite_sample_num == 0: epite_sample_num = FLAGS.shot_num else: epite_sample_num = FLAGS.metatest_episode_test_sample elif data_type == 'val': data_list = self.val_data if FLAGS.metatest_epite_sample_num == 0: epite_sample_num = FLAGS.shot_num else: epite_sample_num = FLAGS.metatest_episode_test_sample else: raise Exception('Please check data list type') dim_input = FLAGS.img_size * FLAGS.img_size * 3 epitr_sample_num = FLAGS.shot_num this_episode = data_list[index] this_task_tr_filenames = this_episode['filenamea'] this_task_te_filenames = this_episode['filenameb'] this_task_tr_labels = this_episode['labela'] this_task_te_labels = this_episode['labelb'] if FLAGS.base_augmentation and FLAGS.metatrain == False: this_inputa, this_labela = process_batch_augmentation(this_task_tr_filenames, \ this_task_tr_labels, dim_input, epitr_sample_num) this_inputb, this_labelb = process_batch(this_task_te_filenames, \ this_task_te_labels, dim_input, epite_sample_num) else: this_inputa, this_labela = process_batch(this_task_tr_filenames, \ this_task_tr_labels, dim_input, epitr_sample_num) this_inputb, this_labelb = process_batch(this_task_te_filenames, \ this_task_te_labels, dim_input, epite_sample_num) return this_inputa, this_labela, this_inputb, this_labelb
def load_episode(self, index, data_type='train'): if data_type == 'train': data_list = self.train_data epite_sample_num = FLAGS.metatrain_epite_sample_num elif data_type == 'test': data_list = self.test_data if FLAGS.metatest_epite_sample_num == 0: epite_sample_num = FLAGS.shot_num else: epite_sample_num = FLAGS.metatest_episode_test_sample elif data_type == 'val': data_list = self.val_data if FLAGS.metatest_epite_sample_num == 0: epite_sample_num = FLAGS.shot_num else: epite_sample_num = FLAGS.metatest_episode_test_sample else: print('[Error] Please check data list type') dim_input = FLAGS.img_size * FLAGS.img_size * 3 epitr_sample_num = FLAGS.shot_num this_episode = data_list[index] this_task_tr_filenames = this_episode['filenamea'] this_task_te_filenames = this_episode['filenameb'] this_task_tr_labels = this_episode['labela'] this_task_te_labels = this_episode['labelb'] if FLAGS.base_augmentation and FLAGS.metatrain == False: this_inputa, this_labela = process_batch_augmentation(this_task_tr_filenames, \ this_task_tr_labels, dim_input, epitr_sample_num) this_inputb, this_labelb = process_batch(this_task_te_filenames, \ this_task_te_labels, dim_input, epite_sample_num) else: this_inputa, this_labela = process_batch(this_task_tr_filenames, \ this_task_tr_labels, dim_input, epitr_sample_num) this_inputb, this_labelb = process_batch(this_task_te_filenames, \ this_task_te_labels, dim_input, epite_sample_num) return this_inputa, this_labela, this_inputb, this_labelb
def test(self): NUM_TEST_POINTS = 600 exp_string = FLAGS.exp_string np.random.seed(1) metaval_accuracies = [] num_samples_per_class = FLAGS.shot_num*2 task_num = FLAGS.way_num * num_samples_per_class half_num_samples = FLAGS.shot_num dim_input = FLAGS.img_size * FLAGS.img_size * 3 filename_dir = FLAGS.logdir_base + 'filenames_and_labels/' this_setting_filename_dir = filename_dir + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.way_num) + 'way/' all_filenames = np.load(this_setting_filename_dir + 'test_filenames.npy').tolist() labels = np.load(this_setting_filename_dir + 'test_labels.npy').tolist() for test_idx in trange(NUM_TEST_POINTS): this_task_filenames = all_filenames[test_idx*task_num:(test_idx+1)*task_num] this_task_tr_filenames = [] this_task_tr_labels = [] this_task_te_filenames = [] this_task_te_labels = [] for class_k in range(FLAGS.way_num): this_class_filenames = this_task_filenames[class_k*num_samples_per_class:(class_k+1)*num_samples_per_class] this_class_label = labels[class_k*num_samples_per_class:(class_k+1)*num_samples_per_class] this_task_tr_filenames += this_class_filenames[0:half_num_samples] this_task_tr_labels += this_class_label[0:half_num_samples] this_task_te_filenames += this_class_filenames[half_num_samples:] this_task_te_labels += this_class_label[half_num_samples:] if FLAGS.base_augmentation or FLAGS.test_base_augmentation: inputa, labela = process_batch_augmentation(this_task_tr_filenames, this_task_tr_labels, dim_input, half_num_samples) inputb, labelb = process_batch_augmentation(this_task_te_filenames, this_task_te_labels, dim_input, half_num_samples) else: inputa, labela = process_batch(this_task_tr_filenames, this_task_tr_labels, dim_input, half_num_samples) inputb, labelb = process_batch(this_task_te_filenames, this_task_te_labels, dim_input, half_num_samples) feed_dict = {self.model.inputa: inputa, self.model.inputb: inputb, self.model.labela: labela, self.model.labelb: labelb, self.model.meta_lr: 0.0} result = self.sess.run(self.model.metaval_total_accuracies, feed_dict) metaval_accuracies.append(result) metaval_accuracies = np.array(metaval_accuracies) means = np.mean(metaval_accuracies, 0) max_idx = np.argmax(means) max_acc = np.max(means) stds = np.std(metaval_accuracies, 0) ci95 = 1.96*stds/np.sqrt(NUM_TEST_POINTS) max_ci95 = ci95[max_idx] print('Mean validation accuracy and confidence intervals') print((means, ci95)) print('***** Best Acc: '+ str(max_acc) + ' CI95: ' + str(max_ci95)) if FLAGS.base_augmentation or FLAGS.test_base_augmentation: out_filename = FLAGS.logdir +'/'+ exp_string + '/' + 'result_aug_' + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.test_iter) + '.csv' out_pkl = FLAGS.logdir +'/'+ exp_string + '/' + 'result_aug_' + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.test_iter) + '.pkl' else: out_filename = FLAGS.logdir +'/'+ exp_string + '/' + 'result_noaug_' + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.test_iter) + '.csv' out_pkl = FLAGS.logdir +'/'+ exp_string + '/' + 'result_noaug_' + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.test_iter) + '.pkl' with open(out_pkl, 'wb') as f: pickle.dump({'mses': metaval_accuracies}, f) with open(out_filename, 'w') as f: writer = csv.writer(f, delimiter=',') writer.writerow(['update'+str(i) for i in range(len(means))]) writer.writerow(means) writer.writerow(stds) writer.writerow(ci95)
def train(self): exp_string = FLAGS.exp_string train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, self.sess.graph) print('Done initializing, starting training') loss_list, acc_list = [], [] train_lr = FLAGS.meta_lr num_samples_per_class = FLAGS.shot_num + 15 task_num = FLAGS.way_num * num_samples_per_class num_samples_per_class_test = FLAGS.shot_num * 2 test_task_num = FLAGS.way_num * num_samples_per_class_test epitr_sample_num = FLAGS.shot_num epite_sample_num = 15 test_task_sample_num = FLAGS.shot_num dim_input = FLAGS.img_size * FLAGS.img_size * 3 filename_dir = FLAGS.logdir_base + 'filenames_and_labels/' this_setting_filename_dir = filename_dir + str(FLAGS.shot_num) + 'shot_' + str(FLAGS.way_num) + 'way/' all_filenames = np.load(this_setting_filename_dir + 'train_filenames.npy').tolist() labels = np.load(this_setting_filename_dir + 'train_labels.npy').tolist() all_test_filenames = np.load(this_setting_filename_dir + 'val_filenames.npy').tolist() test_labels = np.load(this_setting_filename_dir + 'val_labels.npy').tolist() test_idx = 0 for train_idx in trange(FLAGS.metatrain_iterations): inputa = [] labela = [] inputb = [] labelb = [] for meta_batch_idx in range(FLAGS.meta_batch_size): this_task_filenames = all_filenames[(train_idx*FLAGS.meta_batch_size+meta_batch_idx)*task_num:(train_idx*FLAGS.meta_batch_size+meta_batch_idx+1)*task_num] this_task_tr_filenames = [] this_task_tr_labels = [] this_task_te_filenames = [] this_task_te_labels = [] for class_k in range(FLAGS.way_num): this_class_filenames = this_task_filenames[class_k*num_samples_per_class:(class_k+1)*num_samples_per_class] this_class_label = labels[class_k*num_samples_per_class:(class_k+1)*num_samples_per_class] this_task_tr_filenames += this_class_filenames[0:epitr_sample_num] this_task_tr_labels += this_class_label[0:epitr_sample_num] this_task_te_filenames += this_class_filenames[epitr_sample_num:] this_task_te_labels += this_class_label[epitr_sample_num:] if FLAGS.base_augmentation: this_inputa, this_labela = process_batch_augmentation(this_task_tr_filenames, this_task_tr_labels, dim_input, epitr_sample_num, reshape_with_one=False) this_inputb, this_labelb = process_batch_augmentation(this_task_te_filenames, this_task_te_labels, dim_input, epite_sample_num, reshape_with_one=False) else: this_inputa, this_labela = process_batch(this_task_tr_filenames, this_task_tr_labels, dim_input, epitr_sample_num, reshape_with_one=False) this_inputb, this_labelb = process_batch(this_task_te_filenames, this_task_te_labels, dim_input, epite_sample_num, reshape_with_one=False) inputa.append(this_inputa) labela.append(this_labela) inputb.append(this_inputb) labelb.append(this_labelb) inputa = np.array(inputa) labela = np.array(labela) inputb = np.array(inputb) labelb = np.array(labelb) feed_dict = {self.model.inputa: inputa, self.model.inputb: inputb, self.model.labela: labela, self.model.labelb: labelb, self.model.meta_lr: train_lr} input_tensors = [self.model.metatrain_op] input_tensors.extend([self.model.total_loss]) input_tensors.extend([self.model.total_accuracy]) if (train_idx % FLAGS.meta_sum_step == 0 or train_idx % FLAGS.meta_print_step == 0): input_tensors.extend([self.model.summ_op, self.model.total_loss]) result = self.sess.run(input_tensors, feed_dict) loss_list.append(result[1]) acc_list.append(result[2]) if train_idx % FLAGS.meta_sum_step == 0: train_writer.add_summary(result[3], train_idx) if (train_idx!=0) and train_idx % FLAGS.meta_print_step == 0: print_str = 'Iteration:' + str(train_idx) print_str += ' Loss:' + str(np.mean(loss_list)) + ' Acc:' + str(np.mean(acc_list)) print(print_str) loss_list, acc_list = [], [] if (train_idx!=0) and train_idx % FLAGS.meta_save_step == 0: weights = self.sess.run(self.model.weights) ss_weights = self.sess.run(self.model.ss_weights) fc_weights = self.sess.run(self.model.fc_weights) np.save(FLAGS.logdir + '/' + exp_string + '/weights_' + str(train_idx) + '.npy', weights) np.save(FLAGS.logdir + '/' + exp_string + '/ss_weights_' + str(train_idx) + '.npy', ss_weights) np.save(FLAGS.logdir + '/' + exp_string + '/fc_weights_' + str(train_idx) + '.npy', fc_weights) if (train_idx!=0) and train_idx % FLAGS.meta_val_print_step == 0: test_loss = [] test_accs = [] for test_itr in range(10): this_test_task_filenames = all_test_filenames[test_idx*test_task_num:(test_idx+1)*test_task_num] this_test_task_tr_filenames = [] this_test_task_tr_labels = [] this_test_task_te_filenames = [] this_test_task_te_labels = [] for class_k in range(FLAGS.way_num): this_test_class_filenames = this_test_task_filenames[class_k*num_samples_per_class_test:(class_k+1)*num_samples_per_class_test] this_test_class_label = test_labels[class_k*num_samples_per_class_test:(class_k+1)*num_samples_per_class_test] this_test_task_tr_filenames += this_test_class_filenames[0:test_task_sample_num] this_test_task_tr_labels += this_test_class_label[0:test_task_sample_num] this_test_task_te_filenames += this_test_class_filenames[test_task_sample_num:] this_test_task_te_labels += this_test_class_label[test_task_sample_num:] test_inputa, test_labela = process_batch(this_test_task_tr_filenames, this_test_task_tr_labels, dim_input, test_task_sample_num) test_inputb, test_labelb = process_batch(this_test_task_te_filenames, this_test_task_te_labels, dim_input, test_task_sample_num) test_feed_dict = {self.model.inputa: test_inputa, self.model.inputb: test_inputb, self.model.labela: test_labela, self.model.labelb: test_labelb, self.model.meta_lr: 0.0} test_input_tensors = [self.model.total_loss, self.model.total_accuracy] test_result = self.sess.run(test_input_tensors, test_feed_dict) test_loss.append(test_result[0]) test_accs.append(test_result[1]) test_idx += 1 print_str = '[***] Val Loss:' + str(np.mean(test_loss)*FLAGS.meta_batch_size) + ' Val Acc:' + str(np.mean(test_accs)*FLAGS.meta_batch_size) print(print_str) if (train_idx!=0) and train_idx % FLAGS.lr_drop_step == 0: train_lr = train_lr * 0.5 if train_lr < FLAGS.min_meta_lr: train_lr = FLAGS.min_meta_lr print('Train LR: {}'.format(train_lr)) weights = self.sess.run(self.model.weights) ss_weights = self.sess.run(self.model.ss_weights) fc_weights = self.sess.run(self.model.fc_weights) np.save(FLAGS.logdir + '/' + exp_string + '/weights_' + str(train_idx+1) + '.npy', weights) np.save(FLAGS.logdir + '/' + exp_string + '/ss_weights_' + str(train_idx+1) + '.npy', ss_weights) np.save(FLAGS.logdir + '/' + exp_string + '/fc_weights_' + str(train_idx+1) + '.npy', fc_weights)