예제 #1
0
def train(n_exp, h, w, c, nclass, batch_size=100, tgmodel=True):
    graph = tf.Graph()
    with graph.as_default():
        # nr_train = data(n_exp, h, w, c, nclass, batch_size)
        X_data = np.random.rand(n_exp, h, w, c)
        y_data = np.random.rand(n_exp, nclass)
        data_iter = tg.SequentialIterator(X_data, y_data, batchsize=batch_size)

        X_ph = tf.placeholder('float32', [None, h, w, c])
        y_ph = tf.placeholder('float32', [None, nclass])

        if tgmodel:
            # tensorgraph model
            print('..using graph model')
            seq = TGModel(h, w, c, nclass)
            y_train_sb = seq.train_fprop(X_ph)

        else:
            # tensorflow model
            print('..using tensorflow model')
            y_train_sb = TFModel(X_ph, h, w, c, nclass)

        loss_train_sb = tg.cost.mse(y_train_sb, y_ph)
        accu_train_sb = tg.cost.accuracy(y_train_sb, y_ph)

        opt = tf.train.RMSPropOptimizer(0.001)

        # required for BatchNormalization layer
        update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
        with ops.control_dependencies(update_ops):
            train_op = opt.minimize(loss_train_sb)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(graph=graph, config=config) as sess:
        sess.run(init_op)

        for epoch in range(2):
            pbar = tg.ProgressBar(n_exp)
            ttl_train_loss = 0
            # for i in range(0, n_exp, batch_size):
            i = 0
            for X_batch, y_batch in data_iter:
                pbar.update(i)
                i += len(X_batch)
                _, loss_train = sess.run([train_op, loss_train_sb],
                                          feed_dict={X_ph:X_batch, y_ph:y_batch})
                ttl_train_loss += loss_train * batch_size
            pbar.update(n_exp)
            ttl_train_loss /= n_exp
            print('epoch {}, train loss {}'.format(epoch, ttl_train_loss))
예제 #2
0
def test_DataBlocks():
    X = np.random.rand(1000, 200)
    with open('X.npy', 'wb') as f:
        np.save(f, X)

    db = tg.DataBlocks(['X.npy'] * 10, batchsize=32, allow_preload=False)
    for train_blk, valid_blk in db:
        n_exp = 0
        pbar = tg.ProgressBar(len(train_blk))
        for batch in train_blk:
            n_exp += len(batch[0])
            time.sleep(0.05)
            pbar.update(n_exp)
        print()
        pbar = tg.ProgressBar(len(valid_blk))
        n_exp = 0
        for batch in valid_blk:
            n_exp += len(batch[0])
            time.sleep(0.05)
            pbar.update(n_exp)
        print()
예제 #3
0
        def Rotate3D(data, angle, axis):
            axis_ = [(1, 2), (0, 2), (0, 1)]
            return rotate(data, angle, axes=axis_[axis], reshape=False)

        X_train = np.array([Rotate3D()])

        iter_train = tg.SequentialIterator(X_train,
                                           y_train,
                                           batchsize=batchsize)
        iter_test = tg.SequentialIterator(X_test, y_test, batchsize=batchsize)

        best_valid_accu = 0
        for epoch in range(max_epoch):
            print('epoch:', epoch)
            pbar = tg.ProgressBar(len(iter_train))
            ttl_train_cost = 0
            ttl_examples = 0
            print('..training')
            #for i in range(10):
            for XX, yy in iter_train:
                #                X_tr = X_train[i]
                #                y_tr = y_train[i]
                #                y_tr = np.array(y_tr, dtype='int8')
                #
                #                X_tr = X_tr.reshape((1,)+X_tr.shape+(1,))
                #                y_tr = y_tr.reshape((1,)+y_tr.shape+(1,))
                feed_dict = {X_ph: XX, y_ph: yy, phase: 1}

                _, train_cost = sess.run([optimizer, train_cost_sb],
                                         feed_dict=feed_dict)
예제 #4
0
def train(model,
          data,
          epoch_look_back=5,
          max_epoch=100,
          percent_decrease=0,
          batch_size=64,
          learning_rate=0.001,
          weight_regularize=True,
          save_dir=None,
          restore=False):

    if save_dir:
        logdir = '{}/log'.format(save_dir)
        if not os.path.exists(logdir):
            os.makedirs(logdir)
        model_dir = "{}/model".format(save_dir)
        if not os.path.exists(model_dir):
            os.makedirs(model_dir)

    train_tf, n_train, valid_tf, n_valid = data(create_tfrecords=True,
                                                batch_size=batch_size)

    y_train_sb = model._train_fprop(train_tf['X'])
    y_valid_sb = model._test_fprop(valid_tf['X'])

    loss_train_sb = tg.cost.mse(y_train_sb, train_tf['y'])

    if weight_regularize:
        loss_reg = tc.layers.apply_regularization(
            tc.layers.l2_regularizer(2.5e-5),
            weights_list=[
                var for var in tf.global_variables()
                if __MODEL_VARSCOPE__ in var.name
            ])
        loss_train_sb = loss_train_sb + loss_reg

    accu_train_sb = tg.cost.accuracy(y_train_sb, train_tf['y'])
    accu_valid_sb = tg.cost.accuracy(y_valid_sb, valid_tf['y'])

    tf.summary.scalar('train', accu_train_sb)

    if save_dir:
        sav_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                     scope=__MODEL_VARSCOPE__ +
                                     '/TemplateModel')
        saver = tf.train.Saver(sav_vars)

    # opt = tf.train.RMSPropOptimizer(learning_rate)
    opt = tf.train.AdamOptimizer(learning_rate)
    # opt = hvd.DistributedOptimizer(opt)

    # required for BatchNormalization layer
    update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
    with ops.control_dependencies(update_ops):
        train_op = opt.minimize(loss_train_sb)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    # bcast = hvd.broadcast_global_variables(0)

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    # config.gpu_options.visible_device_list = str(hvd.local_rank())

    with tf.Session(config=config) as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        sess.run(init_op)
        if restore:
            logger.info('restoring model')
            saver.restore(sess, restore)
        train_writer = tf.summary.FileWriter('{}/train'.format(logdir),
                                             sess.graph)
        # bcast.run()
        # merge = tf.summary.merge_all()
        es = tg.EarlyStopper(max_epoch, epoch_look_back, percent_decrease)
        epoch = 0
        best_valid_accu = 0
        while True:
            epoch += 1

            pbar = tg.ProgressBar(n_train)
            ttl_train_loss = 0
            for i in range(0, n_train, batch_size):
                pbar.update(i)
                _, loss_train = sess.run([train_op, loss_train_sb])
                # _, loss_train, merge_v = sess.run([train_op, loss_train_sb, merge])
                ttl_train_loss += loss_train * batch_size
                # train_writer.add_summary(merge_v, i)
            pbar.update(n_train)
            ttl_train_loss /= n_train
            print('')
            logger.info('epoch {}, train loss {}'.format(
                epoch, ttl_train_loss))

            pbar = tg.ProgressBar(n_valid)
            ttl_valid_accu = 0
            for i in range(0, n_valid, batch_size):
                pbar.update(i)
                loss_accu = sess.run(accu_valid_sb)
                ttl_valid_accu += loss_accu * batch_size
            pbar.update(n_valid)
            ttl_valid_accu /= n_valid
            print('')
            logger.info('epoch {}, valid accuracy {}'.format(
                epoch, ttl_valid_accu))
            if es.continue_learning(-ttl_valid_accu, epoch=epoch):
                logger.info('best epoch last update: {}'.format(
                    es.best_epoch_last_update))
                logger.info('best valid last update: {}'.format(
                    es.best_valid_last_update))

                if ttl_valid_accu > best_valid_accu:
                    best_valid_accu = ttl_valid_accu
                    if save_dir:
                        save_path = saver.save(sess, model_dir + '/model.tf')
                        print("Best model saved in file: %s" % save_path)

            else:
                logger.info('training done!')
                break

        coord.request_stop()
        coord.join(threads)
