def main(_):
    print("\nParameters:")
    for attr,value in tf.app.flags.FLAGS.flag_values_dict().items():
        print("{}={}".format(attr,value))
    print("")

    os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu)

    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir)

    # Random seed
    rng = np.random.RandomState(FLAGS.seed)  # seed labels
    rng_data = np.random.RandomState(rng.randint(0, 2**10))  # seed shuffling

    # load CIFAR-10
    trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir, 'train')  # float [-1 1] images
    testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')
    trainx_unl = trainx.copy()
    trainx_unl2 = trainx.copy()

    if FLAGS.validation:
        split = int(0.1 * trainx.shape[0])
        print("validation enabled")
        testx = trainx[:split]
        testy = trainy[:split]
        trainx = trainx[split:]
        trainy = trainy[split:]

    nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size)
    nr_batches_test = int(testx.shape[0] / FLAGS.batch_size)

    # select labeled data
    inds = rng_data.permutation(trainx.shape[0])
    trainx = trainx[inds]
    trainy = trainy[inds]
    txs = []
    tys = []
    for j in range(10):
        txs.append(trainx[trainy == j][:FLAGS.labeled])
        tys.append(trainy[trainy == j][:FLAGS.labeled])
    txs = np.concatenate(txs, axis=0)
    tys = np.concatenate(tys, axis=0)

    '''construct graph'''
    unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='unlabeled_data_input_pl')
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3], name='labeled_data_input_pl')
    lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl')
    # scalar pl
    lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl')
    acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl')
    acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl')
    acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl')

    random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z')
    generator(random_z, is_training_pl, init=True)  # init of weightnorm weights
    gen_inp = generator(random_z, is_training_pl, init=False, reuse=True)
    discriminator(unl, is_training_pl, init=True)
    logits_lab, _ = discriminator(inp, is_training_pl, init=False, reuse=True)
    logits_gen, layer_fake = discriminator(gen_inp, is_training_pl, init=False, reuse=True)
    logits_unl, layer_real = discriminator(unl, is_training_pl, init=False, reuse=True)

    with tf.name_scope('loss_functions'):
        # discriminator
        l_unl = tf.reduce_logsumexp(logits_unl, axis=1)
        l_gen = tf.reduce_logsumexp(logits_gen, axis=1)
        loss_lab = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl, logits=logits_lab))
        loss_unl = - 0.5 * tf.reduce_mean(l_unl) \
                   + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \
                   + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen))

        # generator
        m1 = tf.reduce_mean(layer_real, axis=0)
        m2 = tf.reduce_mean(layer_fake, axis=0)

        loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab
        loss_gen = tf.reduce_mean(tf.abs(m1 - m2))
        correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32), tf.cast(lbl, tf.int32))
        accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32))


    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        dvars = [var for var in tvars if 'discriminator_model' in var.name]
        gvars = [var for var in tvars if 'generator_model' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_gen = [x for x in update_ops if ('generator_model' in x.name)]
        update_ops_dis = [x for x in update_ops if ('discriminator_model' in x.name)]
        optimizer_dis = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='dis_optimizer')
        optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_pl, beta1=0.5, name='gen_optimizer')

        with tf.control_dependencies(update_ops_gen):
            train_gen_op = optimizer_gen.minimize(loss_gen, var_list=gvars)

        dis_op = optimizer_dis.minimize(loss_dis, var_list=dvars)
        ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay)
        maintain_averages_op = ema.apply(dvars)

        with tf.control_dependencies([dis_op]):
            train_dis_op = tf.group(maintain_averages_op)

        logits_ema, _ = discriminator(inp, is_training_pl, getter=get_getter(ema), reuse=True)
        correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32), tf.cast(lbl, tf.int32))
        accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32))

    with tf.name_scope('summary'):
        with tf.name_scope('discriminator'):
            tf.summary.scalar('loss_discriminator', loss_dis, ['dis'])

        with tf.name_scope('generator'):
            tf.summary.scalar('loss_generator', loss_gen, ['gen'])

        with tf.name_scope('images'):
            tf.summary.image('gen_images', gen_inp, 10, ['image'])

        with tf.name_scope('epoch'):
            tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch'])
            tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema, ['epoch'])
            tf.summary.scalar('accuracy_test', acc_test_pl, ['epoch'])
            tf.summary.scalar('learning_rate', lr_pl, ['epoch'])

        sum_op_dis = tf.summary.merge_all('dis')
        sum_op_gen = tf.summary.merge_all('gen')
        sum_op_im = tf.summary.merge_all('image')
        sum_op_epoch = tf.summary.merge_all('epoch')

    # training global varialble
    global_epoch = tf.Variable(0, trainable=False, name='global_epoch')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    inc_global_step = tf.assign(global_step, global_step+1)
    inc_global_epoch = tf.assign(global_epoch, global_epoch+1)

    # op initializer for session manager
    init_gen = [var.initializer for var in gvars][:-3]
    with tf.control_dependencies(init_gen):
        op = tf.global_variables_initializer()
    init_feed_dict = {inp: trainx_unl[:FLAGS.batch_size], unl: trainx_unl[:FLAGS.batch_size], is_training_pl: True}

    sv = tf.train.Supervisor(logdir=FLAGS.logdir, global_step=global_epoch, summary_op=None, save_model_secs=0,
                             init_op=op,init_feed_dict=init_feed_dict)

    '''//////training //////'''
    print('start training')
    with sv.managed_session() as sess:
        tf.set_random_seed(rng.randint(2 ** 10))
        print('\ninitialization done')
        print('Starting training from epoch :%d, step:%d \n'%(sess.run(global_epoch),sess.run(global_step)))

        writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)

        while not sv.should_stop():
            epoch = sess.run(global_epoch)
            train_batch = sess.run(global_step)

            if (epoch >= FLAGS.epoch):
                print("Training done")
                sv.stop()
                break

            begin = time.time()
            train_loss_lab=train_loss_unl=train_loss_gen=train_acc=test_acc=test_acc_ma= 0
            lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start,FLAGS.epoch,epoch)

            # construct randomly permuted batches
            trainx = []
            trainy = []
            for t in range(int(np.ceil(trainx_unl.shape[0] / float(txs.shape[0])))):  # same size lbl and unlb
                inds = rng.permutation(txs.shape[0])
                trainx.append(txs[inds])
                trainy.append(tys[inds])
            trainx = np.concatenate(trainx, axis=0)
            trainy = np.concatenate(trainy, axis=0)
            trainx_unl = trainx_unl[rng.permutation(trainx_unl.shape[0])]  # shuffling unl dataset
            trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])]

            # training
            for t in range(nr_batches_train):

                display_progression_epoch(t, nr_batches_train)
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size

                # train discriminator
                feed_dict = {unl: trainx_unl[ran_from:ran_to],
                             is_training_pl: True,
                             inp: trainx[ran_from:ran_to],
                             lbl: trainy[ran_from:ran_to],
                             lr_pl: lr}
                _, acc, lu, lb, sm = sess.run([train_dis_op, accuracy_classifier, loss_lab, loss_unl, sum_op_dis],
                                                  feed_dict=feed_dict)
                train_loss_unl += lu
                train_loss_lab += lb
                train_acc += acc
                if (train_batch % FLAGS.step_print) == 0:
                    writer.add_summary(sm, train_batch)

                # train generator
                _, lg, sm = sess.run([train_gen_op, loss_gen, sum_op_gen], feed_dict={unl: trainx_unl2[ran_from:ran_to],
                                                                                      is_training_pl: True,
                                                                                      lr_pl: lr})
                train_loss_gen += lg
                if (train_batch % FLAGS.step_print) == 0:
                    writer.add_summary(sm, train_batch)

                if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0):
                    ran_from = np.random.randint(0, trainx_unl.shape[0] - FLAGS.batch_size)
                    ran_to = ran_from + FLAGS.batch_size
                    sm = sess.run(sum_op_im,
                                  feed_dict={is_training_pl: True, unl: trainx_unl[ran_from:ran_to]})
                    writer.add_summary(sm, train_batch)

                train_batch += 1
                sess.run(inc_global_step)

            train_loss_lab /= nr_batches_train
            train_loss_unl /= nr_batches_train
            train_loss_gen /= nr_batches_train
            train_acc /= nr_batches_train

            # Testing moving averaged model and raw model
            if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch-1):
                for t in range(nr_batches_test):
                    ran_from = t * FLAGS.batch_size
                    ran_to = (t + 1) * FLAGS.batch_size
                    feed_dict = {inp: testx[ran_from:ran_to],
                                 lbl: testy[ran_from:ran_to],
                                 is_training_pl: False}
                    acc, acc_ema = sess.run([accuracy_classifier, accuracy_ema], feed_dict=feed_dict)
                    test_acc += acc
                    test_acc_ma += acc_ema
                test_acc /= nr_batches_test
                test_acc_ma /= nr_batches_test

                sum = sess.run(sum_op_epoch, feed_dict={acc_train_pl: train_acc,
                                                        acc_test_pl: test_acc,
                                                        acc_test_pl_ema: test_acc_ma,
                                                        lr_pl: lr})
                writer.add_summary(sum, epoch)

                print(
                    "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f "
                    "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f"
                    % (epoch, time.time() - begin, train_loss_gen, train_loss_lab, train_loss_unl, train_acc,
                       test_acc, test_acc_ma))

            sess.run(inc_global_epoch)

            # save snapshots of model
            if ((epoch % FLAGS.freq_save == 0) & (epoch!=0) ) | (epoch == FLAGS.epoch-1):
                string = 'model-' + str(epoch)
                save_path = os.path.join(FLAGS.logdir, string)
                sv.saver.save(sess, save_path)
                print("Model saved in file: %s" % (save_path))
