Ejemplo n.º 1
0
def freeze(ckpt_path):

    OPTIONS = namedtuple(
        'OPTIONS', 'batch_size image_size \
                                 gf_dim df_dim output_c_dim')
    options = OPTIONS._make(
        (args.batch_size, args.fine_size, args.ngf, args.ndf, args.output_nc))

    inp_content_image = tf.placeholder(tf.float32,
                                       shape=(1, None, None, 3),
                                       name='input')

    out_image = generator_resnet(inp_content_image,
                                 options,
                                 name='generatorA2B')
    out_image = tf.identity(out_image, name='output')

    init_op = tf.global_variables_initializer()

    restore_saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init_op)
        restore_saver.restore(sess, ckpt_path)
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, sess.graph_def, output_node_names=['output'])

        path = os.path.dirname(ckpt_path)
        with open(path + '/cyclegan.pb', 'wb') as f:
            f.write(frozen_graph_def.SerializeToString())
Ejemplo n.º 2
0
    def __init__(self, args):
        self.batch_size = args.batch_size
        self.image_width = args.image_width
        self.image_height = args.image_height
        self.input_c_dim = args.input_nc
        self.output_c_dim = args.output_nc
        self.L1_lambda = args.L1_lambda
        self.Lg_lambda = args.Lg_lambda
        self.dataset_dir = args.dataset_dir
        self.segment_class = args.segment_class
        self.alpha_recip = 1. / args.ratio_gan2seg if args.ratio_gan2seg > 0 else 0

        self.use_pix2pix = args.use_pix2pix

        self.discriminator = discriminator()
        if args.use_resnet:
            self.generator = generator_resnet()
        else:
            if args.use_pix2pix:
                self.generator = generator_pix2pix()
                self.discriminator = discriminator_pix2pix()
            else:
                self.generator = generator_unet()

        if args.use_lsgan:
            self.criterionGAN = mae_criterion
        else:
            self.criterionGAN = sce_criterion

        # tf.keras.utils.plot_model(self.discriminator, 'multi_input_and_output_model.png', show_shapes=True)
        # input("")

        OPTIONS = namedtuple(
            'OPTIONS', 'batch_size image_height image_width \
                              gf_dim df_dim output_c_dim is_training segment_class'
        )
        self.options = OPTIONS._make(
            (args.batch_size, args.image_height, args.image_width, args.ngf,
             args.ndf, args.output_nc, args.phase == 'train',
             args.segment_class))

        self._build_model(args)
        self.pool = ImagePool(args.max_size)

        #### [ADDED] CHECKPOINT MANAGER
        self.lr = 0.001
        self.d_optim = tf.keras.optimizers.Adam(learning_rate=self.lr,
                                                beta_1=args.beta1)
        self.g_optim = tf.keras.optimizers.Adam(learning_rate=self.lr,
                                                beta_1=args.beta1)

        self.gen_ckpt = tf.train.Checkpoint(optimizer=self.g_optim,
                                            net=self.generator)
        self.disc_ckpt = tf.train.Checkpoint(optimizer=self.d_optim,
                                             net=self.discriminator)
        self.gen_ckpt_manager = tf.train.CheckpointManager(
            self.gen_ckpt, './checkpoint/gta/gen_ckpts', max_to_keep=3)
        self.disc_ckpt_manager = tf.train.CheckpointManager(
            self.disc_ckpt, './checkpoint/gta/disc_ckpts', max_to_keep=3)