예제 #5
0
def train():
    graph = tf.Graph()
    with graph.as_default():
        batch_size = 100
        nr_train, n_train, nr_test, n_test = cifar10(create_tfrecords=True, batch_size=batch_size)
        seq = cifar10_allcnn.model(nclass=10, h=32, w=32, c=3)

        y_train_sb = seq.train_fprop(nr_train['X'])
        y_test_sb = seq.test_fprop(nr_test['X'])

        loss_train_sb = tg.cost.mse(y_train_sb, nr_train['y'])
        accu_train_sb = tg.cost.accuracy(y_train_sb, nr_train['y'])
        accu_test_sb = tg.cost.accuracy(y_test_sb, nr_test['y'])

        opt = tf.train.RMSPropOptimizer(0.001)
        opt = hvd.DistributedOptimizer(opt)

        # required for BatchNormalization layer
        update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
        with ops.control_dependencies(update_ops):
            train_op = opt.minimize(loss_train_sb)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        bcast = hvd.broadcast_global_variables(0)

    # Pin GPU to be used to process local rank (one GPU per process)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = str(hvd.local_rank())

    with tf.Session(graph=graph, config=config) as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        sess.run(init_op)
        bcast.run()

        for epoch in range(100):
            pbar = tg.ProgressBar(n_train)
            ttl_train_loss = 0
            for i in range(0, n_train, batch_size):
                pbar.update(i)
                _, loss_train = sess.run([train_op, loss_train_sb])
                ttl_train_loss += loss_train * batch_size
            pbar.update(n_train)
            ttl_train_loss /= n_train
            print('epoch {}, train loss {}'.format(epoch, ttl_train_loss))

            pbar = tg.ProgressBar(n_test)
            ttl_test_loss = 0
            for i in range(0, n_test, batch_size):
                pbar.update(i)
                loss_test = sess.run(accu_test_sb)
                ttl_test_loss += loss_test * batch_size
            pbar.update(n_test)
            ttl_test_loss /= n_test
            print('epoch {}, test accuracy {}'.format(epoch, ttl_test_loss))


        coord.request_stop()
        coord.join(threads)