Пример #2
0
def main(_):
    print("\nParameters:")
    for attr, value in tf.app.flags.FLAGS.flag_values_dict().items():
        print("{}={}".format(attr, value))
    print("")

    os.environ["CUDA_VISIBLE_DEVICES"] = str(FLAGS.gpu)

    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir)

    # Random seed
    rng = np.random.RandomState(FLAGS.seed)  # seed labels
    rng_data = np.random.RandomState(rng.randint(0, 2**10))  # seed shuffling

    # load CIFAR-10
    trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir,
                                                'train')  # float [-1 1] images
    testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')
    trainx_unl = trainx.copy()
    trainx_unl2 = trainx.copy()

    if FLAGS.validation:
        split = int(0.1 * trainx.shape[0])
        print("validation enabled")
        testx = trainx[:split]
        testy = trainy[:split]
        trainx = trainx[split:]
        trainy = trainy[split:]

    nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size)
    nr_batches_test = int(testx.shape[0] / FLAGS.batch_size)

    # select labeled data
    inds = rng_data.permutation(trainx.shape[0])
    trainx = trainx[inds]
    trainy = trainy[inds]
    txs = []
    tys = []
    for j in range(10):
        txs.append(trainx[trainy == j][:FLAGS.labeled])
        tys.append(trainy[trainy == j][:FLAGS.labeled])
    txs = np.concatenate(txs, axis=0)
    tys = np.concatenate(tys, axis=0)
    '''construct graph'''
    unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='unlabeled_data_input_pl')
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='labeled_data_input_pl')
    lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl')
    # scalar pl
    lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl')
    acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl')
    acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl')
    acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl')

    random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z')
    generator(random_z, is_training_pl,
              init=True)  # init of weightnorm weights
    gen_inp = generator(random_z, is_training_pl, init=False, reuse=True)
    discriminator(unl, is_training_pl, init=True)
    logits_lab, _ = discriminator(inp, is_training_pl, init=False, reuse=True)
    logits_gen, layer_fake = discriminator(gen_inp,
                                           is_training_pl,
                                           init=False,
                                           reuse=True)
    logits_unl, layer_real = discriminator(unl,
                                           is_training_pl,
                                           init=False,
                                           reuse=True)

    with tf.name_scope('loss_functions'):
        # discriminator
        l_unl = tf.reduce_logsumexp(logits_unl, axis=1)
        l_gen = tf.reduce_logsumexp(logits_gen, axis=1)
        loss_lab = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl,
                                                           logits=logits_lab))
        loss_unl = - 0.5 * tf.reduce_mean(l_unl) \
                   + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \
                   + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen))

        # generator
        m1 = tf.reduce_mean(layer_real, axis=0)
        m2 = tf.reduce_mean(layer_fake, axis=0)

        loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab
        loss_gen = tf.reduce_mean(tf.abs(m1 - m2))
        correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32),
                                tf.cast(lbl, tf.int32))
        accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()
        dvars = [var for var in tvars if 'discriminator_model' in var.name]
        gvars = [var for var in tvars if 'generator_model' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_gen = [
            x for x in update_ops if ('generator_model' in x.name)
        ]
        update_ops_dis = [
            x for x in update_ops if ('discriminator_model' in x.name)
        ]
        optimizer_dis = tf.train.AdamOptimizer(learning_rate=lr_pl,
                                               beta1=0.5,
                                               name='dis_optimizer')
        optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_pl,
                                               beta1=0.5,
                                               name='gen_optimizer')

        with tf.control_dependencies(update_ops_gen):
            train_gen_op = optimizer_gen.minimize(loss_gen, var_list=gvars)

        dis_op = optimizer_dis.minimize(loss_dis, var_list=dvars)
        ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay)
        maintain_averages_op = ema.apply(dvars)

        with tf.control_dependencies([dis_op]):
            train_dis_op = tf.group(maintain_averages_op)

        logits_ema, _ = discriminator(inp,
                                      is_training_pl,
                                      getter=get_getter(ema),
                                      reuse=True)
        correct_pred_ema = tf.equal(
            tf.cast(tf.argmax(logits_ema, 1), tf.int32),
            tf.cast(lbl, tf.int32))
        accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32))

    with tf.name_scope('summary'):
        with tf.name_scope('discriminator'):
            tf.summary.scalar('loss_discriminator', loss_dis, ['dis'])

        with tf.name_scope('generator'):
            tf.summary.scalar('loss_generator', loss_gen, ['gen'])

        with tf.name_scope('images'):
            tf.summary.image('gen_images', gen_inp, 10, ['image'])

        with tf.name_scope('epoch'):
            tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch'])
            tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema,
                              ['epoch'])
            tf.summary.scalar('accuracy_test', acc_test_pl, ['epoch'])
            tf.summary.scalar('learning_rate', lr_pl, ['epoch'])

        sum_op_dis = tf.summary.merge_all('dis')
        sum_op_gen = tf.summary.merge_all('gen')
        sum_op_im = tf.summary.merge_all('image')
        sum_op_epoch = tf.summary.merge_all('epoch')

    # training global varialble
    global_epoch = tf.Variable(0, trainable=False, name='global_epoch')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    inc_global_step = tf.assign(global_step, global_step + 1)
    inc_global_epoch = tf.assign(global_epoch, global_epoch + 1)

    # op initializer for session manager
    init_gen = [var.initializer for var in gvars][:-3]
    with tf.control_dependencies(init_gen):
        op = tf.global_variables_initializer()
    init_feed_dict = {
        inp: trainx_unl[:FLAGS.batch_size],
        unl: trainx_unl[:FLAGS.batch_size],
        is_training_pl: True
    }

    sv = tf.train.Supervisor(logdir=FLAGS.logdir,
                             global_step=global_epoch,
                             summary_op=None,
                             save_model_secs=0,
                             init_op=op,
                             init_feed_dict=init_feed_dict)
    '''//////training //////'''
    print('start training')
    with sv.managed_session() as sess:
        tf.set_random_seed(rng.randint(2**10))
        print('\ninitialization done')
        print('Starting training from epoch :%d, step:%d \n' %
              (sess.run(global_epoch), sess.run(global_step)))

        writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)

        while not sv.should_stop():
            epoch = sess.run(global_epoch)
            train_batch = sess.run(global_step)

            if (epoch >= FLAGS.epoch):
                print("Training done")
                sv.stop()
                break

            begin = time.time()
            train_loss_lab = train_loss_unl = train_loss_gen = train_acc = test_acc = test_acc_ma = 0
            lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start,
                                                    FLAGS.epoch, epoch)

            # construct randomly permuted batches
            trainx = []
            trainy = []
            for t in range(
                    int(np.ceil(
                        trainx_unl.shape[0] /
                        float(txs.shape[0])))):  # same size lbl and unlb
                inds = rng.permutation(txs.shape[0])
                trainx.append(txs[inds])
                trainy.append(tys[inds])
            trainx = np.concatenate(trainx, axis=0)
            trainy = np.concatenate(trainy, axis=0)
            trainx_unl = trainx_unl[rng.permutation(
                trainx_unl.shape[0])]  # shuffling unl dataset
            trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])]

            # training
            for t in range(nr_batches_train):

                display_progression_epoch(t, nr_batches_train)
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size

                # train discriminator
                feed_dict = {
                    unl: trainx_unl[ran_from:ran_to],
                    is_training_pl: True,
                    inp: trainx[ran_from:ran_to],
                    lbl: trainy[ran_from:ran_to],
                    lr_pl: lr
                }
                _, acc, lu, lb, sm = sess.run([
                    train_dis_op, accuracy_classifier, loss_lab, loss_unl,
                    sum_op_dis
                ],
                                              feed_dict=feed_dict)
                train_loss_unl += lu
                train_loss_lab += lb
                train_acc += acc
                if (train_batch % FLAGS.step_print) == 0:
                    writer.add_summary(sm, train_batch)

                # train generator
                _, lg, sm = sess.run(
                    [train_gen_op, loss_gen, sum_op_gen],
                    feed_dict={
                        unl: trainx_unl2[ran_from:ran_to],
                        is_training_pl: True,
                        lr_pl: lr
                    })
                train_loss_gen += lg
                if (train_batch % FLAGS.step_print) == 0:
                    writer.add_summary(sm, train_batch)

                if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0):
                    ran_from = np.random.randint(
                        0, trainx_unl.shape[0] - FLAGS.batch_size)
                    ran_to = ran_from + FLAGS.batch_size
                    sm = sess.run(sum_op_im,
                                  feed_dict={
                                      is_training_pl: True,
                                      unl: trainx_unl[ran_from:ran_to]
                                  })
                    writer.add_summary(sm, train_batch)

                train_batch += 1
                sess.run(inc_global_step)

            train_loss_lab /= nr_batches_train
            train_loss_unl /= nr_batches_train
            train_loss_gen /= nr_batches_train
            train_acc /= nr_batches_train

            # Testing moving averaged model and raw model
            if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch - 1):
                for t in range(nr_batches_test):
                    ran_from = t * FLAGS.batch_size
                    ran_to = (t + 1) * FLAGS.batch_size
                    feed_dict = {
                        inp: testx[ran_from:ran_to],
                        lbl: testy[ran_from:ran_to],
                        is_training_pl: False
                    }
                    acc, acc_ema = sess.run(
                        [accuracy_classifier, accuracy_ema],
                        feed_dict=feed_dict)
                    test_acc += acc
                    test_acc_ma += acc_ema
                test_acc /= nr_batches_test
                test_acc_ma /= nr_batches_test

                sum = sess.run(sum_op_epoch,
                               feed_dict={
                                   acc_train_pl: train_acc,
                                   acc_test_pl: test_acc,
                                   acc_test_pl_ema: test_acc_ma,
                                   lr_pl: lr
                               })
                writer.add_summary(sum, epoch)

                print(
                    "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f "
                    "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f"
                    % (epoch, time.time() - begin, train_loss_gen,
                       train_loss_lab, train_loss_unl, train_acc, test_acc,
                       test_acc_ma))

            sess.run(inc_global_epoch)

            # save snapshots of model
            if ((epoch % FLAGS.freq_save == 0) &
                (epoch != 0)) | (epoch == FLAGS.epoch - 1):
                string = 'model-' + str(epoch)
                save_path = os.path.join(FLAGS.logdir, string)
                sv.saver.save(sess, save_path)
                print("Model saved in file: %s" % (save_path))
