Beispiel #1
0
def run(config, dataset, model, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    batch1, batch2, batch1_dst = dataset.input()

    # saver = tf.train.Saver()
    save_variables = tf.trainable_variables()
    g_list = tf.global_variables()
    bn_variables = [g for g in g_list if 'moving_mean' in g.name] \
                   + [g for g in g_list if 'moving_variance' in g.name] \
                   + [g for g in g_list if 'RMSProp' in g.name]
    # save_variables += bn_variables
    # print([g for g in g_list if 'RMSProp' in g.name])
    saver = tf.train.Saver(var_list=save_variables)

    # image summary
    Ax_op = tf.summary.image('Ax', model.Ax, max_outputs=30)
    Be_op = tf.summary.image('Be', model.Be, max_outputs=30)
    Ax2_op = tf.summary.image('Ax2', model.Ax2, max_outputs=30)
    Be2_op = tf.summary.image('Be2', model.Be2, max_outputs=30)
    Bx_op = tf.summary.image('Bx', model.Bx, max_outputs=30)
    Ae_op = tf.summary.image('Ae', model.Ae, max_outputs=30)

    # G loss summary
    for key in model.G_loss.keys():
        tf.summary.scalar(key, model.G_loss[key])

    loss_G_nodecay_op = tf.summary.scalar('loss_G_nodecay',
                                          model.loss_G_nodecay)
    loss_G_decay_op = tf.summary.scalar('loss_G_decay', model.loss_G_decay)
    loss_G_op = tf.summary.scalar('loss_G', model.loss_G)

    # D loss summary
    for key in model.D_loss.keys():
        tf.summary.scalar(key, model.D_loss[key])

    loss_D_op = tf.summary.scalar('loss_D', model.loss_D)

    # learning rate summary
    g_lr_op = tf.summary.scalar('g_learning_rate', model.g_lr)
    d_lr_op = tf.summary.scalar('d_learning_rate', model.d_lr)

    merged_op = tf.summary.merge_all()

    # start training
    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = 0.3
    sess = tf.Session(config=sess_config)
    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(
        config.model_dir.replace('train_log_saveBN', 'train_log'))
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        print('restored')

    save_variables += bn_variables
    print([g for g in g_list if 'RMSProp' in g.name])
    saver = tf.train.Saver(var_list=save_variables)

    writer = tf.summary.FileWriter(config.log_dir, sess.graph)
    writer.add_graph(sess.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(config.max_iter):
        d_num = 100 if i % 500 == 0 else 1

        # update D with clipping
        for j in range(d_num):
            _, loss_D_sum, _ = sess.run(
                [model.d_opt, model.loss_D, model.clip_d],
                feed_dict={
                    model.Ax: sess.run(batch1),
                    model.Be: sess.run(batch2),
                    model.Ae_dst: sess.run(batch1_dst),
                    model.g_lr: config.g_lr(epoch=i),
                    model.d_lr: config.d_lr(epoch=i)
                })

        # update G
        _, loss_G_sum = sess.run(
            [model.g_opt, model.loss_G],
            feed_dict={
                model.Ax: sess.run(batch1),
                model.Be: sess.run(batch2),
                model.Ae_dst: sess.run(batch1_dst),
                model.g_lr: config.g_lr(epoch=i),
                model.d_lr: config.d_lr(epoch=i)
            })

        print('iter: {:06d},   g_loss: {}    d_loss: {}'.format(
            i, loss_D_sum, loss_G_sum))

        if i % 20 == 0:
            merged_summary = sess.run(merged_op,
                                      feed_dict={
                                          model.Ax: sess.run(batch1),
                                          model.Be: sess.run(batch2),
                                          model.Ae_dst: sess.run(batch1_dst),
                                          model.g_lr: config.g_lr(epoch=i),
                                          model.d_lr: config.d_lr(epoch=i)
                                      })

            writer.add_summary(merged_summary, i)

        if i % 500 == 0:
            saver.save(
                sess,
                os.path.join(config.model_dir, 'model_{:06d}.ckpt'.format(i)))

            img_Ax, img_Be, img_Ae, img_Bx, img_Ax2, img_Be2 = sess.run(
                [model.Ax, model.Be, model.Ae, model.Bx, model.Ax2, model.Be2],
                feed_dict={
                    model.Ax: sess.run(batch1),
                    model.Be: sess.run(batch2)
                })

            for j in range(5):
                img = np.concatenate((img_Ax[j], img_Be[j], img_Ae[j],
                                      img_Bx[j], img_Ax2[j], img_Be2[j]),
                                     axis=1)
                misc.imsave(
                    os.path.join(config.sample_img_dir,
                                 'iter_{:06d}_{}.jpg'.format(i, j)), img)

    writer.close()
    saver.save(sess, os.path.join(config.model_dir, 'model.ckpt'))

    coord.request_stop()
    coord.join(threads)
Beispiel #2
0
def run(config, dataset, model, gpu):
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu

    batchs, labels = dataset.input()

    saver = tf.train.Saver()

    # image summary
    image_summry_op = []
    image_summry_op += [
        tf.summary.image('Ax_{}'.format(i), model.Axs[i], max_outputs=30)
        for i in range(model.n_feat)
    ]
    image_summry_op += [
        tf.summary.image('Be_{}'.format(i), model.Bes[i], max_outputs=30)
        for i in range(model.n_feat)
    ]
    image_summry_op += [
        tf.summary.image('Ax2_{}'.format(i), model.Axs2[i], max_outputs=30)
        for i in range(model.n_feat)
    ]
    image_summry_op += [
        tf.summary.image('Be2_{}'.format(i), model.Bes2[i], max_outputs=30)
        for i in range(model.n_feat)
    ]
    image_summry_op += [
        tf.summary.image('Ae_{}'.format(i), model.Aes[i], max_outputs=30)
        for i in range(model.n_feat)
    ]
    image_summry_op += [
        tf.summary.image('Bx_{}'.format(i), model.Bxs[i], max_outputs=30)
        for i in range(model.n_feat)
    ]

    # G loss summary
    for key in model.G_loss.keys():
        tf.summary.scalar(key, model.G_loss[key])

    loss_G_nodecay_op = tf.summary.scalar('loss_G_nodecay',
                                          model.loss_G_nodecay)
    loss_G_decay_op = tf.summary.scalar('loss_G_decay', model.loss_G_decay)
    loss_G_op = tf.summary.scalar('loss_G', model.loss_G)

    # D loss summary
    for key in model.D_loss.keys():
        tf.summary.scalar(key, model.D_loss[key])

    loss_D_op = tf.summary.scalar('loss_D', model.loss_D)

    # learning rate summary
    g_lr_op = tf.summary.scalar('g_learning_rate', model.g_lr)
    d_lr_op = tf.summary.scalar('d_learning_rate', model.d_lr)

    # merged_op = tf.contrib.deprecated.merge_all_summaries()
    merged_op = tf.summary.merge_all()

    # start training
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    ckpt = tf.train.get_checkpoint_state(config.model_dir)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)

    writer = tf.summary.FileWriter(config.log_dir, sess.graph)
    writer.add_graph(sess.graph)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(config.max_iter):
        d_num = 100 if i % 500 == 0 else 1

        batch_images, batch_labels = sess.run([batchs, labels])
        feed_dict = {
            model.g_lr: config.g_lr(epoch=i),
            model.d_lr: config.d_lr(epoch=i),
        }
        for j in range(model.n_feat):
            feed_dict[model.Axs[j]] = batch_images[2 * j]
            feed_dict[model.Bes[j]] = batch_images[2 * j + 1]
            feed_dict[model.label_Axs[j]] = batch_labels[2 * j]
            feed_dict[model.label_Bes[j]] = batch_labels[2 * j + 1]

        # from IPython import embed; embed();exit()
        # update D with clipping
        for j in range(d_num):
            _, loss_D_sum, _ = sess.run(
                [model.d_opt, model.loss_D, model.clip_d], feed_dict=feed_dict)

        # update G
        _, loss_G_sum = sess.run([model.g_opt, model.loss_G],
                                 feed_dict=feed_dict)

        print('iter: {:06d},   g_loss: {}    d_loss: {}'.format(
            i, loss_D_sum, loss_G_sum))

        if i % 20 == 0:
            merged_summary = sess.run(merged_op, feed_dict=feed_dict)
            writer.add_summary(merged_summary, i)

        if i % 500 == 0:
            saver.save(
                sess,
                os.path.join(config.model_dir, 'model_{:06d}.ckpt'.format(i)))

            img_Axs, img_Bes, img_Aes, img_Bxs, img_Axs2, img_Bes2 = sess.run(
                [
                    model.Axs, model.Bes, model.Aes, model.Bxs, model.Axs2,
                    model.Bes2
                ],
                feed_dict=feed_dict)

            for k in range(model.n_feat):
                for j in range(5):
                    img = np.concatenate(
                        (img_Axs[k][j], img_Bes[k][j], img_Aes[k][j],
                         img_Bxs[k][j], img_Axs2[k][j], img_Bes2[k][j]),
                        axis=1)
                    misc.imsave(
                        os.path.join(
                            config.sample_img_dir,
                            'iter_{:06d}_{}_{}.jpg'.format(
                                i, j, model.feature_list[k])), img)

    writer.close()
    saver.save(sess, os.path.join(config.model_dir, 'model.ckpt'))

    coord.request_stop()
    coord.join(threads)