def main(_): print("\nParameters:") for attr,value in tf.app.flags.FLAGS.flag_values_dict().items(): print("{}={}".format(attr,value)) print("") os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu) if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) # Random seed rng = np.random.RandomState(FLAGS.seed) # seed labels rng_data = np.random.RandomState(rng.randint(0, 2**10)) # seed shuffling # load CIFAR-10 trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') trainx_unl = trainx.copy() trainx_unl2 = trainx.copy() if FLAGS.validation: split = int(0.1 * trainx.shape[0]) print("validation enabled") testx = trainx[:split] testy = trainy[:split] trainx = trainx[split:] trainy = trainy[split:] nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) # select labeled data inds = rng_data.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] txs = [] tys = [] for j in range(10): txs.append(trainx[trainy == j][:FLAGS.labeled]) tys.append(trainy[trainy == j][:FLAGS.labeled]) txs = np.concatenate(txs, axis=0) tys = np.concatenate(tys, axis=0) '''construct graph''' unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='unlabeled_data_input_pl') is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='labeled_data_input_pl') lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl') # scalar pl lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl') acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl') acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl') acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl') random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z') generator(random_z, is_training_pl, init=True) # init of weightnorm weights gen_inp = generator(random_z, is_training_pl, init=False, reuse=True) discriminator(unl, is_training_pl, init=True) logits_lab, _ = discriminator(inp, is_training_pl, init=False, reuse=True) logits_gen, layer_fake = discriminator(gen_inp, is_training_pl, init=False, reuse=True) logits_unl, layer_real = discriminator(unl, is_training_pl, init=False, reuse=True) with tf.name_scope('loss_functions'): # discriminator l_unl = tf.reduce_logsumexp(logits_unl, axis=1) l_gen = tf.reduce_logsumexp(logits_gen, axis=1) loss_lab = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl, logits=logits_lab)) loss_unl = - 0.5 * tf.reduce_mean(l_unl) \ + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \ + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen)) # generator m1 = tf.reduce_mean(layer_real, axis=0) m2 = tf.reduce_mean(layer_fake, axis=0) loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab loss_gen = tf.reduce_mean(tf.abs(m1 - m2)) correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) with tf.name_scope('optimizers'): # control op dependencies for batch norm and trainable variables tvars = tf.trainable_variables() dvars = [var for var in tvars if 'discriminator_model' in var.name] gvars = [var for var in tvars if 'generator_model' in var.name] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)] update_ops_dis = [x for x in update_ops if ('discriminator_model' in x.name)] optimizer_dis = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='dis_optimizer') optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='gen_optimizer') with tf.control_dependencies(update_ops_gen): train_gen_op = optimizer_gen.minimize(loss_gen, var_list=gvars) dis_op = optimizer_dis.minimize(loss_dis, var_list=dvars) ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay) maintain_averages_op = ema.apply(dvars) with tf.control_dependencies([dis_op]): train_dis_op = tf.group(maintain_averages_op) logits_ema, _ = discriminator(inp, is_training_pl, getter=get_getter(ema), reuse=True) correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32)) with tf.name_scope('summary'): with tf.name_scope('discriminator'): tf.summary.scalar('loss_discriminator', loss_dis, ['dis']) with tf.name_scope('generator'): tf.summary.scalar('loss_generator', loss_gen, ['gen']) with tf.name_scope('images'): tf.summary.image('gen_images', gen_inp, 10, ['image']) with tf.name_scope('epoch'): tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch']) tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema, ['epoch']) tf.summary.scalar('accuracy_test', acc_test_pl, ['epoch']) tf.summary.scalar('learning_rate', lr_pl, ['epoch']) sum_op_dis = tf.summary.merge_all('dis') sum_op_gen = tf.summary.merge_all('gen') sum_op_im = tf.summary.merge_all('image') sum_op_epoch = tf.summary.merge_all('epoch') # training global varialble global_epoch = tf.Variable(0, trainable=False, name='global_epoch') global_step = tf.Variable(0, trainable=False, name='global_step') inc_global_step = tf.assign(global_step, global_step+1) inc_global_epoch = tf.assign(global_epoch, global_epoch+1) # op initializer for session manager init_gen = [var.initializer for var in gvars][:-3] with tf.control_dependencies(init_gen): op = tf.global_variables_initializer() init_feed_dict = {inp: trainx_unl[:FLAGS.batch_size], unl: trainx_unl[:FLAGS.batch_size], is_training_pl: True} sv = tf.train.Supervisor(logdir=FLAGS.logdir, global_step=global_epoch, summary_op=None, save_model_secs=0, init_op=op,init_feed_dict=init_feed_dict) '''//////training //////''' print('start training') with sv.managed_session() as sess: tf.set_random_seed(rng.randint(2 ** 10)) print('\ninitialization done') print('Starting training from epoch :%d, step:%d \n'%(sess.run(global_epoch),sess.run(global_step))) writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) while not sv.should_stop(): epoch = sess.run(global_epoch) train_batch = sess.run(global_step) if (epoch >= FLAGS.epoch): print("Training done") sv.stop() break begin = time.time() train_loss_lab=train_loss_unl=train_loss_gen=train_acc=test_acc=test_acc_ma= 0 lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start,FLAGS.epoch,epoch) # construct randomly permuted batches trainx = [] trainy = [] for t in range(int(np.ceil(trainx_unl.shape[0] / float(txs.shape[0])))): # same size lbl and unlb inds = rng.permutation(txs.shape[0]) trainx.append(txs[inds]) trainy.append(tys[inds]) trainx = np.concatenate(trainx, axis=0) trainy = np.concatenate(trainy, axis=0) trainx_unl = trainx_unl[rng.permutation(trainx_unl.shape[0])] # shuffling unl dataset trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])] # training for t in range(nr_batches_train): display_progression_epoch(t, nr_batches_train) ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size # train discriminator feed_dict = {unl: trainx_unl[ran_from:ran_to], is_training_pl: True, inp: trainx[ran_from:ran_to], lbl: trainy[ran_from:ran_to], lr_pl: lr} _, acc, lu, lb, sm = sess.run([train_dis_op, accuracy_classifier, loss_lab, loss_unl, sum_op_dis], feed_dict=feed_dict) train_loss_unl += lu train_loss_lab += lb train_acc += acc if (train_batch % FLAGS.step_print) == 0: writer.add_summary(sm, train_batch) # train generator _, lg, sm = sess.run([train_gen_op, loss_gen, sum_op_gen], feed_dict={unl: trainx_unl2[ran_from:ran_to], is_training_pl: True, lr_pl: lr}) train_loss_gen += lg if (train_batch % FLAGS.step_print) == 0: writer.add_summary(sm, train_batch) if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0): ran_from = np.random.randint(0, trainx_unl.shape[0] - FLAGS.batch_size) ran_to = ran_from + FLAGS.batch_size sm = sess.run(sum_op_im, feed_dict={is_training_pl: True, unl: trainx_unl[ran_from:ran_to]}) writer.add_summary(sm, train_batch) train_batch += 1 sess.run(inc_global_step) train_loss_lab /= nr_batches_train train_loss_unl /= nr_batches_train train_loss_gen /= nr_batches_train train_acc /= nr_batches_train # Testing moving averaged model and raw model if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch-1): for t in range(nr_batches_test): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = {inp: testx[ran_from:ran_to], lbl: testy[ran_from:ran_to], is_training_pl: False} acc, acc_ema = sess.run([accuracy_classifier, accuracy_ema], feed_dict=feed_dict) test_acc += acc test_acc_ma += acc_ema test_acc /= nr_batches_test test_acc_ma /= nr_batches_test sum = sess.run(sum_op_epoch, feed_dict={acc_train_pl: train_acc, acc_test_pl: test_acc, acc_test_pl_ema: test_acc_ma, lr_pl: lr}) writer.add_summary(sum, epoch) print( "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f " "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f" % (epoch, time.time() - begin, train_loss_gen, train_loss_lab, train_loss_unl, train_acc, test_acc, test_acc_ma)) sess.run(inc_global_epoch) # save snapshots of model if ((epoch % FLAGS.freq_save == 0) & (epoch!=0) ) | (epoch == FLAGS.epoch-1): string = 'model-' + str(epoch) save_path = os.path.join(FLAGS.logdir, string) sv.saver.save(sess, save_path) print("Model saved in file: %s" % (save_path))
def main(_): print("\nParameters:") for attr, value in tf.app.flags.FLAGS.flag_values_dict().items(): print("{}={}".format(attr, value)) print("") os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu) if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) # Random seed rng = np.random.RandomState(FLAGS.seed) # seed labels rng_data = np.random.RandomState(rng.randint(0, 2**10)) # seed shuffling # load CIFAR-10 trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') trainx_unl = trainx.copy() trainx_unl2 = trainx.copy() if FLAGS.validation: split = int(0.1 * trainx.shape[0]) print("validation enabled") testx = trainx[:split] testy = trainy[:split] trainx = trainx[split:] trainy = trainy[split:] nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) # select labeled data inds = rng_data.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] txs = [] tys = [] for j in range(10): txs.append(trainx[trainy == j][:FLAGS.labeled]) tys.append(trainy[trainy == j][:FLAGS.labeled]) txs = np.concatenate(txs, axis=0) tys = np.concatenate(tys, axis=0) '''construct graph''' unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='unlabeled_data_input_pl') is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='labeled_data_input_pl') lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl') # scalar pl lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl') acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl') acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl') acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl') random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z') generator(random_z, is_training_pl, init=True) # init of weightnorm weights gen_inp = generator(random_z, is_training_pl, init=False, reuse=True) discriminator(unl, is_training_pl, init=True) logits_lab, _ = discriminator(inp, is_training_pl, init=False, reuse=True) logits_gen, layer_fake = discriminator(gen_inp, is_training_pl, init=False, reuse=True) logits_unl, layer_real = discriminator(unl, is_training_pl, init=False, reuse=True) with tf.name_scope('loss_functions'): # discriminator l_unl = tf.reduce_logsumexp(logits_unl, axis=1) l_gen = tf.reduce_logsumexp(logits_gen, axis=1) loss_lab = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl, logits=logits_lab)) loss_unl = - 0.5 * tf.reduce_mean(l_unl) \ + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \ + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen)) # generator m1 = tf.reduce_mean(layer_real, axis=0) m2 = tf.reduce_mean(layer_fake, axis=0) loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab loss_gen = tf.reduce_mean(tf.abs(m1 - m2)) correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) with tf.name_scope('optimizers'): # control op dependencies for batch norm and trainable variables tvars = tf.trainable_variables() dvars = [var for var in tvars if 'discriminator_model' in var.name] gvars = [var for var in tvars if 'generator_model' in var.name] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops_gen = [ x for x in update_ops if ('generator_model' in x.name) ] update_ops_dis = [ x for x in update_ops if ('discriminator_model' in x.name) ] optimizer_dis = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='dis_optimizer') optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='gen_optimizer') with tf.control_dependencies(update_ops_gen): train_gen_op = optimizer_gen.minimize(loss_gen, var_list=gvars) dis_op = optimizer_dis.minimize(loss_dis, var_list=dvars) ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay) maintain_averages_op = ema.apply(dvars) with tf.control_dependencies([dis_op]): train_dis_op = tf.group(maintain_averages_op) logits_ema, _ = discriminator(inp, is_training_pl, getter=get_getter(ema), reuse=True) correct_pred_ema = tf.equal( tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32)) with tf.name_scope('summary'): with tf.name_scope('discriminator'): tf.summary.scalar('loss_discriminator', loss_dis, ['dis']) with tf.name_scope('generator'): tf.summary.scalar('loss_generator', loss_gen, ['gen']) with tf.name_scope('images'): tf.summary.image('gen_images', gen_inp, 10, ['image']) with tf.name_scope('epoch'): tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch']) tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema, ['epoch']) tf.summary.scalar('accuracy_test', acc_test_pl, ['epoch']) tf.summary.scalar('learning_rate', lr_pl, ['epoch']) sum_op_dis = tf.summary.merge_all('dis') sum_op_gen = tf.summary.merge_all('gen') sum_op_im = tf.summary.merge_all('image') sum_op_epoch = tf.summary.merge_all('epoch') # training global varialble global_epoch = tf.Variable(0, trainable=False, name='global_epoch') global_step = tf.Variable(0, trainable=False, name='global_step') inc_global_step = tf.assign(global_step, global_step + 1) inc_global_epoch = tf.assign(global_epoch, global_epoch + 1) # op initializer for session manager init_gen = [var.initializer for var in gvars][:-3] with tf.control_dependencies(init_gen): op = tf.global_variables_initializer() init_feed_dict = { inp: trainx_unl[:FLAGS.batch_size], unl: trainx_unl[:FLAGS.batch_size], is_training_pl: True } sv = tf.train.Supervisor(logdir=FLAGS.logdir, global_step=global_epoch, summary_op=None, save_model_secs=0, init_op=op, init_feed_dict=init_feed_dict) '''//////training //////''' print('start training') with sv.managed_session() as sess: tf.set_random_seed(rng.randint(2**10)) print('\ninitialization done') print('Starting training from epoch :%d, step:%d \n' % (sess.run(global_epoch), sess.run(global_step))) writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) while not sv.should_stop(): epoch = sess.run(global_epoch) train_batch = sess.run(global_step) if (epoch >= FLAGS.epoch): print("Training done") sv.stop() break begin = time.time() train_loss_lab = train_loss_unl = train_loss_gen = train_acc = test_acc = test_acc_ma = 0 lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start, FLAGS.epoch, epoch) # construct randomly permuted batches trainx = [] trainy = [] for t in range( int(np.ceil( trainx_unl.shape[0] / float(txs.shape[0])))): # same size lbl and unlb inds = rng.permutation(txs.shape[0]) trainx.append(txs[inds]) trainy.append(tys[inds]) trainx = np.concatenate(trainx, axis=0) trainy = np.concatenate(trainy, axis=0) trainx_unl = trainx_unl[rng.permutation( trainx_unl.shape[0])] # shuffling unl dataset trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])] # training for t in range(nr_batches_train): display_progression_epoch(t, nr_batches_train) ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size # train discriminator feed_dict = { unl: trainx_unl[ran_from:ran_to], is_training_pl: True, inp: trainx[ran_from:ran_to], lbl: trainy[ran_from:ran_to], lr_pl: lr } _, acc, lu, lb, sm = sess.run([ train_dis_op, accuracy_classifier, loss_lab, loss_unl, sum_op_dis ], feed_dict=feed_dict) train_loss_unl += lu train_loss_lab += lb train_acc += acc if (train_batch % FLAGS.step_print) == 0: writer.add_summary(sm, train_batch) # train generator _, lg, sm = sess.run( [train_gen_op, loss_gen, sum_op_gen], feed_dict={ unl: trainx_unl2[ran_from:ran_to], is_training_pl: True, lr_pl: lr }) train_loss_gen += lg if (train_batch % FLAGS.step_print) == 0: writer.add_summary(sm, train_batch) if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0): ran_from = np.random.randint( 0, trainx_unl.shape[0] - FLAGS.batch_size) ran_to = ran_from + FLAGS.batch_size sm = sess.run(sum_op_im, feed_dict={ is_training_pl: True, unl: trainx_unl[ran_from:ran_to] }) writer.add_summary(sm, train_batch) train_batch += 1 sess.run(inc_global_step) train_loss_lab /= nr_batches_train train_loss_unl /= nr_batches_train train_loss_gen /= nr_batches_train train_acc /= nr_batches_train # Testing moving averaged model and raw model if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch - 1): for t in range(nr_batches_test): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: testx[ran_from:ran_to], lbl: testy[ran_from:ran_to], is_training_pl: False } acc, acc_ema = sess.run( [accuracy_classifier, accuracy_ema], feed_dict=feed_dict) test_acc += acc test_acc_ma += acc_ema test_acc /= nr_batches_test test_acc_ma /= nr_batches_test sum = sess.run(sum_op_epoch, feed_dict={ acc_train_pl: train_acc, acc_test_pl: test_acc, acc_test_pl_ema: test_acc_ma, lr_pl: lr }) writer.add_summary(sum, epoch) print( "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f " "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f" % (epoch, time.time() - begin, train_loss_gen, train_loss_lab, train_loss_unl, train_acc, test_acc, test_acc_ma)) sess.run(inc_global_epoch) # save snapshots of model if ((epoch % FLAGS.freq_save == 0) & (epoch != 0)) | (epoch == FLAGS.epoch - 1): string = 'model-' + str(epoch) save_path = os.path.join(FLAGS.logdir, string) sv.saver.save(sess, save_path) print("Model saved in file: %s" % (save_path))
def main(_): if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) # Random seed rng = np.random.RandomState(FLAGS.seed) # seed labels rng_data = np.random.RandomState(FLAGS.seed_data) # seed shuffling # load CIFAR-10 trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') trainx_unl = trainx.copy() trainx_unl2 = trainx.copy() nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) # select labeled data inds = rng_data.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] txs = [] tys = [] for j in range(10): txs.append(trainx[trainy == j][:FLAGS.labeled]) tys.append(trainy[trainy == j][:FLAGS.labeled]) txs = np.concatenate(txs, axis=0) tys = np.concatenate(tys, axis=0) print("Data:") print('train examples %d, batch %d, test examples %d, batch %d' \ % (trainx.shape[0], nr_batches_train, testx.shape[0], nr_batches_test)) print('histogram train', np.histogram(trainy, bins=10)[0]) print('histogram test ', np.histogram(testy, bins=10)[0]) print("histogram labeled", np.histogram(tys, bins=10)[0]) print("") '''construct graph''' print('constructing graph') unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='unlabeled_data_input_pl') is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='labeled_data_input_pl') lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl') # scalar pl lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl') acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl') acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl') acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl') kl_weight = tf.placeholder(tf.float32, [], 'kl_weight') random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z') perturb = tf.random_normal([FLAGS.batch_size, 100], mean=0, stddev=0.01) random_z_pert = random_z + FLAGS.scale * perturb / ( tf.expand_dims(tf.norm(perturb, axis=1), axis=1) * tf.ones([1, 100])) generator(random_z, is_training_pl, init=True) # init of weightnorm weights gen_inp = generator(random_z, is_training_pl, init=False, reuse=True) gen_inp_pert = generator(random_z_pert, is_training_pl, init=False, reuse=True) discriminator(unl, is_training_pl, init=True) logits_lab, _ = discriminator(inp, is_training_pl, init=False, reuse=True) logits_gen, layer_fake = discriminator(gen_inp, is_training_pl, init=False, reuse=True) logits_unl, layer_real = discriminator(unl, is_training_pl, init=False, reuse=True) logits_gen_perturb, layer_fake_perturb = discriminator(gen_inp_pert, is_training_pl, init=False, reuse=True) with tf.name_scope('loss_functions'): l_unl = tf.reduce_logsumexp(logits_unl, axis=1) l_gen = tf.reduce_logsumexp(logits_gen, axis=1) # discriminator loss_lab = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl, logits=logits_lab)) loss_unl = - 0.5 * tf.reduce_mean(l_unl) \ + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \ + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen)) # generator m1 = tf.reduce_mean(layer_real, axis=0) m2 = tf.reduce_mean(layer_fake, axis=0) j_loss = tf.reduce_mean( tf.reduce_sum(tf.square(logits_gen - logits_gen_perturb), axis=1)) if FLAGS.nabla: loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab + kl_weight * j_loss loss_gen = tf.reduce_mean(tf.abs(m1 - m2)) print('manifold reg enabled') else: loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab loss_gen = tf.reduce_mean(tf.abs(m1 - m2)) correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) with tf.name_scope('optimizers'): # control op dependencies for batch norm and trainable variables tvars = tf.trainable_variables() dvars = [var for var in tvars if 'discriminator_model' in var.name] gvars = [var for var in tvars if 'generator_model' in var.name] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops_gen = [ x for x in update_ops if ('generator_model' in x.name) ] update_ops_dis = [ x for x in update_ops if ('discriminator_model' in x.name) ] optimizer_dis = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='dis_optimizer') optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='gen_optimizer') with tf.control_dependencies(update_ops_gen): train_gen_op = optimizer_gen.minimize(loss_gen, var_list=gvars) dis_op = optimizer_dis.minimize(loss_dis, var_list=dvars) ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay) maintain_averages_op = ema.apply(dvars) with tf.control_dependencies([dis_op]): train_dis_op = tf.group(maintain_averages_op) logits_ema, _ = discriminator(inp, is_training_pl, getter=get_getter(ema), reuse=True) correct_pred_ema = tf.equal( tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32)) with tf.name_scope('summary'): with tf.name_scope('discriminator'): tf.summary.scalar('loss_discriminator', loss_dis, ['dis']) tf.summary.scalar('kl_loss', j_loss, ['dis']) with tf.name_scope('generator'): tf.summary.scalar('loss_generator', loss_gen, ['gen']) with tf.name_scope('images'): tf.summary.image('gen_images', gen_inp, 10, ['image']) with tf.name_scope('epoch'): tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch']) tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema, ['epoch']) tf.summary.scalar('accuracy_test_raw', acc_test_pl, ['epoch']) tf.summary.scalar('learning_rate', lr_pl, ['epoch']) tf.summary.scalar('j_weight', kl_weight, ['epoch']) sum_op_dis = tf.summary.merge_all('dis') sum_op_gen = tf.summary.merge_all('gen') sum_op_im = tf.summary.merge_all('image') sum_op_epoch = tf.summary.merge_all('epoch') # training global varialble global_epoch = tf.Variable(0, trainable=False, name='global_epoch') global_step = tf.Variable(0, trainable=False, name='global_step') inc_global_step = tf.assign(global_step, global_step + 1) inc_global_epoch = tf.assign(global_epoch, global_epoch + 1) # op initializer for session manager init_gen = [var.initializer for var in gvars][:-3] with tf.control_dependencies(init_gen): op = tf.global_variables_initializer() init_feed_dict = { inp: trainx_unl[:FLAGS.batch_size], unl: trainx_unl[:FLAGS.batch_size], is_training_pl: True, kl_weight: 0 } sv = tf.train.Supervisor(logdir=FLAGS.logdir, global_step=global_epoch, summary_op=None, save_model_secs=0, init_op=op, init_feed_dict=init_feed_dict) inception_scores = [] '''//////training //////''' print('start training') with sv.managed_session() as sess: tf.set_random_seed(rng.randint(2**10)) print('\ninitialization done') print('Starting training from epoch :%d, step:%d \n' % (sess.run(global_epoch), sess.run(global_step))) writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) while not sv.should_stop(): epoch = sess.run(global_epoch) train_batch = sess.run(global_step) if (epoch >= FLAGS.epoch): print("Training done") sv.stop() break begin = time.time() train_loss_lab = train_loss_unl = train_loss_gen = train_acc = test_acc = test_acc_ma = train_j_loss = 0 lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start, FLAGS.epoch, epoch) klw = FLAGS.nabla_w # construct randomly permuted batches trainx = [] trainy = [] for t in range( int(np.ceil( trainx_unl.shape[0] / float(txs.shape[0])))): # same size lbl and unlb inds = rng.permutation(txs.shape[0]) trainx.append(txs[inds]) trainy.append(tys[inds]) trainx = np.concatenate(trainx, axis=0) trainy = np.concatenate(trainy, axis=0) trainx_unl = trainx_unl[rng.permutation( trainx_unl.shape[0])] # shuffling unl dataset trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])] # training for t in range(nr_batches_train): display_progression_epoch(t, nr_batches_train) ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size # train discriminator feed_dict = { unl: trainx_unl[ran_from:ran_to], is_training_pl: True, inp: trainx[ran_from:ran_to], lbl: trainy[ran_from:ran_to], lr_pl: lr, kl_weight: klw } _, acc, lu, lb, jl, sm = sess.run([ train_dis_op, accuracy_classifier, loss_lab, loss_unl, j_loss, sum_op_dis ], feed_dict=feed_dict) train_loss_unl += lu train_loss_lab += lb train_acc += acc train_j_loss += jl if (train_batch % FLAGS.step_print) == 0: writer.add_summary(sm, train_batch) # train generator _, lg, sm = sess.run( [train_gen_op, loss_gen, sum_op_gen], feed_dict={ unl: trainx_unl2[ran_from:ran_to], is_training_pl: True, lr_pl: lr, kl_weight: klw }) train_loss_gen += lg if (train_batch % FLAGS.step_print) == 0: writer.add_summary(sm, train_batch) if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0): ran_from = np.random.randint( 0, trainx_unl.shape[0] - FLAGS.batch_size) ran_to = ran_from + FLAGS.batch_size sm = sess.run(sum_op_im, feed_dict={ is_training_pl: True, unl: trainx_unl[ran_from:ran_to] }) writer.add_summary(sm, train_batch) train_batch += 1 sess.run(inc_global_step) train_loss_lab /= nr_batches_train train_loss_unl /= nr_batches_train train_loss_gen /= nr_batches_train train_acc /= nr_batches_train train_j_loss /= nr_batches_train # Testing moving averaged model and raw model if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch - 1): for t in range(nr_batches_test): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: testx[ran_from:ran_to], lbl: testy[ran_from:ran_to], is_training_pl: False } acc, acc_ema = sess.run( [accuracy_classifier, accuracy_ema], feed_dict=feed_dict) test_acc += acc test_acc_ma += acc_ema test_acc /= nr_batches_test test_acc_ma /= nr_batches_test sum = sess.run(sum_op_epoch, feed_dict={ acc_train_pl: train_acc, acc_test_pl: test_acc, acc_test_pl_ema: test_acc_ma, lr_pl: lr, kl_weight: klw }) writer.add_summary(sum, epoch) print( "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f " "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f" % (epoch, time.time() - begin, train_loss_gen, train_loss_lab, train_loss_unl, train_acc, test_acc, test_acc_ma)) sess.run(inc_global_epoch) # save snap shot of model if ((epoch % FLAGS.freq_save == 0) & (epoch != 0)) | (epoch == FLAGS.epoch - 1): string = 'model-' + str(epoch) save_path = os.path.join(FLAGS.logdir, string) sv.saver.save(sess, save_path) print("Model saved in file: %s" % (save_path)) print("saving images...") sample_images = sess.run(gen_inp, feed_dict={is_training_pl: False}) save_images(sample_images, os.path.join(FLAGS.logdir, '{:06d}.png'.format(epoch))) print('images saved @ ' + os.path.join(FLAGS.logdir, '{:06d}.png'.format(epoch))) num_images_to_eval = 50000 eval_images = [] num_batches = num_images_to_eval // FLAGS.batch_size + 1 print("Calculating Inception Score. Sampling {} images...".format( num_images_to_eval)) np.random.seed(0) for _ in range(num_batches): images = sess.run(gen_inp, feed_dict={is_training_pl: False}) eval_images.append(images) np.random.seed() eval_images = np.vstack(eval_images) eval_images = eval_images[:num_images_to_eval] eval_images = np.clip((eval_images + 1.0) * 127.5, 0.0, 255.0).astype(np.uint8) # Calc Inception score eval_images = list(eval_images) inception_score_mean, inception_score_std = get_inception_score( eval_images) print("Inception Score: Mean = {} \tStd = {}.".format( inception_score_mean, inception_score_std)) inception_scores.append( dict(mean=inception_score_mean, std=inception_score_std)) with open(INCEPTION_FILENAME, 'wb') as f: pickle.dump(inception_scores, f)
def main(_): if not os.path.exists(FLAGS.log_dir): os.mkdir(FLAGS.log_dir) # Random seed rng = np.random.RandomState(FLAGS.seed) # load CIFAR-10 trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [0 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') # overfitting test # trainx = trainx[:10000] # trainy = trainy[:10000] nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) # whitten data print('Starting preprocessing') begin = time.time() m = np.mean(trainx, axis=0) std = np.mean(trainx, axis=0) trainx -= m # trainx /= std testx -= m # testx /= std trainx, testx = zca_whiten(trainx, testx, epsilon=1e-8) print('Preprocessing done in : %ds' % (time.time() - begin)) '''construct graph''' inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='data_input') lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input') is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') accuracy_epoch = tf.placeholder(tf.float32, [], name='epoch_pl') adam_learning_rate_pl = tf.placeholder(tf.float32, [], name='adam_learning_rate_pl') adam_momentum_pl = tf.placeholder(tf.float32, [], name='adam_momentum_pl') with tf.variable_scope('cnn_model'): logits = cifar_model.inference(inp, is_training_pl) with tf.name_scope('loss_function'): loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=lbl) correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), lbl) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) eval_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32)) optimizer = tf.train.AdamOptimizer(learning_rate=adam_learning_rate_pl, beta1=adam_momentum_pl) update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # control dependencies for batch norm ops with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) # Summaries with tf.name_scope('per_batch_summary'): tf.summary.scalar('loss', loss, ['batch']) tf.summary.scalar('accuracy', accuracy, ['batch']) tf.summary.scalar('adam learning rate', adam_learning_rate_pl, ['batch']) tf.summary.scalar('adam momentum', adam_momentum_pl, ['batch']) with tf.name_scope('per_epoch_summary'): tf.summary.scalar('accuracy epoch', accuracy_epoch, ['per_epoch']) tf.summary.merge( tf.contrib.layers.summarize_collection( tf.GraphKeys.TRAINABLE_VARIABLES), ['per_epoch']) with tf.name_scope('input_data'): tf.summary.image('input image', inp, 10, ['per_epoch']) tf.summary.histogram('first input image', tf.reshape(inp[0], [-1]), ['per_epoch']) tf.summary.histogram('input labels', lbl, ['per_epoch']) tf.summary.histogram('output logits', tf.argmax(logits, axis=0), ['per_epoch']) sum_op = tf.summary.merge_all('batch') sum_epoch_op = tf.summary.merge_all('per_epoch') '''//////perform training //////''' with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init, {adam_momentum_pl: 0.9}) train_batch = 0 train_writer = tf.summary.FileWriter( os.path.join(FLAGS.log_dir, 'train'), sess.graph) test_writer = tf.summary.FileWriter( os.path.join(FLAGS.log_dir, 'test'), sess.graph) for epoch in tqdm(range(200)): begin = time.time() # randomly permuted minibatches inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] train_loss, train_tp, test_tp = [0, 0, 0] for t in range(nr_batches_train): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: trainx[ran_from:ran_to], lbl: trainy[ran_from:ran_to], is_training_pl: True, adam_learning_rate_pl: decayed_lr(epoch), adam_momentum_pl: momentum(epoch) } _, ls, tp, sm = sess.run( [train_op, loss, eval_correct, sum_op], feed_dict=feed_dict) train_loss += ls train_tp += tp train_batch += 1 train_writer.add_summary(sm, train_batch) train_loss /= nr_batches_train train_tp /= trainx.shape[0] for t in range(nr_batches_test): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: testx[ran_from:ran_to], lbl: testy[ran_from:ran_to], is_training_pl: False } test_tp += sess.run(eval_correct, feed_dict=feed_dict) test_tp /= testx.shape[0] '''/////epoch summary/////''' sm = sess.run( sum_epoch_op, { accuracy_epoch: train_tp, inp: trainx[:FLAGS.batch_size], lbl: trainy[:FLAGS.batch_size], is_training_pl: False }) train_writer.add_summary(sm, epoch) x = np.random.randint( 0, testx.shape[0] - FLAGS.batch_size) # random batch extracted in testx sm = sess.run( sum_epoch_op, { accuracy_epoch: test_tp, inp: testx[x:x + FLAGS.batch_size], lbl: testy[x:x + FLAGS.batch_size], is_training_pl: False }) test_writer.add_summary(sm, epoch) # print("Epoch %d--Batch %d--Time = %ds | loss train = %.4f | train acc = %.4f | test acc = %.4f" % # (epoch, train_batch, time.time() - begin, train_loss, train_tp, test_tp)) tqdm.write( "Epoch %d--Batch %d--Time = %ds | loss train = %.4f | train acc = %.4f | test acc = %.4f" % (epoch, train_batch, time.time() - begin, train_loss, train_tp, test_tp))
def main(_): print("\nParameters:") for attr, value in FLAGS.__flags.items(): print("{}={}".format(attr.lower(), value)) print("") if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) rng = np.random.RandomState(FLAGS.seed) # seed labels trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') # select labeled data inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] print("first labels trainy: ", trainy[:10]) txs = [] tys = [] for j in range(10): txs.append(trainx[trainy == j][:FLAGS.labeled]) tys.append(trainy[trainy == j][:FLAGS.labeled]) txs = np.concatenate(txs, axis=0) tys = np.concatenate(tys, axis=0) trainx = txs trainy = tys nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) print("trainx shape:", trainx.shape) # placeholder model inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='data_input') lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input') is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') gan_is_training_pl = tf.placeholder(tf.bool, [], name='gan_is_training_pl') learning_rate_pl = tf.placeholder(tf.float16, [], name='adam_learning_rate_pl') acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl') acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl') acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl') generator = DCGANGenerator(batch_size=FLAGS.mc_size) latent_dim = generator.generate_noise().shape[1] z = tf.placeholder(tf.float32, shape=[FLAGS.mc_size, latent_dim]) if not FLAGS.tiny_cnn: from dnn import classifier as classifier print("standard cnn loaded") else: from dnn import tiny_classifier as classifier print("tiny cnn loaded") x_hat = generator(z, is_training=gan_is_training_pl) logits = classifier(inp, is_training=is_training_pl) logits_gen = classifier(x_hat, is_training=is_training_pl, reuse=True) def get_jacobian(y, x): with tf.name_scope("jacob"): grads = tf.stack( [tf.gradients(yi, x)[0] for yi in tf.unstack(y, axis=1)], axis=2) return grads if FLAGS.grad == 'stochastic': print('stochastic reg enabled ...') perturb = tf.random_normal([FLAGS.mc_size, latent_dim], mean=0, stddev=0.01) z_pert = z + FLAGS.scale * perturb / (tf.expand_dims( tf.norm(perturb, axis=1), axis=1) * tf.ones([1, latent_dim])) x_pert = generator(z_pert, is_training=gan_is_training_pl, reuse=True) logits_gen_perturb = classifier(x_pert, is_training=is_training_pl, reuse=True) j_loss = tf.reduce_mean( tf.reduce_sum(tf.square(logits_gen - logits_gen_perturb), axis=1)) tf.reduce_mean( tf.reduce_sum(tf.square(get_jacobian(logits_gen, z)), axis=[1, 2])) elif FLAGS.grad == 'stochastic_v2': print('stochastic v2 reg enabled ...') perturb = tf.nn.l2_normalize(tf.random_normal( [FLAGS.mc_size, latent_dim], mean=0, stddev=0.01), dim=[1]) x_pert = generator(z + FLAGS.scale * perturb, is_training=gan_is_training_pl, reuse=True) logits_gen_perturb = classifier(x_pert, is_training=is_training_pl, reuse=True) j_loss = tf.reduce_mean( tf.reduce_sum(tf.square(logits_gen - logits_gen_perturb), axis=1)) elif FLAGS.grad == 'isotropic_mc': print('isotropic mc reg enabled ...') perturb = tf.nn.l2_normalize( tf.random_normal([FLAGS.mc_size] + inp.get_shape().as_list()[-3:], mean=0, stddev=0.01), dim=[1, 2, 3]) # gaussian noise [mc_size, 32,32,3] x_pert = x_hat + FLAGS.scale * perturb logits_gen_pert = classifier(x_pert, is_training=is_training_pl, reuse=True) j_loss = tf.reduce_mean( tf.reduce_sum(tf.square(logits_gen - logits_gen_pert), axis=1)) elif FLAGS.grad == 'isotropic_inp': print('isotropic inp reg enabled ...') perturb = tf.nn.l2_normalize( tf.random_normal([FLAGS.mc_size] + inp.get_shape().as_list()[-3:], mean=0, stddev=0.01), dim=[1, 2, 3]) # gaussian noise [mc_size, 32,32,3] x_pert = inp + FLAGS.scale * perturb logits_inp_pert = classifier(x_pert, is_training=is_training_pl, reuse=True) j_loss = tf.reduce_mean( tf.reduce_sum(tf.square(logits_gen - logits_inp_pert), axis=1)) elif FLAGS.grad == 'isotropic_rnd': print('isotropic rnd reg enabled ...') epsilon = tf.random_normal( [FLAGS.mc_size] + inp.get_shape().as_list()[-3:], mean=0, stddev=0.01) # gaussian noise [mc_size, 32,32,3] epsilon_hat = tf.nn.l2_normalize( epsilon, dim=[1, 2, 3]) # normalised gaussian noise [mc_size, 32,32,3] rnd_img = tf.random_uniform(shape=[FLAGS.mc_size] + inp.get_shape().as_list()[-3:], minval=-1, maxval=1) x_pert = rnd_img + FLAGS.scale * epsilon_hat logits_pert = classifier(x_pert, is_training=is_training_pl, reuse=True) j_loss = tf.reduce_mean( tf.reduce_sum(tf.square(logits_gen - logits_pert), axis=1)) elif FLAGS.grad == 'grad_latent': print('grad latent enabled ...') grad = get_jacobian(logits_gen, z) j_loss = tf.reduce_mean(tf.reduce_sum(tf.square(grad), axis=[1, 2])) elif FLAGS.grad == 'grad_mc': print('grad mc enabled ...') grad = get_jacobian(logits_gen, x_hat) j_loss = tf.reduce_mean(tf.reduce_sum(tf.square(grad), axis=[1, 2])) elif FLAGS.grad == 'grad_inp': print('grad inp enabled ...') grad = get_jacobian(logits, inp) j_loss = tf.reduce_mean(tf.reduce_sum(tf.square(grad), axis=[1, 2])) elif FLAGS.grad == 'grad_old': print('old grad enabled ...') k = [] for j in range(10): grad = tf.gradients(logits_gen[:, j], z) k.append(grad) J = tf.stack(k) J = tf.squeeze(J) J = tf.transpose(J, perm=[1, 0, 2]) # jacobian j_n = tf.reduce_sum(tf.square(J), axis=[1, 2]) j_loss = tf.reduce_mean(j_n) elif FLAGS.grad == 'comb': jac_manifold = [] jac_ambient = [] for yi in tf.unstack(logits_gen, axis=1): g1, g2 = tf.gradients(yi, [z, x_hat]) jac_ambient.append(g2) jac_manifold.append(g1) jm = tf.square(tf.stack(jac_manifold)) ja = tf.square(tf.stack(jac_ambient)) j_manifold = tf.reduce_mean(tf.reduce_sum(jm, axis=[0, 2])) j_ambient = tf.reduce_mean(tf.reduce_sum(ja, axis=[0, 2, 3, 4])) j_loss = tf.constant(0.) ######## loss function ####### xentropy = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=lbl) if not FLAGS.reg: print('reg disabeled') loss = xentropy else: print('laplacian reg enabled') loss = xentropy + FLAGS.reg_ambient * j_ambient + FLAGS.reg_manifold * j_manifold correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), lbl) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # g_vars = tf.global_variables(scope='generator') g_vars = [var for var in tf.global_variables() if 'generator' in var.name] dnn_vars = [var for var in tf.trainable_variables() if var not in g_vars] # [print(var.name) for var in dnn_vars] optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_pl) update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # control dependencies for batch norm ops # with tf.control_dependencies(update_ops): # train_op = optimizer.minimize(loss, var_list=dnn_vars) dvars = [ var for var in tf.trainable_variables() if 'classifier' in var.name ] # [print(var) for var in dvars] with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, var_list=dvars) ### ema ### ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay) maintain_averages_op = ema.apply(dvars) with tf.control_dependencies([train_op]): train_dis_op = tf.group(maintain_averages_op) logits_ema = classifier(inp, is_training_pl, getter=get_getter(ema), reuse=True) correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32)) def linear_decay(decay_start, decay_end, epoch): return min( -1 / (decay_end - decay_start) * epoch + 1 + decay_start / (decay_end - decay_start), 1) # all_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # print("list des vars") # [print(var.name) for var in all_var] # print("name var") # [print(var.name for var in g_vars)] with tf.name_scope('summary'): with tf.name_scope('discriminator'): tf.summary.scalar('xentropy', xentropy, ['dis']) tf.summary.scalar('laplacian_loss', j_loss, ['dis']) with tf.name_scope('images'): tf.summary.image('gen_images', x_hat, 4, ['image']) # tf.summary.image('gen_pert', x_pert, 4, ['image']) with tf.name_scope('epoch'): tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch']) tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema, ['epoch']) tf.summary.scalar('accuracy_test_raw', acc_test_pl, ['epoch']) tf.summary.scalar('learning_rate', learning_rate_pl, ['epoch']) sum_op_dis = tf.summary.merge_all('dis') sum_op_im = tf.summary.merge_all('image') sum_op_epoch = tf.summary.merge_all('epoch') print("batch size monte carlo: ", generator.generate_noise().shape) print("") saver = tf.train.Saver(var_list=g_vars) var_init = [var for var in tf.global_variables() if var not in g_vars] init_op = tf.variables_initializer(var_list=var_init) # config = tf.ConfigProto(device_count={'GPU': 0}) # config.gpu_options.allow_growth = True with tf.Session() as sess: writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) sess.run(init_op) if tf.train.latest_checkpoint(FLAGS.snapshot) is not None: saver.restore(sess, tf.train.latest_checkpoint(FLAGS.snapshot)) print("model restored @ %s" % FLAGS.snapshot) train_batch = 0 for epoch in tqdm(range(FLAGS.epoch), disable=not FLAGS.verbose): begin = time.time() # randomly permuted minibatches inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] train_loss = train_acc = test_acc = train_j = test_acc_ema = 0 lr = FLAGS.learning_rate * linear_decay(FLAGS.decay, FLAGS.epoch, epoch) for t in tqdm(range(nr_batches_train), disable=not FLAGS.verbose): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: trainx[ran_from:ran_to], lbl: trainy[ran_from:ran_to], is_training_pl: True, gan_is_training_pl: False, learning_rate_pl: lr, z: generator.generate_noise() } _, ls, acc, j, sm = sess.run( [train_dis_op, loss, accuracy, j_loss, sum_op_dis], feed_dict=feed_dict) train_loss += ls train_acc += acc train_j += j writer.add_summary(sm, train_batch) train_batch += 1 train_loss /= nr_batches_train train_acc /= nr_batches_train train_j /= nr_batches_train if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0): sm = sess.run(sum_op_im, feed_dict={ gan_is_training_pl: False, z: generator.generate_noise(), inp: trainx[:FLAGS.batch_size] }) writer.add_summary(sm, train_batch) if (epoch % FLAGS.freq_test == 0): for t in range(nr_batches_test): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: testx[ran_from:ran_to], lbl: testy[ran_from:ran_to], is_training_pl: False } acc, acc_ema = sess.run([accuracy, accuracy_ema], feed_dict=feed_dict) test_acc += acc test_acc_ema += acc_ema test_acc /= nr_batches_test test_acc_ema /= nr_batches_test sum = sess.run(sum_op_epoch, feed_dict={ acc_train_pl: train_acc, acc_test_pl: test_acc, acc_test_pl_ema: test_acc_ema, learning_rate_pl: lr }) writer.add_summary(sum, epoch) tqdm.write( "Epoch %03d | Time = %03ds | lr = %.3e | loss train = %.4f | train acc = %.2f | test acc = %.2f | test acc_ema = %.2f" % (epoch, time.time() - begin, lr, train_loss, train_acc * 100, test_acc * 100, test_acc_ema * 100)) if status_reporter: # report status for ray tune status_reporter(timesteps_total=epoch, mean_accuracy=test_acc_ema)
def main(_): if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) # Random seed rng = np.random.RandomState(FLAGS.seed) # seed labels rng_data = np.random.RandomState(FLAGS.seed_data) # seed shuffling # load CIFAR-10 trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') trainx_unl = trainx.copy() trainx_unl2 = trainx.copy() nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) # select labeled data inds = rng_data.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] txs = [] tys = [] for j in range(10): txs.append(trainx[trainy == j][:FLAGS.labeled]) tys.append(trainy[trainy == j][:FLAGS.labeled]) txs = np.concatenate(txs, axis=0) tys = np.concatenate(tys, axis=0) config = FLAGS.__flags generator = DCGANGenerator(**config) discriminator = SNDCGAN_Discrminator(output_dim=10, features=True, **config) global_step = tf.Variable(0, name="global_step", trainable=False) increase_global_step = global_step.assign(global_step + 1) '''construct graph''' print('constructing graph') unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='unlabeled_data_input_pl') is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='labeled_data_input_pl') lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl') # scalar pl lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl') acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl') acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl') acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl') random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z') gen_inp = generator(random_z, is_training_pl) logits_gen, layer_fake = discriminator(gen_inp, update_collection=None, features=True) logits_unl, layer_real = discriminator(unl, update_collection="NO_OPS", features=True) logits_lab, _ = discriminator(inp, update_collection="NO_OPS") with tf.name_scope('loss_functions'): l_unl = tf.reduce_logsumexp(logits_unl, axis=1) l_gen = tf.reduce_logsumexp(logits_gen, axis=1) # discriminator loss_lab = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl, logits=logits_lab)) loss_unl = - 0.5 * tf.reduce_mean(l_unl) \ + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \ + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen)) # generator m1 = tf.reduce_mean(layer_real, axis=0) m2 = tf.reduce_mean(layer_fake, axis=0) loss_gen = tf.reduce_mean(tf.abs(m1 - m2)) loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) with tf.name_scope('optimizers'): d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='critic') g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.adam_alpha, beta1=FLAGS.adam_beta1, beta2=FLAGS.adam_beta2) d_gvs = optimizer.compute_gradients(loss_dis, var_list=d_vars) g_gvs = optimizer.compute_gradients(loss_gen, var_list=g_vars) d_solver = optimizer.apply_gradients(d_gvs) g_solver = optimizer.apply_gradients(g_gvs) ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay) maintain_averages_op = ema.apply(d_vars) with tf.control_dependencies([d_solver]): train_dis_op = tf.group(maintain_averages_op) logits_ema, _ = discriminator(inp, update_collection="NO_OPS", getter=get_getter(ema)) correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32)) with tf.name_scope('summary'): with tf.name_scope('discriminator'): tf.summary.scalar('loss_discriminator', loss_dis, ['dis']) with tf.name_scope('generator'): tf.summary.scalar('loss_generator', loss_gen, ['gen']) with tf.name_scope('images'): tf.summary.image('gen_images', gen_inp, 10, ['image']) with tf.name_scope('epoch'): tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch']) tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema, ['epoch']) tf.summary.scalar('accuracy_test_raw', acc_test_pl, ['epoch']) tf.summary.scalar('learning_rate', lr_pl, ['epoch']) sum_op_dis = tf.summary.merge_all('dis') sum_op_gen = tf.summary.merge_all('gen') sum_op_im = tf.summary.merge_all('image') sum_op_epoch = tf.summary.merge_all('epoch') '''//////training //////''' print('start training') with tf.Session() as sess: tf.set_random_seed(rng.randint(2**10)) sess.run(tf.global_variables_initializer()) print('\ninitialization done') writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph) train_batch = 0 for epoch in tqdm(range(FLAGS.epoch)): begin = time.time() train_loss_lab = train_loss_unl = train_loss_gen = train_acc = test_acc = test_acc_ma = train_j_loss = 0 lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start, FLAGS.epoch, epoch) # construct randomly permuted batches trainx = [] trainy = [] for t in range( int(np.ceil( trainx_unl.shape[0] / float(txs.shape[0])))): # same size lbl and unlb inds = rng.permutation(txs.shape[0]) trainx.append(txs[inds]) trainy.append(tys[inds]) trainx = np.concatenate(trainx, axis=0) trainy = np.concatenate(trainy, axis=0) trainx_unl = trainx_unl[rng.permutation( trainx_unl.shape[0])] # shuffling unl dataset trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])] # training for t in tqdm(range(nr_batches_train)): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size # train discriminator feed_dict = { unl: trainx_unl[ran_from:ran_to], is_training_pl: True, inp: trainx[ran_from:ran_to], lbl: trainy[ran_from:ran_to], lr_pl: lr } _, acc, lu, lb, sm = sess.run([ train_dis_op, accuracy_classifier, loss_lab, loss_unl, sum_op_dis ], feed_dict=feed_dict) train_loss_unl += lu train_loss_lab += lb train_acc += acc if (train_batch % FLAGS.step_print) == 0: writer.add_summary(sm, train_batch) # train generator _, lg, sm = sess.run( [g_solver, loss_gen, sum_op_gen], feed_dict={ unl: trainx_unl2[ran_from:ran_to], is_training_pl: True, lr_pl: lr }) train_loss_gen += lg if (train_batch % FLAGS.step_print) == 0: writer.add_summary(sm, train_batch) if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0): ran_from = np.random.randint( 0, trainx_unl.shape[0] - FLAGS.batch_size) ran_to = ran_from + FLAGS.batch_size sm = sess.run(sum_op_im, feed_dict={ is_training_pl: True, unl: trainx_unl[ran_from:ran_to] }) writer.add_summary(sm, train_batch) train_batch += 1 train_loss_lab /= nr_batches_train train_loss_unl /= nr_batches_train train_loss_gen /= nr_batches_train train_acc /= nr_batches_train train_j_loss /= nr_batches_train # Testing moving averaged model and raw model if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch - 1): for t in range(nr_batches_test): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: testx[ran_from:ran_to], lbl: testy[ran_from:ran_to], is_training_pl: False } acc, acc_ema = sess.run( [accuracy_classifier, accuracy_ema], feed_dict=feed_dict) test_acc += acc test_acc_ma += acc_ema test_acc /= nr_batches_test test_acc_ma /= nr_batches_test print( "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f " "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f" % (epoch, time.time() - begin, train_loss_gen, train_loss_lab, train_loss_unl, train_acc, test_acc, test_acc_ma))
def main(_): print("\nParameters:") for attr, value in sorted(FLAGS.__flags.items()): print("{}={}".format(attr.lower(), value)) print("") if not os.path.exists(FLAGS.logdir): os.makedirs(FLAGS.logdir) rng = np.random.RandomState(FLAGS.seed) # seed labels trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') trainx_unl = trainx.copy() # select labeled data inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] txs = [] tys = [] for j in range(10): txs.append(trainx[trainy == j][:FLAGS.labeled]) tys.append(trainy[trainy == j][:FLAGS.labeled]) txs = np.concatenate(txs, axis=0) tys = np.concatenate(tys, axis=0) trainx = txs trainy = tys nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) print(trainx.shape) inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='data_input') lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input') is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') learning_rate_pl = tf.placeholder(tf.float32, [], name='adam_learning_rate_pl') logits = classifier(inp, is_training=is_training_pl) loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=lbl) correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), lbl) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_pl) update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # control dependencies for batch norm ops dvars = [ var for var in tf.trainable_variables() if 'classifier' in var.name ] [print(var) for var in dvars] with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss, var_list=dvars) ### ema ### ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay) maintain_averages_op = ema.apply(dvars) with tf.control_dependencies([train_op]): train_dis_op = tf.group(maintain_averages_op) logits_ema = classifier(inp, is_training_pl, getter=get_getter(ema), reuse=True) correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32)) accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32)) def linear_decay(decay_start, decay_end, epoch): return min( -1 / (decay_end - decay_start) * epoch + 1 + decay_start / (decay_end - decay_start), 1) # all_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) # print("list des vars") # [print(var.name) for var in all_var] config = tf.ConfigProto(device_count={'GPU': 0}) config.gpu_options.allow_growth = True with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in tqdm(range(200)): begin = time.time() # randomly permuted minibatches inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] train_loss = train_acc = test_acc = test_acc_ema = 0 lr = FLAGS.learning_rate * linear_decay(100, 200, epoch) for t in tqdm(range(nr_batches_train)): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: trainx[ran_from:ran_to], lbl: trainy[ran_from:ran_to], is_training_pl: True, learning_rate_pl: lr } _, ls, acc = sess.run([train_dis_op, loss, accuracy], feed_dict=feed_dict) train_loss += ls train_acc += acc train_loss /= nr_batches_train train_acc /= nr_batches_train * 100 for t in range(nr_batches_test): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: testx[ran_from:ran_to], lbl: testy[ran_from:ran_to], is_training_pl: False } acc, acc_ema = sess.run([accuracy, accuracy_ema], feed_dict=feed_dict) test_acc += acc test_acc_ema += acc_ema test_acc /= nr_batches_test test_acc_ema /= nr_batches_test tqdm.write( "Epoch %03d | Time = %03ds | lr = %.3e | loss train = %.4f | train acc = %.2f | test acc = %.2f | test acc_ema = %.2f" % (epoch, time.time() - begin, lr, train_loss, train_acc * 100, test_acc * 100, test_acc_ema * 100)) if status_reporter: # report status for ray tune status_reporter(timesteps_total=epoch, mean_accuracy=test_acc)
def main(_): FLAGS._parse_flags() print("\nParameters:") for attr, value in sorted(FLAGS.__flags.items()): print("{}={}".format(attr.lower(), value)) print("") rng = np.random.RandomState(FLAGS.seed) # seed labels trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') # select labeled data inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] txs = [] tys = [] for j in range(10): txs.append(trainx[trainy == j][:FLAGS.labeled]) tys.append(trainy[trainy == j][:FLAGS.labeled]) txs = np.concatenate(txs, axis=0) tys = np.concatenate(tys, axis=0) trainx = txs trainy = tys nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size) nr_batches_test = int(testx.shape[0] / FLAGS.batch_size) print(trainx.shape) inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='data_input') lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input') is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl') learning_rate_pl = tf.placeholder(tf.float32, [], name='adam_learning_rate_pl') classifier = DNN() logits = classifier(inp, is_training=is_training_pl) loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=lbl) correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), lbl) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_pl) update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS) # control dependencies for batch norm ops with tf.control_dependencies(update_ops): train_op = optimizer.minimize(loss) def linear_decay(decay_start, decay_end, epoch): return min( -1 / (decay_end - decay_start) * epoch + 1 + decay_start / (decay_end - decay_start), 1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for epoch in tqdm(range(FLAGS.epoch), disable=not verbose): begin = time.time() # randomly permuted minibatches inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] train_loss = train_acc = test_acc = 0 lr = FLAGS.learning_rate * linear_decay(100, 200, epoch) for t in tqdm(range(nr_batches_train), disable=not verbose): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: trainx[ran_from:ran_to], lbl: trainy[ran_from:ran_to], is_training_pl: True, learning_rate_pl: lr } _, ls, acc = sess.run([train_op, loss, accuracy], feed_dict=feed_dict) train_loss += ls train_acc += acc train_loss /= nr_batches_train train_acc /= nr_batches_train for t in range(nr_batches_test): ran_from = t * FLAGS.batch_size ran_to = (t + 1) * FLAGS.batch_size feed_dict = { inp: testx[ran_from:ran_to], lbl: testy[ran_from:ran_to], is_training_pl: False } test_acc += sess.run(accuracy, feed_dict=feed_dict) test_acc /= nr_batches_test tqdm.write( "Epoch %03d | Time = %03ds | lr = %.4f | loss train = %.4f | train acc = %.4f | test acc = %.4f" % (epoch, time.time() - begin, lr, train_loss, train_acc, test_acc)) if status_reporter: # report status for ray tune status_reporter(timesteps_total=epoch, mean_accuracy=test_acc)
flags.DEFINE_float('adam_beta1', 0.5, 'beta1 in Adam') flags.DEFINE_float('adam_beta2', 0.999, 'beta2 in Adam') flags.DEFINE_integer('n_dis', 1, 'n discrminator train') flags.DEFINE_string('snapshot', '/tmp/snaphots', 'snapshot directory') flags.DEFINE_string('data_dir', './tmp/data/cifar-10-python/', 'data directory') flags.DEFINE_integer('seed', 10, 'seed numpy') flags.DEFINE_integer('labeled', 400, 'labeled data per class') flags.DEFINE_string('logdir', './log', 'log directory') flags.DEFINE_float('reg_w', 1e-3, 'weight regularization') mkdir('tmp') ############################################## trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train') # float [-1 1] images testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test') trainx_unl = trainx.copy() # select labeled data rng = np.random.RandomState(FLAGS.seed) # seed labels inds = rng.permutation(trainx.shape[0]) trainx = trainx[inds] trainy = trainy[inds] txs = [] tys = [] for j in range(10): txs.append(trainx[trainy == j][:FLAGS.labeled]) tys.append(trainy[trainy == j][:FLAGS.labeled]) txs = np.concatenate(txs, axis=0) tys = np.concatenate(tys, axis=0) trainx = txs