def get_tfrecords(self): # xtrain: all records # *_l : partial records if self.dataset == 'CIFAR10': from cifar10 import inputs, unlabeled_inputs xtrain_l, ytrain_l = inputs(batch_size=self.batch_size, train=True, validation=False, shuffle=True) xtrain = unlabeled_inputs(batch_size=self.batch_size, validation=False, shuffle=True) xtest, ytest = inputs(batch_size=self.batch_size, train=False, validation=False, shuffle=True) elif self.dataset == 'SVHN': from svhn import inputs, unlabeled_inputs xtrain_l, ytrain_l = inputs(batch_size=self.batch_size, train=True, validation=False, shuffle=True) xtrain = unlabeled_inputs(batch_size=self.batch_size, validation=False, shuffle=True) xtest, ytest = inputs(batch_size=self.batch_size, train=False, validation=False, shuffle=True) elif self.dataset == 'MNIST': from mnist import inputs xtrain, _ = inputs(self.batch_size, 'train') xtrain_l, ytrain_l = inputs(self.batch_size, 'train_labeled') xtest, ytest = inputs(self.batch_size, 'test') return (xtrain_l, ytrain_l), xtrain, (xtest, ytest)
def get_tfrecords(self, idxes=None): """ idxes: idxes of TEST set in k-fold cross varidation. i.e., list of i, where 0 <= i < k+1 xtrain: all records *_l : partial records """ if c.FLAGS.dataset == 'CIFAR10': from cifar10 import inputs, unlabeled_inputs xtrain_l, ytrain_l = inputs(batch_size=c.BATCH_SIZE, train=True, validation=False, shuffle=True) xtrain = unlabeled_inputs(batch_size=c.BATCH_SIZE_UL, validation=False, shuffle=True) xtest , ytest = inputs(batch_size=c.BATCH_SIZE_TEST, train=False, validation=False, shuffle=True) elif c.FLAGS.dataset =='SVHN': from svhn import inputs, unlabeled_inputs xtrain_l, ytrain_l = inputs(batch_size=c.BATCH_SIZE, train=True, validation=False, shuffle=True) xtrain = unlabeled_inputs(batch_size=c.BATCH_SIZE_UL, validation=False, shuffle=True) xtest , ytest = inputs(batch_size=c.BATCH_SIZE_TEST, train=False, validation=False, shuffle=True) elif c.FLAGS.dataset == 'MNIST': from mnist import inputs xtrain_l, ytrain_l = inputs(c.BATCH_SIZE, 'train_labeled') xtrain,_ = inputs(c.BATCH_SIZE_UL, 'train') xtest , ytest = inputs(c.BATCH_SIZE_TEST, 'test') else: indices_train, indices_test = _split_k_into_train_and_test(self._k, idxes) #indices_train = [0] #print('=====================================================') #print('CAUTION! n of tfrecord_train was forced to be reduced', indices_train) #print('=====================================================') paths_train = self.paths_tfrecord_train if idxes is None: """ use different data sources for training and test respectively.""" if self.path_root_tests is None: raise ValueError('path_root_tests was not given.') path_tests = self.path_tfrecord_test else: """ do cross validataion over a single data source. """ if self.path_root_tests is not None: raise ValueError('selection of test data source is confusing since both path_root_tests and idxes are given.') paths_test = self.paths_tfrecord_train """ DEBUG print('indices_train:',indices_train) print('indices_test:', indices_test) sys.exit('oshimai') """ print('... reading TFRecords for train') (xtrain_l, ytrain_l), n_train_l = self.inputs(paths_train, indices_train, batch_size=c.BATCH_SIZE) print('... reading TFRecords for test') (xtest, ytest), n_test = self.inputs(paths_test, indices_test, batch_size=c.BATCH_SIZE_TEST) if IS_UNLABELED_ENABLE: print('... reading TFRecords of unlabeled') (xtrain, _), n_train_u = self.inputs(paths_train, indices_all, batch_size=c.BATCH_SIZE_UL ) else: xtrain = xtrain_l n_train_u = 0 """ CAUTION! in SSL setting, x_l and x_ul are in parallel and independently input to the model. bellow n_batches_train would be used as the num of training iteration per epoch, but when c.BATCH_SIZE and c.BATCH_SIZE_UL are different, n_batches_train is no longer valid. so, instead of n_batches_train, the num of training iteration per epoch should be given as arg, as vat original code. """ #self.n_batches_train = int((n_train_l + n_train_u) / (c.BATCH_SIZE + c.BATCH_SIZE_UL)) self.n_batches_train = int((n_train_l + n_train_u) / (c.BATCH_SIZE )) self.n_batches_test = int( n_test / c.BATCH_SIZE_TEST) print(' n_batches_train: %d, # n_batches_test : %d'%(self.n_batches_train, self.n_batches_test)) # [ToDo] the way to set n_labeled is inconsistent in datasets. self.n_labeled = n_train_l self.n_train = n_train_l + n_train_u return (xtrain_l, ytrain_l), xtrain, (xtest , ytest)
def main(_): print(FLAGS.epsilon, FLAGS.top_bn) numpy.random.seed(seed=FLAGS.seed) tf.set_random_seed(numpy.random.randint(1234)) with tf.Graph().as_default() as g: with tf.device("/cpu:0"): images, labels = inputs(batch_size=FLAGS.batch_size, train=True, validation=FLAGS.validation, shuffle=True) ul_images = unlabeled_inputs(batch_size=FLAGS.ul_batch_size, validation=FLAGS.validation, shuffle=True) images_eval_train, labels_eval_train = inputs( batch_size=FLAGS.eval_batch_size, train=True, validation=FLAGS.validation, shuffle=True) ul_images_eval_train = unlabeled_inputs( batch_size=FLAGS.eval_batch_size, validation=FLAGS.validation, shuffle=True) images_eval_test, labels_eval_test = inputs( batch_size=FLAGS.eval_batch_size, train=False, validation=FLAGS.validation, shuffle=True) with tf.device(FLAGS.device): lr = tf.placeholder(tf.float32, shape=[], name="learning_rate") mom = tf.placeholder(tf.float32, shape=[], name="momentum") with tf.variable_scope("CNN") as scope: # Build training graph loss, train_op, global_step = build_training_graph( images, labels, ul_images, lr, mom) scope.reuse_variables() # Build eval graph losses_eval_train = build_eval_graph(images_eval_train, labels_eval_train, ul_images_eval_train) losses_eval_test = build_eval_graph(images_eval_test, labels_eval_test, images_eval_test) init_op = tf.global_variables_initializer() if not FLAGS.log_dir: logdir = None writer_train = None writer_test = None else: logdir = FLAGS.log_dir writer_train = tf.summary.FileWriter(FLAGS.log_dir + "/train", g) writer_test = tf.summary.FileWriter(FLAGS.log_dir + "/test", g) saver = tf.train.Saver(tf.global_variables()) sv = tf.train.Supervisor(is_chief=True, logdir=logdir, init_op=init_op, init_feed_dict={ lr: FLAGS.learning_rate, mom: FLAGS.mom1 }, saver=saver, global_step=global_step, summary_op=None, summary_writer=None, save_model_secs=150, recovery_wait_secs=0) print("Training...") with sv.managed_session() as sess: for ep in range(FLAGS.num_epochs): if sv.should_stop(): break if ep < FLAGS.epoch_decay_start: feed_dict = {lr: FLAGS.learning_rate, mom: FLAGS.mom1} else: decayed_lr = ( (FLAGS.num_epochs - ep) / float(FLAGS.num_epochs - FLAGS.epoch_decay_start)) * FLAGS.learning_rate feed_dict = {lr: decayed_lr, mom: FLAGS.mom2} sum_loss = 0 start = time.time() for i in range(FLAGS.num_iter_per_epoch): _, batch_loss, _ = sess.run([train_op, loss, global_step], feed_dict=feed_dict) sum_loss += batch_loss end = time.time() print("Epoch:", ep, "CE_loss_train:", sum_loss / FLAGS.num_iter_per_epoch, "elapsed_time:", end - start) if (ep + 1 ) % FLAGS.eval_freq == 0 or ep + 1 == FLAGS.num_epochs: # Eval on training data act_values_dict = {} for key, _ in losses_eval_train.iteritems(): act_values_dict[key] = 0 n_iter_per_epoch = NUM_EVAL_EXAMPLES / FLAGS.eval_batch_size for i in range(n_iter_per_epoch): values = losses_eval_train.values() act_values = sess.run(values) for key, value in zip(act_values_dict.keys(), act_values): act_values_dict[key] += value summary = tf.Summary() current_global_step = sess.run(global_step) for key, value in act_values_dict.iteritems(): print("train-" + key, value / n_iter_per_epoch) summary.value.add(tag=key, simple_value=value / n_iter_per_epoch) if writer_train is not None: writer_train.add_summary(summary, current_global_step) # Eval on test data act_values_dict = {} for key, _ in losses_eval_test.iteritems(): act_values_dict[key] = 0 n_iter_per_epoch = NUM_EVAL_EXAMPLES / FLAGS.eval_batch_size for i in range(n_iter_per_epoch): values = losses_eval_test.values() act_values = sess.run(values) for key, value in zip(act_values_dict.keys(), act_values): act_values_dict[key] += value summary = tf.Summary() current_global_step = sess.run(global_step) for key, value in act_values_dict.iteritems(): print("test-" + key, value / n_iter_per_epoch) summary.value.add(tag=key, simple_value=value / n_iter_per_epoch) if writer_test is not None: writer_test.add_summary(summary, current_global_step) saver.save(sess, sv.save_path, global_step=global_step) sv.stop()
def main(_): print(FLAGS.epsilon, FLAGS.top_bn) np.random.seed(seed=FLAGS.seed) tf.set_random_seed(np.random.randint(1234)) with tf.Graph().as_default() as g: with tf.device("/cpu:0"): if FLAGS.data_set == 'CelebA': (images, labels), (_,_),(_,_) = d.get_data(batch_size=FLAGS.batch_size,image_size=FLAGS.img_size) (images_eval_train, labels_eval_train), (_,_),(images_eval_test, labels_eval_test) = \ d.get_data(batch_size=FLAGS.eval_batch_size,image_size=FLAGS.img_size) ul_images = images ul_images_eval_train = images_eval_train else: images, labels = inputs(batch_size=FLAGS.batch_size, train=True, validation=FLAGS.validation, shuffle=True) ul_images = unlabeled_inputs(batch_size=FLAGS.ul_batch_size, validation=FLAGS.validation, shuffle=True) images_eval_train, labels_eval_train = inputs(batch_size=FLAGS.eval_batch_size, train=True, validation=FLAGS.validation, shuffle=True) ul_images_eval_train = unlabeled_inputs(batch_size=FLAGS.eval_batch_size, validation=FLAGS.validation, shuffle=True) images_eval_test, labels_eval_test = inputs(batch_size=FLAGS.eval_batch_size, train=False, validation=FLAGS.validation, shuffle=True) lr = tf.placeholder(tf.float32, shape=[], name="learning_rate") mom = tf.placeholder(tf.float32, shape=[], name="momentum") loss, train_op, x_adv, x_reconst = build_training_graph(images, labels, ul_images, lr, mom) # Build eval graph if not FLAGS.draw_adv_img: losses_eval_train, _ = build_eval_graph(images_eval_train, labels_eval_train, ul_images_eval_train) losses_eval_test, results = build_eval_graph(images_eval_test, labels_eval_test, images_eval_test) saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.allocator_type = 'BFC' sess = tf.Session(config=config) if FLAGS.method == 'lvat': print('-------------------------------------------') print("... restore the variables from frozen model.") u.restore(sess, SCOPE_ENCODER, CKPT_AE) print('-------------------------------------------') #if FLAGS.draw_adv_img: if False: print("... restore the variables of the classifier. log__dir:", FLAGS.log__dir) ckpt = tf.train.get_checkpoint_state(FLAGS.log__dir) if ckpt and ckpt.model_checkpoint_path: u.restore(sess, SCOPE_CLASSIFIER, FLAGS.log__dir) op_init = u.init_uninitialized_vars(sess) sess.run(op_init, feed_dict={lr: FLAGS.learning_rate, mom: FLAGS.mom1}) else: sys.exit('failed to restore') else: print("... init the variables for the classifier to be trained.") classifier_vars = tf.get_collection( tf.GraphKeys.VARIABLES, scope=SCOPE_CLASSIFIER) print('classifier_vars:', classifier_vars) op_init = tf.variables_initializer(classifier_vars) optimizer_vars = tf.get_collection( tf.GraphKeys.VARIABLES, scope='scope_optimizer') print('optimizer_vars:', optimizer_vars) op_init_optimiser = tf.variables_initializer(optimizer_vars) sess.run([op_init, op_init_optimiser], feed_dict={lr: FLAGS.learning_rate, mom: FLAGS.mom1}) tf.train.start_queue_runners(sess=sess) if FLAGS.draw_adv_img: print('... skip training') _x, _x_adv, _x_reconst = sess.run([ul_images, x_adv, x_reconst]) _N = 7 print(math.floor(FLAGS.ul_batch_size // _N)) for i in range(math.floor(FLAGS.ul_batch_size // _N)): draw_x(_x, _x_reconst, _x_adv, n_x=_N, offset=i, show_reconst=(FLAGS.method == 'lvat'), filename='ep_%s_%.2f_%d'%(FLAGS.method, FLAGS.epsilon, i)) sys.exit('exit draw_adv_img') else: print('... start training') for ep in range(FLAGS.num_epochs): if ep < FLAGS.epoch_decay_start: feed_dict = {lr: FLAGS.learning_rate, mom: FLAGS.mom1} else: decayed_lr = ((FLAGS.num_epochs - ep) / float( FLAGS.num_epochs - FLAGS.epoch_decay_start)) * FLAGS.learning_rate feed_dict = {lr: decayed_lr, mom: FLAGS.mom2} sum_loss = 0 start = time.time() for i in tqdm(range(FLAGS.num_iter_per_epoch), leave=False): _, batch_loss = sess.run([train_op, loss], feed_dict=feed_dict) sum_loss += batch_loss end = time.time() print("Epoch:", ep, "CE_loss_train:", sum_loss / FLAGS.num_iter_per_epoch, "elapsed_time:", end - start, flush=True) if (ep >= FLAGS.eval_start) and ((ep + 1) % FLAGS.eval_freq == 0 or ep + 1 == FLAGS.num_epochs): test(sess, losses_eval_train, ep, "train-") test(sess, losses_eval_test, ep, "test-") if ep % 10 == 0: print("Model saved in file: %s" % saver.save(sess, FLAGS.log__dir + '/model.ckpt')) return
def main(_): print(FLAGS.epsilon, FLAGS.top_bn) np.random.seed(seed=FLAGS.seed) tf.set_random_seed(np.random.randint(1234)) with tf.Graph().as_default() as g: with tf.device("/cpu:0"): images, labels = inputs(batch_size=FLAGS.batch_size, train=True, validation=FLAGS.validation, shuffle=True) ul_images = tf.placeholder(shape=images.shape, dtype=tf.float32) '''unlabeled_inputs(batch_size=FLAGS.ul_batch_size, validation=FLAGS.validation, shuffle=True)''' images_eval_train, labels_eval_train = inputs(batch_size=FLAGS.eval_batch_size, train=True, validation=FLAGS.validation, shuffle=True) ul_images_eval_train = unlabeled_inputs(batch_size=FLAGS.eval_batch_size, validation=FLAGS.validation, shuffle=True) images_eval_test, labels_eval_test = inputs(batch_size=FLAGS.eval_batch_size, train=False, validation=FLAGS.validation, shuffle=True) def placeholder_like(x, name=None): return tf.placeholder(shape=x.shape, dtype=tf.float32, name=name) def random_sphere(shape): n = tf.random_normal(shape=shape, dtype=tf.float32) n = tf.reshape(n, shape=(int(shape[0]), -1)) n = tf.nn.l2_normalize(n, dim=1) n = tf.reshape(n, shape) return n def random_sphere_numpy(shape): n = np.random.normal(size=shape) proj_shape = tuple([n.shape[0]] + [1 for _ in range(len(shape) - 1)]) return n / np.linalg.norm(n.reshape((n.shape[0], -1)), axis=1).reshape(proj_shape) print(ul_images.shape) # ul_u = random_sphere(ul_images.shape) # ul_u_eval_train = random_sphere(ul_images_eval_train.shape) # ul_u_eval_test = random_sphere(images_eval_test.shape) ul_u = placeholder_like(ul_images, "ul_u") ul_u_eval_train = placeholder_like(ul_images_eval_train, "ul_u_eval_train") ul_u_eval_test = placeholder_like(images_eval_test, "ul_u_eval_test") with tf.device(FLAGS.device): lr = tf.placeholder(tf.float32, shape=[], name="learning_rate") mom = tf.placeholder(tf.float32, shape=[], name="momentum") with tf.variable_scope("CNN") as scope: # Build training graph loss, train_op, global_step, ul_u_updated = build_training_graph( images, labels, ul_images, ul_u, lr, mom) scope.reuse_variables() # Build eval graph losses_eval_train = build_eval_graph(images_eval_train, labels_eval_train, ul_images_eval_train, ul_u_eval_train) losses_eval_test = build_eval_graph(images_eval_test, labels_eval_test, images_eval_test, ul_u_eval_test) init_op = tf.global_variables_initializer() if not FLAGS.log_dir: logdir = None writer_train = None writer_test = None else: logdir = FLAGS.log_dir writer_train = tf.summary.FileWriter(FLAGS.log_dir + "/train", g) writer_test = tf.summary.FileWriter(FLAGS.log_dir + "/test", g) saver = tf.train.Saver(tf.global_variables()) sv = tf.train.Supervisor( is_chief=True, logdir=logdir, init_op=init_op, init_feed_dict={lr: FLAGS.learning_rate, mom: FLAGS.mom1}, saver=saver, global_step=global_step, summary_op=None, summary_writer=None, save_model_secs=150, recovery_wait_secs=0) ul_images_np = np.load("train_images.npy").reshape((-1, 32, 32, 3)) print("TRUNCATING UL DATA") ul_images_np = ul_images_np[:FLAGS.batch_size] ul_u_np = random_sphere_numpy(ul_images_np.shape) print(ul_images_np.shape, ul_u_np.shape) print("Training...") with sv.managed_session() as sess: for ep in range(FLAGS.num_epochs): if sv.should_stop(): break if ep < FLAGS.epoch_decay_start: feed_dict = {lr: FLAGS.learning_rate, mom: FLAGS.mom1} else: decayed_lr = ((FLAGS.num_epochs - ep) / float( FLAGS.num_epochs - FLAGS.epoch_decay_start)) * FLAGS.learning_rate feed_dict = {lr: decayed_lr, mom: FLAGS.mom2} sum_loss = 0 start = time.time() for i in range(FLAGS.num_iter_per_epoch): picked = range(FLAGS.batch_size) # np.random.choice(len(ul_images_np), size=FLAGS.batch_size, replace=False) feed_dict[ul_images] = ul_images_np[picked] feed_dict[ul_u] = ul_u_np[picked] ul_u_updated_np, _, batch_loss, _ = sess.run([ul_u_updated, train_op, loss, global_step], feed_dict=feed_dict) delta = ul_u_updated_np - ul_u_np[picked] # print("pos", ul_u_updated_np.reshape((FLAGS.batch_size, -1))[0, :4]) # print("delta", np.linalg.norm(delta.reshape((FLAGS.batch_size, -1)), axis=1)[:4]) print(np.linalg.norm(ul_u_updated_np - ul_u_np[picked]), ul_u_updated_np.reshape((FLAGS.batch_size, -1))[0, :3]) ul_u_np[picked] = ul_u_updated_np sum_loss += batch_loss end = time.time() print("Epoch:", ep, "CE_loss_train:", sum_loss / FLAGS.num_iter_per_epoch, "elapsed_time:", end - start) if (ep + 1) % FLAGS.eval_freq == 0 or ep + 1 == FLAGS.num_epochs: # Eval on training data act_values_dict = {} feed_dict = {ul_u_eval_train: random_sphere_numpy(ul_u_eval_train.shape)} for key, _ in losses_eval_train.iteritems(): act_values_dict[key] = 0 n_iter_per_epoch = NUM_EVAL_EXAMPLES / FLAGS.eval_batch_size for i in range(n_iter_per_epoch): values = losses_eval_train.values() act_values = sess.run(values, feed_dict=feed_dict) for key, value in zip(act_values_dict.keys(), act_values): act_values_dict[key] += value summary = tf.Summary() current_global_step = sess.run(global_step) for key, value in act_values_dict.iteritems(): print("train-" + key, value / n_iter_per_epoch) summary.value.add(tag=key, simple_value=value / n_iter_per_epoch) if writer_train is not None: writer_train.add_summary(summary, current_global_step) # Eval on test data act_values_dict = {} print("HOW COME THIS DOES NOT DEPEND ON ul_images_eval_train? SOMETHING'S WRONG HERE.") feed_dict = {ul_u_eval_test: random_sphere_numpy(ul_u_eval_test.shape)} for key, _ in losses_eval_test.iteritems(): act_values_dict[key] = 0 n_iter_per_epoch = NUM_EVAL_EXAMPLES / FLAGS.eval_batch_size for i in range(n_iter_per_epoch): values = losses_eval_test.values() act_values = sess.run(values, feed_dict=feed_dict) for key, value in zip(act_values_dict.keys(), act_values): act_values_dict[key] += value summary = tf.Summary() current_global_step = sess.run(global_step) for key, value in act_values_dict.iteritems(): print("test-" + key, value / n_iter_per_epoch) summary.value.add(tag=key, simple_value=value / n_iter_per_epoch) if writer_test is not None: writer_test.add_summary(summary, current_global_step) saver.save(sess, sv.save_path, global_step=global_step) sv.stop()