Пример #3
0
def main(_):
    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir)

    # Random seed
    rng = np.random.RandomState(FLAGS.seed)  # seed labels
    rng_data = np.random.RandomState(FLAGS.seed_data)  # seed shuffling

    # load CIFAR-10
    trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir,
                                                'train')  # float [-1 1] images
    testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')
    trainx_unl = trainx.copy()
    trainx_unl2 = trainx.copy()
    nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size)
    nr_batches_test = int(testx.shape[0] / FLAGS.batch_size)

    # select labeled data
    inds = rng_data.permutation(trainx.shape[0])
    trainx = trainx[inds]
    trainy = trainy[inds]
    txs = []
    tys = []
    for j in range(10):
        txs.append(trainx[trainy == j][:FLAGS.labeled])
        tys.append(trainy[trainy == j][:FLAGS.labeled])
    txs = np.concatenate(txs, axis=0)
    tys = np.concatenate(tys, axis=0)

    print("Data:")
    print('train examples %d, batch %d, test examples %d, batch %d' \
          % (trainx.shape[0], nr_batches_train, testx.shape[0], nr_batches_test))
    print('histogram train', np.histogram(trainy, bins=10)[0])
    print('histogram test ', np.histogram(testy, bins=10)[0])
    print("histogram labeled", np.histogram(tys, bins=10)[0])
    print("")
    '''construct graph'''
    print('constructing graph')
    unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='unlabeled_data_input_pl')
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='labeled_data_input_pl')
    lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl')
    # scalar pl
    lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl')
    acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl')
    acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl')
    acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl')
    kl_weight = tf.placeholder(tf.float32, [], 'kl_weight')

    random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z')
    perturb = tf.random_normal([FLAGS.batch_size, 100], mean=0, stddev=0.01)
    random_z_pert = random_z + FLAGS.scale * perturb / (
        tf.expand_dims(tf.norm(perturb, axis=1), axis=1) * tf.ones([1, 100]))
    generator(random_z, is_training_pl,
              init=True)  # init of weightnorm weights
    gen_inp = generator(random_z, is_training_pl, init=False, reuse=True)
    gen_inp_pert = generator(random_z_pert,
                             is_training_pl,
                             init=False,
                             reuse=True)

    discriminator(unl, is_training_pl, init=True)
    logits_lab, _ = discriminator(inp, is_training_pl, init=False, reuse=True)
    logits_gen, layer_fake = discriminator(gen_inp,
                                           is_training_pl,
                                           init=False,
                                           reuse=True)
    logits_unl, layer_real = discriminator(unl,
                                           is_training_pl,
                                           init=False,
                                           reuse=True)
    logits_gen_perturb, layer_fake_perturb = discriminator(gen_inp_pert,
                                                           is_training_pl,
                                                           init=False,
                                                           reuse=True)

    with tf.name_scope('loss_functions'):
        l_unl = tf.reduce_logsumexp(logits_unl, axis=1)
        l_gen = tf.reduce_logsumexp(logits_gen, axis=1)
        # discriminator
        loss_lab = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl,
                                                           logits=logits_lab))
        loss_unl = - 0.5 * tf.reduce_mean(l_unl) \
                   + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \
                   + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen))

        # generator
        m1 = tf.reduce_mean(layer_real, axis=0)
        m2 = tf.reduce_mean(layer_fake, axis=0)

        j_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(logits_gen - logits_gen_perturb), axis=1))

        if FLAGS.nabla:
            loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab + kl_weight * j_loss
            loss_gen = tf.reduce_mean(tf.abs(m1 - m2))
            print('manifold reg enabled')
        else:
            loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab
            loss_gen = tf.reduce_mean(tf.abs(m1 - m2))

        correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32),
                                tf.cast(lbl, tf.int32))
        accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    with tf.name_scope('optimizers'):
        # control op dependencies for batch norm and trainable variables
        tvars = tf.trainable_variables()

        dvars = [var for var in tvars if 'discriminator_model' in var.name]
        gvars = [var for var in tvars if 'generator_model' in var.name]

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        update_ops_gen = [
            x for x in update_ops if ('generator_model' in x.name)
        ]
        update_ops_dis = [
            x for x in update_ops if ('discriminator_model' in x.name)
        ]

        optimizer_dis = tf.train.AdamOptimizer(learning_rate=lr_pl,
                                               beta1=0.5,
                                               name='dis_optimizer')
        optimizer_gen = tf.train.AdamOptimizer(learning_rate=lr_pl,
                                               beta1=0.5,
                                               name='gen_optimizer')

        with tf.control_dependencies(update_ops_gen):
            train_gen_op = optimizer_gen.minimize(loss_gen, var_list=gvars)

        dis_op = optimizer_dis.minimize(loss_dis, var_list=dvars)

        ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay)
        maintain_averages_op = ema.apply(dvars)

        with tf.control_dependencies([dis_op]):
            train_dis_op = tf.group(maintain_averages_op)

        logits_ema, _ = discriminator(inp,
                                      is_training_pl,
                                      getter=get_getter(ema),
                                      reuse=True)
        correct_pred_ema = tf.equal(
            tf.cast(tf.argmax(logits_ema, 1), tf.int32),
            tf.cast(lbl, tf.int32))
        accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32))

    with tf.name_scope('summary'):
        with tf.name_scope('discriminator'):
            tf.summary.scalar('loss_discriminator', loss_dis, ['dis'])
            tf.summary.scalar('kl_loss', j_loss, ['dis'])

        with tf.name_scope('generator'):
            tf.summary.scalar('loss_generator', loss_gen, ['gen'])

        with tf.name_scope('images'):
            tf.summary.image('gen_images', gen_inp, 10, ['image'])

        with tf.name_scope('epoch'):
            tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch'])
            tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema,
                              ['epoch'])
            tf.summary.scalar('accuracy_test_raw', acc_test_pl, ['epoch'])
            tf.summary.scalar('learning_rate', lr_pl, ['epoch'])
            tf.summary.scalar('j_weight', kl_weight, ['epoch'])

        sum_op_dis = tf.summary.merge_all('dis')
        sum_op_gen = tf.summary.merge_all('gen')
        sum_op_im = tf.summary.merge_all('image')
        sum_op_epoch = tf.summary.merge_all('epoch')

    # training global varialble
    global_epoch = tf.Variable(0, trainable=False, name='global_epoch')
    global_step = tf.Variable(0, trainable=False, name='global_step')
    inc_global_step = tf.assign(global_step, global_step + 1)
    inc_global_epoch = tf.assign(global_epoch, global_epoch + 1)

    # op initializer for session manager
    init_gen = [var.initializer for var in gvars][:-3]
    with tf.control_dependencies(init_gen):
        op = tf.global_variables_initializer()
    init_feed_dict = {
        inp: trainx_unl[:FLAGS.batch_size],
        unl: trainx_unl[:FLAGS.batch_size],
        is_training_pl: True,
        kl_weight: 0
    }

    sv = tf.train.Supervisor(logdir=FLAGS.logdir,
                             global_step=global_epoch,
                             summary_op=None,
                             save_model_secs=0,
                             init_op=op,
                             init_feed_dict=init_feed_dict)

    inception_scores = []
    '''//////training //////'''
    print('start training')
    with sv.managed_session() as sess:
        tf.set_random_seed(rng.randint(2**10))
        print('\ninitialization done')
        print('Starting training from epoch :%d, step:%d \n' %
              (sess.run(global_epoch), sess.run(global_step)))

        writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)

        while not sv.should_stop():
            epoch = sess.run(global_epoch)
            train_batch = sess.run(global_step)

            if (epoch >= FLAGS.epoch):
                print("Training done")
                sv.stop()
                break

            begin = time.time()
            train_loss_lab = train_loss_unl = train_loss_gen = train_acc = test_acc = test_acc_ma = train_j_loss = 0
            lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start,
                                                    FLAGS.epoch, epoch)
            klw = FLAGS.nabla_w

            # construct randomly permuted batches
            trainx = []
            trainy = []
            for t in range(
                    int(np.ceil(
                        trainx_unl.shape[0] /
                        float(txs.shape[0])))):  # same size lbl and unlb
                inds = rng.permutation(txs.shape[0])
                trainx.append(txs[inds])
                trainy.append(tys[inds])
            trainx = np.concatenate(trainx, axis=0)
            trainy = np.concatenate(trainy, axis=0)
            trainx_unl = trainx_unl[rng.permutation(
                trainx_unl.shape[0])]  # shuffling unl dataset
            trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])]

            # training
            for t in range(nr_batches_train):

                display_progression_epoch(t, nr_batches_train)
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size

                # train discriminator
                feed_dict = {
                    unl: trainx_unl[ran_from:ran_to],
                    is_training_pl: True,
                    inp: trainx[ran_from:ran_to],
                    lbl: trainy[ran_from:ran_to],
                    lr_pl: lr,
                    kl_weight: klw
                }
                _, acc, lu, lb, jl, sm = sess.run([
                    train_dis_op, accuracy_classifier, loss_lab, loss_unl,
                    j_loss, sum_op_dis
                ],
                                                  feed_dict=feed_dict)
                train_loss_unl += lu
                train_loss_lab += lb
                train_acc += acc
                train_j_loss += jl
                if (train_batch % FLAGS.step_print) == 0:
                    writer.add_summary(sm, train_batch)

                # train generator
                _, lg, sm = sess.run(
                    [train_gen_op, loss_gen, sum_op_gen],
                    feed_dict={
                        unl: trainx_unl2[ran_from:ran_to],
                        is_training_pl: True,
                        lr_pl: lr,
                        kl_weight: klw
                    })
                train_loss_gen += lg
                if (train_batch % FLAGS.step_print) == 0:
                    writer.add_summary(sm, train_batch)

                if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0):
                    ran_from = np.random.randint(
                        0, trainx_unl.shape[0] - FLAGS.batch_size)
                    ran_to = ran_from + FLAGS.batch_size
                    sm = sess.run(sum_op_im,
                                  feed_dict={
                                      is_training_pl: True,
                                      unl: trainx_unl[ran_from:ran_to]
                                  })
                    writer.add_summary(sm, train_batch)

                train_batch += 1
                sess.run(inc_global_step)

            train_loss_lab /= nr_batches_train
            train_loss_unl /= nr_batches_train
            train_loss_gen /= nr_batches_train
            train_acc /= nr_batches_train
            train_j_loss /= nr_batches_train

            # Testing moving averaged model and raw model
            if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch - 1):
                for t in range(nr_batches_test):
                    ran_from = t * FLAGS.batch_size
                    ran_to = (t + 1) * FLAGS.batch_size
                    feed_dict = {
                        inp: testx[ran_from:ran_to],
                        lbl: testy[ran_from:ran_to],
                        is_training_pl: False
                    }
                    acc, acc_ema = sess.run(
                        [accuracy_classifier, accuracy_ema],
                        feed_dict=feed_dict)
                    test_acc += acc
                    test_acc_ma += acc_ema
                test_acc /= nr_batches_test
                test_acc_ma /= nr_batches_test

                sum = sess.run(sum_op_epoch,
                               feed_dict={
                                   acc_train_pl: train_acc,
                                   acc_test_pl: test_acc,
                                   acc_test_pl_ema: test_acc_ma,
                                   lr_pl: lr,
                                   kl_weight: klw
                               })
                writer.add_summary(sum, epoch)

                print(
                    "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f "
                    "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f"
                    % (epoch, time.time() - begin, train_loss_gen,
                       train_loss_lab, train_loss_unl, train_acc, test_acc,
                       test_acc_ma))

            sess.run(inc_global_epoch)

            # save snap shot of model
            if ((epoch % FLAGS.freq_save == 0) &
                (epoch != 0)) | (epoch == FLAGS.epoch - 1):
                string = 'model-' + str(epoch)
                save_path = os.path.join(FLAGS.logdir, string)
                sv.saver.save(sess, save_path)
                print("Model saved in file: %s" % (save_path))

            print("saving images...")
            sample_images = sess.run(gen_inp,
                                     feed_dict={is_training_pl: False})
            save_images(sample_images,
                        os.path.join(FLAGS.logdir, '{:06d}.png'.format(epoch)))
            print('images saved @ ' +
                  os.path.join(FLAGS.logdir, '{:06d}.png'.format(epoch)))
            num_images_to_eval = 50000
            eval_images = []
            num_batches = num_images_to_eval // FLAGS.batch_size + 1
            print("Calculating Inception Score. Sampling {} images...".format(
                num_images_to_eval))
            np.random.seed(0)
            for _ in range(num_batches):
                images = sess.run(gen_inp, feed_dict={is_training_pl: False})
                eval_images.append(images)
            np.random.seed()
            eval_images = np.vstack(eval_images)
            eval_images = eval_images[:num_images_to_eval]
            eval_images = np.clip((eval_images + 1.0) * 127.5, 0.0,
                                  255.0).astype(np.uint8)
            # Calc Inception score
            eval_images = list(eval_images)
            inception_score_mean, inception_score_std = get_inception_score(
                eval_images)
            print("Inception Score: Mean = {} \tStd = {}.".format(
                inception_score_mean, inception_score_std))
            inception_scores.append(
                dict(mean=inception_score_mean, std=inception_score_std))
            with open(INCEPTION_FILENAME, 'wb') as f:
                pickle.dump(inception_scores, f)