def main(args=None):
    print(args)
    tf.reset_default_graph()
    """
    Read dataset parser     
    """
    flags.network_name = args[0].split('/')[-1].split('.')[0].split(
        'main_')[-1]
    flags.logs_dir = './logs_' + flags.network_name
    dataset_parser = GANParser(flags=flags)
    """
    Transform data to TFRecord format (Only do once.)     
    """
    if False:
        dataset_parser.load_paths(is_jpg=True, load_val=True)
        dataset_parser.data2record(name='{}_train.tfrecords'.format(
            dataset_parser.dataset_name),
                                   set_type='train',
                                   test_num=None)
        dataset_parser.data2record(name='{}_val.tfrecords'.format(
            dataset_parser.dataset_name),
                                   set_type='val',
                                   test_num=None)
        # coco_parser.data2record_test(name='coco_stuff2017_test-dev_all_label.tfrecords', is_dev=True, test_num=None)
        # coco_parser.data2record_test(name='coco_stuff2017_test_all_label.tfrecords', is_dev=False, test_num=None)
        return
    """
    Build Graph
    """
    with tf.Graph().as_default():
        """
        Input (TFRecord)
        """
        with tf.name_scope('TFRecord'):
            # DatasetA
            training_a_dataset = dataset_parser.tfrecord_get_dataset(
                name='{}_trainA.tfrecords'.format(dataset_parser.dataset_name),
                batch_size=flags.batch_size,
                shuffle_size=None)
            val_a_dataset = dataset_parser.tfrecord_get_dataset(
                name='{}_valA.tfrecords'.format(dataset_parser.dataset_name),
                batch_size=flags.batch_size,
                need_flip=(flags.mode == 'train'))
            # DatasetB
            training_b_dataset = dataset_parser.tfrecord_get_dataset(
                name='{}_trainB.tfrecords'.format(dataset_parser.dataset_name),
                batch_size=flags.batch_size,
                shuffle_size=None)
            val_b_dataset = dataset_parser.tfrecord_get_dataset(
                name='{}_valB.tfrecords'.format(dataset_parser.dataset_name),
                batch_size=flags.batch_size,
                is_label=True,
                need_flip=(flags.mode == 'train'))
            # A feed-able iterator
            with tf.name_scope('RealA'):
                handle_a = tf.placeholder(tf.string, shape=[])
                iterator_a = tf.contrib.data.Iterator.from_string_handle(
                    handle_a, training_a_dataset.output_types,
                    training_a_dataset.output_shapes)
                real_a, real_a_name, real_a_shape = iterator_a.get_next()
            with tf.name_scope('RealB'):
                handle_b = tf.placeholder(tf.string, shape=[])
                iterator_b = tf.contrib.data.Iterator.from_string_handle(
                    handle_b, training_b_dataset.output_types,
                    training_b_dataset.output_shapes)
                real_b, real_b_name, real_b_shape = iterator_b.get_next()
            with tf.name_scope('InitialA_op'):
                training_a_iterator = training_a_dataset.make_initializable_iterator(
                )
                validation_a_iterator = val_a_dataset.make_initializable_iterator(
                )
            with tf.name_scope('InitialB_op'):
                training_b_iterator = training_b_dataset.make_initializable_iterator(
                )
                validation_b_iterator = val_b_dataset.make_initializable_iterator(
                )
        """
        Network (Computes predictions from the inference model)
        """
        with tf.name_scope('Network'):
            # Input
            global_step = tf.Variable(0,
                                      trainable=False,
                                      name='global_step',
                                      dtype=tf.int32)
            global_step_update_op = tf.assign_add(global_step,
                                                  1,
                                                  name='global_step_update_op')
            # mean_rgb = tf.constant((123.68, 116.78, 103.94), dtype=tf.float32)
            fake_b_pool = tf.placeholder(tf.float32,
                                         shape=[
                                             None, flags.image_height,
                                             flags.image_width, flags.c_in_dim
                                         ],
                                         name='fake_B_pool')
            image_linear_shape = tf.constant(
                flags.image_height * flags.image_width * flags.c_in_dim,
                dtype=tf.int32,
                name='image_linear_shape')

            # A -> B
            '''
            with tf.name_scope('Generator'):
                with slim.arg_scope(vgg.vgg_arg_scope()):
                    net, end_points = vgg.vgg_16(real_a - mean_rgb, num_classes=1, is_training=True, 
                    spatial_squeeze=False)
                    print(net)
                    return

                with tf.variable_scope('Generator_A2B'):
                    pred = tf.layers.conv2d(tf.nn.relu(net), 1, 1, 1)
                    pred_upscale = tf.image.resize_bilinear(pred, (flags.image_height, flags.image_width), 
                    name='up_scale')
                    segment_a = tf.nn.sigmoid(pred_upscale, name='segment_a')

            # sigmoid cross entropy Loss
            with tf.name_scope('loss_gen_a2b'):
                loss_gen_a2b = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                    logits=pred_upscale, labels=real_b/255.0, name='sigmoid'), name='mean')
            '''

            # A -> B
            # adjusted_a = tf.zeros_like(real_a, tf.float32, name='mask', optimize=True)
            # adjusted_a = high_light(real_a, name='high_light')
            adjusted_a = tf.layers.average_pooling2d(real_a,
                                                     11,
                                                     strides=1,
                                                     padding='same',
                                                     name='adjusted_a')
            logits_a = generator_resnet(real_a,
                                        flags,
                                        False,
                                        name="Generator_A2B")
            segment_a = tf.nn.tanh(logits_a, name='segment_a')

            logits_a_ori = tf.image.resize_bilinear(
                logits_a, (real_a_shape[0][0], real_b_shape[0][1]),
                name='logits_a_ori')
            segment_a_ori = tf.nn.tanh(logits_a_ori, name='segment_a_ori')

            with tf.variable_scope('Fake_B'):
                foreground = tf.multiply(real_a, segment_a, name='foreground')
                background = tf.multiply(adjusted_a, (1 - segment_a),
                                         name='background')
                fake_b_logits = tf.add(foreground,
                                       background,
                                       name='fake_b_logits')
                fake_b = tf.clip_by_value(fake_b_logits, 0, 255, name='fake_b')

            #
            fake_b_f = tf.reshape(fake_b, [-1, image_linear_shape],
                                  name='fake_b_f')
            fake_b_pool_f = tf.reshape(fake_b_pool, [-1, image_linear_shape],
                                       name='fake_b_pool_f')
            real_b_f = tf.reshape(real_b, [-1, image_linear_shape],
                                  name='real_b_f')
            dis_fake_b = discriminator_se_wgangp(fake_b_f,
                                                 flags,
                                                 reuse=False,
                                                 name="Discriminator_B")
            dis_fake_b_pool = discriminator_se_wgangp(fake_b_pool_f,
                                                      flags,
                                                      reuse=True,
                                                      name="Discriminator_B")
            dis_real_b = discriminator_se_wgangp(real_b_f,
                                                 flags,
                                                 reuse=True,
                                                 name="Discriminator_B")

            # WGAN Loss
            with tf.name_scope('loss_gen_a2b'):
                loss_gen_a2b = -tf.reduce_mean(dis_fake_b)

            with tf.name_scope('loss_dis_b'):
                loss_dis_b_adv_real = -tf.reduce_mean(dis_real_b)
                loss_dis_b_adv_fake = tf.reduce_mean(dis_fake_b_pool)
                loss_dis_b = tf.reduce_mean(dis_fake_b_pool) - tf.reduce_mean(
                    dis_real_b)
                with tf.name_scope('wgan-gp'):
                    alpha = tf.random_uniform(shape=[flags.batch_size, 1],
                                              minval=0.,
                                              maxval=1.)
                    differences = fake_b_pool_f - real_b_f
                    interpolates = real_b_f + (alpha * differences)
                    gradients = tf.gradients(
                        discriminator_se_wgangp(interpolates,
                                                flags,
                                                reuse=True,
                                                name="Discriminator_B"),
                        [interpolates])[0]
                    slopes = tf.sqrt(
                        tf.reduce_sum(tf.square(gradients),
                                      reduction_indices=[1]))
                    gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
                    loss_dis_b += flags.lambda_gp * gradient_penalty

            # Optimizer
            '''
            trainable_var_resnet = tf.get_collection(
                key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_16')
            trainable_var_gen_a2b = tf.get_collection(
                key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator_A2B') + trainable_var_resnet
            slim.model_analyzer.analyze_vars(trainable_var_gen_a2b, print_info=True)
            '''
            trainable_var_gen_a2b = tf.get_collection(
                key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator_A2B')
            trainable_var_dis_b = tf.get_collection(
                key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator_B')
            with tf.name_scope('learning_rate_decay'):
                decay = tf.maximum(
                    0.,
                    1. -
                    (tf.cast(global_step, tf.float32) / flags.training_iter),
                    name='decay')
                learning_rate = tf.multiply(flags.learning_rate,
                                            decay,
                                            name='learning_rate')
            train_op_gen_a2b = train_op(loss_gen_a2b,
                                        learning_rate,
                                        flags,
                                        trainable_var_gen_a2b,
                                        name='gen_a2b')
            train_op_dis_b = train_op(loss_dis_b,
                                      learning_rate,
                                      flags,
                                      trainable_var_dis_b,
                                      name='dis_b')

        saver = tf.train.Saver(max_to_keep=2)
        # Graph Logs
        with tf.name_scope('GEN_a2b'):
            tf.summary.scalar("loss/gen_a2b/all", loss_gen_a2b)
        with tf.name_scope('DIS_b'):
            tf.summary.scalar("loss/dis_b/all", loss_dis_b)
            tf.summary.scalar("loss/dis_b/adv_real", loss_dis_b_adv_real)
            tf.summary.scalar("loss/dis_b/adv_fake", loss_dis_b_adv_fake)
        summary_op = tf.summary.merge_all()
        """
        Session
        """
        tfconfig = tf.ConfigProto(allow_soft_placement=True)
        tfconfig.gpu_options.allow_growth = True
        with tf.Session(config=tfconfig) as sess:
            with tf.name_scope('Initial'):
                ckpt = tf.train.get_checkpoint_state(
                    dataset_parser.checkpoint_dir)
                if ckpt and ckpt.model_checkpoint_path:
                    print("Model restored: {}".format(
                        ckpt.model_checkpoint_path))
                    saver.restore(sess, ckpt.model_checkpoint_path)
                else:
                    print("No Model found.")
                    init_op = tf.group(tf.global_variables_initializer(),
                                       tf.local_variables_initializer())
                    sess.run(init_op)

                    # init_fn = slim.assign_from_checkpoint_fn('./pretrained/vgg_16.ckpt',
                    #                                          slim.get_model_variables('vgg_16'))
                    # init_fn(sess)
                summary_writer = tf.summary.FileWriter(dataset_parser.logs_dir,
                                                       sess.graph)
            """
            Training Mode
            """
            if flags.mode == 'train':
                print('Training mode! Batch size:{:d}'.format(
                    flags.batch_size))
                with tf.variable_scope('Input_port'):
                    training_a_handle = sess.run(
                        training_a_iterator.string_handle())
                    training_b_handle = sess.run(
                        training_b_iterator.string_handle())
                    # val_a_handle = sess.run(validation_a_iterator.string_handle())
                    # val_b_handle = sess.run(validation_b_iterator.string_handle())
                    image_pool_a, image_pool_b = ImagePool(
                        flags.pool_size), ImagePool(flags.pool_size)

                print('Start Training!')
                start_time = time.time()
                sess.run([
                    training_a_iterator.initializer,
                    training_b_iterator.initializer
                ])
                feed_dict_train = {
                    handle_a: training_a_handle,
                    handle_b: training_b_handle
                }
                # feed_dict_valid = {is_training: False}
                global_step_sess = sess.run(global_step)
                while global_step_sess < flags.training_iter:
                    try:
                        # Update gen_A2B, gen_B2A
                        _, fake_b_sess = sess.run([train_op_gen_a2b, fake_b],
                                                  feed_dict=feed_dict_train)

                        # _, loss_gen_a2b_sess = sess.run([train_op_gen_a2b, loss_gen_a2b], feed_dict=feed_dict_train)

                        # Update dis_B, dis_A
                        fake_b_pool_query = image_pool_b.query(fake_b_sess)
                        _ = sess.run(train_op_dis_b,
                                     feed_dict={
                                         fake_b_pool: fake_b_pool_query,
                                         handle_b: training_b_handle
                                     })

                        sess.run(global_step_update_op)
                        global_step_sess, learning_rate_sess = sess.run(
                            [global_step, learning_rate])
                        print(
                            'global step:[{:d}/{:d}], learning rate:{:f}, time:{:4.4f}'
                            .format(global_step_sess, flags.training_iter,
                                    learning_rate_sess,
                                    time.time() - start_time))

                        # Logging the events
                        if global_step_sess % flags.log_freq == 1:
                            print('Logging the events')
                            summary_op_sess = sess.run(summary_op,
                                                       feed_dict={
                                                           handle_a:
                                                           training_a_handle,
                                                           handle_b:
                                                           training_b_handle,
                                                           fake_b_pool:
                                                           fake_b_pool_query
                                                       })
                            summary_writer.add_summary(summary_op_sess,
                                                       global_step_sess)
                            # summary_writer.flush()

                        # Observe training situation (For debugging.)
                        if flags.debug and global_step_sess % flags.observe_freq == 1:
                            real_a_sess, real_b_sess, adjusted_a_sess, segment_a_sess, fake_b_sess, \
                                real_a_name_sess, real_b_name_sess = \
                                sess.run([real_a, real_b, adjusted_a, segment_a, fake_b,
                                          real_a_name, real_b_name],
                                         feed_dict={handle_a: training_a_handle, handle_b: training_b_handle})
                            print('Logging training images.')
                            dataset_parser.visualize_data(
                                real_a=real_a_sess,
                                real_b=real_b_sess,
                                adjusted_a=adjusted_a_sess,
                                segment_a=segment_a_sess,
                                fake_b=fake_b_sess,
                                shape=(1, 1),
                                global_step=global_step_sess,
                                logs_dir=dataset_parser.logs_image_train_dir,
                                real_a_name=real_a_name_sess[0].decode(),
                                real_b_name=real_b_name_sess[0].decode())
                        """
                        Saving the checkpoint
                        """
                        if global_step_sess % flags.save_freq == 0:
                            print('Saving model...')
                            saver.save(sess,
                                       dataset_parser.checkpoint_dir +
                                       '/model.ckpt',
                                       global_step=global_step_sess)

                    except tf.errors.OutOfRangeError:
                        print(
                            '----------------One epochs finished!----------------'
                        )
                        sess.run([
                            training_a_iterator.initializer,
                            training_b_iterator.initializer
                        ])
            elif flags.mode == 'test':
                from PIL import Image
                import scipy.ndimage.filters
                import scipy.io as sio
                import numpy as np
                print('Start Testing!')
                '''
                with tf.variable_scope('Input_port'):
                    val_a_handle = sess.run(validation_a_iterator.string_handle())
                    val_b_handle = sess.run(validation_b_iterator.string_handle())
                sess.run([validation_a_iterator.initializer, validation_b_iterator.initializer])
                '''
                with tf.variable_scope('Input_port'):
                    val_a_handle = sess.run(
                        validation_a_iterator.string_handle())
                    val_b_handle = sess.run(
                        validation_b_iterator.string_handle())
                sess.run([
                    validation_a_iterator.initializer,
                    validation_b_iterator.initializer
                ])
                feed_dict_test = {
                    handle_a: val_a_handle,
                    handle_b: val_b_handle
                }
                image_idx = 0
                while True:
                    try:
                        segment_a_ori_sess, real_a_name_sess, real_b_sess, real_a_sess, fake_b_sess = \
                            sess.run([segment_a_ori, real_a_name, real_b, real_a, fake_b], feed_dict=feed_dict_test)
                        segment_a_np = (np.squeeze(segment_a_ori_sess) +
                                        1.0) * 127.5
                        binary_a = np.zeros_like(segment_a_np, dtype=np.uint8)

                        # binary_a[segment_a_np > 127.5] = 255
                        binary_mean = np.mean(segment_a_np)
                        binary_a_high = np.mean(
                            segment_a_np[segment_a_np > binary_mean])
                        binary_a_low = np.mean(
                            segment_a_np[segment_a_np < binary_mean])
                        binary_a_ave = (binary_a_high + binary_a_low) / 2.0
                        segment_a_np_blur = scipy.ndimage.filters.gaussian_filter(
                            segment_a_np, sigma=11)
                        binary_a[segment_a_np_blur > binary_a_ave] = 255

                        sio.savemat(
                            '{}/{}.mat'.format(
                                dataset_parser.logs_mat_output_dir,
                                real_a_name_sess[0].decode()), {
                                    'pred': segment_a_np,
                                    'binary': binary_a
                                })

                        # -----------------------------------------------------------------------------
                        if image_idx % 1 == 0:
                            real_a_sess = np.squeeze(real_a_sess)
                            x_png = Image.fromarray(
                                real_a_sess.astype(np.uint8))
                            x_png.save('{}/{}_0_img.png'.format(
                                dataset_parser.logs_image_val_dir,
                                real_a_name_sess[0].decode()),
                                       format='PNG')
                            x_png = Image.fromarray(
                                segment_a_np.astype(np.uint8))
                            x_png.save('{}/{}_1_pred.png'.format(
                                dataset_parser.logs_image_val_dir,
                                real_a_name_sess[0].decode()),
                                       format='PNG')
                            x_png = Image.fromarray(binary_a.astype(np.uint8))
                            x_png.save('{}/{}_2_binary.png'.format(
                                dataset_parser.logs_image_val_dir,
                                real_a_name_sess[0].decode()),
                                       format='PNG')
                            fake_b_sess = np.squeeze(fake_b_sess)
                            x_png = Image.fromarray(
                                fake_b_sess.astype(np.uint8))
                            x_png.save('{}/{}_3_fake.png'.format(
                                dataset_parser.logs_image_val_dir,
                                real_a_name_sess[0].decode()),
                                       format='PNG')
                            real_b_sess = np.squeeze(real_b_sess)
                            x_png = Image.fromarray(
                                real_b_sess.astype(np.uint8))
                            x_png.save('{}/{}_4_gt.png'.format(
                                dataset_parser.logs_image_val_dir,
                                real_a_name_sess[0].decode()),
                                       format='PNG')

                        print(image_idx)
                        image_idx += 1
                    except tf.errors.OutOfRangeError:
                        print(
                            '----------------One epochs finished!----------------'
                        )
                        break