예제 #6
0
def train(modelclass, dt=None):

    batchsize = 64
    gen_learning_rate = 0.001
    dis_learning_rate = 0.001
    bottleneck_dim = 300

    max_epoch = 1000
    epoch_look_back = 3
    percent_decrease = 0
    noise_factor = 0.05
    max_outputs = 10

    noise_type = 'normal'

    print('gen_learning_rate:', gen_learning_rate)
    print('dis_learning_rate:', dis_learning_rate)
    print('noise_factor:', noise_factor)
    print('noise_type:', noise_type)


    if dt is None:
        timestamp = tg.utils.ts()
    else:
        timestamp = dt
    save_path = './save/{}/model'.format(timestamp)
    logdir = './log/{}'.format(timestamp)

    #X_train, y_train, X_valid, y_valid = Cifar10()
    X_train, y_train, X_valid, y_valid = data_char()
    _, h, w, c = X_train.shape
    _, nclass = y_train.shape
    # c = 1
    # train_embed, test_embed = text_embed(ch_embed_dim, sent_len, word_len)    
    
    data_train = tg.SequentialIterator(X_train, y_train, batchsize=batchsize)
    data_valid = tg.SequentialIterator(X_valid, y_valid, batchsize=batchsize)
    # gan = AuGan(h, w, nclass, bottleneck_dim)
    gan = getattr(model, modelclass)(h, w, c, nclass, bottleneck_dim)

    y_ph, noise_ph, G_train_sb, G_test_sb, gen_var_list = gan.generator()
    real_ph, real_train, real_valid, fake_train, fake_valid, dis_var_list = gan.discriminator()

    print('..using model:', gan.__class__.__name__)

    print('Generator Variables')
    for var in gen_var_list:
        print(var.name)

    print('\nDiscriminator Variables')
    for var in dis_var_list:
        print(var.name)

    with gan.tf_graph.as_default():
        # X_oh = ph2onehot(X_ph)


        # train_mse = tf.reduce_mean((X_ph - G_train_s)**2)
        # valid_mse = tf.reduce_mean((X_ph - G_valid_s)**2)
        # gen_train_cost_sb = generator_cost(class_train_sb, judge_train_sb)
        # gen_valid_cost_sb = generator_cost(class_test_sb, judge_test_sb)
        gen_train_cost_sb = generator_cost(y_ph, real_train, fake_train)
        fake_clss, fake_judge = fake_train

        dis_train_cost_sb = discriminator_cost(y_ph, real_train, fake_train)
        # dis_train_cost_sb = discriminator_cost(class_train_sb, judge_train_sb)
        # dis_valid_cost_sb = disciminator_cost(class_test_sb, judge_test_sb)

        # gen_train_img = put_kernels_on_grid(G_train_sb, batchsize)
        #
        gen_train_sm = tf.summary.image('gen_train_img', G_train_sb, max_outputs=max_outputs)
        gen_train_mg = tf.summary.merge([gen_train_sm])

        gen_train_cost_sm = tf.summary.scalar('gen_cost', gen_train_cost_sb)
        dis_train_cost_sm = tf.summary.scalar('dis_cost', dis_train_cost_sb)
        cost_train_mg = tf.summary.merge([gen_train_cost_sm, dis_train_cost_sm])


        # gen_optimizer = tf.train.RMSPropOptimizer(gen_learning_rate).minimize(gen_train_cost_sb, var_list=gen_var_list)
        # dis_optimizer = tf.train.RMSPropOptimizer(dis_learning_rate).minimize(dis_train_cost_sb, var_list=dis_var_list)

        gen_optimizer = tf.train.AdamOptimizer(gen_learning_rate).minimize(gen_train_cost_sb, var_list=gen_var_list)
        dis_optimizer = tf.train.AdamOptimizer(dis_learning_rate).minimize(dis_train_cost_sb, var_list=dis_var_list)

        clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in dis_var_list]



        init = tf.global_variables_initializer()
        gan.sess.run(init)
        es = tg.EarlyStopper(max_epoch=max_epoch,
                             epoch_look_back=epoch_look_back,
                             percent_decrease=percent_decrease)



        ttl_iter = 0
        error_writer = tf.summary.FileWriter(logdir + '/experiment', gan.sess.graph)

        img_writer = tf.summary.FileWriter('{}/orig_img'.format(logdir))
        orig_sm = tf.summary.image('orig_img', real_ph, max_outputs=max_outputs)
        # import pdb; pdb.set_trace()
        img_writer.add_summary(orig_sm.eval(session=gan.sess, feed_dict={real_ph:data_train[:100].data[0]}))
        img_writer.flush()
        img_writer.close()

        for epoch in range(1, max_epoch):
            print('epoch:', epoch)
            print('..training')
            print('..logdir', logdir)
            pbar = tg.ProgressBar(len(data_train))
            n_exp = 0
            ttl_mse = 0
            ttl_gen_cost = 0
            ttl_dis_cost = 0
            error_writer.reopen()
            for X_batch, y_batch in data_train:

                for i in range(3):
                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0, scale=noise_factor, size=(len(X_batch), bottleneck_dim))
                    else:
                        noise = np.random.uniform(-1,1, size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {noise_ph:noise, real_ph:X_batch, y_ph:y_batch}
                    gan.sess.run([dis_optimizer, clip_D], feed_dict=feed_dict)

                for i in range(1):
                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0, scale=noise_factor, size=(len(X_batch), bottleneck_dim))
                    else:
                        noise = np.random.uniform(-1,1, size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {noise_ph:noise, real_ph:X_batch, y_ph:y_batch}
                    gan.sess.run(gen_optimizer, feed_dict={noise_ph:noise, real_ph:X_batch, y_ph:y_batch})


                fake_judge_v, cost_train, gen_cost, dis_cost = gan.sess.run([fake_judge, cost_train_mg, gen_train_cost_sb, dis_train_cost_sb],
                                                               feed_dict=feed_dict)


                ttl_gen_cost += gen_cost * len(X_batch)
                ttl_dis_cost += dis_cost * len(X_batch)
                n_exp += len(X_batch)
                pbar.update(n_exp)
                error_writer.add_summary(cost_train, n_exp + ttl_iter)
                error_writer.flush()
            error_writer.close()


            ttl_iter += n_exp

            mean_gan_cost = ttl_gen_cost / n_exp
            mean_dis_cost = ttl_dis_cost / n_exp
            print('\nmean train gen cost:', mean_gan_cost)
            print('mean train dis cost:', mean_dis_cost)


            if save_path:
                # print('\n..saving best model to: {}'.format(save_path))
                dname = os.path.dirname(save_path)
                if not os.path.exists(dname):
                    os.makedirs(dname)
                print('saved to {}'.format(dname))
                # gan.save(save_path)

                for X_batch, y_batch in data_train:

                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0, scale=noise_factor, size=(len(X_batch), bottleneck_dim))
                    else:
                        noise = np.random.uniform(-1,1, size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {noise_ph:noise, real_ph:X_batch, y_ph:y_batch}
                    G_train, G_img = gan.sess.run([G_train_sb, gen_train_mg], feed_dict=feed_dict)
                    train_writer = tf.summary.FileWriter('{}/experiment/{}'.format(logdir, epoch))

                    train_writer.add_summary(G_img)

                    train_writer.flush()
                    train_writer.close()

                    break



    return save_path
예제 #7
0
        saver = tf.train.Saver(var_list)

# merged = tf.summary.merge_all()
# train_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/train',
#                                       sess.graph)
# test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/test')
# tf.global_variables_initializer().run()
    print('hvd rank:', hvd.rank())

    # print('Initialized')

    for epoch in range(500):
        if hvd.rank() == 0:
            train_writer = tf.summary.FileWriter(
                logdir + '/train/{}'.format(epoch), sess.graph)
        pbar = tg.ProgressBar(n_exp)
        ttl_d_loss = 0
        ttl_g_loss = 0
        for i in range(0, n_exp, batch_size):
            pbar.update(i)
            # _, loss_train = sess.run([train_op, loss_train_sb])
            # ttl_train_loss += loss_train
            # for i in range(3):
            # sess.run(d_clip)
            z_val = np.random.uniform(size=(batch_size, z_dim), low=-1, high=1)
            sess.run(d_loss_op, feed_dict={z_ph: z_val})
            for i in range(3):
                z_val = np.random.uniform(size=(batch_size, z_dim),
                                          low=-1,
                                          high=1)
                # import pdb; pdb.set_trace()
예제 #8
0
def train(modelclass, dt=None):

    batchsize = 64
    gen_learning_rate = 0.001
    dis_learning_rate = 0.001
    enc_learning_rate = 0.001
    bottleneck_dim = 300

    max_epoch = 100
    epoch_look_back = 3
    percent_decrease = 0
    noise_factor = 0.1  #  20170616_1459: 0.05   20170616_1951: 0.01    
    max_outputs = 10

    noise_type = 'normal'

    print('gen_learning_rate:', gen_learning_rate)
    print('dis_learning_rate:', dis_learning_rate)
    print('noise_factor:', noise_factor)
    print('noise_type:', noise_type)


    if dt is None:
        timestamp = tg.utils.ts()
    else:
        timestamp = dt
    save_path = './save/{}/model'.format(timestamp)
    logdir = './log/{}'.format(timestamp)

    X_train, y_train, X_valid, y_valid = Mnist()  
    # X_train, y_train, X_valid, y_valid = X_train[0:10000], y_train[0:10000], X_valid[0:10000], y_valid[0:10000]
    # 0617_1346: 0.05   #0619_1033: 0.01   0619_1528:0.1  0619_1944: 0.3
    # X_train, y_train, X_valid, y_valid = Cifar100()
    # X_train, y_train, X_valid, y_valid = Cifar10(contrast_normalize=False, whiten=False)

    _, h, w, c = X_train.shape
    _, nclass = y_train.shape

    data_train = tg.SequentialIterator(X_train, y_train, batchsize=batchsize)
    data_valid = tg.SequentialIterator(X_valid, y_valid, batchsize=batchsize)
    
    gan = getattr(model, modelclass)(h, w, c, nclass, bottleneck_dim)

    y_ph, noise_ph, G_train_sb, G_test_sb, gen_var_list, G_train_enc, G_test_enc, G_train_embed, G_test_embed = gan.generator()
    real_ph, real_train, real_valid, fake_train, fake_valid, dis_var_list = gan.discriminator()
    # real_ph, real_train, real_valid, fake_train, fake_valid, dis_var_list = gan.discriminator_allconv()

    print('..using model:', gan.__class__.__name__)

    print('Generator Variables')
    for var in gen_var_list:
        print(var.name)

    print('\nDiscriminator Variables')
    for var in dis_var_list:
        print(var.name)
    with gan.tf_graph.as_default():

        gen_train_cost_sb = generator_cost(y_ph, real_train, fake_train)
        fake_clss, fake_judge = fake_train

        dis_train_cost_sb = discriminator_cost(y_ph, real_train, fake_train)
        
        enc_train_cost_sb = encoder_cost(y_ph, G_train_enc)

        gen_train_sm = tf.summary.image('gen_train_img', G_train_sb, max_outputs=max_outputs)
        gen_train_mg = tf.summary.merge([gen_train_sm])

        gen_train_cost_sm = tf.summary.scalar('gen_cost', gen_train_cost_sb)
        dis_train_cost_sm = tf.summary.scalar('dis_cost', dis_train_cost_sb)
        enc_train_cost_sm = tf.summary.scalar('enc_cost', enc_train_cost_sb)
        cost_train_mg = tf.summary.merge([gen_train_cost_sm, dis_train_cost_sm, enc_train_cost_sm])

        gen_optimizer = tf.train.AdamOptimizer(gen_learning_rate).minimize(gen_train_cost_sb, var_list=gen_var_list)
        dis_optimizer = tf.train.AdamOptimizer(dis_learning_rate).minimize(dis_train_cost_sb, var_list=dis_var_list)
        enc_optimizer = tf.train.AdamOptimizer(enc_learning_rate).minimize(enc_train_cost_sb)

        clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in dis_var_list]
        
        # embedding_var = tf.Variable(tf.zeros([60000, 300]), trainable=False, name="embedding")
        # prepare projector config
        
        # summary_writer = tf.summary.FileWriter(logdir)
        # saver = tf.train.Saver([embedding_var])
            
        
        
        init = tf.global_variables_initializer()
        gan.sess.run(init)
        # es = tg.EarlyStopper(max_epoch=max_epoch,
        #                      epoch_look_back=epoch_look_back,
        #                      percent_decrease=percent_decrease)

        ttl_iter = 0
        error_writer = tf.summary.FileWriter(logdir + '/experiment', gan.sess.graph)
        

        img_writer = tf.summary.FileWriter('{}/orig_img'.format(logdir))
        orig_sm = tf.summary.image('orig_img', real_ph, max_outputs=max_outputs)
        img_writer.add_summary(orig_sm.eval(session=gan.sess, feed_dict={real_ph:data_train[:100].data[0]}))
        img_writer.flush()
        img_writer.close()
        
        #embed = gan.sess.graph.get_tensor_by_name('Generator/genc4')
        # Create metadata
        # embeddir = logdir 
        # if not os.path.exists(embeddir):
        #     os.makedirs(embeddir)
        # metadata_path = os.path.join(embeddir, 'metadata.tsv')
        
        temp_acc = []
        
        for epoch in range(1, max_epoch):
            print('epoch:', epoch)
            print('..training')
            print('..logdir', logdir)
            pbar = tg.ProgressBar(len(data_train))
            n_exp = 0
            ttl_mse = 0
            ttl_gen_cost = 0
            ttl_dis_cost = 0
            ttl_enc_cost = 0
            error_writer.reopen()
            
            if epoch == max_epoch-1:
                output = np.empty([0,300], 'float32')
                labels = np.empty([0,10], 'int32')
            
            # metadata = open(metadata_path, 'w')
            # metadata.write("Name\tLabels\n")

            for X_batch, y_batch in data_train:

                for i in range(3):
                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0, scale=noise_factor, size=(len(X_batch), bottleneck_dim))
                    else:
                        noise = np.random.uniform(-1,1, size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {noise_ph:noise, real_ph:X_batch, y_ph:y_batch}
                    gan.sess.run([dis_optimizer, clip_D], feed_dict=feed_dict)

                for i in range(1):
                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0, scale=noise_factor, size=(len(X_batch), bottleneck_dim))
                    else:
                        noise = np.random.uniform(-1,1, size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {noise_ph:noise, real_ph:X_batch, y_ph:y_batch}
                    gan.sess.run([enc_optimizer, gen_optimizer], feed_dict={noise_ph:noise, real_ph:X_batch, y_ph:y_batch})
                                
                fake_judge_v, cost_train,enc_cost, gen_cost, dis_cost = gan.sess.run([fake_judge, cost_train_mg, enc_train_cost_sb,gen_train_cost_sb,dis_train_cost_sb],
                                                               feed_dict=feed_dict)

                ttl_gen_cost += gen_cost * len(X_batch)
                ttl_dis_cost += dis_cost * len(X_batch)
                ttl_enc_cost += enc_cost * len(X_batch)
                n_exp += len(X_batch)
                pbar.update(n_exp)
                error_writer.add_summary(cost_train, n_exp + ttl_iter)
                error_writer.flush()
                
                if epoch == max_epoch-1:
                    results = gan.sess.run(G_train_embed, feed_dict = {real_ph:X_batch, y_ph:y_batch})
                    output = np.concatenate((output, results), axis = 0)
                    labels = np.concatenate((labels, y_batch), axis = 0)
                # import pdb; pdb.set_trace()
                # for x_row, y_row in zip(X_batch, y_batch):
                #    metadata.write('{}\t{}\n'.format(x_row, y_row))
            # metadata.close()
            error_writer.close()
            
            # import pdb; pdb.set_trace()
            # for ot in output:
            #     temp = tf.stack(ot, axis = 0)
            
            #embedding_var = tf.Variable(temp)
            
            # sess.run(tf.variables_initializer([embedding_var]))
            
            # saver.save(gan.sess, os.path.join(embeddir, 'model.ckpt'))
            
            # config = projector.ProjectorConfig()
            # embedding = config.embeddings.add()
            # embedding.tensor_name = embedding_var.name
            # embedding.metadata_path = metadata_path  
            # save embedding_var
            # projector.visualize_embeddings(summary_writer, config)
            
            ttl_iter += n_exp

            mean_gan_cost = ttl_gen_cost / n_exp
            mean_dis_cost = ttl_dis_cost / n_exp
            mean_enc_cost = ttl_enc_cost / n_exp
            print('\nmean train gen cost:', mean_gan_cost)
            print('mean train dis cost:', mean_dis_cost)
            print('enc train dis cost:', mean_enc_cost)
            lab = []
            
            if epoch == max_epoch-1:
                embeddir = './genData/3'
                if not os.path.exists(embeddir):
                    os.makedirs(embeddir)
                lab = np.nonzero(labels)[1]
                np.save(embeddir + 'embed.npy', output)
                np.save(embeddir + 'label.npy', lab)                    
            
                       
            valid_error = 0
            valid_accuracy = 0
            ttl_examples = 0
            for X_batch, ys in data_valid:
                feed_dict = {real_ph:X_batch, y_ph:y_batch}

                valid_outs = gan.sess.run(G_test_enc, feed_dict=feed_dict)
                valid_error += total_mse([valid_outs], [ys])[0]
                valid_accuracy += total_accuracy([valid_outs], [ys])[0]
                ttl_examples += len(X_batch)

            temp_acc.append(valid_accuracy/float(ttl_examples))
            print 'max accuracy is:\t', max(temp_acc)        
        print 'max accuracy is:\t', max(temp_acc)  

    return save_path
예제 #9
0
파일: main.py 프로젝트: Shirlly/GAN
def train(modelclass, dt=None):

    batchsize = 64
    gen_learning_rate = 0.001
    dis_learning_rate = 0.001
    bottleneck_dim = 300

    max_epoch = 100
    epoch_look_back = 3
    percent_decrease = 0
    noise_factor = 0.3  #  20170616_1459: 0.05   20170616_1951: 0.01
    max_outputs = 10

    noise_type = 'normal'

    print('gen_learning_rate:', gen_learning_rate)
    print('dis_learning_rate:', dis_learning_rate)
    print('noise_factor:', noise_factor)
    print('noise_type:', noise_type)

    if dt is None:
        timestamp = tg.utils.ts()
    else:
        timestamp = dt
    save_path = './save/{}/model'.format(timestamp)
    logdir = './log/{}'.format(timestamp)

    X_train, y_train, X_valid, y_valid = Mnist()
    # 0617_1346: 0.05   #0619_1033: 0.01   0619_1528:0.1  0619_1944: 0.3
    # X_train, y_train, X_valid, y_valid = Cifar100()
    # X_train, y_train, X_valid, y_valid = Cifar10(contrast_normalize=False, whiten=False)

    _, h, w, c = X_train.shape
    _, nclass = y_train.shape

    data_train = tg.SequentialIterator(X_train, y_train, batchsize=batchsize)
    data_valid = tg.SequentialIterator(X_valid, y_valid, batchsize=batchsize)

    gan = getattr(model, modelclass)(h, w, c, nclass, bottleneck_dim)

    y_ph, noise_ph, G_train_sb, G_test_sb, gen_var_list = gan.generator()
    real_ph, real_train, real_valid, fake_train, fake_valid, dis_var_list = gan.discriminator(
    )
    # real_ph, real_train, real_valid, fake_train, fake_valid, dis_var_list = gan.discriminator_allconv()

    print('..using model:', gan.__class__.__name__)

    print('Generator Variables')
    for var in gen_var_list:
        print(var.name)

    print('\nDiscriminator Variables')
    for var in dis_var_list:
        print(var.name)

    with gan.tf_graph.as_default():
        gen_train_cost_sb = generator_cost(y_ph, real_train, fake_train)
        fake_clss, fake_judge = fake_train

        dis_train_cost_sb = discriminator_cost(y_ph, real_train, fake_train)
        gen_train_sm = tf.summary.image('gen_train_img',
                                        G_train_sb,
                                        max_outputs=max_outputs)
        gen_train_mg = tf.summary.merge([gen_train_sm])

        gen_train_cost_sm = tf.summary.scalar('gen_cost', gen_train_cost_sb)
        dis_train_cost_sm = tf.summary.scalar('dis_cost', dis_train_cost_sb)
        cost_train_mg = tf.summary.merge(
            [gen_train_cost_sm, dis_train_cost_sm])

        gen_optimizer = tf.train.AdamOptimizer(gen_learning_rate).minimize(
            gen_train_cost_sb, var_list=gen_var_list)
        dis_optimizer = tf.train.AdamOptimizer(dis_learning_rate).minimize(
            dis_train_cost_sb, var_list=dis_var_list)

        clip_D = [
            p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in dis_var_list
        ]

        init = tf.global_variables_initializer()
        gan.sess.run(init)
        es = tg.EarlyStopper(max_epoch=max_epoch,
                             epoch_look_back=epoch_look_back,
                             percent_decrease=percent_decrease)

        ttl_iter = 0
        error_writer = tf.summary.FileWriter(logdir + '/experiment',
                                             gan.sess.graph)

        img_writer = tf.summary.FileWriter('{}/orig_img'.format(logdir))
        orig_sm = tf.summary.image('orig_img',
                                   real_ph,
                                   max_outputs=max_outputs)
        img_writer.add_summary(
            orig_sm.eval(session=gan.sess,
                         feed_dict={real_ph: data_train[:100].data[0]}))
        img_writer.flush()
        img_writer.close()

        for epoch in range(1, max_epoch):
            print('epoch:', epoch)
            print('..training')
            print('..logdir', logdir)
            pbar = tg.ProgressBar(len(data_train))
            n_exp = 0
            ttl_mse = 0
            ttl_gen_cost = 0
            ttl_dis_cost = 0
            error_writer.reopen()
            batch_iter = 1
            for X_batch, y_batch in data_train:

                for i in range(3):
                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0,
                                                 scale=noise_factor,
                                                 size=(len(X_batch),
                                                       bottleneck_dim))
                    else:
                        noise = np.random.uniform(
                            -1, 1,
                            size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {
                        noise_ph: noise,
                        real_ph: X_batch,
                        y_ph: y_batch
                    }
                    gan.sess.run([dis_optimizer, clip_D], feed_dict=feed_dict)

                for i in range(1):
                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0,
                                                 scale=noise_factor,
                                                 size=(len(X_batch),
                                                       bottleneck_dim))
                    else:
                        noise = np.random.uniform(
                            -1, 1,
                            size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {
                        noise_ph: noise,
                        real_ph: X_batch,
                        y_ph: y_batch
                    }
                    gan.sess.run(gen_optimizer,
                                 feed_dict={
                                     noise_ph: noise,
                                     real_ph: X_batch,
                                     y_ph: y_batch
                                 })

                if batch_iter == 1:
                    G_train, G_img = gan.sess.run([G_train_sb, gen_train_mg],
                                                  feed_dict=feed_dict)
                    gen_writer = tf.summary.FileWriter(
                        '{}/generator/{}'.format(logdir, epoch))
                    gen_writer.add_summary(G_img)
                    gen_writer.flush()
                    gen_writer.close()
                    batch_iter = 0

                fake_judge_v, cost_train, gen_cost, dis_cost = gan.sess.run(
                    [
                        fake_judge, cost_train_mg, gen_train_cost_sb,
                        dis_train_cost_sb
                    ],
                    feed_dict=feed_dict)

                ttl_gen_cost += gen_cost * len(X_batch)
                ttl_dis_cost += dis_cost * len(X_batch)
                n_exp += len(X_batch)
                pbar.update(n_exp)
                error_writer.add_summary(cost_train, n_exp + ttl_iter)
                error_writer.flush()

            error_writer.close()

            ttl_iter += n_exp

            mean_gan_cost = ttl_gen_cost / n_exp
            mean_dis_cost = ttl_dis_cost / n_exp
            print('\nmean train gen cost:', mean_gan_cost)
            print('mean train dis cost:', mean_dis_cost)

            if save_path:
                # print('\n..saving best model to: {}'.format(save_path))
                dname = os.path.dirname(save_path)
                if not os.path.exists(dname):
                    os.makedirs(dname)
                print('saved to {}'.format(dname))
                # gan.save(save_path)

                for X_batch, y_batch in data_train:

                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0,
                                                 scale=noise_factor,
                                                 size=(len(X_batch),
                                                       bottleneck_dim))
                    else:
                        noise = np.random.uniform(
                            -1, 1,
                            size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {
                        noise_ph: noise,
                        real_ph: X_batch,
                        y_ph: y_batch
                    }
                    # print '---- Before ----'
                    # print '--Number of threads running ', threading.active_count()
                    G_train, G_img = gan.sess.run([G_train_sb, gen_train_mg],
                                                  feed_dict=feed_dict)
                    train_writer = tf.summary.FileWriter(
                        '{}/experiment/{}'.format(logdir, epoch))
                    # print '---- After ----'
                    # print '--Number of threads running ', threading.active_count()
                    train_writer.add_summary(G_img)
                    train_writer.flush()
                    train_writer.close()

                    break

    return save_path
예제 #10
0
파일: main_out_lab.py 프로젝트: Shirlly/GAN
def train(modelclass, dt=None):

    batchsize = 64
    gen_learning_rate = 0.001
    dis_learning_rate = 0.001
    bottleneck_dim = 300

    max_epoch = 2
    epoch_look_back = 3
    percent_decrease = 0
    noise_factor = 0.1  #  20170616_1459: 0.05   20170616_1951: 0.01    
    max_outputs = 10

    noise_type = 'normal'

    print('gen_learning_rate:', gen_learning_rate)
    print('dis_learning_rate:', dis_learning_rate)
    print('noise_factor:', noise_factor)
    print('noise_type:', noise_type)


    if dt is None:
        timestamp = tg.utils.ts()
    else:
        timestamp = dt
    save_path = './save/{}/model'.format(timestamp)
    logdir = './log/{}'.format(timestamp)

    X_train, y_train, X_valid, y_valid = Mnist()  
    AuX_train = X_train
    Auy_train = y_train
    aux = np.empty((0, 28, 28, 1), 'float32')
    auy = np.empty((0, 10), 'int32')
    # 0617_1346: 0.05   #0619_1033: 0.01   0619_1528:0.1  0619_1944: 0.3
    # X_train, y_train, X_valid, y_valid = Cifar100()
    # X_train, y_train, X_valid, y_valid = Cifar10(contrast_normalize=False, whiten=False)

    _, h, w, c = X_train.shape
    _, nclass = y_train.shape

    data_train = tg.SequentialIterator(X_train, y_train, batchsize=batchsize)
    data_valid = tg.SequentialIterator(X_valid, y_valid, batchsize=batchsize)
    
    print '\n====== Before augment data size ', X_train.shape , ' ======\n'
    
    gan = getattr(model, modelclass)(h, w, c, nclass, bottleneck_dim)

    y_ph, noise_ph, G_train_sb, G_test_sb, gen_var_list = gan.generator()
    real_ph, real_train, real_valid, fake_train, fake_valid, dis_var_list = gan.discriminator()
    # real_ph, real_train, real_valid, fake_train, fake_valid, dis_var_list = gan.discriminator_allconv()

    print('..using model:', gan.__class__.__name__)

    print('Generator Variables')
    for var in gen_var_list:
        print(var.name)

    print('\nDiscriminator Variables')
    for var in dis_var_list:
        print(var.name)

    with gan.tf_graph.as_default():

        gen_train_cost_sb = generator_cost(y_ph, real_train, fake_train)
        fake_clss, fake_judge = fake_train

        dis_train_cost_sb = discriminator_cost(y_ph, real_train, fake_train)

        gen_train_sm = tf.summary.image('gen_train_img', G_train_sb, max_outputs=max_outputs)
        gen_train_mg = tf.summary.merge([gen_train_sm])

        gen_train_cost_sm = tf.summary.scalar('gen_cost', gen_train_cost_sb)
        dis_train_cost_sm = tf.summary.scalar('dis_cost', dis_train_cost_sb)
        cost_train_mg = tf.summary.merge([gen_train_cost_sm, dis_train_cost_sm])

        gen_optimizer = tf.train.AdamOptimizer(gen_learning_rate).minimize(gen_train_cost_sb, var_list=gen_var_list)
        dis_optimizer = tf.train.AdamOptimizer(dis_learning_rate).minimize(dis_train_cost_sb, var_list=dis_var_list)

        clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in dis_var_list]

        init = tf.global_variables_initializer()
        gan.sess.run(init)
        es = tg.EarlyStopper(max_epoch=max_epoch,
                             epoch_look_back=epoch_look_back,
                             percent_decrease=percent_decrease)

        ttl_iter = 0
        error_writer = tf.summary.FileWriter(logdir + '/experiment', gan.sess.graph)
        
        img_writer = tf.summary.FileWriter('{}/orig_img'.format(logdir))
        orig_sm = tf.summary.image('orig_img', real_ph, max_outputs=max_outputs)
        img_writer.add_summary(orig_sm.eval(session=gan.sess, feed_dict={real_ph:data_train[:100].data[0]}))
        img_writer.flush()
        img_writer.close()

        for epoch in range(1, max_epoch):
            print('epoch:', epoch)
            print('..training')
            print('..logdir', logdir)
            pbar = tg.ProgressBar(len(data_train))
            n_exp = 0
            ttl_mse = 0
            ttl_gen_cost = 0
            ttl_dis_cost = 0
            error_writer.reopen()
            for X_batch, y_batch in data_train:

                for i in range(3):
                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0, scale=noise_factor, size=(len(X_batch), bottleneck_dim))
                    else:
                        noise = np.random.uniform(-1,1, size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {noise_ph:noise, real_ph:X_batch, y_ph:y_batch}
                    gan.sess.run([dis_optimizer, clip_D], feed_dict=feed_dict)

                for i in range(1):
                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0, scale=noise_factor, size=(len(X_batch), bottleneck_dim))
                    else:
                        noise = np.random.uniform(-1,1, size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {noise_ph:noise, real_ph:X_batch, y_ph:y_batch}
                    gan.sess.run(gen_optimizer, feed_dict={noise_ph:noise, real_ph:X_batch, y_ph:y_batch})
                                
                fake_judge_v, cost_train, gen_cost, dis_cost = gan.sess.run([fake_judge, cost_train_mg, gen_train_cost_sb, dis_train_cost_sb],
                                                               feed_dict=feed_dict)

                ttl_gen_cost += gen_cost * len(X_batch)
                ttl_dis_cost += dis_cost * len(X_batch)
                n_exp += len(X_batch)
                pbar.update(n_exp)
                error_writer.add_summary(cost_train, n_exp + ttl_iter)
                error_writer.flush()
                
            error_writer.close()

            ttl_iter += n_exp

            mean_gan_cost = ttl_gen_cost / n_exp
            mean_dis_cost = ttl_dis_cost / n_exp
            print('\nmean train gen cost:', mean_gan_cost)
            print('mean train dis cost:', mean_dis_cost)


            if save_path and epoch == max_epoch-1:
                # print('\n..saving best model to: {}'.format(save_path))
                dname = os.path.dirname(save_path)
                if not os.path.exists(dname):
                    os.makedirs(dname)
                print('saved to {}'.format(dname))
                train_writer = tf.summary.FileWriter('{}/experiment/{}'.format(logdir, epoch))
                
                for X_batch, y_batch in data_train:
                    #import pdb; pdb.set_trace()

                    if noise_type == 'normal':
                        noise = np.random.normal(loc=0, scale=noise_factor, size=(len(X_batch), bottleneck_dim))
                    else:
                        noise = np.random.uniform(-1,1, size=(len(X_batch), bottleneck_dim)) * noise_factor

                    feed_dict = {noise_ph:noise, real_ph:X_batch, y_ph:y_batch}
                    G_train, G_img, fake_dis = gan.sess.run([G_train_sb, gen_train_mg, fake_train], feed_dict=feed_dict)
                    fake_class_dis, fake_judge_dis = fake_dis
                    idx = [i for i,v in enumerate(fake_judge_dis) if v>0.5]
                    aux = np.concatenate((aux, G_train[idx]), axis = 0)
                    auy = np.concatenate((auy, fake_class_dis[idx]), axis = 0)
                    AuX_train = np.concatenate((G_train, AuX_train), axis = 0)
                    Auy_train = np.concatenate((y_batch, Auy_train), axis = 0)
                    # temp_data = zip(G_img, y_batch)
                    # aug_data.append(temp_data)
                    train_writer.add_summary(G_img)                    
                    train_writer.flush()
                train_writer.close()
                xname = 'genx.npy'
                yname = 'geny.npy'
                np.save('{}/{}'.format(logdir, xname), aux)
                np.save('{}/{}'.format(logdir, yname), auy)
        
        print '\n====== Augment data size ', AuX_train.shape , ' ======\n'
        print '\n====== Augment data size ', Auy_train.shape , ' ======\n'
        

    return save_path, X_train, y_train, X_valid, y_valid, AuX_train, Auy_train, aux, auy
예제 #11
0
def train():
    learning_rate = 0.001
    batchsize = 32

    max_epoch = 300
    es = tg.EarlyStopper(max_epoch=max_epoch,
                         epoch_look_back=3,
                         percent_decrease=0)

    seq = model()
    X_train, y_train, X_test, y_test = Mnist(flatten=False,
                                             onehot=True,
                                             binary=True,
                                             datadir='.')
    iter_train = tg.SequentialIterator(X_train, y_train, batchsize=batchsize)
    iter_test = tg.SequentialIterator(X_test, y_test, batchsize=batchsize)

    X_ph = tf.placeholder('float32', [None, 28, 28, 1])
    y_ph = tf.placeholder('float32', [None, 10])

    y_train_sb = seq.train_fprop(X_ph)
    y_test_sb = seq.test_fprop(X_ph)

    train_cost_sb = entropy(y_ph, y_train_sb)
    test_cost_sb = entropy(y_ph, y_test_sb)
    test_accu_sb = accuracy(y_ph, y_test_sb)

    # required for BatchNormalization layer
    optimizer = tf.train.AdamOptimizer(learning_rate)
    update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
    with ops.control_dependencies(update_ops):
        train_ops = optimizer.minimize(train_cost_sb)

    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)
    with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        best_valid_accu = 0
        for epoch in range(max_epoch):
            print('epoch:', epoch)
            pbar = tg.ProgressBar(len(iter_train))
            ttl_train_cost = 0
            ttl_examples = 0
            print('..training')
            for X_batch, y_batch in iter_train:
                feed_dict = {X_ph: X_batch, y_ph: y_batch}
                _, train_cost = sess.run([train_ops, train_cost_sb],
                                         feed_dict=feed_dict)
                ttl_train_cost += len(X_batch) * train_cost
                ttl_examples += len(X_batch)
                pbar.update(ttl_examples)
            mean_train_cost = ttl_train_cost / float(ttl_examples)
            print('\ntrain cost', mean_train_cost)

            ttl_valid_cost = 0
            ttl_valid_accu = 0
            ttl_examples = 0
            pbar = tg.ProgressBar(len(iter_test))
            print('..validating')
            for X_batch, y_batch in iter_test:
                feed_dict = {X_ph: X_batch, y_ph: y_batch}
                valid_cost, valid_accu = sess.run([test_cost_sb, test_accu_sb],
                                                  feed_dict=feed_dict)
                ttl_valid_cost += len(X_batch) * valid_cost
                ttl_valid_accu += len(X_batch) * valid_accu
                ttl_examples += len(X_batch)
                pbar.update(ttl_examples)
            mean_valid_cost = ttl_valid_cost / float(ttl_examples)
            mean_valid_accu = ttl_valid_accu / float(ttl_examples)
            print('\nvalid cost', mean_valid_cost)
            print('valid accu', mean_valid_accu)
            if best_valid_accu < mean_valid_accu:
                best_valid_accu = mean_valid_accu

            if es.continue_learning(valid_error=mean_valid_cost, epoch=epoch):
                print('epoch', epoch)
                print('best epoch last update:', es.best_epoch_last_update)
                print('best valid last update:', es.best_valid_last_update)
                print('best valid accuracy:', best_valid_accu)
            else:
                print('training done!')
                break