Пример #4
0
def main(_):
    if not os.path.exists(FLAGS.log_dir):
        os.mkdir(FLAGS.log_dir)

    # Random seed
    rng = np.random.RandomState(FLAGS.seed)

    # load CIFAR-10
    trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir,
                                                'train')  # float [0 1] images
    testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')
    # overfitting test
    # trainx = trainx[:10000]
    # trainy = trainy[:10000]

    nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size)
    nr_batches_test = int(testx.shape[0] / FLAGS.batch_size)

    # whitten data
    print('Starting preprocessing')
    begin = time.time()
    m = np.mean(trainx, axis=0)
    std = np.mean(trainx, axis=0)
    trainx -= m
    # trainx /= std
    testx -= m
    # testx /= std
    trainx, testx = zca_whiten(trainx, testx, epsilon=1e-8)
    print('Preprocessing done in : %ds' % (time.time() - begin))
    '''construct graph'''
    inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='data_input')
    lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input')
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    accuracy_epoch = tf.placeholder(tf.float32, [], name='epoch_pl')
    adam_learning_rate_pl = tf.placeholder(tf.float32, [],
                                           name='adam_learning_rate_pl')
    adam_momentum_pl = tf.placeholder(tf.float32, [], name='adam_momentum_pl')

    with tf.variable_scope('cnn_model'):
        logits = cifar_model.inference(inp, is_training_pl)

    with tf.name_scope('loss_function'):
        loss = tf.losses.sparse_softmax_cross_entropy(logits=logits,
                                                      labels=lbl)
        correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32),
                                      lbl)
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        eval_correct = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))

    optimizer = tf.train.AdamOptimizer(learning_rate=adam_learning_rate_pl,
                                       beta1=adam_momentum_pl)

    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # control dependencies for batch norm ops
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)

    # Summaries
    with tf.name_scope('per_batch_summary'):
        tf.summary.scalar('loss', loss, ['batch'])
        tf.summary.scalar('accuracy', accuracy, ['batch'])
        tf.summary.scalar('adam learning rate', adam_learning_rate_pl,
                          ['batch'])
        tf.summary.scalar('adam momentum', adam_momentum_pl, ['batch'])

    with tf.name_scope('per_epoch_summary'):
        tf.summary.scalar('accuracy epoch', accuracy_epoch, ['per_epoch'])
        tf.summary.merge(
            tf.contrib.layers.summarize_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES), ['per_epoch'])
        with tf.name_scope('input_data'):
            tf.summary.image('input image', inp, 10, ['per_epoch'])
            tf.summary.histogram('first input image', tf.reshape(inp[0], [-1]),
                                 ['per_epoch'])
            tf.summary.histogram('input labels', lbl, ['per_epoch'])
            tf.summary.histogram('output logits', tf.argmax(logits, axis=0),
                                 ['per_epoch'])

    sum_op = tf.summary.merge_all('batch')
    sum_epoch_op = tf.summary.merge_all('per_epoch')
    '''//////perform training //////'''
    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init, {adam_momentum_pl: 0.9})
        train_batch = 0
        train_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.log_dir, 'train'), sess.graph)
        test_writer = tf.summary.FileWriter(
            os.path.join(FLAGS.log_dir, 'test'), sess.graph)

        for epoch in tqdm(range(200)):
            begin = time.time()

            # randomly permuted minibatches
            inds = rng.permutation(trainx.shape[0])
            trainx = trainx[inds]
            trainy = trainy[inds]

            train_loss, train_tp, test_tp = [0, 0, 0]

            for t in range(nr_batches_train):
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size
                feed_dict = {
                    inp: trainx[ran_from:ran_to],
                    lbl: trainy[ran_from:ran_to],
                    is_training_pl: True,
                    adam_learning_rate_pl: decayed_lr(epoch),
                    adam_momentum_pl: momentum(epoch)
                }

                _, ls, tp, sm = sess.run(
                    [train_op, loss, eval_correct, sum_op],
                    feed_dict=feed_dict)

                train_loss += ls
                train_tp += tp
                train_batch += 1
                train_writer.add_summary(sm, train_batch)

            train_loss /= nr_batches_train
            train_tp /= trainx.shape[0]

            for t in range(nr_batches_test):
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size
                feed_dict = {
                    inp: testx[ran_from:ran_to],
                    lbl: testy[ran_from:ran_to],
                    is_training_pl: False
                }

                test_tp += sess.run(eval_correct, feed_dict=feed_dict)

            test_tp /= testx.shape[0]
            '''/////epoch summary/////'''
            sm = sess.run(
                sum_epoch_op, {
                    accuracy_epoch: train_tp,
                    inp: trainx[:FLAGS.batch_size],
                    lbl: trainy[:FLAGS.batch_size],
                    is_training_pl: False
                })
            train_writer.add_summary(sm, epoch)
            x = np.random.randint(
                0, testx.shape[0] -
                FLAGS.batch_size)  # random batch extracted in testx
            sm = sess.run(
                sum_epoch_op, {
                    accuracy_epoch: test_tp,
                    inp: testx[x:x + FLAGS.batch_size],
                    lbl: testy[x:x + FLAGS.batch_size],
                    is_training_pl: False
                })
            test_writer.add_summary(sm, epoch)
            # print("Epoch %d--Batch %d--Time = %ds | loss train = %.4f | train acc = %.4f | test acc = %.4f" %
            #       (epoch, train_batch, time.time() - begin, train_loss, train_tp, test_tp))
            tqdm.write(
                "Epoch %d--Batch %d--Time = %ds | loss train = %.4f | train acc = %.4f | test acc = %.4f"
                % (epoch, train_batch, time.time() - begin, train_loss,
                   train_tp, test_tp))
