def get_data_loader(task, batch_size=1, split='train'): # NOTE: batch size here is # instances PER CLASS if task.dataset == 'mnist': normalize = transforms.Normalize(mean=[0.13066, 0.13066, 0.13066], std=[0.30131, 0.30131, 0.30131]) dset = MNIST(task, transform=transforms.Compose( [transforms.ToTensor(), normalize]), split=split) else: normalize = transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) dset = Omniglot(task, transform=transforms.Compose( [transforms.ToTensor(), normalize]), split=split) sampler = ClassBalancedSampler( task.num_cl, task.num_inst, batch_cutoff=(None if split != 'train' else batch_size)) loader = DataLoader(dset, batch_size=batch_size * task.num_cl, sampler=sampler, num_workers=0, pin_memory=True) return loader
def get_data_loader(task, split='train'): dset = Omniglot(task, transform=transforms.ToTensor(), split=split) print('img ids', dset.img_ids) print('labels', dset.labels) loader = DataLoader(dset, batch_size=inner_batch_size, shuffle=True, num_workers=1, pin_memory=True) return loader
def main( config, RANDOM_SEED, LOG_DIR, TASK_NUM, N_WAY, K_SHOTS, TRAIN_NUM, ALPHA, TRAIN_NUM_SGD, #Inner sgd steps. VALID_NUM_SGD, LEARNING_RATE, #BETA DECAY_VAL, DECAY_STEPS, DECAY_STAIRCASE, SAVE_PERIOD, SUMMARY_PERIOD): np.random.seed(RANDOM_SEED) tf.set_random_seed(RANDOM_SEED) # >>>>>>> DATASET omni = Omniglot(seed=RANDOM_SEED) _, x, y, x_prime, y_prime = omni.build_queue(TASK_NUM, N_WAY, K_SHOTS) _, x_val, y_val, x_prime_val, y_prime_val = omni.build_queue(TASK_NUM, N_WAY, K_SHOTS, train=False) # <<<<<<< # >>>>>>> MODEL with tf.variable_scope('train'): # Optimizing global_step = tf.Variable(0, trainable=False) learning_rate = tf.train.exponential_decay(LEARNING_RATE, global_step, DECAY_STEPS, DECAY_VAL, staircase=DECAY_STAIRCASE) tf.summary.scalar('lr', learning_rate) with tf.variable_scope('params') as params: pass net = Maml(ALPHA, TRAIN_NUM_SGD, learning_rate, global_step, x, y, x_prime, y_prime, partial(_omniglot_arch, num_classes=N_WAY), partial(_xent_loss, num_classes=N_WAY), params, is_training=True) with tf.variable_scope('valid'): params.reuse_variables() valid_net = Maml(ALPHA, VALID_NUM_SGD, 0.0, tf.Variable(0, trainable=False), x_val, y_val, x_prime_val, y_prime_val, partial(_omniglot_arch, num_classes=N_WAY), partial(_xent_loss, num_classes=N_WAY), params, is_training=False) with tf.variable_scope('misc'): def _get_acc(logits, labels): return tf.reduce_mean( tf.cast(tf.equal(tf.argmax(logits, axis=-1), labels), tf.float32)) # Summary Operations tf.summary.scalar('loss', net.loss) tf.summary.scalar('acc', _get_acc(net.logits, y_prime)) for it in range(TRAIN_NUM_SGD - 1): tf.summary.scalar( 'acc_it_%d' % (it), _get_acc(net.logits_per_steps[:, :, :, it], y_prime)) summary_op = tf.summary.merge_all() # Initialize op init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) config_summary = tf.summary.text('TrainConfig', tf.convert_to_tensor( config.as_matrix()), collections=[]) extended_summary_op = tf.summary.merge([ tf.summary.scalar('valid_loss', valid_net.loss), tf.summary.scalar('valid_acc', _get_acc(valid_net.logits, y_prime_val)) ] + [ tf.summary.scalar( 'valid_acc_it_%d' % (it), _get_acc(valid_net.logits_per_steps[:, :, :, it], y_prime_val)) for it in range(VALID_NUM_SGD - 1) ]) # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> Run! config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) sess.graph.finalize() sess.run(init_op) summary_writer = tf.summary.FileWriter(LOG_DIR, sess.graph) summary_writer.add_summary(config_summary.eval(session=sess)) try: # Start Queueing coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord, sess=sess) for step in tqdm(xrange(TRAIN_NUM), dynamic_ncols=True): it, loss, _ = sess.run([global_step, net.loss, net.train_op]) tqdm.write('[%5d] Loss: %1.3f' % (it, loss)) if (it % SAVE_PERIOD == 0): net.save(sess, LOG_DIR, step=it) if (it % SUMMARY_PERIOD == 0): summary = sess.run(summary_op) summary_writer.add_summary(summary, it) if (it % (SUMMARY_PERIOD * 10) == 0): #Extended Summary summary = sess.run(extended_summary_op) summary_writer.add_summary(summary, it) except Exception as e: coord.request_stop(e) finally: net.save(sess, LOG_DIR) coord.request_stop() coord.join(threads)