예제 #12
0
def train():
    learning_rate = 0.001
    batchsize = 64
    max_epoch = 300
    es = tg.EarlyStopper(max_epoch=max_epoch,
                         epoch_look_back=None,
                         percent_decrease=0)

    X_train, y_train, X_test, y_test = Cifar10(contrast_normalize=False,
                                               whiten=False)
    _, h, w, c = X_train.shape
    _, nclass = y_train.shape

    seq = model(nclass=nclass, h=h, w=w, c=c)
    iter_train = tg.SequentialIterator(X_train, y_train, batchsize=batchsize)
    iter_test = tg.SequentialIterator(X_test, y_test, batchsize=batchsize)

    X_ph = tf.placeholder('float32', [None, h, w, c])
    y_ph = tf.placeholder('float32', [None, nclass])

    y_train_sb = seq.train_fprop(X_ph)
    y_test_sb = seq.test_fprop(X_ph)

    train_cost_sb = entropy(y_ph, y_train_sb)
    test_cost_sb = entropy(y_ph, y_test_sb)
    test_accu_sb = accuracy(y_ph, y_test_sb)

    # required for BatchNormalization layer
    optimizer = tf.train.AdamOptimizer(learning_rate)
    update_ops = ops.get_collection(ops.GraphKeys.UPDATE_OPS)
    with ops.control_dependencies(update_ops):
        train_ops = optimizer.minimize(train_cost_sb)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        init = tf.global_variables_initializer()
        sess.run(init)

        best_valid_accu = 0
        for epoch in range(max_epoch):
            print('epoch:', epoch)
            pbar = tg.ProgressBar(len(iter_train))
            ttl_train_cost = 0
            ttl_examples = 0
            print('..training')
            for X_batch, y_batch in iter_train:
                feed_dict = {X_ph: X_batch, y_ph: y_batch}
                _, train_cost = sess.run([train_ops, train_cost_sb],
                                         feed_dict=feed_dict)
                ttl_train_cost += len(X_batch) * train_cost
                ttl_examples += len(X_batch)
                pbar.update(ttl_examples)
            mean_train_cost = ttl_train_cost / float(ttl_examples)
            print('\ntrain cost', mean_train_cost)

            ttl_valid_cost = 0
            ttl_valid_accu = 0
            ttl_examples = 0
            pbar = tg.ProgressBar(len(iter_test))
            print('..validating')
            for X_batch, y_batch in iter_test:
                feed_dict = {X_ph: X_batch, y_ph: y_batch}
                valid_cost, valid_accu = sess.run([test_cost_sb, test_accu_sb],
                                                  feed_dict=feed_dict)
                ttl_valid_cost += len(X_batch) * valid_cost
                ttl_valid_accu += len(X_batch) * valid_accu
                ttl_examples += len(X_batch)
                pbar.update(ttl_examples)
            mean_valid_cost = ttl_valid_cost / float(ttl_examples)
            mean_valid_accu = ttl_valid_accu / float(ttl_examples)
            print('\nvalid cost', mean_valid_cost)
            print('valid accu', mean_valid_accu)
            if best_valid_accu < mean_valid_accu:
                best_valid_accu = mean_valid_accu

            if es.continue_learning(valid_error=mean_valid_cost, epoch=epoch):
                print('epoch', epoch)
                print('best epoch last update:', es.best_epoch_last_update)
                print('best valid last update:', es.best_valid_last_update)
                print('best valid accuracy:', best_valid_accu)
            else:
                print('training done!')
                break