Пример #5
0
def main(_):
    print("\nParameters:")
    for attr, value in FLAGS.__flags.items():
        print("{}={}".format(attr.lower(), value))
    print("")
    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir)
    rng = np.random.RandomState(FLAGS.seed)  # seed labels

    trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir,
                                                'train')  # float [-1 1] images
    testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')

    # select labeled data
    inds = rng.permutation(trainx.shape[0])
    trainx = trainx[inds]
    trainy = trainy[inds]
    print("first labels trainy: ", trainy[:10])

    txs = []
    tys = []
    for j in range(10):
        txs.append(trainx[trainy == j][:FLAGS.labeled])
        tys.append(trainy[trainy == j][:FLAGS.labeled])
    txs = np.concatenate(txs, axis=0)
    tys = np.concatenate(tys, axis=0)
    trainx = txs
    trainy = tys
    nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size)
    nr_batches_test = int(testx.shape[0] / FLAGS.batch_size)
    print("trainx shape:", trainx.shape)

    # placeholder model
    inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='data_input')
    lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input')
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    gan_is_training_pl = tf.placeholder(tf.bool, [], name='gan_is_training_pl')
    learning_rate_pl = tf.placeholder(tf.float16, [],
                                      name='adam_learning_rate_pl')

    acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl')
    acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl')
    acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl')

    generator = DCGANGenerator(batch_size=FLAGS.mc_size)
    latent_dim = generator.generate_noise().shape[1]
    z = tf.placeholder(tf.float32, shape=[FLAGS.mc_size, latent_dim])

    if not FLAGS.tiny_cnn:
        from dnn import classifier as classifier
        print("standard cnn loaded")
    else:
        from dnn import tiny_classifier as classifier
        print("tiny cnn loaded")

    x_hat = generator(z, is_training=gan_is_training_pl)
    logits = classifier(inp, is_training=is_training_pl)
    logits_gen = classifier(x_hat, is_training=is_training_pl, reuse=True)

    def get_jacobian(y, x):
        with tf.name_scope("jacob"):
            grads = tf.stack(
                [tf.gradients(yi, x)[0] for yi in tf.unstack(y, axis=1)],
                axis=2)
        return grads

    if FLAGS.grad == 'stochastic':
        print('stochastic reg enabled ...')
        perturb = tf.random_normal([FLAGS.mc_size, latent_dim],
                                   mean=0,
                                   stddev=0.01)
        z_pert = z + FLAGS.scale * perturb / (tf.expand_dims(
            tf.norm(perturb, axis=1), axis=1) * tf.ones([1, latent_dim]))
        x_pert = generator(z_pert, is_training=gan_is_training_pl, reuse=True)
        logits_gen_perturb = classifier(x_pert,
                                        is_training=is_training_pl,
                                        reuse=True)
        j_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(logits_gen - logits_gen_perturb), axis=1))
        tf.reduce_mean(
            tf.reduce_sum(tf.square(get_jacobian(logits_gen, z)), axis=[1, 2]))

    elif FLAGS.grad == 'stochastic_v2':
        print('stochastic v2 reg enabled ...')
        perturb = tf.nn.l2_normalize(tf.random_normal(
            [FLAGS.mc_size, latent_dim], mean=0, stddev=0.01),
                                     dim=[1])
        x_pert = generator(z + FLAGS.scale * perturb,
                           is_training=gan_is_training_pl,
                           reuse=True)
        logits_gen_perturb = classifier(x_pert,
                                        is_training=is_training_pl,
                                        reuse=True)
        j_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(logits_gen - logits_gen_perturb), axis=1))

    elif FLAGS.grad == 'isotropic_mc':
        print('isotropic mc reg enabled ...')
        perturb = tf.nn.l2_normalize(
            tf.random_normal([FLAGS.mc_size] + inp.get_shape().as_list()[-3:],
                             mean=0,
                             stddev=0.01),
            dim=[1, 2, 3])  # gaussian noise [mc_size, 32,32,3]
        x_pert = x_hat + FLAGS.scale * perturb
        logits_gen_pert = classifier(x_pert,
                                     is_training=is_training_pl,
                                     reuse=True)
        j_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(logits_gen - logits_gen_pert), axis=1))

    elif FLAGS.grad == 'isotropic_inp':
        print('isotropic inp reg enabled ...')
        perturb = tf.nn.l2_normalize(
            tf.random_normal([FLAGS.mc_size] + inp.get_shape().as_list()[-3:],
                             mean=0,
                             stddev=0.01),
            dim=[1, 2, 3])  # gaussian noise [mc_size, 32,32,3]
        x_pert = inp + FLAGS.scale * perturb
        logits_inp_pert = classifier(x_pert,
                                     is_training=is_training_pl,
                                     reuse=True)
        j_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(logits_gen - logits_inp_pert), axis=1))

    elif FLAGS.grad == 'isotropic_rnd':
        print('isotropic rnd reg enabled ...')
        epsilon = tf.random_normal(
            [FLAGS.mc_size] + inp.get_shape().as_list()[-3:],
            mean=0,
            stddev=0.01)  # gaussian noise [mc_size, 32,32,3]
        epsilon_hat = tf.nn.l2_normalize(
            epsilon, dim=[1, 2,
                          3])  # normalised gaussian noise [mc_size, 32,32,3]
        rnd_img = tf.random_uniform(shape=[FLAGS.mc_size] +
                                    inp.get_shape().as_list()[-3:],
                                    minval=-1,
                                    maxval=1)
        x_pert = rnd_img + FLAGS.scale * epsilon_hat
        logits_pert = classifier(x_pert,
                                 is_training=is_training_pl,
                                 reuse=True)
        j_loss = tf.reduce_mean(
            tf.reduce_sum(tf.square(logits_gen - logits_pert), axis=1))

    elif FLAGS.grad == 'grad_latent':
        print('grad latent enabled ...')
        grad = get_jacobian(logits_gen, z)
        j_loss = tf.reduce_mean(tf.reduce_sum(tf.square(grad), axis=[1, 2]))

    elif FLAGS.grad == 'grad_mc':
        print('grad mc enabled ...')
        grad = get_jacobian(logits_gen, x_hat)
        j_loss = tf.reduce_mean(tf.reduce_sum(tf.square(grad), axis=[1, 2]))

    elif FLAGS.grad == 'grad_inp':
        print('grad inp enabled ...')
        grad = get_jacobian(logits, inp)
        j_loss = tf.reduce_mean(tf.reduce_sum(tf.square(grad), axis=[1, 2]))

    elif FLAGS.grad == 'grad_old':
        print('old grad enabled ...')
        k = []
        for j in range(10):
            grad = tf.gradients(logits_gen[:, j], z)
            k.append(grad)
        J = tf.stack(k)
        J = tf.squeeze(J)
        J = tf.transpose(J, perm=[1, 0, 2])  # jacobian
        j_n = tf.reduce_sum(tf.square(J), axis=[1, 2])
        j_loss = tf.reduce_mean(j_n)

    elif FLAGS.grad == 'comb':
        jac_manifold = []
        jac_ambient = []
        for yi in tf.unstack(logits_gen, axis=1):
            g1, g2 = tf.gradients(yi, [z, x_hat])
            jac_ambient.append(g2)
            jac_manifold.append(g1)
        jm = tf.square(tf.stack(jac_manifold))
        ja = tf.square(tf.stack(jac_ambient))

        j_manifold = tf.reduce_mean(tf.reduce_sum(jm, axis=[0, 2]))
        j_ambient = tf.reduce_mean(tf.reduce_sum(ja, axis=[0, 2, 3, 4]))
        j_loss = tf.constant(0.)

    ######## loss function #######
    xentropy = tf.losses.sparse_softmax_cross_entropy(logits=logits,
                                                      labels=lbl)
    if not FLAGS.reg:
        print('reg disabeled')
        loss = xentropy
    else:
        print('laplacian reg enabled')
        loss = xentropy + FLAGS.reg_ambient * j_ambient + FLAGS.reg_manifold * j_manifold

    correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), lbl)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    # g_vars = tf.global_variables(scope='generator')
    g_vars = [var for var in tf.global_variables() if 'generator' in var.name]
    dnn_vars = [var for var in tf.trainable_variables() if var not in g_vars]

    # [print(var.name) for var in dnn_vars]

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_pl)
    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # control dependencies for batch norm ops
    # with tf.control_dependencies(update_ops):
    #     train_op = optimizer.minimize(loss, var_list=dnn_vars)

    dvars = [
        var for var in tf.trainable_variables() if 'classifier' in var.name
    ]
    # [print(var) for var in dvars]
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, var_list=dvars)

    ### ema ###
    ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay)
    maintain_averages_op = ema.apply(dvars)

    with tf.control_dependencies([train_op]):
        train_dis_op = tf.group(maintain_averages_op)

    logits_ema = classifier(inp,
                            is_training_pl,
                            getter=get_getter(ema),
                            reuse=True)
    correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32),
                                tf.cast(lbl, tf.int32))
    accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32))

    def linear_decay(decay_start, decay_end, epoch):
        return min(
            -1 / (decay_end - decay_start) * epoch + 1 + decay_start /
            (decay_end - decay_start), 1)

    # all_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    # print("list des vars")
    # [print(var.name) for var in all_var]
    # print("name var")
    # [print(var.name for var in g_vars)]

    with tf.name_scope('summary'):
        with tf.name_scope('discriminator'):
            tf.summary.scalar('xentropy', xentropy, ['dis'])
            tf.summary.scalar('laplacian_loss', j_loss, ['dis'])

        with tf.name_scope('images'):
            tf.summary.image('gen_images', x_hat, 4, ['image'])
            # tf.summary.image('gen_pert', x_pert, 4, ['image'])

        with tf.name_scope('epoch'):
            tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch'])
            tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema,
                              ['epoch'])
            tf.summary.scalar('accuracy_test_raw', acc_test_pl, ['epoch'])
            tf.summary.scalar('learning_rate', learning_rate_pl, ['epoch'])

        sum_op_dis = tf.summary.merge_all('dis')
        sum_op_im = tf.summary.merge_all('image')
        sum_op_epoch = tf.summary.merge_all('epoch')

    print("batch size monte carlo: ", generator.generate_noise().shape)
    print("")

    saver = tf.train.Saver(var_list=g_vars)
    var_init = [var for var in tf.global_variables() if var not in g_vars]
    init_op = tf.variables_initializer(var_list=var_init)

    # config = tf.ConfigProto(device_count={'GPU': 0})
    # config.gpu_options.allow_growth = True

    with tf.Session() as sess:
        writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)

        sess.run(init_op)
        if tf.train.latest_checkpoint(FLAGS.snapshot) is not None:
            saver.restore(sess, tf.train.latest_checkpoint(FLAGS.snapshot))
            print("model restored @ %s" % FLAGS.snapshot)
        train_batch = 0
        for epoch in tqdm(range(FLAGS.epoch), disable=not FLAGS.verbose):
            begin = time.time()

            # randomly permuted minibatches
            inds = rng.permutation(trainx.shape[0])
            trainx = trainx[inds]
            trainy = trainy[inds]
            train_loss = train_acc = test_acc = train_j = test_acc_ema = 0

            lr = FLAGS.learning_rate * linear_decay(FLAGS.decay, FLAGS.epoch,
                                                    epoch)

            for t in tqdm(range(nr_batches_train), disable=not FLAGS.verbose):
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size
                feed_dict = {
                    inp: trainx[ran_from:ran_to],
                    lbl: trainy[ran_from:ran_to],
                    is_training_pl: True,
                    gan_is_training_pl: False,
                    learning_rate_pl: lr,
                    z: generator.generate_noise()
                }

                _, ls, acc, j, sm = sess.run(
                    [train_dis_op, loss, accuracy, j_loss, sum_op_dis],
                    feed_dict=feed_dict)

                train_loss += ls
                train_acc += acc
                train_j += j
                writer.add_summary(sm, train_batch)
                train_batch += 1

            train_loss /= nr_batches_train
            train_acc /= nr_batches_train
            train_j /= nr_batches_train

            if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0):
                sm = sess.run(sum_op_im,
                              feed_dict={
                                  gan_is_training_pl: False,
                                  z: generator.generate_noise(),
                                  inp: trainx[:FLAGS.batch_size]
                              })
                writer.add_summary(sm, train_batch)

            if (epoch % FLAGS.freq_test == 0):
                for t in range(nr_batches_test):
                    ran_from = t * FLAGS.batch_size
                    ran_to = (t + 1) * FLAGS.batch_size
                    feed_dict = {
                        inp: testx[ran_from:ran_to],
                        lbl: testy[ran_from:ran_to],
                        is_training_pl: False
                    }

                    acc, acc_ema = sess.run([accuracy, accuracy_ema],
                                            feed_dict=feed_dict)
                    test_acc += acc
                    test_acc_ema += acc_ema
                test_acc /= nr_batches_test
                test_acc_ema /= nr_batches_test

                sum = sess.run(sum_op_epoch,
                               feed_dict={
                                   acc_train_pl: train_acc,
                                   acc_test_pl: test_acc,
                                   acc_test_pl_ema: test_acc_ema,
                                   learning_rate_pl: lr
                               })
                writer.add_summary(sum, epoch)

                tqdm.write(
                    "Epoch %03d | Time = %03ds | lr = %.3e | loss train = %.4f | train acc = %.2f | test acc = %.2f | test acc_ema = %.2f"
                    % (epoch, time.time() - begin, lr, train_loss,
                       train_acc * 100, test_acc * 100, test_acc_ema * 100))

                if status_reporter:  # report status for ray tune
                    status_reporter(timesteps_total=epoch,
                                    mean_accuracy=test_acc_ema)
Пример #6
0
def main(_):
    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir)

    # Random seed
    rng = np.random.RandomState(FLAGS.seed)  # seed labels
    rng_data = np.random.RandomState(FLAGS.seed_data)  # seed shuffling

    # load CIFAR-10
    trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir,
                                                'train')  # float [-1 1] images
    testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')
    trainx_unl = trainx.copy()
    trainx_unl2 = trainx.copy()
    nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size)
    nr_batches_test = int(testx.shape[0] / FLAGS.batch_size)

    # select labeled data
    inds = rng_data.permutation(trainx.shape[0])
    trainx = trainx[inds]
    trainy = trainy[inds]
    txs = []
    tys = []
    for j in range(10):
        txs.append(trainx[trainy == j][:FLAGS.labeled])
        tys.append(trainy[trainy == j][:FLAGS.labeled])
    txs = np.concatenate(txs, axis=0)
    tys = np.concatenate(tys, axis=0)

    config = FLAGS.__flags
    generator = DCGANGenerator(**config)
    discriminator = SNDCGAN_Discrminator(output_dim=10,
                                         features=True,
                                         **config)

    global_step = tf.Variable(0, name="global_step", trainable=False)
    increase_global_step = global_step.assign(global_step + 1)
    '''construct graph'''
    print('constructing graph')
    unl = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='unlabeled_data_input_pl')
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='labeled_data_input_pl')
    lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input_pl')

    # scalar pl
    lr_pl = tf.placeholder(tf.float32, [], name='learning_rate_pl')
    acc_train_pl = tf.placeholder(tf.float32, [], 'acc_train_pl')
    acc_test_pl = tf.placeholder(tf.float32, [], 'acc_test_pl')
    acc_test_pl_ema = tf.placeholder(tf.float32, [], 'acc_test_pl')

    random_z = tf.random_uniform([FLAGS.batch_size, 100], name='random_z')
    gen_inp = generator(random_z, is_training_pl)
    logits_gen, layer_fake = discriminator(gen_inp,
                                           update_collection=None,
                                           features=True)
    logits_unl, layer_real = discriminator(unl,
                                           update_collection="NO_OPS",
                                           features=True)
    logits_lab, _ = discriminator(inp, update_collection="NO_OPS")

    with tf.name_scope('loss_functions'):
        l_unl = tf.reduce_logsumexp(logits_unl, axis=1)
        l_gen = tf.reduce_logsumexp(logits_gen, axis=1)
        # discriminator
        loss_lab = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=lbl,
                                                           logits=logits_lab))
        loss_unl = - 0.5 * tf.reduce_mean(l_unl) \
                   + 0.5 * tf.reduce_mean(tf.nn.softplus(l_unl)) \
                   + 0.5 * tf.reduce_mean(tf.nn.softplus(l_gen))

        # generator
        m1 = tf.reduce_mean(layer_real, axis=0)
        m2 = tf.reduce_mean(layer_fake, axis=0)
        loss_gen = tf.reduce_mean(tf.abs(m1 - m2))
        loss_dis = FLAGS.unl_weight * loss_unl + FLAGS.lbl_weight * loss_lab

        correct_pred = tf.equal(tf.cast(tf.argmax(logits_lab, 1), tf.int32),
                                tf.cast(lbl, tf.int32))
        accuracy_classifier = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    with tf.name_scope('optimizers'):
        d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='critic')
        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='generator')

        optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.adam_alpha,
                                           beta1=FLAGS.adam_beta1,
                                           beta2=FLAGS.adam_beta2)

        d_gvs = optimizer.compute_gradients(loss_dis, var_list=d_vars)
        g_gvs = optimizer.compute_gradients(loss_gen, var_list=g_vars)
        d_solver = optimizer.apply_gradients(d_gvs)
        g_solver = optimizer.apply_gradients(g_gvs)

    ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay)
    maintain_averages_op = ema.apply(d_vars)

    with tf.control_dependencies([d_solver]):
        train_dis_op = tf.group(maintain_averages_op)

    logits_ema, _ = discriminator(inp,
                                  update_collection="NO_OPS",
                                  getter=get_getter(ema))
    correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32),
                                tf.cast(lbl, tf.int32))
    accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32))

    with tf.name_scope('summary'):
        with tf.name_scope('discriminator'):
            tf.summary.scalar('loss_discriminator', loss_dis, ['dis'])

        with tf.name_scope('generator'):
            tf.summary.scalar('loss_generator', loss_gen, ['gen'])

        with tf.name_scope('images'):
            tf.summary.image('gen_images', gen_inp, 10, ['image'])

        with tf.name_scope('epoch'):
            tf.summary.scalar('accuracy_train', acc_train_pl, ['epoch'])
            tf.summary.scalar('accuracy_test_moving_average', acc_test_pl_ema,
                              ['epoch'])
            tf.summary.scalar('accuracy_test_raw', acc_test_pl, ['epoch'])
            tf.summary.scalar('learning_rate', lr_pl, ['epoch'])

        sum_op_dis = tf.summary.merge_all('dis')
        sum_op_gen = tf.summary.merge_all('gen')
        sum_op_im = tf.summary.merge_all('image')
        sum_op_epoch = tf.summary.merge_all('epoch')
    '''//////training //////'''
    print('start training')
    with tf.Session() as sess:
        tf.set_random_seed(rng.randint(2**10))
        sess.run(tf.global_variables_initializer())
        print('\ninitialization done')

        writer = tf.summary.FileWriter(FLAGS.logdir, sess.graph)

        train_batch = 0

        for epoch in tqdm(range(FLAGS.epoch)):
            begin = time.time()

            train_loss_lab = train_loss_unl = train_loss_gen = train_acc = test_acc = test_acc_ma = train_j_loss = 0
            lr = FLAGS.learning_rate * linear_decay(FLAGS.decay_start,
                                                    FLAGS.epoch, epoch)

            # construct randomly permuted batches
            trainx = []
            trainy = []
            for t in range(
                    int(np.ceil(
                        trainx_unl.shape[0] /
                        float(txs.shape[0])))):  # same size lbl and unlb
                inds = rng.permutation(txs.shape[0])
                trainx.append(txs[inds])
                trainy.append(tys[inds])
            trainx = np.concatenate(trainx, axis=0)
            trainy = np.concatenate(trainy, axis=0)
            trainx_unl = trainx_unl[rng.permutation(
                trainx_unl.shape[0])]  # shuffling unl dataset
            trainx_unl2 = trainx_unl2[rng.permutation(trainx_unl2.shape[0])]

            # training
            for t in tqdm(range(nr_batches_train)):

                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size

                # train discriminator
                feed_dict = {
                    unl: trainx_unl[ran_from:ran_to],
                    is_training_pl: True,
                    inp: trainx[ran_from:ran_to],
                    lbl: trainy[ran_from:ran_to],
                    lr_pl: lr
                }
                _, acc, lu, lb, sm = sess.run([
                    train_dis_op, accuracy_classifier, loss_lab, loss_unl,
                    sum_op_dis
                ],
                                              feed_dict=feed_dict)
                train_loss_unl += lu
                train_loss_lab += lb
                train_acc += acc
                if (train_batch % FLAGS.step_print) == 0:
                    writer.add_summary(sm, train_batch)

                # train generator
                _, lg, sm = sess.run(
                    [g_solver, loss_gen, sum_op_gen],
                    feed_dict={
                        unl: trainx_unl2[ran_from:ran_to],
                        is_training_pl: True,
                        lr_pl: lr
                    })
                train_loss_gen += lg
                if (train_batch % FLAGS.step_print) == 0:
                    writer.add_summary(sm, train_batch)

                if (train_batch % FLAGS.freq_print == 0) & (train_batch != 0):
                    ran_from = np.random.randint(
                        0, trainx_unl.shape[0] - FLAGS.batch_size)
                    ran_to = ran_from + FLAGS.batch_size
                    sm = sess.run(sum_op_im,
                                  feed_dict={
                                      is_training_pl: True,
                                      unl: trainx_unl[ran_from:ran_to]
                                  })
                    writer.add_summary(sm, train_batch)

                train_batch += 1

            train_loss_lab /= nr_batches_train
            train_loss_unl /= nr_batches_train
            train_loss_gen /= nr_batches_train
            train_acc /= nr_batches_train
            train_j_loss /= nr_batches_train

            # Testing moving averaged model and raw model
            if (epoch % FLAGS.freq_test == 0) | (epoch == FLAGS.epoch - 1):
                for t in range(nr_batches_test):
                    ran_from = t * FLAGS.batch_size
                    ran_to = (t + 1) * FLAGS.batch_size
                    feed_dict = {
                        inp: testx[ran_from:ran_to],
                        lbl: testy[ran_from:ran_to],
                        is_training_pl: False
                    }
                    acc, acc_ema = sess.run(
                        [accuracy_classifier, accuracy_ema],
                        feed_dict=feed_dict)
                    test_acc += acc
                    test_acc_ma += acc_ema
                test_acc /= nr_batches_test
                test_acc_ma /= nr_batches_test

                print(
                    "Epoch %d | time = %ds | loss gen = %.4f | loss lab = %.4f | loss unl = %.4f "
                    "| train acc = %.4f| test acc = %.4f | test acc ema = %0.4f"
                    % (epoch, time.time() - begin, train_loss_gen,
                       train_loss_lab, train_loss_unl, train_acc, test_acc,
                       test_acc_ma))
Пример #7
0
def main(_):
    print("\nParameters:")
    for attr, value in sorted(FLAGS.__flags.items()):
        print("{}={}".format(attr.lower(), value))
    print("")
    if not os.path.exists(FLAGS.logdir):
        os.makedirs(FLAGS.logdir)
    rng = np.random.RandomState(FLAGS.seed)  # seed labels

    trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir,
                                                'train')  # float [-1 1] images
    testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')
    trainx_unl = trainx.copy()

    # select labeled data
    inds = rng.permutation(trainx.shape[0])
    trainx = trainx[inds]
    trainy = trainy[inds]
    txs = []
    tys = []
    for j in range(10):
        txs.append(trainx[trainy == j][:FLAGS.labeled])
        tys.append(trainy[trainy == j][:FLAGS.labeled])
    txs = np.concatenate(txs, axis=0)
    tys = np.concatenate(tys, axis=0)
    trainx = txs
    trainy = tys
    nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size)
    nr_batches_test = int(testx.shape[0] / FLAGS.batch_size)
    print(trainx.shape)

    inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='data_input')
    lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input')
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate_pl = tf.placeholder(tf.float32, [],
                                      name='adam_learning_rate_pl')

    logits = classifier(inp, is_training=is_training_pl)

    loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=lbl)
    correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), lbl)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_pl)
    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # control dependencies for batch norm ops

    dvars = [
        var for var in tf.trainable_variables() if 'classifier' in var.name
    ]
    [print(var) for var in dvars]
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, var_list=dvars)

    ### ema ###
    ema = tf.train.ExponentialMovingAverage(decay=FLAGS.ma_decay)
    maintain_averages_op = ema.apply(dvars)

    with tf.control_dependencies([train_op]):
        train_dis_op = tf.group(maintain_averages_op)

    logits_ema = classifier(inp,
                            is_training_pl,
                            getter=get_getter(ema),
                            reuse=True)
    correct_pred_ema = tf.equal(tf.cast(tf.argmax(logits_ema, 1), tf.int32),
                                tf.cast(lbl, tf.int32))
    accuracy_ema = tf.reduce_mean(tf.cast(correct_pred_ema, tf.float32))

    def linear_decay(decay_start, decay_end, epoch):
        return min(
            -1 / (decay_end - decay_start) * epoch + 1 + decay_start /
            (decay_end - decay_start), 1)

    # all_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    # print("list des vars")
    # [print(var.name) for var in all_var]
    config = tf.ConfigProto(device_count={'GPU': 0})
    config.gpu_options.allow_growth = True

    with tf.Session() as sess:

        sess.run(tf.global_variables_initializer())

        for epoch in tqdm(range(200)):
            begin = time.time()

            # randomly permuted minibatches
            inds = rng.permutation(trainx.shape[0])
            trainx = trainx[inds]
            trainy = trainy[inds]
            train_loss = train_acc = test_acc = test_acc_ema = 0

            lr = FLAGS.learning_rate * linear_decay(100, 200, epoch)

            for t in tqdm(range(nr_batches_train)):
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size
                feed_dict = {
                    inp: trainx[ran_from:ran_to],
                    lbl: trainy[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate_pl: lr
                }

                _, ls, acc = sess.run([train_dis_op, loss, accuracy],
                                      feed_dict=feed_dict)

                train_loss += ls
                train_acc += acc

            train_loss /= nr_batches_train
            train_acc /= nr_batches_train * 100

            for t in range(nr_batches_test):
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size
                feed_dict = {
                    inp: testx[ran_from:ran_to],
                    lbl: testy[ran_from:ran_to],
                    is_training_pl: False
                }

                acc, acc_ema = sess.run([accuracy, accuracy_ema],
                                        feed_dict=feed_dict)
                test_acc += acc
                test_acc_ema += acc_ema

            test_acc /= nr_batches_test
            test_acc_ema /= nr_batches_test

            tqdm.write(
                "Epoch %03d | Time = %03ds | lr = %.3e | loss train = %.4f | train acc = %.2f | test acc = %.2f | test acc_ema = %.2f"
                % (epoch, time.time() - begin, lr, train_loss, train_acc * 100,
                   test_acc * 100, test_acc_ema * 100))

            if status_reporter:  # report status for ray tune
                status_reporter(timesteps_total=epoch, mean_accuracy=test_acc)
Пример #8
0
def main(_):
    FLAGS._parse_flags()
    print("\nParameters:")
    for attr, value in sorted(FLAGS.__flags.items()):
        print("{}={}".format(attr.lower(), value))
    print("")

    rng = np.random.RandomState(FLAGS.seed)  # seed labels

    trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir,
                                                'train')  # float [-1 1] images
    testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')

    # select labeled data
    inds = rng.permutation(trainx.shape[0])
    trainx = trainx[inds]
    trainy = trainy[inds]
    txs = []
    tys = []
    for j in range(10):
        txs.append(trainx[trainy == j][:FLAGS.labeled])
        tys.append(trainy[trainy == j][:FLAGS.labeled])
    txs = np.concatenate(txs, axis=0)
    tys = np.concatenate(tys, axis=0)
    trainx = txs
    trainy = tys
    nr_batches_train = int(trainx.shape[0] / FLAGS.batch_size)
    nr_batches_test = int(testx.shape[0] / FLAGS.batch_size)
    print(trainx.shape)

    inp = tf.placeholder(tf.float32, [FLAGS.batch_size, 32, 32, 3],
                         name='data_input')
    lbl = tf.placeholder(tf.int32, [FLAGS.batch_size], name='lbl_input')
    is_training_pl = tf.placeholder(tf.bool, [], name='is_training_pl')
    learning_rate_pl = tf.placeholder(tf.float32, [],
                                      name='adam_learning_rate_pl')

    classifier = DNN()
    logits = classifier(inp, is_training=is_training_pl)

    loss = tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=lbl)
    correct_prediction = tf.equal(tf.cast(tf.argmax(logits, 1), tf.int32), lbl)
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate_pl)
    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # control dependencies for batch norm ops
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss)

    def linear_decay(decay_start, decay_end, epoch):
        return min(
            -1 / (decay_end - decay_start) * epoch + 1 + decay_start /
            (decay_end - decay_start), 1)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in tqdm(range(FLAGS.epoch), disable=not verbose):
            begin = time.time()

            # randomly permuted minibatches
            inds = rng.permutation(trainx.shape[0])
            trainx = trainx[inds]
            trainy = trainy[inds]
            train_loss = train_acc = test_acc = 0

            lr = FLAGS.learning_rate * linear_decay(100, 200, epoch)

            for t in tqdm(range(nr_batches_train), disable=not verbose):
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size
                feed_dict = {
                    inp: trainx[ran_from:ran_to],
                    lbl: trainy[ran_from:ran_to],
                    is_training_pl: True,
                    learning_rate_pl: lr
                }

                _, ls, acc = sess.run([train_op, loss, accuracy],
                                      feed_dict=feed_dict)

                train_loss += ls
                train_acc += acc

            train_loss /= nr_batches_train
            train_acc /= nr_batches_train

            for t in range(nr_batches_test):
                ran_from = t * FLAGS.batch_size
                ran_to = (t + 1) * FLAGS.batch_size
                feed_dict = {
                    inp: testx[ran_from:ran_to],
                    lbl: testy[ran_from:ran_to],
                    is_training_pl: False
                }

                test_acc += sess.run(accuracy, feed_dict=feed_dict)
            test_acc /= nr_batches_test

            tqdm.write(
                "Epoch %03d | Time = %03ds | lr = %.4f | loss train = %.4f | train acc = %.4f | test acc = %.4f"
                % (epoch, time.time() - begin, lr, train_loss, train_acc,
                   test_acc))

            if status_reporter:  # report status for ray tune
                status_reporter(timesteps_total=epoch, mean_accuracy=test_acc)
Пример #9
0
flags.DEFINE_float('adam_beta1', 0.5, 'beta1 in Adam')
flags.DEFINE_float('adam_beta2', 0.999, 'beta2 in Adam')
flags.DEFINE_integer('n_dis', 1, 'n discrminator train')
flags.DEFINE_string('snapshot', '/tmp/snaphots', 'snapshot directory')
flags.DEFINE_string('data_dir', './tmp/data/cifar-10-python/',
                    'data directory')
flags.DEFINE_integer('seed', 10, 'seed numpy')
flags.DEFINE_integer('labeled', 400, 'labeled data per class')

flags.DEFINE_string('logdir', './log', 'log directory')
flags.DEFINE_float('reg_w', 1e-3, 'weight regularization')

mkdir('tmp')

##############################################
trainx, trainy = cifar10_input._get_dataset(FLAGS.data_dir,
                                            'train')  # float [-1 1] images
testx, testy = cifar10_input._get_dataset(FLAGS.data_dir, 'test')
trainx_unl = trainx.copy()
# select labeled data
rng = np.random.RandomState(FLAGS.seed)  # seed labels
inds = rng.permutation(trainx.shape[0])
trainx = trainx[inds]
trainy = trainy[inds]
txs = []
tys = []
for j in range(10):
    txs.append(trainx[trainy == j][:FLAGS.labeled])
    tys.append(trainy[trainy == j][:FLAGS.labeled])
txs = np.concatenate(txs, axis=0)
tys = np.concatenate(tys, axis=0)
trainx = txs