def build_model(self):
        """ Graph Input """
        # images
        if self.custom_dataset :
            Image_Data_Class = ImageData(self.img_size, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)

            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

            inputs_iterator = inputs.make_one_shot_iterator()

            self.inputs = inputs_iterator.get_next()

        else :
            self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images')

        # noises
        self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim], name='z')

        """ Loss Function """
        # output of D for real images
        real_logits = self.discriminator(self.inputs)

        # output of D for fake images
        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
            GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
        else :
            GP = 0

        # get loss for discriminator
        self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP

        # get loss for generator
        self.g_loss = generator_loss(self.gan_type, fake=fake_logits)

        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        # optimizers
        self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)

        """" Testing """
        # for test
        self.fake_images = self.generator(self.z, is_training=False, reuse=True)

        """ Summary """
        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
def _input_fn(channel):
    """Returns a Dataset for reading from a SageMaker PipeMode channel."""
    features = {
        'data': tf.FixedLenFeature([], tf.string),
        'labels': tf.FixedLenFeature([], tf.int64),
    }

    def parse(record):
        parsed = tf.parse_single_example(record, features)
        return ({
            'data': tf.decode_raw(parsed['data'], tf.float64)
        }, parsed['labels'])

    ds = PipeModeDataset(channel)
    if EPOCHS > 1:
        ds = ds.repeat(EPOCHS)
    ds = ds.prefetch(PREFETCH_SIZE)
    ds = ds.apply(map_and_batch(parse, batch_size=BATCH_SIZE,
                                num_parallel_batches=NUM_PARALLEL_BATCHES))

    return ds
示例#3
0
    def input_fn(params):
        """The actual input function."""
        if use_tpu:
            batch_size = params["batch_size"]
        else:
            batch_size = bsz

        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        d = tf.data.TFRecordDataset(input_file)
        if is_training:
            d = d.repeat()
            d = d.shuffle(buffer_size=100)

        d = d.apply(
            contrib_data.map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size,
                drop_remainder=drop_remainder))

        return d
示例#4
0
    def input_fn():
        """The actual input function."""
        micro_batch_size = opts["micro_batch_size"]

        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        d = tf.data.TFRecordDataset(input_file)
        """
        Not sure the repeat is in the right place.
        Only for training.
        """
        if is_training:
            buffer_size = opts['buffer_size']
            d = d.shuffle(buffer_size=buffer_size)
        d = d.apply(
            map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=micro_batch_size,
                drop_remainder=True))
        d = d.repeat()

        return d
示例#5
0
def _input_fn(channel):
    """Returns a Dataset for reading from a SageMaker PipeMode channel."""
    features = {
        "data": tf.FixedLenFeature([], tf.string),
        "labels": tf.FixedLenFeature([], tf.int64),
    }

    def parse(record):
        parsed = tf.parse_single_example(record, features)
        return ({
            "data": tf.decode_raw(parsed["data"], tf.float64)
        }, parsed["labels"])

    ds = PipeModeDataset(channel)
    if EPOCHS > 1:
        ds = ds.repeat(EPOCHS)
    ds = ds.prefetch(PREFETCH_SIZE)
    ds = ds.apply(
        map_and_batch(parse,
                      batch_size=BATCH_SIZE,
                      num_parallel_batches=NUM_PARALLEL_BATCHES))

    return ds
def _input_fn():
    features = {
        'data': tf.FixedLenFeature([], tf.string),
        'labels': tf.FixedLenFeature([], tf.int64),
    }

    def parse(record):
        parsed = tf.parse_single_example(record, features)
        return ({
            'data': tf.decode_raw(parsed['data'], tf.float64)
        }, parsed['labels'])

    ds = PipeModeDataset(config.channel, benchmark=True)
    if config.epochs > 1:
        ds = ds.repeat(config.epochs)
    if config.prefetch_size > 0:
        ds = ds.prefetch(config.prefetch_size)
    ds = ds.apply(
        map_and_batch(parse,
                      batch_size=config.batch_size,
                      num_parallel_batches=config.parallel_transform_calls))

    return ds
    def input_fn(params):
        """The actual input function."""
        batch_size = params["batch_size"]

        # For training, we want a lot of parallel reading and shuffling.
        # For eval, we want no shuffling and parallel reading doesn't matter.
        d = tf.data.Dataset.list_files(input_file, shuffle=False)
        d = d.apply(
            contrib_data.parallel_interleave(functools.partial(
                tf.data.TFRecordDataset,
                compression_type=FLAGS.compression_type),
                                             cycle_length=32,
                                             sloppy=is_training))
        if is_training:
            d = d.repeat()
            d = d.shuffle(buffer_size=100)

        d = d.apply(
            contrib_data.map_and_batch(
                lambda record: _decode_record(record, name_to_features),
                batch_size=batch_size,
                drop_remainder=drop_remainder))

        return d
示例#8
0
    def build_model(self):
        if self.phase == 'train':
            self.lr = tf.placeholder(tf.float32, name='learning_rate')
            self.manual_d_loss = tf.placeholder(tf.float32,
                                                name='manual_d_loss')
            """ Input Image"""
            Image_Data_Class = ImageData(self.img_size, self.img_ch,
                                         self.augment_flag)

            # All optimized variables
            # Discriminator
            #self.D_optim = []
            self.Discriminator_loss_total = []
            #self.D_loss = []

            # Generator
            #self.real_A = []
            #self.real_B = []
            #self.fake_A = []
            #self.fake_B = []
            #self.G_optim = []
            self.Generator_loss_total = []
            #self.G_loss = []

            # TODO(jhhuang): we should change how we load data here
            trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
            trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)

            # TODO(jhhuang): change GPU devices
            gpu_device = '/gpu:0'
            trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)
            )  #.apply(prefetch_to_device(gpu_device))
            trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)
            )  #.apply(prefetch_to_device(gpu_device))

            trainA_iterator = trainA.make_one_shot_iterator()
            trainB_iterator = trainB.make_one_shot_iterator()

            # TODO(jhhuang): revisit the for-loop
            reuse_vars = False
            available_gpus = self.get_available_gpus()
            num_gpus = len(available_gpus)
            print(str(available_gpus))
            for i in range(num_gpus):
                print("Current GPU: " + str(available_gpus[i]))
                with tf.device(
                        self.assign_to_device(available_gpus[i],
                                              ps_device='/cpu:0')):
                    #gpu_device = '/gpu:0'
                    #trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device)).as_numpy_iterator()
                    #trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device)).as_numpy_iterator()

                    domain_A = trainA_iterator.get_next()
                    domain_B = trainB_iterator.get_next()
                    """ Define Generator, Discriminator """
                    x_ab, cam_ab = self.generate_a2b(
                        domain_A, reuse=reuse_vars)  # real a
                    x_ba, cam_ba = self.generate_b2a(
                        domain_B, reuse=reuse_vars)  # real b

                    x_aba, _ = self.generate_b2a(x_ab, reuse=True)  # real b
                    x_bab, _ = self.generate_a2b(x_ba, reuse=True)  # real a

                    x_aa, cam_aa = self.generate_b2a(domain_A,
                                                     reuse=True)  # fake b
                    x_bb, cam_bb = self.generate_a2b(domain_B,
                                                     reuse=True)  # fake a

                    real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(
                        domain_A, domain_B, reuse=reuse_vars)
                    fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(
                        x_ba, x_ab)

                    tf.print("real_A_logit: ", real_A_logit)
                    tf.print("real_A_cam_logit: ", real_A_cam_logit)
                    tf.print("real_B_logit: ", real_B_logit)
                    tf.print("real_B_cam_logit: ", real_B_cam_logit)
                    tf.print("fake_A_logit: ", fake_A_logit)
                    tf.print("fake_A_cam_logit: ", fake_A_cam_logit)
                    tf.print("fake_B_logit: ", fake_B_logit)
                    tf.print("fake_B_cam_logit: ", fake_B_cam_logit)
                    """ Define Loss """
                    if self.gan_type.__contains__(
                            'wgan') or self.gan_type == 'dragan':
                        GP_A, GP_CAM_A = self.gradient_panalty(
                            real=domain_A, fake=x_ba, scope="discriminator_A")
                        GP_B, GP_CAM_B = self.gradient_panalty(
                            real=domain_B, fake=x_ab, scope="discriminator_B")
                    else:
                        GP_A, GP_CAM_A = 0, 0
                        GP_B, GP_CAM_B = 0, 0

                    G_ad_loss_A = (
                        generator_loss(self.gan_type, fake_A_logit) +
                        generator_loss(self.gan_type, fake_A_cam_logit))
                    G_ad_loss_B = (
                        generator_loss(self.gan_type, fake_B_logit) +
                        generator_loss(self.gan_type, fake_B_cam_logit))

                    self.fool_discriminator_counter = tf.placeholder(
                        tf.int32, name="fool_discriminator_counter")
                    fool_dis = tf.cast(
                        (self.fool_discriminator_counter % 100 == 0), tf.bool)
                    D_ad_loss_A = tf.cond(
                        fool_dis, lambda: (discriminator_loss(
                            self.gan_type, fake_A_logit, real_A_logit
                        ) + discriminator_loss(
                            self.gan_type, fake_A_cam_logit, real_A_cam_logit
                        ) + GP_A + GP_CAM_A), lambda: (discriminator_loss(
                            self.gan_type, real_A_logit,
                            fake_A_logit) + discriminator_loss(
                                self.gan_type, real_A_cam_logit,
                                fake_A_cam_logit) + GP_A + GP_CAM_A))
                    D_ad_loss_B = tf.cond(
                        fool_dis, lambda: (discriminator_loss(
                            self.gan_type, fake_B_logit, real_B_logit
                        ) + discriminator_loss(
                            self.gan_type, fake_B_cam_logit, real_B_cam_logit
                        ) + GP_B + GP_CAM_B), lambda: (discriminator_loss(
                            self.gan_type, real_B_logit,
                            fake_B_logit) + discriminator_loss(
                                self.gan_type, real_B_cam_logit,
                                fake_B_cam_logit) + GP_B + GP_CAM_B))
                    #else :
                    #    D_ad_loss_A = (discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) + discriminator_loss(self.gan_type, real_A_cam_logit, fake_A_cam_logit) + GP_A + GP_CAM_A)
                    #    D_ad_loss_B = (discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) + discriminator_loss(self.gan_type, real_B_cam_logit, fake_B_cam_logit) + GP_B + GP_CAM_B)
                    #self.fool_discriminator_counter += 1

                    reconstruction_A = L1_loss(x_aba,
                                               domain_A)  # reconstruction
                    reconstruction_B = L1_loss(x_bab,
                                               domain_B)  # reconstruction

                    identity_A = L1_loss(x_aa, domain_A)
                    identity_B = L1_loss(x_bb, domain_B)

                    cam_A = cam_loss(source=cam_ba, non_source=cam_aa)
                    cam_B = cam_loss(source=cam_ab, non_source=cam_bb)

                    Generator_A_gan = self.adv_weight * G_ad_loss_A
                    Generator_A_cycle = self.cycle_weight * reconstruction_B
                    Generator_A_identity = self.identity_weight * identity_A
                    Generator_A_cam = self.cam_weight * cam_A

                    Generator_B_gan = self.adv_weight * G_ad_loss_B
                    Generator_B_cycle = self.cycle_weight * reconstruction_A
                    Generator_B_identity = self.identity_weight * identity_B
                    Generator_B_cam = self.cam_weight * cam_B

                    Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
                    Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam

                    Discriminator_A_loss = self.adv_weight * D_ad_loss_A
                    Discriminator_B_loss = self.adv_weight * D_ad_loss_B

                    #self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss('generator')
                    self.Generator_loss_total.append(
                        Generator_A_loss + Generator_B_loss +
                        regularization_loss('generator'))
                    #self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss('discriminator') + self.manual_d_loss
                    self.Discriminator_loss_total.append(
                        Discriminator_A_loss + Discriminator_B_loss +
                        regularization_loss('discriminator') +
                        self.manual_d_loss)
                    reuse_vars = True
            """ Result Image """
            #self.fake_A = x_ba
            #self.fake_A.append(x_ba)
            #self.fake_B = x_ab
            #self.fake_B.append(x_ab)

            #self.real_A = self.domain_A
            #self.real_A.append(self.domain_A)
            #self.real_B = self.domain_B
            #self.real_B.append(self.domain_B)
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]
            """ Compute Average Loss """
            self.Generator_loss = tf.reduce_mean(self.Generator_loss_total)
            self.Discriminator_loss = tf.reduce_mean(
                self.Discriminator_loss_total)

            print("G_vars : ")
            for elemnt in G_vars:
                print("element from G_vars : " + str(element))

            self.G_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Generator_loss,
                                                      var_list=G_vars)
            #self.G_optim.append(tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars))
            self.D_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Discriminator_loss,
                                                      var_list=D_vars)
            #self.D_optim.append(tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars))
            """" Summary """
            self.all_G_loss = tf.summary.scalar("Generator_loss",
                                                self.Generator_loss)
            self.all_D_loss = tf.summary.scalar("Discriminator_loss",
                                                self.Discriminator_loss)

            self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
            self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan)
            self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle)
            self.G_A_identity = tf.summary.scalar("G_A_identity",
                                                  Generator_A_identity)
            self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam)

            self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
            self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan)
            self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle)
            self.G_B_identity = tf.summary.scalar("G_B_identity",
                                                  Generator_B_identity)
            self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam)

            self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
            self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)

            self.rho_var = []
            for var in tf.trainable_variables():
                if 'rho' in var.name:
                    self.rho_var.append(tf.summary.histogram(var.name, var))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_min",
                                          tf.reduce_min(var)))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_max",
                                          tf.reduce_max(var)))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_mean",
                                          tf.reduce_mean(var)))

            g_summary_list = [
                self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity,
                self.G_A_cam, self.G_B_loss, self.G_B_gan, self.G_B_cycle,
                self.G_B_identity, self.G_B_cam, self.all_G_loss
            ]

            g_summary_list.extend(self.rho_var)
            d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss]

            self.G_loss = tf.summary.merge(g_summary_list)
            #self.G_loss.append(tf.summary.merge(g_summary_list))
            self.D_loss = tf.summary.merge(d_summary_list)
            #self.D_loss.append(tf.summary.merge(d_summary_list))

        else:
            """ Test """
            self.test_domain_A = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_A')
            self.test_domain_B = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_B')

            self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
            self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
示例#9
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-s',
        '--sample',
        type=str,
        default='',
        help='Samples based on input song. Empty string means training.')
    parser.add_argument('-c',
                        '--concat',
                        dest='concat',
                        action='store_true',
                        help='Enable concatenation.')
    parser.add_argument('--no-concat',
                        dest='concat',
                        action='store_false',
                        help='Disable concatenation.')
    parser.set_defaults(concat=False)
    parser.add_argument('-r',
                        '--record',
                        dest='record',
                        action='store_true',
                        help='Enable recording.')
    parser.add_argument('--no-record',
                        dest='record',
                        action='store_false',
                        help='Disable recording.')
    parser.set_defaults(
        record=False)  # Warning: Windows kills python if enabled.
    args = parser.parse_args()
    sampling = args.sample != ''

    if not os.path.exists('Checkpoints_v1'):
        os.makedirs('Checkpoints_v1')
    if not os.path.exists('Logs_v1'):
        os.makedirs('Logs_v1')
    if not os.path.exists('Samples_v1'):
        os.makedirs('Samples_v1')
    if not os.path.exists('Timelines_v1'):
        os.makedirs('Timelines_v1')

    filename = 'Dataset/cond_dataset.tfrecord'
    dataset = tf.data.TFRecordDataset(filename, num_parallel_reads=8)

    def _parse(example_proto):
        feature = {
            'label': tf.FixedLenFeature([], tf.int64),
            'data': tf.FixedLenFeature([], tf.string)
        }
        parsed = tf.parse_single_example(example_proto, feature)
        data = tf.decode_raw(parsed['data'], tf.uint8)
        label = tf.cast(parsed['label'], tf.uint8)
        data = tf.py_func(func=np.unpackbits, inp=[data], Tout=tf.uint8)
        label = tf.py_func(func=np.unpackbits, inp=[label], Tout=tf.uint8)
        data = tf.cast(data, tf.float32)
        label = tf.cast(label, tf.float32)
        data = tf.reshape(data, [CHANNEL_NUM, CLASS_NUM, INPUT_LENGTH])
        label.set_shape([8])
        label = label[:CHANNEL_NUM]
        label = tf.expand_dims(tf.expand_dims(label, axis=-1), axis=-1)
        data = data * 2 - 1
        return {'data': data, 'label': label}

    dataset = dataset.apply(data.shuffle_and_repeat(buffer_size=16384))
    dataset = dataset.apply(
        data.map_and_batch(_parse,
                           batch_size=BATCH_SIZE,
                           num_parallel_batches=16,
                           drop_remainder=True))
    dataset = dataset.prefetch(32)
    dataset = dataset.apply(data.prefetch_to_device('/gpu:0'))
    iterator = dataset.make_one_shot_iterator()
    next_data = iterator.get_next()
    real_input_3 = next_data['data']
    label = next_data['label']

    input_noise = tf.placeholder(dtype=tf.float32,
                                 shape=[None, NOISE_LENGTH],
                                 name='input_noise')

    real_input_2 = tf.layers.average_pooling2d(inputs=real_input_3,
                                               pool_size=2,
                                               strides=2,
                                               padding='same',
                                               data_format='channels_first',
                                               name='real_input_2')
    real_input_2 -= tf.reduce_min(real_input_2)
    real_input_2 /= (tf.reduce_max(real_input_2) + EPSILON)
    real_input_2 = 2 * real_input_2 - 1
    real_input_1 = tf.layers.average_pooling2d(inputs=real_input_2,
                                               pool_size=2,
                                               strides=2,
                                               padding='same',
                                               data_format='channels_first',
                                               name='real_input_1')
    real_input_1 -= tf.reduce_min(real_input_1)
    real_input_1 /= (tf.reduce_max(real_input_1) + EPSILON)
    real_input_1 = 2 * real_input_1 - 1
    train = tf.placeholder(dtype=tf.bool, name='train')
    # shape: [None, 6, 64, 256]
    encode = encoder(inputs=real_input_3,
                     update_collection=SPECTRAL_UPDATE_OPS,
                     train=train)
    # shape: [None, 64]
    tf.summary.histogram('encode', encode)
    print('Encoder set')

    real_input_3_image = tf.expand_dims(real_input_3[:1],
                                        axis=-1,
                                        name='real_input_3_image')
    # shape: [1, 6, 128, 512, 1]
    real_input_2_image = tf.expand_dims(real_input_2[:1],
                                        axis=-1,
                                        name='real_input_2_image')
    # shape: [1, 6, 64, 256, 1]
    real_input_1_image = tf.expand_dims(real_input_1[:1],
                                        axis=-1,
                                        name='real_input_1_image')
    # shape: [1, 6, 32, 128, 1]
    for i in range(CHANNEL_NUM):
        tf.summary.image('real_input_1_%d' % i, real_input_1_image[:, i])
        tf.summary.image('real_input_2_%d' % i, real_input_2_image[:, i])
        tf.summary.image('real_input_3_%d' % i, real_input_3_image[:, i])

    shared_output = shared_gen(noise=input_noise,
                               label=label,
                               update_collection=SPECTRAL_UPDATE_OPS,
                               train=train)
    # shape: [None, 64, 16, 64]
    gen1 = generator1(inputs=shared_output,
                      label=label,
                      update_collection=SPECTRAL_UPDATE_OPS,
                      train=train)
    output_gen1 = process(gen1, 1, train, SPECTRAL_UPDATE_OPS) * label
    # shape: [None, 6, 32, 128]
    gen2 = generator2(inputs=gen1,
                      label=label,
                      update_collection=SPECTRAL_UPDATE_OPS,
                      train=train)
    output_gen2 = process(gen2, 2, train, SPECTRAL_UPDATE_OPS) * label
    # shape: [None, 6, 64, 256]
    gen3 = generator3(inputs=gen2,
                      label=label,
                      update_collection=SPECTRAL_UPDATE_OPS,
                      train=train)
    output_gen3 = process(gen3, 3, train, SPECTRAL_UPDATE_OPS) * label
    # shape: [None, 6, 128, 512]
    print('Generators set')
    dis1_real = discriminator1(inputs=real_input_1,
                               encode=encode,
                               update_collection=SPECTRAL_UPDATE_OPS)
    dis1_gen = discriminator1(inputs=output_gen1,
                              encode=encode,
                              update_collection=NO_OPS)
    dis2_real = discriminator2(inputs=real_input_2,
                               encode=encode,
                               update_collection=SPECTRAL_UPDATE_OPS)
    dis2_gen = discriminator2(inputs=output_gen2,
                              encode=encode,
                              update_collection=NO_OPS)
    dis3_real = discriminator3(inputs=real_input_3,
                               encode=encode,
                               update_collection=SPECTRAL_UPDATE_OPS)
    dis3_gen = discriminator3(inputs=output_gen3,
                              encode=encode,
                              update_collection=NO_OPS)
    print('Discriminators set')
    loss_dis1 = tf.reduce_mean(dis1_gen - dis1_real) + lipschitz_penalty(
        real=real_input_1,
        gen=output_gen1,
        encode=encode,
        discriminator=discriminator1) + DRIFT * tf.reduce_mean(
            tf.square(dis1_real))
    loss_gen1 = -tf.reduce_mean(dis1_gen)
    loss_dis2 = tf.reduce_mean(dis2_gen - dis2_real) + lipschitz_penalty(
        real=real_input_2,
        gen=output_gen2,
        encode=encode,
        discriminator=discriminator2) + DRIFT * tf.reduce_mean(
            tf.square(dis2_real))
    loss_gen2 = -tf.reduce_mean(dis2_gen)
    loss_dis3 = tf.reduce_mean(dis3_gen - dis3_real) + lipschitz_penalty(
        real=real_input_3,
        gen=output_gen3,
        encode=encode,
        discriminator=discriminator3) + DRIFT * tf.reduce_mean(
            tf.square(dis3_real))
    loss_gen3 = -tf.reduce_mean(dis3_gen)
    loss_gen = tf.add_n([loss_gen1, loss_gen2, loss_gen3]) / 3
    tf.summary.scalar('loss_dis1', loss_dis1)
    tf.summary.scalar('loss_gen1', loss_gen1)
    tf.summary.scalar('dis1_real', tf.reduce_mean(dis1_real))
    tf.summary.scalar('loss_dis2', loss_dis2)
    tf.summary.scalar('loss_gen2', loss_gen2)
    tf.summary.scalar('dis2_real', tf.reduce_mean(dis2_real))
    tf.summary.scalar('loss_dis3', loss_dis3)
    tf.summary.scalar('loss_gen3', loss_gen3)
    tf.summary.scalar('dis3_real', tf.reduce_mean(dis3_real))
    tf.summary.scalar('loss_gen', loss_gen)
    print('Losses set')
    gen_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                scope='Shared_generator')
    #gen_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Encoder')
    for i in range(CHANNEL_NUM):
        gen_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='Generator1_%d' % i)
        gen_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='Generator2_%d' % i)
        gen_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='Generator3_%d' % i)
    dis1_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 scope='Discriminator1')
    dis1_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='Encoder')
    dis2_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 scope='Discriminator2')
    dis2_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='Encoder')
    dis3_var = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                 scope='Discriminator3')
    dis3_var += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                  scope='Encoder')
    gen_extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                             scope='Shared_generator')
    #gen_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='Encoder')
    for i in range(CHANNEL_NUM):
        gen_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                  scope='Generator1_%d' % i)
        gen_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                  scope='Generator2_%d' % i)
        gen_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                  scope='Generator3_%d' % i)
    dis1_extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                              scope='Discriminator1')
    dis1_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope='Encoder')
    dis2_extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                              scope='Discriminator2')
    dis2_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope='Encoder')
    dis3_extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                              scope='Discriminator3')
    dis3_extra_update_ops += tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                               scope='Encoder')
    spectral_norm_update_ops = tf.get_collection(SPECTRAL_UPDATE_OPS)
    with tf.name_scope('optimizers'):
        with tf.control_dependencies(dis1_extra_update_ops):
            dis1_train = tf.train.AdamOptimizer(learning_rate=0.0004,
                                                beta1=0.5,
                                                beta2=0.9).minimize(
                                                    loss=loss_dis1,
                                                    var_list=dis1_var,
                                                    name='dis1_train')
        print('dis1_train setup')
        with tf.control_dependencies(dis2_extra_update_ops):
            dis2_train = tf.train.AdamOptimizer(learning_rate=0.0004,
                                                beta1=0.5,
                                                beta2=0.9).minimize(
                                                    loss=loss_dis2,
                                                    var_list=dis2_var,
                                                    name='dis2_train')
        print('dis2_train setup')
        with tf.control_dependencies(dis3_extra_update_ops):
            dis3_train = tf.train.AdamOptimizer(learning_rate=0.0004,
                                                beta1=0.5,
                                                beta2=0.9).minimize(
                                                    loss=loss_dis3,
                                                    var_list=dis3_var,
                                                    name='dis3_train')
        print('dis3_train setup')
        with tf.control_dependencies(gen_extra_update_ops):
            gen_train = tf.train.AdamOptimizer(learning_rate=0.0001,
                                               beta1=0.5,
                                               beta2=0.9).minimize(
                                                   loss=loss_gen,
                                                   var_list=gen_var,
                                                   name='gen_train')
        print('gen_train setup')
    print('Optimizers set')
    gpu_options = tf.GPUOptions(allow_growth=True, allocator_type='BFC')
    config = tf.ConfigProto(allow_soft_placement=True, gpu_options=gpu_options)
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        merged = tf.summary.merge_all()
        sess.run(tf.global_variables_initializer())
        if tf.train.latest_checkpoint('Checkpoints_v1') is not None:
            print('Restoring...')
            saver.restore(sess, tf.train.latest_checkpoint('Checkpoints_v1'))
        feed_dict = {input_noise: None, train: True}
        print('preparing complete')
        if sampling:
            feed_dict = {input_noise: None, real_input_3: None, train: False}
            path = args.sample
            try:
                feed_dict[real_input_3] = roll(path)[:BATCH_SIZE]
            except:
                print('Error while opening file.')
                exit()
            feed_dict[input_noise] = get_noise([BATCH_SIZE, NOISE_LENGTH])
            samples = sess.run(output_gen3, feed_dict=feed_dict)
            path = path.split('/')[-1]
            if not os.path.exists('Samples_v1/sample_%s'):
                os.mkdir('Samples_v1/sample_%s' % path)
            np.save(file='Samples_v1/sample_%s' % path + '/%s' % path,
                    arr=samples)
            unpack_sample(name='Samples_v1/sample_%s' % path +
                          '/%s.npy' % path,
                          concat=args.concat)
            exit()
        writer = tf.summary.FileWriter('Logs_v1', sess.graph)
        run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
        epoch_num = 100001
        for train_count in tqdm(range(epoch_num), smoothing=0.7):
            for i in range(TRAIN_RATIO_DIS):
                feed_dict[input_noise] = get_noise([BATCH_SIZE, NOISE_LENGTH])
                feed_dict[train] = True
                *_, loss_val_dis1, loss_val_dis2, loss_val_dis3 = sess.run(
                    [
                        dis1_train, dis2_train, dis3_train, loss_dis1,
                        loss_dis2, loss_dis3
                    ],
                    feed_dict=feed_dict,
                    options=run_options)
            for i in range(TRAIN_RATIO_GEN):
                feed_dict[input_noise] = get_noise([BATCH_SIZE, NOISE_LENGTH])
                feed_dict[train] = True
                summary, _, loss_val_gen = sess.run(
                    [merged, gen_train, loss_gen],
                    feed_dict=feed_dict,
                    options=run_options)
            sess.run(spectral_norm_update_ops)
            writer.add_summary(summary, train_count)
            tqdm.write('%06d' % train_count, end=' ')
            tqdm.write('Discriminator1 loss : %.7f' % loss_val_dis1, end=' ')
            tqdm.write('Discriminator2 loss : %.7f' % loss_val_dis2, end=' ')
            tqdm.write('Discriminator3 loss : %.7f' % loss_val_dis3, end=' ')
            tqdm.write('Generator loss : %.7f' % loss_val_gen)
            if train_count % 1000 == 0:
                feed_dict[input_noise] = get_noise([BATCH_SIZE, NOISE_LENGTH])
                feed_dict[train] = False
                samples = sess.run(output_gen3, feed_dict=feed_dict)
                np.save(file='Samples_v1/song_%06d' % train_count, arr=samples)
                unpack_sample('Samples_v1/song_%06d' % train_count)
                save_path = saver.save(
                    sess, 'Checkpoints_v1/song_%06d' % train_count + '.ckpt')
                tqdm.write('Model Saved: %s' % save_path)
                if args.record:
                    trace_options = tf.RunOptions(
                        trace_level=tf.RunOptions.FULL_TRACE)  # pylint: disable=E1101
                    run_metadata = tf.RunMetadata()
                    sess.run([dis1_train, dis2_train, dis3_train, gen_train],
                             feed_dict=feed_dict,
                             options=trace_options,
                             run_metadata=run_metadata)
                    writer.add_run_metadata(run_metadata,
                                            'run_%d' % train_count)
                    tl = timeline.Timeline(run_metadata.step_stats)  # pylint: disable=E1101
                    ctf = tl.generate_chrome_trace_format()
                    with open('Timelines_v1/timeline_%d.json' % train_count,
                              'w') as f:
                        f.write(ctf)
        writer.close()
示例#10
0
    def build_model(self):
        label_fix_onehot_list = []
        """ Input Image"""
        if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
            img_class = ImageData_celebA(self.img_height, self.img_width,
                                         self.img_ch, self.dataset_path,
                                         self.label_list, self.augment_flag)
            img_class.preprocess(self.phase)

        else:
            img_class = Image_data(self.img_height, self.img_width,
                                   self.img_ch, self.dataset_path,
                                   self.label_list, self.augment_flag)
            img_class.preprocess()

            label_fix_onehot_list = img_class.label_onehot_list
            label_fix_onehot_list = tf.tile(
                tf.expand_dims(label_fix_onehot_list, axis=1),
                [1, self.batch_size, 1])

        dataset_num = len(img_class.image)
        print("Dataset number : ", dataset_num)

        if self.phase == 'train':
            self.lr = tf.placeholder(tf.float32, name='learning_rate')

            if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
                img_and_label = tf.data.Dataset.from_tensor_slices(
                    (img_class.image, img_class.label,
                     img_class.train_label_onehot_list))
            else:
                img_and_label = tf.data.Dataset.from_tensor_slices(
                    (img_class.image, img_class.label))

            gpu_device = '/gpu:0'
            img_and_label = img_and_label.apply(
                shuffle_and_repeat(dataset_num)).apply(
                    map_and_batch(img_class.image_processing,
                                  self.batch_size,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            img_and_label_iterator = img_and_label.make_one_shot_iterator()

            if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
                self.x_real, label_org, label_fix_onehot_list = img_and_label_iterator.get_next(
                )
                label_trg = tf.random_shuffle(
                    label_org)  # Target domain labels
                label_fix_onehot_list = tf.transpose(label_fix_onehot_list,
                                                     perm=[1, 0, 2])
            else:
                self.x_real, label_org = img_and_label_iterator.get_next()
                label_trg = tf.random_shuffle(
                    label_org)  # Target domain labels
            """ Define Generator, Discriminator """
            fake_style_code = tf.random_normal(
                shape=[self.batch_size, self.style_dim])
            x_fake = self.generator(self.x_real, label_trg,
                                    fake_style_code)  # real a

            recon_style_code = tf.random_normal(
                shape=[self.batch_size, self.style_dim])
            x_recon = self.generator(x_fake,
                                     label_org,
                                     recon_style_code,
                                     reuse=True)  # real b

            real_logit, real_cls, _ = self.discriminator(self.x_real)
            fake_logit, fake_cls, fake_noise = self.discriminator(x_fake,
                                                                  reuse=True)
            """ Define Loss """
            if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
                GP = self.gradient_panalty(real=self.x_real, fake=x_fake)
            else:
                GP = 0

            g_adv_loss = self.adv_weight * generator_loss(
                self.gan_type, fake_logit)
            g_cls_loss = self.cls_weight * classification_loss(logit=fake_cls,
                                                               label=label_trg)
            g_rec_loss = self.rec_weight * L1_loss(self.x_real, x_recon)
            g_noise_loss = self.noise_weight * L1_loss(fake_style_code,
                                                       fake_noise)

            d_adv_loss = self.adv_weight * discriminator_loss(
                self.gan_type, real_logit, fake_logit) + GP
            d_cls_loss = self.cls_weight * classification_loss(logit=real_cls,
                                                               label=label_org)
            d_noise_loss = self.noise_weight * L1_loss(fake_style_code,
                                                       fake_noise)

            self.d_loss = d_adv_loss + d_cls_loss + d_noise_loss
            self.g_loss = g_adv_loss + g_cls_loss + g_rec_loss + g_noise_loss
            """ Result Image """
            if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
                self.x_fake_list = []

                for _ in range(self.num_style):
                    random_style_code = tf.random_normal(
                        shape=[self.batch_size, self.style_dim])
                    self.x_fake_list.append(
                        tf.map_fn(lambda c: self.generator(
                            self.x_real, c, random_style_code, reuse=True),
                                  label_fix_onehot_list,
                                  dtype=tf.float32))

            else:
                self.x_fake_list = []

                for _ in range(self.num_style):
                    random_style_code = tf.random_normal(
                        shape=[self.batch_size, self.style_dim])
                    self.x_fake_list.append(
                        tf.map_fn(lambda c: self.generator(
                            self.x_real, c, random_style_code, reuse=True),
                                  label_fix_onehot_list,
                                  dtype=tf.float32))
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.g_optimizer = tf.train.AdamOptimizer(self.lr,
                                                      beta1=0.5,
                                                      beta2=0.999).minimize(
                                                          self.g_loss,
                                                          var_list=G_vars)
            self.d_optimizer = tf.train.AdamOptimizer(self.lr,
                                                      beta1=0.5,
                                                      beta2=0.999).minimize(
                                                          self.d_loss,
                                                          var_list=D_vars)
            """" Summary """
            self.Generator_loss = tf.summary.scalar("g_loss", self.g_loss)
            self.Discriminator_loss = tf.summary.scalar("d_loss", self.d_loss)

            self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
            self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss)
            self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss)
            self.g_noise_loss = tf.summary.scalar("g_noise_loss", g_noise_loss)

            self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss)
            self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss)
            self.d_noise_loss = tf.summary.scalar("d_noise_loss", d_noise_loss)

            self.g_summary_loss = tf.summary.merge([
                self.Generator_loss, self.g_adv_loss, self.g_cls_loss,
                self.g_rec_loss, self.g_noise_loss
            ])
            self.d_summary_loss = tf.summary.merge([
                self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss,
                self.d_noise_loss
            ])

        else:
            """ Test """
            if self.dataset_name == 'celebA-HQ' or self.dataset_name == 'celebA':
                img_and_label = tf.data.Dataset.from_tensor_slices(
                    (img_class.test_image, img_class.test_label,
                     img_class.test_label_onehot_list))
                dataset_num = len(img_class.test_image)

                gpu_device = '/gpu:0'
                img_and_label = img_and_label.apply(
                    shuffle_and_repeat(dataset_num)).apply(
                        map_and_batch(img_class.image_processing,
                                      batch_size=self.batch_size,
                                      num_parallel_batches=16,
                                      drop_remainder=True)).apply(
                                          prefetch_to_device(gpu_device, None))

                img_and_label_iterator = img_and_label.make_one_shot_iterator()

                self.x_test, _, self.test_label_fix_onehot_list = img_and_label_iterator.get_next(
                )
                self.test_img_placeholder = tf.placeholder(
                    tf.float32,
                    [1, self.img_height, self.img_width, self.img_ch])
                self.test_label_fix_placeholder = tf.placeholder(
                    tf.float32, [self.c_dim, 1, self.c_dim])

                self.custom_image = tf.placeholder(
                    tf.float32,
                    [1, self.img_height, self.img_width, self.img_ch],
                    name='custom_image')  # Custom Image
                custom_label_fix_onehot_list = tf.transpose(
                    np.expand_dims(label2onehot(self.label_list), axis=0),
                    perm=[1, 0, 2])  # [c_dim, bs, c_dim]
                """ Test Image """
                test_random_style_code = tf.random_normal(
                    shape=[1, self.style_dim])

                self.x_test_fake_list = tf.map_fn(
                    lambda c: self.generator(self.test_img_placeholder, c,
                                             test_random_style_code),
                    self.test_label_fix_placeholder,
                    dtype=tf.float32)
                self.custom_fake_image = tf.map_fn(lambda c: self.generator(
                    self.custom_image, c, test_random_style_code, reuse=True),
                                                   custom_label_fix_onehot_list,
                                                   dtype=tf.float32)

            else:
                self.custom_image = tf.placeholder(
                    tf.float32,
                    [1, self.img_height, self.img_width, self.img_ch],
                    name='custom_image')  # Custom Image
                custom_label_fix_onehot_list = tf.transpose(
                    np.expand_dims(label2onehot(self.label_list), axis=0),
                    perm=[1, 0, 2])  # [c_dim, bs, c_dim]

                test_random_style_code = tf.random_normal(
                    shape=[1, self.style_dim])
                self.custom_fake_image = tf.map_fn(
                    lambda c: self.generator(self.custom_image, c,
                                             test_random_style_code),
                    custom_label_fix_onehot_list,
                    dtype=tf.float32)
示例#11
0
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')
        """ Input Image"""
        Image_Data_Class = ImageData(self.img_size, self.img_ch,
                                     self.augment_flag)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)

        gpu_device = '/gpu:0'
        trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(Image_Data_Class.image_processing,
                          self.batch_size,
                          num_parallel_batches=16,
                          drop_remainder=True)).apply(
                              prefetch_to_device(gpu_device, self.batch_size))
        trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(Image_Data_Class.image_processing,
                          self.batch_size,
                          num_parallel_batches=16,
                          drop_remainder=True)).apply(
                              prefetch_to_device(gpu_device, self.batch_size))

        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()

        self.identity_A = trainA_iterator.get_next()
        self.shape_A = trainA_iterator.get_next()
        self.other_A = trainA_iterator.get_next()

        self.shape_B = trainB_iterator.get_next()

        self.test_identity_A = tf.placeholder(
            tf.float32, [1, self.img_size, self.img_size, self.img_ch],
            name='test_identity_A')
        self.test_shape_B = tf.placeholder(
            tf.float32, [1, self.img_size, self.img_size, self.img_ch],
            name='test_shape_B')
        """ Define Generator, Discriminator """
        self.fake_same = self.generator(x_identity=self.identity_A,
                                        x_shape=self.shape_A)
        self.fake_diff = self.generator(x_identity=self.identity_A,
                                        x_shape=self.shape_B,
                                        reuse=True)
        fake_diff_shape = self.generator(x_identity=self.shape_B,
                                         x_shape=self.fake_diff,
                                         reuse=True)
        fake_diff_identity = self.generator(x_identity=self.fake_diff,
                                            x_shape=self.shape_B,
                                            reuse=True)

        real_logit = self.discriminator(x_identity=self.identity_A,
                                        x=self.other_A)
        fake_logit = self.discriminator(x_identity=self.identity_A,
                                        x=self.fake_diff,
                                        reuse=True)
        """ Define Loss """
        g_identity_loss = self.adv_weight * generator_loss(
            self.gan_type, fake_logit) * 64
        g_shape_loss_same = self.L1_weight * L1_loss(self.fake_same,
                                                     self.shape_A)
        g_shape_loss_diff_shape = self.L1_weight * L1_loss(
            fake_diff_shape, self.shape_B) + self.L1_weight * L1_loss(
                self.fake_diff, self.shape_B)
        g_shape_loss_diff_identity = self.L1_weight * L1_loss(
            fake_diff_identity, self.fake_diff)

        self.Generator_loss = g_identity_loss + g_shape_loss_same + g_shape_loss_diff_shape + g_shape_loss_diff_identity
        self.Discriminator_loss = self.adv_weight * discriminator_loss(
            self.gan_type, real=real_logit, fake=fake_logit)
        """ Result Image """
        self.test_fake = self.generator(x_identity=self.test_identity_A,
                                        x_shape=self.test_shape_B,
                                        reuse=True)
        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.G_optim = tf.train.AdamOptimizer(
            self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss,
                                                      var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(
            self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss,
                                                      var_list=D_vars)
        """" Summary """
        self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
        self.D_loss = tf.summary.scalar("Discriminator_loss",
                                        self.Discriminator_loss)

        self.G_identity = tf.summary.scalar("G_identity", g_identity_loss)
        self.G_shape_loss_same = tf.summary.scalar("G_shape_loss_same",
                                                   g_shape_loss_same)
        self.G_shape_loss_diff_shape = tf.summary.scalar(
            "G_shape_loss_diff_shape", g_shape_loss_diff_shape)
        self.G_shape_loss_diff_identity = tf.summary.scalar(
            "G_shape_loss_diff_identity", g_shape_loss_diff_identity)

        self.G_loss_merge = tf.summary.merge([
            self.G_loss, self.G_identity, self.G_shape_loss_same,
            self.G_shape_loss_diff_shape, self.G_shape_loss_diff_identity
        ])
        self.D_loss_merge = tf.summary.merge([self.D_loss])
示例#12
0
def get_squad_dataset(opts, is_training):
    seq_length = opts['seq_length']
    name_to_features = {
        "unique_ids": tf.FixedLenFeature([], tf.int64),
        "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
        "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
        "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
    }
    if is_training:
        name_to_features["start_positions"] = tf.FixedLenFeature([], tf.int64)
        name_to_features["end_positions"] = tf.FixedLenFeature([], tf.int64)

    micro_batch_size = opts['micro_batch_size']
    tfrecord_dir = opts['tfrecord_dir']
    opts['global_batch_size'] = opts['replicas'] * opts[
        'micro_batch_size'] * opts['gradient_accumulation_count']
    if opts["version_2_with_negative"]:
        base_name_train = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_SQuAD20"
        base_name_eval = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_{opts['global_batch_size']}_SQuAD20"
    else:
        base_name_train = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_SQuAD11"
        base_name_eval = f"{opts['seq_length']}_{opts['doc_stride']}_{opts['max_query_length']}_{opts['global_batch_size']}_SQuAD11"
    if is_training:
        filename = os.path.join(tfrecord_dir,
                                base_name_train + "_train.tfrecord")
        input_file = opts['train_file']
    else:
        filename = os.path.join(tfrecord_dir,
                                base_name_eval + "_eval.tfrecord")
        input_file = opts['predict_file']

    if not os.path.exists(filename):
        tf.logging.info(f'Preprocessing SQuAD input file: {filename}')

        cache_path = os.path.dirname(filename)
        if not os.path.exists(cache_path):
            tf.logging.info(f'Creating SQuAD cache in {cache_path}')
            os.makedirs(cache_path)

        examples = read_squad_examples(input_file=input_file,
                                       opts=opts,
                                       is_training=is_training)

        writer = FeatureWriter(filename=filename, is_training=is_training)
        features = []
        tokenizer = tokenization.FullTokenizer(
            vocab_file=opts['vocab_file'], do_lower_case=opts['do_lower_case'])

        def append_feature(feature):
            features.append(feature)
            writer.process_feature(feature)

        if is_training:
            # Don't need to pad for repeated dataset
            padding_to = 1
        else:
            # Padding for pipeline depth
            padding_to = opts['replicas'] * opts['micro_batch_size'] * opts[
                'gradient_accumulation_count']

        num_of_features = convert_examples_to_features(
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=opts['seq_length'],
            doc_stride=opts['doc_stride'],
            max_query_length=opts['max_query_length'],
            is_training=is_training,
            output_fn=append_feature,
            padding_to=padding_to)

        if is_training:
            metadatafile = os.path.join(
                tfrecord_dir, "train_" + base_name_train + ".metadata")
            if not os.path.exists(metadatafile):
                tf.logging.info(
                    f'Logging converted no. of SQuAD train features in {metadatafile}'
                )
                with open(metadatafile, 'w') as f:
                    f.write(str(num_of_features) + '\n')
        else:
            metadatafile = os.path.join(tfrecord_dir,
                                        "eval_" + base_name_eval + ".metadata")
            if not os.path.exists(metadatafile):
                tf.logging.info(
                    f'Logging converted no. of SQuAD eval features in {metadatafile}'
                )
                with open(metadatafile, 'w') as f:
                    f.write(str(num_of_features) + '\n')
        writer.close()

    d = tf.data.TFRecordDataset(filename)

    if is_training:
        if opts['distributed_worker_count'] > 1:
            d = d.shard(num_shards=opts['distributed_worker_count'],
                        index=opts['distributed_worker_index'])
        d = d.shuffle(buffer_size=100000, reshuffle_each_iteration=False)
        d = d.repeat()

    if opts['generated_data']:
        d = d.repeat()

    d = d.apply(
        map_and_batch(lambda record: _decode_record(record, name_to_features),
                      batch_size=micro_batch_size,
                      drop_remainder=True))

    return d
示例#13
0
    def build_model(self, A):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')

        """ Input Image"""
        Image_data_class = ImageData(load_size=self.img_size, channels=self.img_ch, data_path=self.dataset_path, selected_attrs=self.selected_attrs, augment_flag=self.augment_flag)
        Image_data_class.preprocess()

        train_dataset_num = len(Image_data_class.train_dataset)
        test_dataset_num = len(Image_data_class.test_dataset)

        train_dataset = tf.data.Dataset.from_tensor_slices((Image_data_class.train_dataset, Image_data_class.train_dataset_label, Image_data_class.train_dataset_fix_label))
        test_dataset = tf.data.Dataset.from_tensor_slices((Image_data_class.test_dataset, Image_data_class.test_dataset_label, Image_data_class.test_dataset_fix_label))

        gpu_device = '/gpu:0'
        train_dataset = train_dataset.\
            apply(shuffle_and_repeat(train_dataset_num)).\
            apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
            apply(prefetch_to_device(gpu_device, self.batch_size))

        test_dataset = test_dataset.\
            apply(shuffle_and_repeat(test_dataset_num)).\
            apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
            apply(prefetch_to_device(gpu_device, self.batch_size))

        train_dataset_iterator = train_dataset.make_one_shot_iterator()
        test_dataset_iterator = test_dataset.make_one_shot_iterator()


        self.x_real, label_org, label_fix_list = train_dataset_iterator.get_next() # Input image / Original domain labels
        label_trg = tf.random_shuffle(label_org) # Target domain labels
        label_fix_list = tf.transpose(label_fix_list, perm=[1, 0, 2])

        self.x_test, test_label_org, test_label_fix_list = test_dataset_iterator.get_next()  # Input image / Original domain labels
        test_label_fix_list = tf.transpose(test_label_fix_list, perm=[1, 0, 2])

        self.custom_image = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='custom_image') # Custom Image
        custom_label_fix_list = tf.transpose(create_labels(self.custom_label, self.selected_attrs), perm=[1, 0, 2])

        """ Define Generator, Discriminator """
        # binary label transformation
        dist = tf_contrib.distributions.Categorical(probs=[0.25, 0.5, 0.25])
        hat_c = tf.cast(dist.sample([self.batch_size, self.c_dim]) - 1, dtype ='float32')

        x_fake, w_fake = self.generator(self.x_real, hat_c) # real a
        x_recon, w_recon = self.generator(x_fake, -hat_c, reuse=True) # real b

        real_logit, real_cls = self.discriminator(self.x_real)
        fake_logit, fake_cls = self.discriminator(x_fake, reuse=True)

        # warp cycle
        A_fake = tf_contrib.image.dense_image_warp(A, w_fake)
        A_cycle = tf_contrib.image.dense_image_warp(A_fake, w_recon)
        
        """ Define Loss """
        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
            GP = self.gradient_panalty(real=self.x_real, fake=x_fake)
        else :
            GP = 0

        g_adv_loss = generator_loss(loss_func=self.gan_type, fake=fake_logit)
        g_cls_loss = binary_label_loss(hat_c, fake_cls, self.c_dim)

        warp_cycle_loss = tf.reduce_mean((A_cycle - A)** 2, axis=[1,2,3])
        g_rec_loss = tf.reduce_mean(warp_cycle_loss)

        smooth_loss = tf.reduce_mean(total_variation(w_fake))
        d_adv_loss = discriminator_loss(loss_func=self.gan_type, real=real_logit, fake=fake_logit)
        d_cls_loss = classification_loss(logit=real_cls, label=label_org)

        self.d_loss = self.adv_weight * d_adv_loss + self.gp_weight * GP + self.cls_weight * d_cls_loss
        self.g_loss = self.adv_weight * g_adv_loss + self.cls_weight * g_cls_loss + self.rec_weight * g_rec_loss + self.smooth_weight * smooth_loss


        """ Result Image """
        self.x_fake_list = tf.map_fn(lambda x : self.generator(self.x_real, x, reuse=True)[0], label_fix_list, dtype=tf.float32)


        """ Test Image """
        self.x_test_fake_list = tf.map_fn(lambda x : self.generator(self.x_test, x, reuse=True)[0], test_label_fix_list, dtype=tf.float32)
        self.custom_fake_image = tf.map_fn(lambda x : self.generator(self.custom_image, x, reuse=True)[0], custom_label_fix_list, dtype=tf.float32)


        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars)
        self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss, var_list=D_vars)


        """" Summary """
        self.Generator_loss = tf.summary.scalar("Generator_loss", self.g_loss)
        self.Discriminator_loss = tf.summary.scalar("Discriminator_loss", self.d_loss)

        self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
        self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss)
        self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss)

        self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss)
        self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss)

        self.g_summary_loss = tf.summary.merge([self.Generator_loss, self.g_adv_loss, self.g_cls_loss, self.g_rec_loss])
        self.d_summary_loss = tf.summary.merge([self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss])
示例#14
0
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')
        """ Input Image"""
        Image_data_class = ImageData(load_size=self.img_size,
                                     channels=self.img_ch,
                                     data_path=self.dataset_path,
                                     dataset_name=self.dataset_name,
                                     selected_attrs=self.selected_attrs,
                                     augment_flag=self.augment_flag)
        Image_data_class.preprocess()

        train_dataset_num = len(Image_data_class.train_dataset)
        test_dataset_num = len(Image_data_class.test_dataset)

        train_dataset = tf.data.Dataset.from_tensor_slices(
            (Image_data_class.train_dataset,
             Image_data_class.train_dataset_label,
             Image_data_class.train_dataset_fix_label))
        test_dataset = tf.data.Dataset.from_tensor_slices(
            (Image_data_class.test_dataset,
             Image_data_class.test_dataset_label,
             Image_data_class.test_dataset_fix_label))

        gpu_device = '/gpu:0'
        train_dataset = train_dataset.\
            apply(shuffle_and_repeat(train_dataset_num)).\
            apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
            apply(prefetch_to_device(gpu_device, self.batch_size))

        test_dataset = test_dataset.\
            apply(shuffle_and_repeat(test_dataset_num)).\
            apply(map_and_batch(Image_data_class.image_processing, self.batch_size, num_parallel_batches=8, drop_remainder=True)).\
            apply(prefetch_to_device(gpu_device, self.batch_size))

        train_dataset_iterator = train_dataset.make_one_shot_iterator()
        test_dataset_iterator = test_dataset.make_one_shot_iterator()

        self.x_real, label_org, label_fix_list = train_dataset_iterator.get_next(
        )  # Input image / Original domain labels
        label_trg = tf.random_shuffle(label_org)  # Target domain labels
        label_fix_list = tf.transpose(label_fix_list, perm=[1, 0, 2])

        self.label_org = label_org
        self.label_trg = label_trg
        self.label_fix_list = label_fix_list

        self.x_test, test_label_org, test_label_fix_list = test_dataset_iterator.get_next(
        )  # Input image / Original domain labels
        test_label_fix_list = tf.transpose(test_label_fix_list, perm=[1, 0, 2])

        self.custom_image = tf.placeholder(
            tf.float32, [1, self.img_size, self.img_size, self.img_ch],
            name='custom_image')  # Custom Image
        custom_label_fix_list = tf.transpose(create_labels(
            self.custom_label, self.selected_attrs, self.dataset_name),
                                             perm=[1, 0, 2])
        """ Define Generator, Discriminator """
        x_fake = self.generator(self.x_real, label_trg)  # real a
        x_recon = self.generator(x_fake, label_org, reuse=True)  # real b

        real_logit, real_cls = self.discriminator(self.x_real)
        fake_logit, fake_cls = self.discriminator(x_fake, reuse=True)
        """ Define Loss """
        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_panalty(real=self.x_real, fake=x_fake)
        else:
            GP = 0

        g_adv_loss = generator_loss(loss_func=self.gan_type, fake=fake_logit)
        g_cls_loss = classification_loss(logit=fake_cls, label=label_trg)
        g_rec_loss = L1_loss(self.x_real, x_recon)

        d_adv_loss = discriminator_loss(
            loss_func=self.gan_type, real=real_logit, fake=fake_logit) + GP
        d_cls_loss = classification_loss(logit=real_cls, label=label_org)

        self.d_loss = self.adv_weight * d_adv_loss + self.cls_weight * d_cls_loss
        self.g_loss = self.adv_weight * g_adv_loss + self.cls_weight * g_cls_loss + self.rec_weight * g_rec_loss
        """ Result Image """
        self.x_fake_list = tf.map_fn(
            lambda x: self.generator(self.x_real, x, reuse=True),
            label_fix_list,
            dtype=tf.float32)
        """ Test Image """
        self.x_test_fake_list = tf.map_fn(
            lambda x: self.generator(self.x_test, x, reuse=True),
            test_label_fix_list,
            dtype=tf.float32)
        self.custom_fake_image = tf.map_fn(
            lambda x: self.generator(self.custom_image, x, reuse=True),
            custom_label_fix_list,
            dtype=tf.float32)
        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.g_optimizer = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.g_loss,
                                                      var_list=G_vars)
        self.d_optimizer = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.d_loss,
                                                      var_list=D_vars)
        """" Summary """
        self.Generator_loss = tf.summary.scalar("Generator_loss", self.g_loss)
        self.Discriminator_loss = tf.summary.scalar("Discriminator_loss",
                                                    self.d_loss)

        self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
        self.g_cls_loss = tf.summary.scalar("g_cls_loss", g_cls_loss)
        self.g_rec_loss = tf.summary.scalar("g_rec_loss", g_rec_loss)

        self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss)
        self.d_cls_loss = tf.summary.scalar("d_cls_loss", d_cls_loss)

        self.g_summary_loss = tf.summary.merge([
            self.Generator_loss, self.g_adv_loss, self.g_cls_loss,
            self.g_rec_loss
        ])
        self.d_summary_loss = tf.summary.merge(
            [self.Discriminator_loss, self.d_adv_loss, self.d_cls_loss])
示例#15
0
    def build_model(self):

        if self.phase == 'train':
            """ Input Image"""
            img_data_class = Image_data(self.img_height, self.img_width,
                                        self.img_ch, self.dataset_path,
                                        self.augment_flag)
            img_data_class.preprocess()

            self.dataset_num = len(img_data_class.image_list)

            img_and_class = tf.data.Dataset.from_tensor_slices(
                (img_data_class.image_list, img_data_class.class_list))

            gpu_device = '/gpu:0'
            img_and_class = img_and_class.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                    map_and_batch(img_data_class.image_processing,
                                  batch_size=self.batch_size * self.gpu_num,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            img_and_class_iterator = img_and_class.make_one_shot_iterator()

            self.content_img, self.content_class = img_and_class_iterator.get_next(
            )
            self.style_img, self.style_class = img_and_class_iterator.get_next(
            )

            self.content_img = tf.split(self.content_img,
                                        num_or_size_splits=self.gpu_num)
            self.content_class = tf.split(self.content_class,
                                          num_or_size_splits=self.gpu_num)
            self.style_img = tf.split(self.style_img,
                                      num_or_size_splits=self.gpu_num)
            self.style_class = tf.split(self.style_class,
                                        num_or_size_splits=self.gpu_num)

            self.fake_img = []

            d_adv_losses = []
            g_adv_losses = []
            g_recon_losses = []
            g_feature_losses = []

            for gpu_id in range(self.gpu_num):
                with tf.device(
                        tf.DeviceSpec(device_type="GPU", device_index=gpu_id)):
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=(gpu_id > 0)):
                        """ Define Generator, Discriminator """
                        content_code = self.content_encoder(
                            self.content_img[gpu_id])
                        style_class_code = self.class_encoder(
                            self.style_img[gpu_id])
                        content_class_code = self.class_encoder(
                            self.content_img[gpu_id])

                        fake_img = self.generator(content_code,
                                                  style_class_code)
                        recon_img = self.generator(content_code,
                                                   content_class_code)

                        real_logit, style_feature_map = self.discriminator(
                            self.style_img[gpu_id], self.style_class[gpu_id])
                        fake_logit, fake_feature_map = self.discriminator(
                            fake_img, self.style_class[gpu_id])

                        recon_logit, recon_feature_map = self.discriminator(
                            recon_img, self.content_class[gpu_id])
                        _, content_feature_map = self.discriminator(
                            self.content_img[gpu_id],
                            self.content_class[gpu_id])
                        """ Define Loss """
                        d_adv_loss = self.adv_weight * discriminator_loss(
                            self.gan_type, real_logit, fake_logit,
                            self.style_img[gpu_id])
                        g_adv_loss = 0.5 * self.adv_weight * (
                            generator_loss(self.gan_type, fake_logit) +
                            generator_loss(self.gan_type, recon_logit))

                        g_recon_loss = self.recon_weight * L1_loss(
                            self.content_img[gpu_id], recon_img)

                        content_feature_map = tf.reduce_mean(tf.reduce_mean(
                            content_feature_map, axis=2),
                                                             axis=1)
                        recon_feature_map = tf.reduce_mean(tf.reduce_mean(
                            recon_feature_map, axis=2),
                                                           axis=1)
                        fake_feature_map = tf.reduce_mean(tf.reduce_mean(
                            fake_feature_map, axis=2),
                                                          axis=1)
                        style_feature_map = tf.reduce_mean(tf.reduce_mean(
                            style_feature_map, axis=2),
                                                           axis=1)

                        g_feature_loss = self.feature_weight * (
                            L1_loss(recon_feature_map, content_feature_map) +
                            L1_loss(fake_feature_map, style_feature_map))

                        d_adv_losses.append(d_adv_loss)
                        g_adv_losses.append(g_adv_loss)
                        g_recon_losses.append(g_recon_loss)
                        g_feature_losses.append(g_feature_loss)

                        self.fake_img.append(fake_img)

            self.g_loss = tf.reduce_mean(g_adv_losses) + \
                          tf.reduce_mean(g_recon_losses) + \
                          tf.reduce_mean(g_feature_losses) + regularization_loss('encoder') + regularization_loss('generator')

            self.d_loss = tf.reduce_mean(d_adv_losses) + regularization_loss(
                'discriminator')
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [
                var for var in t_vars
                if 'encoder' in var.name or 'generator' in var.name
            ]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            if self.gpu_num == 1:
                prev_G_optim = tf.train.RMSPropOptimizer(
                    self.lr, decay=0.99,
                    epsilon=1e-8).minimize(self.g_loss, var_list=G_vars)
                self.D_optim = tf.train.RMSPropOptimizer(
                    self.lr, decay=0.99,
                    epsilon=1e-8).minimize(self.d_loss, var_list=D_vars)
                # Pytorch : decay=0.99, epsilon=1e-8

            else:
                prev_G_optim = tf.train.RMSPropOptimizer(
                    self.lr, decay=0.99,
                    epsilon=1e-8).minimize(self.g_loss,
                                           var_list=G_vars,
                                           colocate_gradients_with_ops=True)
                self.D_optim = tf.train.RMSPropOptimizer(
                    self.lr, decay=0.99,
                    epsilon=1e-8).minimize(self.d_loss,
                                           var_list=D_vars,
                                           colocate_gradients_with_ops=True)
                # Pytorch : decay=0.99, epsilon=1e-8

            self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay)
            with tf.control_dependencies([prev_G_optim]):
                self.G_optim = self.ema.apply(G_vars)
            """" Summary """
            self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss)
            self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss)

            self.summary_g_adv_loss = tf.summary.scalar(
                "g_adv_loss", tf.reduce_mean(g_adv_losses))
            self.summary_g_recon_loss = tf.summary.scalar(
                "g_recon_loss", tf.reduce_mean(g_recon_losses))
            self.summary_g_feature_loss = tf.summary.scalar(
                "g_feature_loss", tf.reduce_mean(g_feature_losses))

            g_summary_list = [
                self.summary_g_loss, self.summary_g_adv_loss,
                self.summary_g_recon_loss, self.summary_g_feature_loss
            ]

            d_summary_list = [self.summary_d_loss]

            self.summary_merge_g_loss = tf.summary.merge(g_summary_list)
            self.summary_merge_d_loss = tf.summary.merge(d_summary_list)

        else:
            """ Test """
            self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay)
            self.test_content_img = tf.placeholder(
                tf.float32, [1, self.img_height, self.img_width, self.img_ch])
            self.test_class_img = tf.placeholder(
                tf.float32,
                [self.K, self.img_height, self.img_width, self.img_ch])

            test_content_code = self.content_encoder(self.test_content_img)
            test_style_class_code = tf.reduce_mean(self.class_encoder(
                self.test_class_img),
                                                   axis=0,
                                                   keepdims=True)

            self.test_fake_img = self.generator(test_content_code,
                                                test_style_class_code)
示例#16
0
def get_pretraining_dataset(opts, data_type, is_training=True, num_cpu_threads=4, use_static_mask=False):
    if is_training:
        input_file = opts['train_file']
    else:
        input_file = opts['test_file']
    micro_batch_size = opts['micro_batch_size']
    max_seq_length = opts['seq_length']
    max_predictions_per_seq = opts['max_predictions_per_seq']

    input_files = []
    for input_pattern in input_file.split(","):
        input_files.extend(tf.gfile.Glob(input_pattern))

    tf.logging.info("*** Input Files ***")
    for input_file in input_files:
        tf.logging.info("  %s" % input_file)

    if use_static_mask:
        # The masked tokens have been re-arranaged to always be at the first
        # 'max_predictions_per_seq' positions.
        name_to_features = {
            "input_ids":
            tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_position":
            tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids":
            tf.FixedLenFeature([max_seq_length], tf.int64),
            "mask_padding_index":
            tf.FixedLenFeature([1], tf.int64),
            "seq_padding_index":
            tf.FixedLenFeature([1], tf.int64),
            "masked_labels":
            tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
            "masked_lm_weights":
            tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
            "next_sentence_labels":
            tf.FixedLenFeature([1], tf.int64),
        }
    else:
        # Default, the tokens have not been re-arranged.
        name_to_features = {
            "input_ids":
            tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask":
            tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids":
            tf.FixedLenFeature([max_seq_length], tf.int64),
            "masked_lm_positions":
            tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
            "masked_lm_ids":
            tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
            "masked_lm_weights":
            tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
            "next_sentence_labels":
            tf.FixedLenFeature([1], tf.int64),
        }

    # For training, we want a lot of parallel reading and shuffling.
    # For eval, we want no shuffling and parallel reading doesn't matter.
    if is_training:
        d = tf.data.Dataset.from_tensor_slices(tf.constant(input_files))
        d = d.repeat()

        # `cycle_length` is the number of parallel files that get read.
        cycle_length = min(num_cpu_threads, len(input_files))

        # `sloppy` mode means that the interleaving is not exact. This adds
        # even more randomness to the training pipeline.
        d = d.apply(parallel_interleave(tf.data.TFRecordDataset, sloppy=is_training, cycle_length=cycle_length))

        # `buffer_size` should be set big enough to keep data shuffle sufficiently.
        if opts['distributed_worker_count'] > 1:
            d = d.shard(num_shards=opts['distributed_worker_count'], index=opts['distributed_worker_index'])
            d = d.shuffle(buffer_size=1000, seed=opts['seed'])
        else:
            d = d.shuffle(buffer_size=1000)
    else:
        d = tf.data.TFRecordDataset(input_files)
        d = d.repeat()

    d = d.apply(map_and_batch(
            lambda record: _decode_record(record, name_to_features, data_type),
            batch_size=micro_batch_size,
            num_parallel_batches=num_cpu_threads,
            drop_remainder=True))
    return d
    def build_model(self):
        if self.custom_dataset:
            Image_Data_Class = ImageData(self.img_size, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)

            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

            inputs_iterator = inputs.make_one_shot_iterator()

            self.inputs = inputs_iterator.get_next()

        else:
            self.inputs = tf.placeholder(
                tf.float32,
                [self.batch_size, self.img_size, self.img_size, self.c_dim],
                name='real_images')

        self.z = tf.placeholder(tf.float32,
                                [self.batch_size, 1, 1, self.z_dim],
                                name='z')

        real_logits = self.discriminator(self.inputs)

        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
        else:
            GP = 0

        self.d_loss = discriminator_loss(
            self.gan_type, real=real_logits, fake=fake_logits) + GP

        self.g_loss = generator_loss(self.gan_type, fake=fake_logits)

        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                  self.d_loss, var_list=d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                  self.g_loss, var_list=g_vars)

        self.fake_images = self.generator(self.z,
                                          is_training=False,
                                          reuse=True)

        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
示例#18
0
    if not export_file:
        fail_for_missing_file()

    with open(export_file) as f:
        export_json = load(f)

    legend = export_json['legend']
    tfrecord_paths = export_json['tfrecord_paths']

    image_dim = 512
    test_set_size = math.ceil(0.20 * len(tfrecord_paths))

    training_dataset = (tf.data.TFRecordDataset(
        tfrecord_paths).skip(test_set_size).map(_parse_tfrecord).apply(
            shuffle_and_repeat(50)).apply(map_and_batch(_resize(image_dim),
                                                        8)))
    test_dataset = (tf.data.TFRecordDataset(tfrecord_paths).take(
        test_set_size).map(_parse_tfrecord).apply(
            map_and_batch(_resize(image_dim), test_set_size)))

    training_iterator = training_dataset.make_one_shot_iterator()
    test_iterator = test_dataset.make_initializable_iterator()

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_dataset.output_types, training_dataset.output_shapes)
    images, labels = iterator.get_next()

    number_of_classes = len(legend) + 1  # +1 to include background class
    weight_decay = 0.0005
    dropout_keep_prob = tf.placeholder(tf.float32, shape=[])
示例#19
0
    def build_model(self):
        # some parameters
        bs = self.batch_size
        """ Graph Input """
        # images
        if self.custom_dataset:
            Image_Data_Class = ImageData(self.output_height, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)

            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=8,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

            inputs_iterator = inputs.make_one_shot_iterator()

            self.inputs = inputs_iterator.get_next()

        else:
            self.inputs = tf.placeholder(tf.float32, [
                self.batch_size, self.output_height, self.output_height,
                self.c_dim
            ],
                                         name='real_images')

        # noises
        self.z = tf.placeholder(tf.float32, [bs, self.z_dim], name='z')
        """ Loss Function """
        x_fake = self.generator(self.z, is_training=True, reuse=False)
        x_real_encoder = self.encoder(self.inputs,
                                      is_training=True,
                                      reuse=False,
                                      sn=True)
        x_fake_encoder = self.encoder(x_fake,
                                      is_training=True,
                                      reuse=True,
                                      sn=True)
        x_real_fake = tf.subtract(x_real_encoder, x_fake_encoder)
        x_fake_real = tf.subtract(x_fake_encoder, x_real_encoder)
        x_real_fake_score = self.discriminator(x_real_fake,
                                               reuse=False,
                                               sn=True)
        x_fake_real_score = self.discriminator(x_fake_real,
                                               reuse=True,
                                               sn=True)

        # get loss for discriminator
        self.d_loss = discriminator_loss(self.loss_type,
                                         real=x_real_fake_score,
                                         fake=x_fake_real_score)

        # get loss for generator
        self.g_loss = generator_loss(self.loss_type,
                                     real=x_real_fake_score,
                                     fake=x_fake_real_score)
        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.trainable_variables()
        d_vars = [
            var for var in t_vars
            if 'discriminator' in var.name or 'encoder' in var.name
        ]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        # optimizers
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \
                .minimize(self.d_loss, var_list=d_vars)
            self.g_optim = tf.train.AdamOptimizer(self.learning_rate*5, beta1=self.beta1) \
                .minimize(self.g_loss, var_list=g_vars)
        """" Testing """
        # for test
        self.fake_images = self.generator(self.z,
                                          is_training=False,
                                          reuse=True)
        """ Summary """
        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
示例#20
0
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='lr')
        """ Input Image"""
        Image_Data_Class = ImageData(self.img_size, self.img_ch,
                                     self.augment_flag)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)

        gpu_device = '/gpu:0'
        trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(Image_Data_Class.image_processing,
                          self.batch_size,
                          num_parallel_batches=16,
                          drop_remainder=True)).apply(
                              prefetch_to_device(gpu_device, self.batch_size))
        trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(
            map_and_batch(Image_Data_Class.image_processing,
                          self.batch_size,
                          num_parallel_batches=16,
                          drop_remainder=True)).apply(
                              prefetch_to_device(gpu_device, self.batch_size))

        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()

        self.domain_A = trainA_iterator.get_next()
        self.domain_B = trainB_iterator.get_next()
        """ Define Encoder, Generator, Discriminator """
        random_z = tf.random_normal(shape=[self.batch_size, self.n_z],
                                    mean=0.0,
                                    stddev=1.0,
                                    dtype=tf.float32)

        # encode
        content_a, attribute_a, mean_a, logvar_a = self.Encoder_A(
            self.domain_A)
        content_b, attribute_b, mean_b, logvar_b = self.Encoder_B(
            self.domain_B)

        # decode (fake, identity, random)
        fake_a = self.Decoder_A(content_B=content_b, attribute_A=attribute_a)
        fake_b = self.Decoder_B(content_A=content_a, attribute_B=attribute_b)

        recon_a = self.Decoder_A(content_B=content_a,
                                 attribute_A=attribute_a,
                                 reuse=True)
        recon_b = self.Decoder_B(content_A=content_b,
                                 attribute_B=attribute_b,
                                 reuse=True)

        random_fake_a = self.Decoder_A(content_B=content_b,
                                       attribute_A=random_z,
                                       reuse=True)
        random_fake_b = self.Decoder_B(content_A=content_a,
                                       attribute_B=random_z,
                                       reuse=True)

        # encode & decode again for cycle-consistency
        content_fake_a, attribute_fake_a, _, _ = self.Encoder_A(fake_a,
                                                                reuse=True)
        content_fake_b, attribute_fake_b, _, _ = self.Encoder_B(fake_b,
                                                                reuse=True)

        cycle_a = self.Decoder_A(content_B=content_fake_b,
                                 attribute_A=attribute_fake_a,
                                 reuse=True)
        cycle_b = self.Decoder_B(content_A=content_fake_a,
                                 attribute_B=attribute_fake_b,
                                 reuse=True)

        # for latent regression
        _, attribute_fake_random_a, _, _ = self.Encoder_A(random_fake_a,
                                                          random_fake=True,
                                                          reuse=True)
        _, attribute_fake_random_b, _, _ = self.Encoder_B(random_fake_b,
                                                          random_fake=True,
                                                          reuse=True)

        # discriminate
        real_A_logit, real_B_logit = self.discriminate_real(
            self.domain_A, self.domain_B)
        fake_A_logit, fake_B_logit = self.discriminate_fake(fake_a, fake_b)
        random_fake_A_logit, random_fake_B_logit = self.discriminate_fake(
            random_fake_a, random_fake_b)
        content_A_logit, content_B_logit = self.discriminate_content(
            content_a, content_b)
        """ Define Loss """
        g_adv_loss_a = generator_loss(self.gan_type,
                                      fake_A_logit) + generator_loss(
                                          self.gan_type, random_fake_A_logit)
        g_adv_loss_b = generator_loss(self.gan_type,
                                      fake_B_logit) + generator_loss(
                                          self.gan_type, random_fake_B_logit)

        g_con_loss_a = generator_loss(self.gan_type,
                                      content_A_logit,
                                      content=True)
        g_con_loss_b = generator_loss(self.gan_type,
                                      content_B_logit,
                                      content=True)

        g_cyc_loss_a = L1_loss(cycle_a, self.domain_A)
        g_cyc_loss_b = L1_loss(cycle_b, self.domain_B)

        g_rec_loss_a = L1_loss(recon_a, self.domain_A)
        g_rec_loss_b = L1_loss(recon_b, self.domain_B)

        g_latent_loss_a = L1_loss(attribute_fake_random_a, random_z)
        g_latent_loss_b = L1_loss(attribute_fake_random_b, random_z)

        if self.concat:
            g_kl_loss_a = kl_loss(mean_a, logvar_a) + l2_regularize(content_a)
            g_kl_loss_b = kl_loss(mean_b, logvar_b) + l2_regularize(content_b)
        else:
            g_kl_loss_a = l2_regularize(attribute_a) + l2_regularize(content_a)
            g_kl_loss_b = l2_regularize(attribute_b) + l2_regularize(content_b)

        d_adv_loss_a = discriminator_loss(self.gan_type, real_A_logit,
                                          fake_A_logit)
        d_adv_loss_b = discriminator_loss(self.gan_type, real_B_logit,
                                          fake_B_logit)

        d_con_loss = discriminator_loss(self.gan_type, content_A_logit,
                                        content_B_logit)

        Generator_A_domain_loss = self.domain_adv_w * g_adv_loss_a
        Generator_A_content_loss = self.content_adv_w * g_con_loss_a
        Generator_A_cycle_loss = self.cycle_w * g_cyc_loss_b
        Generator_A_recon_loss = self.recon_w * g_rec_loss_a
        Generator_A_latent_loss = self.latent_w * g_latent_loss_a
        Generator_A_kl_loss = self.kl_w * g_kl_loss_a

        Generator_A_loss = Generator_A_domain_loss + \
                           Generator_A_content_loss + \
                           Generator_A_cycle_loss + \
                           Generator_A_recon_loss + \
                           Generator_A_latent_loss + \
                           Generator_A_kl_loss

        Generator_B_domain_loss = self.domain_adv_w * g_adv_loss_b
        Generator_B_content_loss = self.content_adv_w * g_con_loss_b
        Generator_B_cycle_loss = self.cycle_w * g_cyc_loss_a
        Generator_B_recon_loss = self.recon_w * g_rec_loss_b
        Generator_B_latent_loss = self.latent_w * g_latent_loss_b
        Generator_B_kl_loss = self.kl_w * g_kl_loss_b

        Generator_B_loss = Generator_B_domain_loss + \
                           Generator_B_content_loss + \
                           Generator_B_cycle_loss + \
                           Generator_B_recon_loss + \
                           Generator_B_latent_loss + \
                           Generator_B_kl_loss

        Discriminator_A_loss = self.domain_adv_w * d_adv_loss_a
        Discriminator_B_loss = self.domain_adv_w * d_adv_loss_b
        Discriminator_content_loss = self.content_adv_w * d_con_loss

        self.Generator_loss = Generator_A_loss + Generator_B_loss
        self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss
        self.Discriminator_content_loss = Discriminator_content_loss
        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [
            var for var in t_vars
            if 'endoer' in var.name or 'generator' in var.name
        ]
        D_vars = [
            var for var in t_vars
            if 'discriminator' in var.name and 'content' not in var.name
        ]
        D_content_vars = [
            var for var in t_vars if 'content_discriminator' in var.name
        ]

        grads, _ = tf.clip_by_global_norm(tf.gradients(
            self.Discriminator_content_loss, D_content_vars),
                                          clip_norm=5)

        self.G_optim = tf.train.AdamOptimizer(
            self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss,
                                                      var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(
            self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss,
                                                      var_list=D_vars)
        self.D_content_optim = tf.train.AdamOptimizer(
            self.lr, beta1=0.5,
            beta2=0.999).apply_gradients(zip(grads, D_content_vars))
        """" Summary """
        self.lr_write = tf.summary.scalar("learning_rate", self.lr)

        self.all_G_loss = tf.summary.scalar("Generator_loss",
                                            self.Generator_loss)
        self.all_D_loss = tf.summary.scalar("Discriminator_loss",
                                            self.Discriminator_loss)

        self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
        self.G_A_domain_loss = tf.summary.scalar("G_A_domain_loss",
                                                 Generator_A_domain_loss)
        self.G_A_content_loss = tf.summary.scalar("G_A_content_loss",
                                                  Generator_A_content_loss)
        self.G_A_cycle_loss = tf.summary.scalar("G_A_cycle_loss",
                                                Generator_A_cycle_loss)
        self.G_A_recon_loss = tf.summary.scalar("G_A_recon_loss",
                                                Generator_A_recon_loss)
        self.G_A_latent_loss = tf.summary.scalar("G_A_latent_loss",
                                                 Generator_A_latent_loss)
        self.G_A_kl_loss = tf.summary.scalar("G_A_kl_loss",
                                             Generator_A_kl_loss)

        self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
        self.G_B_domain_loss = tf.summary.scalar("G_B_domain_loss",
                                                 Generator_B_domain_loss)
        self.G_B_content_loss = tf.summary.scalar("G_B_content_loss",
                                                  Generator_B_content_loss)
        self.G_B_cycle_loss = tf.summary.scalar("G_B_cycle_loss",
                                                Generator_B_cycle_loss)
        self.G_B_recon_loss = tf.summary.scalar("G_B_recon_loss",
                                                Generator_B_recon_loss)
        self.G_B_latent_loss = tf.summary.scalar("G_B_latent_loss",
                                                 Generator_B_latent_loss)
        self.G_B_kl_loss = tf.summary.scalar("G_B_kl_loss",
                                             Generator_B_kl_loss)

        self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
        self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)

        self.G_loss = tf.summary.merge([
            self.G_A_loss, self.G_A_domain_loss, self.G_A_content_loss,
            self.G_A_cycle_loss, self.G_A_recon_loss, self.G_A_latent_loss,
            self.G_A_kl_loss, self.G_B_loss, self.G_B_domain_loss,
            self.G_B_content_loss, self.G_B_cycle_loss, self.G_B_recon_loss,
            self.G_B_latent_loss, self.G_B_kl_loss, self.all_G_loss
        ])

        self.D_loss = tf.summary.merge(
            [self.D_A_loss, self.D_B_loss, self.all_D_loss])

        self.D_content_loss = tf.summary.scalar(
            "Discriminator_content_loss", self.Discriminator_content_loss)
        """ Image """
        self.fake_A = fake_a
        self.fake_B = fake_b

        self.real_A = self.domain_A
        self.real_B = self.domain_B
        """ Test """
        self.test_image = tf.placeholder(
            tf.float32, [1, self.img_size, self.img_size, self.img_ch],
            name='test_image')
        self.test_random_z = tf.random_normal(shape=[1, self.n_z],
                                              mean=0.0,
                                              stddev=1.0,
                                              dtype=tf.float32)

        test_content_a, test_attribute_a, _, _ = self.Encoder_A(
            self.test_image, is_training=False, reuse=True)
        test_content_b, test_attribute_b, _, _ = self.Encoder_B(
            self.test_image, is_training=False, reuse=True)

        self.test_fake_A = self.Decoder_A(content_B=test_content_b,
                                          attribute_A=self.test_random_z,
                                          reuse=True)
        self.test_fake_B = self.Decoder_B(content_A=test_content_a,
                                          attribute_B=self.test_random_z,
                                          reuse=True)
        """ Guided Image Translation """
        self.content_image = tf.placeholder(
            tf.float32, [1, self.img_size, self.img_size, self.img_ch],
            name='content_image')
        self.attribute_image = tf.placeholder(
            tf.float32, [1, self.img_size, self.img_size, self.img_ch],
            name='guide_attribute_image')

        guide_content_A, guide_attribute_A, _, _ = self.Encoder_A(
            self.content_image, is_training=False, reuse=True)
        guide_content_B, guide_attribute_B, _, _ = self.Encoder_B(
            self.attribute_image, is_training=False, reuse=True)

        self.guide_fake_A = self.Decoder_A(content_B=guide_content_B,
                                           attribute_A=guide_attribute_A,
                                           reuse=True)
        self.guide_fake_B = self.Decoder_B(content_A=guide_content_A,
                                           attribute_B=guide_attribute_B,
                                           reuse=True)
示例#21
0
def _initialize_tf_dataset(file_names,
                           augment,
                           over_sample,
                           shuffle,
                           target_size,
                           normalize,
                           nb_workers=8,
                           batch_size=64,
                           shuffle_buffer_size=3000,
                           input_type="img",
                           nr_epochs=1):
    import tensorflow as tf
    from tensorflow.contrib.data import map_and_batch
    from tensorflow.contrib.data import shuffle_and_repeat

    if not type(target_size) is list:
        target_size = list(target_size)

    with tf.name_scope('input_pipeline'):
        dataset = tf.data.TFRecordDataset(file_names)
        if shuffle:
            dataset = dataset.apply(
                shuffle_and_repeat(shuffle_buffer_size, nr_epochs))

        def _decode_and_augment_image(example_proto):
            keys_to_features = {
                'label': tf.FixedLenFeature([], tf.int64),
                'shape': tf.FixedLenFeature([], tf.string),
                'image': tf.FixedLenFeature([], tf.string),
            }
            tfrecord_features = tf.parse_single_example(
                example_proto, keys_to_features)

            image = tf.decode_raw(tfrecord_features['image'], tf.uint8)
            shape = tf.decode_raw(tfrecord_features['shape'], tf.int64)
            if input_type == ".jpeg":
                image = tf.reshape(image, target_size + [3])
            else:
                image = tf.reshape(image, target_size)
            label = tfrecord_features['label']

            if augment:
                image = tf.image.random_flip_left_right(image)
                image = tf.image.random_flip_up_down(image)
                degrees = tf.random_uniform((), minval=-180, maxval=180)
                image = tf.contrib.image.rotate(image, degrees)

                width_shift = tf.random_uniform((), minval=0, maxval=0.05)
                height_shift = tf.random_uniform((), minval=0, maxval=0.05)

                horizontal_pad = tf.cast(tf.ceil(width_shift * target_size[0]),
                                         tf.int32)
                vertical_pad = tf.cast(tf.ceil(height_shift * target_size[1]),
                                       tf.int32)

                padding = tf.stack([
                    horizontal_pad, horizontal_pad, vertical_pad, vertical_pad,
                    tf.constant(0),
                    tf.constant(0)
                ])
                padding = tf.reshape(padding, (3, 2))

                image = tf.pad(image, padding)
                image = tf.random_crop(image, target_size + [3])

                zoom = tf.random_uniform((), minval=-0.1, maxval=0.1)
                new_dim = tf.cast(tf.ceil((1 - zoom) * target_size[0]),
                                  dtype=tf.int32)

                image = tf.image.resize_image_with_crop_or_pad(
                    image, new_dim, new_dim)

                image = tf.image.resize_images(
                    image, target_size, method=tf.image.ResizeMethod.BILINEAR)

            if normalize:
                std = tf.constant(np.array(
                    [70.53946096, 51.71475228, 43.03428563]),
                                  dtype=tf.float32)
                std = tf.expand_dims(tf.expand_dims(std, axis=0), axis=0)

                mean = tf.constant(np.array(
                    [108.64628601, 75.86886597, 54.34005736]),
                                   dtype=tf.float32)
                mean = tf.expand_dims(tf.expand_dims(mean, axis=0), axis=0)

                image = (tf.cast(image, dtype=tf.float32) - mean) / std

            label = tf.reshape(label, [1])
            if input_type == ".jpeg":
                image = tf.reshape(image, target_size + [3])
            else:
                image = tf.reshape(image, target_size)

            return {'shape': shape, 'image': image}, label

        dataset = dataset \
            .apply(map_and_batch(_decode_and_augment_image, batch_size=batch_size, num_parallel_batches=nb_workers,
                                 drop_remainder=True)) \
            .prefetch(nb_workers)

        # def _augment_images(example)

        return dataset
    def build_model(self):
        """ Input Image"""
        img_data_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.augment_flag)
        train_captions, train_images, test_captions, test_images, idx_to_word, word_to_idx = img_data_class.preprocess()
        """
        train_captions: (8855, 10, 18), test_captions: (2933, 10, 18)
        train_images: (8855,), test_images: (2933,)
        idx_to_word : 5450 5450
        """

        if self.phase == 'train' :
            self.lr = tf.placeholder(tf.float32, name='learning_rate')

            self.dataset_num = len(train_images)


            img_and_caption = tf.data.Dataset.from_tensor_slices((train_images, train_captions))

            gpu_device = '/gpu:0'
            img_and_caption = img_and_caption.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16,
                              drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))


            img_and_caption_iterator = img_and_caption.make_one_shot_iterator()
            real_img_256, caption = img_and_caption_iterator.get_next()
            target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32)
            caption = tf.gather(caption, target_sentence_index, axis=1)

            word_emb, sent_emb, mask = self.rnn_encoder(caption, n_words=len(idx_to_word),
                                                        embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128, n_layers=1,
                                                        bidirectional=True, rnn_type='lstm')

            noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0)
            fake_imgs, _, mu, logvar = self.generator(noise, sent_emb, word_emb, mask)

            real_img_64, real_img_128 = resize(real_img_256, target_size=[64, 64]), resize(real_img_256, target_size=[128, 128])
            fake_img_64, fake_img_128, fake_img_256 = fake_imgs[0], fake_imgs[1], fake_imgs[2]

            uncond_real_logits, cond_real_logits = self.discriminator([real_img_64, real_img_128, real_img_256], sent_emb)
            _, cond_wrong_logits = self.discriminator([real_img_64[:(self.batch_size - 1)], real_img_128[:(self.batch_size - 1)], real_img_256[:(self.batch_size - 1)]], sent_emb[1:self.batch_size])
            uncond_fake_logits, cond_fake_logits = self.discriminator([fake_img_64, fake_img_128, fake_img_256], sent_emb)

            self.g_adv_loss, self.d_adv_loss = 0, 0
            for i in range(3):
                self.g_adv_loss += self.adv_weight * (generator_loss(self.gan_type, uncond_fake_logits[i]) + generator_loss(self.gan_type, cond_fake_logits[i]))

                uncond_real_loss, uncond_fake_loss = discriminator_loss(self.gan_type, uncond_real_logits[i], uncond_fake_logits[i])
                cond_real_loss, cond_fake_loss = discriminator_loss(self.gan_type, cond_real_logits[i], cond_fake_logits[i])
                _, cond_wrong_loss = discriminator_loss(self.gan_type, None, cond_wrong_logits[i])

                self.d_adv_loss += self.adv_weight * (((uncond_real_loss + cond_real_loss) / 2) + (uncond_fake_loss + cond_fake_loss + cond_wrong_loss) / 3)

            self.g_kl_loss = self.kl_weight * kl_loss(mu, logvar)

            self.g_loss = self.g_adv_loss + self.g_kl_loss
            self.d_loss = self.d_adv_loss

            self.real_img = real_img_256
            self.fake_img = fake_img_256


            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.g_loss, var_list=G_vars)
            self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.d_loss,var_list=D_vars)


            """" Summary """
            self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss)
            self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss)

            self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", self.g_adv_loss)
            self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", self.g_kl_loss)

            self.summary_d_adv_loss = tf.summary.scalar("d_adv_loss", self.d_adv_loss)


            g_summary_list = [self.summary_g_loss,
                              self.summary_g_adv_loss, self.summary_g_kl_loss]

            d_summary_list = [self.summary_d_loss,
                              self.summary_d_adv_loss]

            self.summary_merge_g_loss = tf.summary.merge(g_summary_list)
            self.summary_merge_d_loss = tf.summary.merge(d_summary_list)

        else :
            """ Test """
            self.dataset_num = len(test_captions)

            gpu_device = '/gpu:0'
            img_and_caption = tf.data.Dataset.from_tensor_slices((test_images, test_captions))

            img_and_caption = img_and_caption.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(img_data_class.image_processing, batch_size=self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(
                prefetch_to_device(gpu_device, None))

            img_and_caption_iterator = img_and_caption.make_one_shot_iterator()
            real_img_256, caption = img_and_caption_iterator.get_next()
            target_sentence_index = tf.random_uniform(shape=[], minval=0, maxval=10, dtype=tf.int32)
            caption = tf.gather(caption, target_sentence_index, axis=1)

            word_emb, sent_emb, mask = self.rnn_encoder(caption, n_words=len(idx_to_word),
                                                        embed_dim=self.embed_dim, drop_rate=0.5, n_hidden=128,
                                                        n_layers=1,
                                                        bidirectional=True, rnn_type='lstm',
                                                        is_training=False)

            noise = tf.random_normal(shape=[self.batch_size, self.z_dim], mean=0.0, stddev=1.0)
            fake_imgs, _, _, _ = self.generator(noise, sent_emb, word_emb, mask, is_training=False)

            self.test_real_img = real_img_256
            self.test_fake_img = fake_imgs[2]
示例#23
0
def get_batch(datasets,
              preprocess_name,
              is_training,
              batch_size,
              num_gpu=1,
              seed=None):
    with tf.device('/cpu:0'):

        num_class = datasets.num_class
        file_name = datasets.source
        feature = datasets.feature
        decoder = datasets.decoder
        name = datasets.description['name']

        image_preprocessing_fn = get_preprocess_fn(preprocess_name)

        dataset = tf.data.Dataset.from_tensor_slices(file_name)

        if is_training:
            # Shuffle the input files
            dataset = dataset.shuffle(len(file_name),
                                      seed=seed,
                                      reshuffle_each_iteration=True)
        '''  
    Convert to individual records.
    cycle_length = 8 means 8 files will be read and deserialized in parallel.
    This number is low enough to not cause too much contention on small systems
    but high enough to provide the benefits of parallelization. You may want
    to increase this number if you have a large number of CPU cores.
    '''

        cycle_length = min(10, len(file_name))
        dataset = dataset.apply(
            data.parallel_interleave(tf.data.TFRecordDataset,
                                     cycle_length=cycle_length))

        # We prefetch a batch at a time, This can help smooth out the time taken to
        # load input files as we go through shuffling and processing.
        dataset = dataset.prefetch(buffer_size=batch_size)

        if is_training:
            dataset = dataset.apply(
                data.shuffle_and_repeat(buffer_size=10000, seed=seed))
        else:
            dataset = dataset.repeat()

        def map_func(record):

            parsed = tf.parse_single_example(record, feature)
            image = decoder(parsed['image/encoded'])
            # Perform additional preprocessing on the parsed data.
            image = image_preprocessing_fn(image,
                                           datasets,
                                           is_training=is_training)
            label = parsed['image/class/label']

            label = tf.one_hot(label, num_class)
            return image, label

        '''
    Parse the raw records into images and labels. Testing has shown that setting
    num_parallel_batches > 1 produces no improvement in throughput, since
    batch_size is almost always much greater than the number of CPU cores.    
    '''
        dataset = dataset.apply(
            data.map_and_batch(map_func=map_func,
                               batch_size=batch_size,
                               num_parallel_batches=1))
        '''
    Operations between the final prefetch and the get_next call to the iterator
    will happen synchronously during run time. We prefetch here again to
    background all of the above processing work and keep it out of the
    critical training path.    
    '''

        dataset = dataset.prefetch(buffer_size=32)
        iterator = dataset.make_one_shot_iterator()
        return iterator
merged_summary_op = tf.summary.merge_all()
saver = tf.train.Saver(max_to_keep=3)

dataset = tf.data.TFRecordDataset('Dataset/dataset_1.tfrecord')
def _parse(example_proto):
    feature = {'roll' : tf.FixedLenFeature([], tf.string)}
    parsed = tf.parse_single_example(example_proto, feature)
    data = tf.decode_raw(parsed['roll'], tf.uint8)
    data = tf.py_func(func=np.unpackbits, inp=[data], Tout=tf.uint8)
    data = tf.cast(data, tf.float32)
    data = tf.reshape(data, [CLASS_NUM, 600])
    data = data * 2 - 1
    return data

dataset = dataset.apply(data.shuffle_and_repeat(buffer_size=60000, count=3))
dataset = dataset.apply(data.map_and_batch(_parse, batch_size=batch_size, num_parallel_batches=2))
iterator = dataset.prefetch(batch_size).make_one_shot_iterator()
real_input_next_element = iterator.get_next()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    if os.path.exists(model_path):
        load_path = saver.restore(sess, model_path)
        print("Model restored from file: %s" % load_path)
        
    for i in range(total_batch):
        tfdata = sess.run(real_input_next_element)
        reshape_tfdata = tfdata.reshape([-1, CLASS_NUM, 600])    #  
        cuted_tfdata = reshape_tfdata[:, :, :INPUT_LENGTH]
        reshape_cuted_tfdata = cuted_tfdata.reshape([-1, CLASS_NUM*INPUT_LENGTH])
示例#25
0
    def build_model(self):
        """ Graph """
        if self.phase == 'train':
            self.d_loss_per_res = {}
            self.g_loss_per_res = {}
            self.generator_optim = {}
            self.discriminator_optim = {}
            self.alpha_summary_per_res = {}
            self.d_summary_per_res = {}
            self.g_summary_per_res = {}
            self.train_fake_images = {}

            for res in self.resolutions[self.resolutions.index(self.start_res
                                                               ):]:
                g_loss_per_gpu = []
                d_loss_per_gpu = []
                train_fake_images_per_gpu = []

                batch_size = self.batch_sizes.get(res, self.batch_size_base)
                global_step = tf.get_variable(
                    'global_step_{}'.format(res),
                    shape=[],
                    dtype=tf.float32,
                    initializer=tf.initializers.zeros(),
                    trainable=False,
                    aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER)
                alpha_const, zero_constant = get_alpha_const(
                    self.iteration // 2, batch_size * self.gpu_num,
                    global_step)

                # smooth transition variable
                do_train_trans = self.train_with_trans[res]

                alpha = tf.get_variable(
                    'alpha_{}'.format(res),
                    shape=[],
                    dtype=tf.float32,
                    initializer=tf.initializers.ones()
                    if do_train_trans else tf.initializers.zeros(),
                    trainable=False,
                    aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER)

                if do_train_trans:
                    alpha_assign_op = tf.assign(alpha, alpha_const)
                else:
                    alpha_assign_op = tf.assign(alpha, zero_constant)

                with tf.control_dependencies([alpha_assign_op]):
                    for gpu_id in range(self.gpu_num):
                        with tf.device(
                                tf.DeviceSpec(device_type="GPU",
                                              device_index=gpu_id)):
                            with tf.variable_scope(tf.get_variable_scope(),
                                                   reuse=(gpu_id > 0)):
                                # images
                                gpu_device = '/gpu:{}'.format(gpu_id)
                                image_class = ImageData(res)
                                inputs = tf.data.Dataset.from_tensor_slices(
                                    self.dataset)

                                inputs = inputs. \
                                    apply(shuffle_and_repeat(self.dataset_num)). \
                                    apply(map_and_batch(image_class.image_processing, batch_size, num_parallel_batches=16, drop_remainder=True)). \
                                    apply(prefetch_to_device(gpu_device, None))
                                # When using dataset.prefetch, use buffer_size=None to let it detect optimal buffer size

                                inputs_iterator = inputs.make_one_shot_iterator(
                                )

                                real_img = inputs_iterator.get_next()
                                z = tf.random_normal(
                                    shape=[batch_size, self.z_dim])

                                fake_img = self.generator(z, alpha, res)
                                real_img = smooth_crossfade(real_img, alpha)

                                real_logit = self.discriminator(
                                    real_img, alpha, res)
                                fake_logit = self.discriminator(
                                    fake_img, alpha, res)

                                # compute loss
                                d_loss, g_loss = compute_loss(
                                    real_img, real_logit, fake_logit)

                                d_loss_per_gpu.append(d_loss)
                                g_loss_per_gpu.append(g_loss)
                                train_fake_images_per_gpu.append(fake_img)

                print("Create graph for {} resolution".format(res))

                # prepare appropriate training vars
                d_vars, g_vars = filter_trainable_variables(res)

                d_loss = tf.reduce_mean(d_loss_per_gpu)
                g_loss = tf.reduce_mean(g_loss_per_gpu)

                d_lr = self.d_learning_rates.get(res, self.learning_rate_base)
                g_lr = self.g_learning_rates.get(res, self.learning_rate_base)

                if self.gpu_num == 1:
                    colocate_grad = False
                else:
                    colocate_grad = True

                d_optim = tf.train.AdamOptimizer(
                    d_lr, beta1=0.0, beta2=0.99, epsilon=1e-8).minimize(
                        d_loss,
                        var_list=d_vars,
                        colocate_gradients_with_ops=colocate_grad)

                g_optim = tf.train.AdamOptimizer(
                    g_lr, beta1=0.0, beta2=0.99, epsilon=1e-8).minimize(
                        g_loss,
                        var_list=g_vars,
                        global_step=global_step,
                        colocate_gradients_with_ops=colocate_grad)

                self.discriminator_optim[res] = d_optim
                self.generator_optim[res] = g_optim

                self.d_loss_per_res[res] = d_loss
                self.g_loss_per_res[res] = g_loss

                self.train_fake_images[res] = tf.concat(
                    train_fake_images_per_gpu, axis=0)
                """ Summary """
                self.alpha_summary_per_res[res] = tf.summary.scalar(
                    "alpha_{}".format(res), alpha)

                self.d_summary_per_res[res] = tf.summary.scalar(
                    "d_loss_{}".format(res), self.d_loss_per_res[res])
                self.g_summary_per_res[res] = tf.summary.scalar(
                    "g_loss_{}".format(res), self.g_loss_per_res[res])

        else:
            """" Testing """
            test_z = tf.random_normal(shape=[self.batch_size, self.z_dim])
            alpha = tf.constant(0.0, dtype=tf.float32, shape=[])
            self.fake_images = self.generator(test_z,
                                              alpha=alpha,
                                              target_img_size=self.img_size,
                                              is_training=False)
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')
        """ Input Image"""
        img_class = Image_data(self.img_height, self.img_width, self.img_ch,
                               self.segmap_ch, self.dataset_path,
                               self.augment_flag)
        img_class.preprocess()

        self.dataset_num = len(img_class.image)
        self.test_dataset_num = len(img_class.segmap_test)

        img_and_segmap = tf.data.Dataset.from_tensor_slices(
            (img_class.image, img_class.segmap))
        segmap_test = tf.data.Dataset.from_tensor_slices(img_class.segmap_test)

        gpu_device = '/gpu:0'
        img_and_segmap = img_and_segmap.apply(
            shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(img_class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

        segmap_test = segmap_test.apply(shuffle_and_repeat(
            self.dataset_num)).apply(
                map_and_batch(img_class.test_image_processing,
                              batch_size=self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

        img_and_segmap_iterator = img_and_segmap.make_one_shot_iterator()
        segmap_test_iterator = segmap_test.make_one_shot_iterator()

        self.real_x, self.real_x_segmap, self.real_x_segmap_onehot = img_and_segmap_iterator.get_next(
        )
        self.real_x_segmap_test, self.real_x_segmap_test_onehot = segmap_test_iterator.get_next(
        )
        """ Define Generator, Discriminator """
        fake_x, x_mean, x_var = self.image_translate(
            segmap_img=self.real_x_segmap_onehot, x_img=self.real_x)
        real_logit, fake_logit = self.image_discriminate(
            segmap_img=self.real_x_segmap_onehot,
            real_img=self.real_x,
            fake_img=fake_x)

        if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
            GP = self.gradient_penalty(real=self.real_x,
                                       segmap=self.real_x_segmap_onehot,
                                       fake=fake_x)
        else:
            GP = 0
        """ Define Loss """
        g_adv_loss = self.adv_weight * generator_loss(self.gan_type,
                                                      fake_logit)
        g_kl_loss = self.kl_weight * kl_loss(x_mean, x_var)
        g_vgg_loss = self.vgg_weight * VGGLoss()(self.real_x, fake_x)
        g_feature_loss = self.feature_weight * feature_loss(
            real_logit, fake_logit)
        g_reg_loss = regularization_loss('generator') + regularization_loss(
            'encoder')

        d_adv_loss = self.adv_weight * (
            discriminator_loss(self.gan_type, real_logit, fake_logit) + GP)
        d_reg_loss = regularization_loss('discriminator')

        self.g_loss = g_adv_loss + g_kl_loss + g_vgg_loss + g_feature_loss + g_reg_loss
        self.d_loss = d_adv_loss + d_reg_loss
        """ Result Image """
        self.fake_x = fake_x
        self.random_fake_x, _, _ = self.image_translate(
            segmap_img=self.real_x_segmap_onehot,
            random_style=True,
            reuse=True)
        """ Test """
        self.test_segmap_image = tf.placeholder(tf.float32, [
            1, self.img_height, self.img_width,
            len(img_class.color_value_dict)
        ])
        self.random_test_fake_x, _, _ = self.image_translate(
            segmap_img=self.test_segmap_image, random_style=True, reuse=True)

        self.test_guide_image = tf.placeholder(
            tf.float32, [1, self.img_height, self.img_width, self.img_ch])
        self.guide_test_fake_x, _, _ = self.image_translate(
            segmap_img=self.test_segmap_image,
            x_img=self.test_guide_image,
            reuse=True)
        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [
            var for var in t_vars
            if 'encoder' in var.name or 'generator' in var.name
        ]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        if self.TTUR:
            beta1 = 0.0
            beta2 = 0.9

            g_lr = self.lr / 2
            d_lr = self.lr * 2

        else:
            beta1 = self.beta1
            beta2 = self.beta2
            g_lr = self.lr
            d_lr = self.lr

        self.G_optim = tf.train.AdamOptimizer(
            g_lr, beta1=beta1, beta2=beta2).minimize(self.g_loss,
                                                     var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(
            d_lr, beta1=beta1, beta2=beta2).minimize(self.d_loss,
                                                     var_list=D_vars)
        """" Summary """
        self.summary_g_loss = tf.summary.scalar("g_loss", self.g_loss)
        self.summary_d_loss = tf.summary.scalar("d_loss", self.d_loss)

        self.summary_g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
        self.summary_g_kl_loss = tf.summary.scalar("g_kl_loss", g_kl_loss)
        self.summary_g_vgg_loss = tf.summary.scalar("g_vgg_loss", g_vgg_loss)
        self.summary_g_feature_loss = tf.summary.scalar(
            "g_feature_loss", g_feature_loss)

        g_summary_list = [
            self.summary_g_loss, self.summary_g_adv_loss,
            self.summary_g_kl_loss, self.summary_g_vgg_loss,
            self.summary_g_feature_loss
        ]
        d_summary_list = [self.summary_d_loss]

        self.G_loss = tf.summary.merge(g_summary_list)
        self.D_loss = tf.summary.merge(d_summary_list)
示例#27
0
文件: model.py 项目: WYu-Feng/ISMC
    def build_model(self):
        if self.phase == 'train':
            #初始化步长
            self.lr = tf.placeholder(tf.float32, name='learning_rate')
            """ Input Image"""
            Image_Data_Class = ImageData(self.img_size, self.img_ch,
                                         self.augment_flag)
            trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
            trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
            #将图片地址切片
            #print('ok trainA:',trainA)

            gpu_device = '/gpu:0'
            trainA = trainA.apply(shuffle_and_repeat(100)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device, None))
            trainB = trainB.apply(shuffle_and_repeat(100)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device, None))
            #打乱原图片排列
            #map_and_batch 1:将tensor的嵌套结构映射到另一个tensor嵌套结构的函数
            #              2:要在此数据集合并的单个batch中的连续元素数(一个batch 32 个元素,即输出是32维的元素)
            #              3:要并行创建的batch数。一方面,较高的值可以帮助减轻落后者的影响。另一方面,如果CPU空闲,较高的值可能会增加竞争。
            #              4:表示是否应丢弃最后一个batch,以防其大小小于所需值
            #返回 32*64*64*3 的随机增强的数组
            #应用gpu加速
            #print('ok trainA:',trainA.shape)
            trainA_iterator = trainA.make_one_shot_iterator()
            trainB_iterator = trainB.make_one_shot_iterator()

            self.domain_A = trainA_iterator.get_next()
            self.domain_B = trainB_iterator.get_next()
            """ Define Generator, Discriminator """
            x_ab, cam_ab = self.generate_a2b(self.domain_A)  # real a
            #self.domain_A 是 卡通图片
            #x_ab是由self.domain_A经过下采样 再 上采样得到的图片
            x_ba, cam_ba = self.generate_b2a(self.domain_B)  # real b
            #generate_a2b和generate_b2a是两套不同的参数

            x_aba, _ = self.generate_b2a(x_ab, reuse=True)  # real b
            x_bab, _ = self.generate_a2b(x_ba, reuse=True)  # real a
            #固定参数不变,再将generate_a2b生成的图片,用generate_b2a生成一遍
            #generate_b2a 尝试 将真人图生成卡通图
            #generate_a2b 尝试 将卡通图生成真人图
            #可以看做将generate_a2b与generate_b2a作为逆变换

            x_aa, cam_aa = self.generate_b2a(self.domain_A,
                                             reuse=True)  # fake b
            x_bb, cam_bb = self.generate_a2b(self.domain_B,
                                             reuse=True)  # fake a
            #固定参数不变
            #***将卡通图生成卡通图
            #确保在generate_b2a和generate_a2b过程中,颜色区域是不变的
            real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(
                self.domain_A, self.domain_B)
            #鉴别 真卡通图 与 真真人图
            fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(
                x_ba, x_ab)
            #鉴别 假卡通图 与 假真人图
            #输入的是生成器生成的图片
            #输出的是图片经过卷积的32*8*8*1的张量和池化结果的连接(32, 2)
            """ Define Loss """
            if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
                GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A,
                                                       fake=x_ba,
                                                       scope="discriminator_A")
                GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B,
                                                       fake=x_ab,
                                                       scope="discriminator_B")
            else:
                GP_A, GP_CAM_A = 0, 0
                GP_B, GP_CAM_B = 0, 0

            #接下来是对于假图片判别器discriminate_fake 真图片判别器discriminate_real的损失计算
            #对判别器和生成器的损失计算
            '''
            一、对抗损失(T)  
                用的最小二乘损失
            '''
            #fake_A_logit[0],fake_A_logit[1]与1的平方差的和+fake_A_cam_logit[0],fake_A_cam_logit[1]与1的平方差的和
            #从生成器的角度看,真图片会被discriminate鉴别器判为0,生成的图片经过判别器(1-fake_*),也要尽量被判别为0
            G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) +
                           generator_loss(self.gan_type, fake_A_cam_logit))
            #生成假卡通图的损失
            G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) +
                           generator_loss(self.gan_type, fake_B_cam_logit))
            #生成假真人图的损失
            #生成器的损失

            D_ad_loss_A = (
                discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) +
                discriminator_loss(self.gan_type, real_A_cam_logit,
                                   fake_A_cam_logit) + GP_A + GP_CAM_A)
            #对卡通图的鉴别损失
            #discriminator_loss力求将真图片判断为1,假图片判断为0
            D_ad_loss_B = (
                discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) +
                discriminator_loss(self.gan_type, real_B_cam_logit,
                                   fake_B_cam_logit) + GP_B + GP_CAM_B)
            #鉴别器的损失
            '''
            二、Cycle 损失(T)
                 the image should be successfully translated back to the original domain
            '''
            reconstruction_A = L1_loss(x_aba, self.domain_A)  # reconstruction
            #由卡通图生成真人图再生成卡通图的损失
            reconstruction_B = L1_loss(x_bab, self.domain_B)  # reconstruction
            #由真人图生成卡通图再生成真人图的损失
            '''
            三、Identity 损失
                确保在A B相互变化时,身份信息不丢失
            '''
            identity_A = L1_loss(x_aa, self.domain_A)
            identity_B = L1_loss(x_bb, self.domain_B)
            #将卡通图生成卡通图的损失
            '''
            四、CAM 损失
                利用辅助分类器,使得G和D知道在哪里进行强化变换
                在A变B时,热力图应该有明显显示
                在A变A时,热力图应该没有显示
            '''
            cam_A = cam_loss(source=cam_ba, non_source=cam_aa)
            #cam_ba是从真人图到卡通图的全连接(两次,用不同方法池化)
            #cam_aa是从卡通图到卡通图的全连接
            cam_B = cam_loss(source=cam_ab, non_source=cam_bb)

            #开始时的比重是如何决定的???
            #
            Generator_A_gan = self.adv_weight * G_ad_loss_A
            #1
            #网络由真人图生成卡通图的损失*相对的比重
            Generator_A_cycle = self.cycle_weight * reconstruction_B
            #10
            #self.generate_a2b(self.generate_b2a(self.domain_B), reuse=True)
            #由真人图生成卡通图再生成真人图的损失*相对的比重
            Generator_A_identity = self.identity_weight * identity_A
            #10
            #由卡通图生成卡通图的损失*相对的比重
            Generator_A_cam = self.cam_weight * cam_A
            #1000
            #从真人图到卡通图的全连接*相对的比重

            Generator_B_gan = self.adv_weight * G_ad_loss_B
            Generator_B_cycle = self.cycle_weight * reconstruction_A
            Generator_B_identity = self.identity_weight * identity_B
            Generator_B_cam = self.cam_weight * cam_B
            print('ok 5')

            Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
            #所有生成卡通图的损失
            Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam

            Discriminator_A_loss = self.adv_weight * D_ad_loss_A
            #对生成的卡通图 和 真的卡通图的鉴别损失+生成的卡通图的全连接 和 真的卡通图的全连接的鉴别损失*权重
            Discriminator_B_loss = self.adv_weight * D_ad_loss_B

            self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss(
                'generator')
            #生成器的总损失
            self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss(
                'discriminator')
            #鉴别器的总损失
            print('55')
            """ Result Image """
            #生成的假图片(用于储存)
            self.fake_A = x_ba
            self.fake_B = x_ab

            #输入的真图片
            self.real_A = self.domain_A
            self.real_B = self.domain_B

            self.imgba = imgba
            self.imgab = imgab
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.G_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Generator_loss,
                                                      var_list=G_vars)
            self.D_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Discriminator_loss,
                                                      var_list=D_vars)
            #var_list:在优化时每次要迭代更新的参数集合
            """" Summary """
            # self.all_G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
            # self.all_D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)

            # self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
            # self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan)
            # self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle)
            # self.G_A_identity = tf.summary.scalar("G_A_identity", Generator_A_identity)
            # self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam)

            # self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
            # self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan)
            # self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle)
            # self.G_B_identity = tf.summary.scalar("G_B_identity", Generator_B_identity)
            # self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam)

            # self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
            # self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)
            '''
            画图
            '''
            # self.rho_var = []
            # for var in tf.trainable_variables():
            # if 'rho' in var.name:
            # self.rho_var.append(tf.summary.histogram(var.name, var))
            # self.rho_var.append(tf.summary.scalar(var.name + "_min", tf.reduce_min(var)))
            # self.rho_var.append(tf.summary.scalar(var.name + "_max", tf.reduce_max(var)))
            # self.rho_var.append(tf.summary.scalar(var.name + "_mean", tf.reduce_mean(var)))
            # print('ok 7')

            # g_summary_list = [self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity, self.G_A_cam,
            # self.G_B_loss, self.G_B_gan, self.G_B_cycle, self.G_B_identity, self.G_B_cam,
            # self.all_G_loss]

            # g_summary_list.extend(self.rho_var)
            # d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss]

            # self.G_loss = tf.summary.merge(g_summary_list)
            # self.D_loss = tf.summary.merge(d_summary_list)

        else:
            """ Test """
            self.test_domain_A = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_A')
            self.test_domain_B = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_B')

            self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
            self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
示例#28
0
    def build_model(self):
        if self.phase == 'train':
            self.lr = tf.placeholder(tf.float32, name='learning_rate')
            """ Input Image"""
            Image_Data_Class = ImageData(self.img_size, self.img_ch,
                                         self.augment_flag)

            trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
            trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)

            gpu_device = '/gpu:0'
            trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device, None))
            trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device, None))

            trainA_iterator = trainA.make_one_shot_iterator()
            trainB_iterator = trainB.make_one_shot_iterator()

            self.domain_A = trainA_iterator.get_next()
            self.domain_B = trainB_iterator.get_next()
            """ Define Generator, Discriminator """
            x_ab, cam_ab = self.generate_a2b(self.domain_A)  # real a
            x_ba, cam_ba = self.generate_b2a(self.domain_B)  # real b

            x_aba, _ = self.generate_b2a(x_ab, reuse=True)  # real b
            x_bab, _ = self.generate_a2b(x_ba, reuse=True)  # real a

            x_aa, cam_aa = self.generate_b2a(self.domain_A,
                                             reuse=True)  # fake b
            x_bb, cam_bb = self.generate_a2b(self.domain_B,
                                             reuse=True)  # fake a

            real_A_logit, real_A_cam_logit, real_B_logit, real_B_cam_logit = self.discriminate_real(
                self.domain_A, self.domain_B)
            fake_A_logit, fake_A_cam_logit, fake_B_logit, fake_B_cam_logit = self.discriminate_fake(
                x_ba, x_ab)
            """ Define Loss """
            if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
                GP_A, GP_CAM_A = self.gradient_panalty(real=self.domain_A,
                                                       fake=x_ba,
                                                       scope="discriminator_A")
                GP_B, GP_CAM_B = self.gradient_panalty(real=self.domain_B,
                                                       fake=x_ab,
                                                       scope="discriminator_B")
            else:
                GP_A, GP_CAM_A = 0, 0
                GP_B, GP_CAM_B = 0, 0

            G_ad_loss_A = (generator_loss(self.gan_type, fake_A_logit) +
                           generator_loss(self.gan_type, fake_A_cam_logit))
            G_ad_loss_B = (generator_loss(self.gan_type, fake_B_logit) +
                           generator_loss(self.gan_type, fake_B_cam_logit))

            D_ad_loss_A = (
                discriminator_loss(self.gan_type, real_A_logit, fake_A_logit) +
                discriminator_loss(self.gan_type, real_A_cam_logit,
                                   fake_A_cam_logit) + GP_A + GP_CAM_A)
            D_ad_loss_B = (
                discriminator_loss(self.gan_type, real_B_logit, fake_B_logit) +
                discriminator_loss(self.gan_type, real_B_cam_logit,
                                   fake_B_cam_logit) + GP_B + GP_CAM_B)

            reconstruction_A = L1_loss(x_aba, self.domain_A)  # reconstruction
            reconstruction_B = L1_loss(x_bab, self.domain_B)  # reconstruction

            identity_A = L1_loss(x_aa, self.domain_A)
            identity_B = L1_loss(x_bb, self.domain_B)

            cam_A = cam_loss(source=cam_ba, non_source=cam_aa)
            cam_B = cam_loss(source=cam_ab, non_source=cam_bb)

            Generator_A_gan = self.adv_weight * G_ad_loss_A
            Generator_A_cycle = self.cycle_weight * reconstruction_B
            Generator_A_identity = self.identity_weight * identity_A
            Generator_A_cam = self.cam_weight * cam_A

            Generator_B_gan = self.adv_weight * G_ad_loss_B
            Generator_B_cycle = self.cycle_weight * reconstruction_A
            Generator_B_identity = self.identity_weight * identity_B
            Generator_B_cam = self.cam_weight * cam_B

            Generator_A_loss = Generator_A_gan + Generator_A_cycle + Generator_A_identity + Generator_A_cam
            Generator_B_loss = Generator_B_gan + Generator_B_cycle + Generator_B_identity + Generator_B_cam

            Discriminator_A_loss = self.adv_weight * D_ad_loss_A
            Discriminator_B_loss = self.adv_weight * D_ad_loss_B

            self.Generator_loss = Generator_A_loss + Generator_B_loss + regularization_loss(
                'generator')
            self.Discriminator_loss = Discriminator_A_loss + Discriminator_B_loss + regularization_loss(
                'discriminator')
            """ Result Image """
            self.fake_A = x_ba
            self.fake_B = x_ab

            self.real_A = self.domain_A
            self.real_B = self.domain_B
            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            self.G_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Generator_loss,
                                                      var_list=G_vars)
            self.D_optim = tf.train.AdamOptimizer(self.lr,
                                                  beta1=0.5,
                                                  beta2=0.999).minimize(
                                                      self.Discriminator_loss,
                                                      var_list=D_vars)
            """" Summary """
            self.all_G_loss = tf.summary.scalar("Generator_loss",
                                                self.Generator_loss)
            self.all_D_loss = tf.summary.scalar("Discriminator_loss",
                                                self.Discriminator_loss)

            self.G_A_loss = tf.summary.scalar("G_A_loss", Generator_A_loss)
            self.G_A_gan = tf.summary.scalar("G_A_gan", Generator_A_gan)
            self.G_A_cycle = tf.summary.scalar("G_A_cycle", Generator_A_cycle)
            self.G_A_identity = tf.summary.scalar("G_A_identity",
                                                  Generator_A_identity)
            self.G_A_cam = tf.summary.scalar("G_A_cam", Generator_A_cam)

            self.G_B_loss = tf.summary.scalar("G_B_loss", Generator_B_loss)
            self.G_B_gan = tf.summary.scalar("G_B_gan", Generator_B_gan)
            self.G_B_cycle = tf.summary.scalar("G_B_cycle", Generator_B_cycle)
            self.G_B_identity = tf.summary.scalar("G_B_identity",
                                                  Generator_B_identity)
            self.G_B_cam = tf.summary.scalar("G_B_cam", Generator_B_cam)

            self.D_A_loss = tf.summary.scalar("D_A_loss", Discriminator_A_loss)
            self.D_B_loss = tf.summary.scalar("D_B_loss", Discriminator_B_loss)

            self.rho_var = []
            for var in tf.trainable_variables():
                if 'rho' in var.name:
                    self.rho_var.append(tf.summary.histogram(var.name, var))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_min",
                                          tf.reduce_min(var)))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_max",
                                          tf.reduce_max(var)))
                    self.rho_var.append(
                        tf.summary.scalar(var.name + "_mean",
                                          tf.reduce_mean(var)))

            g_summary_list = [
                self.G_A_loss, self.G_A_gan, self.G_A_cycle, self.G_A_identity,
                self.G_A_cam, self.G_B_loss, self.G_B_gan, self.G_B_cycle,
                self.G_B_identity, self.G_B_cam, self.all_G_loss
            ]

            g_summary_list.extend(self.rho_var)
            d_summary_list = [self.D_A_loss, self.D_B_loss, self.all_D_loss]

            self.G_loss = tf.summary.merge(g_summary_list)
            self.D_loss = tf.summary.merge(d_summary_list)

        else:
            """ Test """
            self.test_domain_A = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_A')
            self.test_domain_B = tf.placeholder(
                tf.float32, [1, self.img_size, self.img_size, self.img_ch],
                name='test_domain_B')

            self.test_fake_B, _ = self.generate_a2b(self.test_domain_A)
            self.test_fake_A, _ = self.generate_b2a(self.test_domain_B)
示例#29
0
def _initialize_tf_dataset(file_names,
                           augment,
                           over_sample,
                           shuffle,
                           target_size,
                           normalize,
                           nb_workers=8,
                           batch_size=64,
                           shuffle_buffer_size=3000,
                           input_type="img",
                           nr_epochs=1):
  import tensorflow as tf
  from tensorflow.contrib.data import map_and_batch
  from tensorflow.contrib.data import shuffle_and_repeat

  if not type(target_size) is list:
    target_size = list(target_size)

  with tf.name_scope('input_pipeline'):
    dataset = tf.data.TFRecordDataset(file_names)
    if shuffle:
      dataset = dataset.apply(
          shuffle_and_repeat(shuffle_buffer_size, nr_epochs))

    def _decode_and_augment_image(example_proto):
      keys_to_features = {
          'label': tf.FixedLenFeature([], tf.int64),
          'shape': tf.FixedLenFeature([], tf.string),
          'image': tf.FixedLenFeature([], tf.string),
      }
      tfrecord_features = tf.parse_single_example(example_proto,
                                                  keys_to_features)

      image = tf.decode_raw(tfrecord_features['image'], tf.uint8)
      shape = tf.decode_raw(tfrecord_features['shape'], tf.int64)
      if input_type == ".jpeg":
        image = tf.reshape(image, target_size + [3])
      else:
        image = tf.reshape(image, target_size)
      label = tfrecord_features['label']

      if augment:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        degrees = tf.random_uniform((), minval=-180, maxval=180)
        image = tf.contrib.image.rotate(image, degrees)

        width_shift = tf.random_uniform((), minval=0, maxval=0.05)
        height_shift = tf.random_uniform((), minval=0, maxval=0.05)

        horizontal_pad = tf.cast(
            tf.ceil(width_shift * target_size[0]), tf.int32)
        vertical_pad = tf.cast(tf.ceil(height_shift * target_size[1]), tf.int32)

        padding = tf.stack([
            horizontal_pad, horizontal_pad, vertical_pad, vertical_pad,
            tf.constant(0),
            tf.constant(0)
        ])
        padding = tf.reshape(padding, (3, 2))

        image = tf.pad(image, padding)
        image = tf.random_crop(image, target_size + [3])

        zoom = tf.random_uniform((), minval=-0.1, maxval=0.1)
        new_dim = tf.cast(tf.ceil((1 - zoom) * target_size[0]), dtype=tf.int32)

        image = tf.image.resize_image_with_crop_or_pad(image, new_dim, new_dim)

        image = tf.image.resize_images(
            image, target_size, method=tf.image.ResizeMethod.BILINEAR)

      if normalize:
        std = tf.constant(
            np.array([70.53946096, 51.71475228, 43.03428563]), dtype=tf.float32)
        std = tf.expand_dims(tf.expand_dims(std, axis=0), axis=0)

        mean = tf.constant(
            np.array([108.64628601, 75.86886597, 54.34005736]),
            dtype=tf.float32)
        mean = tf.expand_dims(tf.expand_dims(mean, axis=0), axis=0)

        image = (tf.cast(image, dtype=tf.float32) - mean) / std

      label = tf.reshape(label, [1])
      if input_type == ".jpeg":
        image = tf.reshape(image, target_size + [3])
      else:
        image = tf.reshape(image, target_size)

      return {'shape': shape, 'image': image}, label

    dataset = dataset \
        .apply(map_and_batch(_decode_and_augment_image, batch_size=batch_size, num_parallel_batches=nb_workers,
                             drop_remainder=True)) \
        .prefetch(nb_workers)

    # def _augment_images(example)

    return dataset
示例#30
0
    def build_model(self):

        self.ema = tf.train.ExponentialMovingAverage(decay=self.ema_decay)

        if self.phase == 'train' :
            """ Input Image"""
            img_class = Image_data(self.img_height, self.img_width, self.img_ch, self.dataset_path, self.label_list,
                                   self.augment_flag)
            img_class.preprocess()

            dataset_num = len(img_class.image)
            print("Dataset number : ", dataset_num)

            self.lr = tf.placeholder(tf.float32, name='learning_rate')
            self.ds_weight_placeholder = tf.placeholder(tf.float32, name='ds_weight')


            img_and_label = tf.data.Dataset.from_tensor_slices((img_class.image, img_class.label))

            gpu_device = '/gpu:0'
            img_and_label = img_and_label.apply(shuffle_and_repeat(dataset_num)).apply(
                map_and_batch(img_class.image_processing, self.batch_size * self.gpu_num, num_parallel_batches=16,
                              drop_remainder=True)).apply(prefetch_to_device(gpu_device, None))

            img_and_label_iterator = img_and_label.make_one_shot_iterator()

            self.x_real, label_org = img_and_label_iterator.get_next() # [bs, 256, 256, 3], [bs, 1]
            # label_trg = tf.random_shuffle(label_org)  # Target domain labels
            label_trg = tf.random_uniform(shape=tf.shape(label_org), minval=0, maxval=self.c_dim, dtype=tf.int32) # Target domain labels

            """ split """
            x_real_gpu_split = tf.split(self.x_real, num_or_size_splits=self.gpu_num, axis=0)
            label_org_gpu_split = tf.split(label_org, num_or_size_splits=self.gpu_num, axis=0)
            label_trg_gpu_split = tf.split(label_trg, num_or_size_splits=self.gpu_num, axis=0)

            g_adv_loss_per_gpu = []
            g_sty_recon_loss_per_gpu = []
            g_sty_diverse_loss_per_gpu = []
            g_cyc_loss_per_gpu = []
            g_loss_per_gpu = []

            d_adv_loss_per_gpu = []
            d_loss_per_gpu = []

            for gpu_id in range(self.gpu_num):
                with tf.device(tf.DeviceSpec(device_type="GPU", device_index=gpu_id)):
                    with tf.variable_scope(tf.get_variable_scope(), reuse=(gpu_id > 0)):

                        x_real_split = tf.split(x_real_gpu_split[gpu_id], num_or_size_splits=self.batch_size, axis=0)
                        label_org_split = tf.split(label_org_gpu_split[gpu_id], num_or_size_splits=self.batch_size, axis=0)
                        label_trg_split = tf.split(label_trg_gpu_split[gpu_id], num_or_size_splits=self.batch_size, axis=0)

                        g_adv_loss = None
                        g_sty_recon_loss = None
                        g_sty_diverse_loss = None
                        g_cyc_loss = None

                        d_adv_loss = None
                        d_simple_gp = None
                        d_gp = None

                        for each_bs in range(self.batch_size) :
                            """ Define Generator, Discriminator """
                            x_real_each = x_real_split[each_bs] # [1, 256, 256, 3]
                            label_org_each = tf.squeeze(label_org_split[each_bs], axis=[0, 1]) # [1, 1] -> []
                            label_trg_each = tf.squeeze(label_trg_split[each_bs], axis=[0, 1])

                            random_style_code = tf.random_normal(shape=[1, self.style_dim])
                            random_style_code_1 = tf.random_normal(shape=[1, self.style_dim])
                            random_style_code_2 = tf.random_normal(shape=[1, self.style_dim])

                            random_style = tf.gather(self.mapping_network(random_style_code), label_trg_each)
                            random_style_1 = tf.gather(self.mapping_network(random_style_code_1), label_trg_each)
                            random_style_2 = tf.gather(self.mapping_network(random_style_code_2), label_trg_each)

                            x_fake = self.generator(x_real_each, random_style) # for adversarial objective
                            x_fake_1 = self.generator(x_real_each, random_style_1) # for style diversification
                            x_fake_2 = self.generator(x_real_each, random_style_2) # for style diversification

                            x_real_each_style = tf.gather(self.style_encoder(x_real_each), label_org_each) # for cycle consistency
                            x_fake_style = tf.gather(self.style_encoder(x_fake), label_trg_each) # for style reconstruction

                            x_cycle = self.generator(x_fake, x_real_each_style) # for cycle consistency

                            real_logit = tf.gather(self.discriminator(x_real_each), label_org_each)
                            fake_logit = tf.gather(self.discriminator(x_fake), label_trg_each)

                            """ Define loss """
                            if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan':
                                GP = self.gradient_panalty(real=x_real_each, fake=x_fake, real_label=label_org_each)
                            else:
                                GP = tf.constant([0], tf.float32)

                            if each_bs == 0 :
                                g_adv_loss = self.adv_weight * generator_loss(self.gan_type, fake_logit)
                                g_sty_recon_loss = self.sty_weight * L1_loss(random_style, x_fake_style)
                                g_sty_diverse_loss = self.ds_weight_placeholder * L1_loss(x_fake_1, x_fake_2)
                                g_cyc_loss = self.cyc_weight * L1_loss(x_real_each, x_cycle)

                                d_adv_loss = self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit)
                                d_simple_gp = self.adv_weight * simple_gp(real_logit, fake_logit, x_real_each, x_fake, r1_gamma=self.r1_weight, r2_gamma=0.0)
                                d_gp = self.adv_weight * GP

                            else :
                                g_adv_loss = tf.concat([g_adv_loss, self.adv_weight * generator_loss(self.gan_type, fake_logit)], axis=0)
                                g_sty_recon_loss = tf.concat([g_sty_recon_loss, self.sty_weight * L1_loss(random_style, x_fake_style)], axis=0)
                                g_sty_diverse_loss = tf.concat([g_sty_diverse_loss, self.ds_weight_placeholder * L1_loss(x_fake_1, x_fake_2)], axis=0)
                                g_cyc_loss = tf.concat([g_cyc_loss, self.cyc_weight * L1_loss(x_real_each, x_cycle)], axis=0)

                                d_adv_loss = tf.concat([d_adv_loss, self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit)], axis=0)
                                d_simple_gp = tf.concat([d_simple_gp, self.adv_weight * simple_gp(real_logit, fake_logit, x_real_each, x_fake, r1_gamma=self.r1_weight, r2_gamma=0.0)], axis=0)
                                d_gp = tf.concat([d_gp, self.adv_weight * GP], axis=0)


                        g_adv_loss = tf.reduce_mean(g_adv_loss)
                        g_sty_recon_loss = tf.reduce_mean(g_sty_recon_loss)
                        g_sty_diverse_loss = tf.reduce_mean(g_sty_diverse_loss)
                        g_cyc_loss = tf.reduce_mean(g_cyc_loss)

                        d_adv_loss = tf.reduce_mean(d_adv_loss)
                        d_simple_gp = tf.reduce_mean(tf.reduce_sum(d_simple_gp, axis=[1, 2, 3]))
                        d_gp = tf.reduce_mean(d_gp)

                        g_loss = g_adv_loss + g_sty_recon_loss - g_sty_diverse_loss + g_cyc_loss
                        d_loss = d_adv_loss + d_simple_gp + d_gp

                        g_adv_loss_per_gpu.append(g_adv_loss)
                        g_sty_recon_loss_per_gpu.append(g_sty_recon_loss)
                        g_sty_diverse_loss_per_gpu.append(g_sty_diverse_loss)
                        g_cyc_loss_per_gpu.append(g_cyc_loss)

                        d_adv_loss_per_gpu.append(d_adv_loss)

                        g_loss_per_gpu.append(g_loss)
                        d_loss_per_gpu.append(d_loss)

            g_adv_loss = tf.reduce_mean(g_adv_loss_per_gpu)
            g_sty_recon_loss = tf.reduce_mean(g_sty_recon_loss_per_gpu)
            g_sty_diverse_loss = tf.reduce_mean(g_sty_diverse_loss_per_gpu)
            g_cyc_loss = tf.reduce_mean(g_cyc_loss_per_gpu)
            self.g_loss = tf.reduce_mean(g_loss_per_gpu)

            d_adv_loss = tf.reduce_mean(d_adv_loss_per_gpu)
            self.d_loss = tf.reduce_mean(d_loss_per_gpu)


            """ Training """
            t_vars = tf.trainable_variables()
            G_vars = [var for var in t_vars if 'generator' in var.name]
            E_vars = [var for var in t_vars if 'encoder' in var.name]
            F_vars = [var for var in t_vars if 'mapping' in var.name]
            D_vars = [var for var in t_vars if 'discriminator' in var.name]

            if self.gpu_num == 1 :
                prev_g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=G_vars)
                prev_e_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=E_vars)
                prev_f_optimizer = tf.train.AdamOptimizer(self.lr * 0.01, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=F_vars)

                self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.d_loss, var_list=D_vars)

            else :
                prev_g_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.g_loss, var_list=G_vars,
                                                                                                 colocate_gradients_with_ops=True)
                prev_e_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.g_loss,
                                                                                                 var_list=E_vars,
                                                                                                 colocate_gradients_with_ops=True)
                prev_f_optimizer = tf.train.AdamOptimizer(self.lr * 0.01, beta1=0, beta2=0.99).minimize(self.g_loss,
                                                                                                        var_list=F_vars,
                                                                                                        colocate_gradients_with_ops=True)

                self.d_optimizer = tf.train.AdamOptimizer(self.lr, beta1=0, beta2=0.99).minimize(self.d_loss,
                                                                                                 var_list=D_vars,
                                                                                                 colocate_gradients_with_ops=True)

            with tf.control_dependencies([prev_g_optimizer, prev_e_optimizer, prev_f_optimizer]):
                self.g_optimizer = self.ema.apply(G_vars)
                self.e_optimizer = self.ema.apply(E_vars)
                self.f_optimizer = self.ema.apply(F_vars)

            """" Summary """
            self.Generator_loss = tf.summary.scalar("g_loss", self.g_loss)
            self.Discriminator_loss = tf.summary.scalar("d_loss", self.d_loss)

            self.g_adv_loss = tf.summary.scalar("g_adv_loss", g_adv_loss)
            self.g_sty_recon_loss = tf.summary.scalar("g_sty_recon_loss", g_sty_recon_loss)
            self.g_sty_diverse_loss = tf.summary.scalar("g_sty_diverse_loss", g_sty_diverse_loss)
            self.g_cyc_loss = tf.summary.scalar("g_cyc_loss", g_cyc_loss)

            self.d_adv_loss = tf.summary.scalar("d_adv_loss", d_adv_loss)

            g_summary_list = [self.Generator_loss, self.g_adv_loss, self.g_sty_recon_loss, self.g_sty_diverse_loss, self.g_cyc_loss]
            d_summary_list = [self.Discriminator_loss, self.d_adv_loss]

            self.g_summary_loss = tf.summary.merge(g_summary_list)
            self.d_summary_loss = tf.summary.merge(d_summary_list)

            """ Result Image """
            def return_g_images(generator, image, code):
                x = generator(image, code)
                return x

            self.x_fake_list = []
            first_x_real = tf.expand_dims(self.x_real[0], axis=0)

            label_fix_list = tf.constant([idx for idx in range(self.c_dim)])

            for _ in range(self.num_style):
                random_style_code = tf.truncated_normal(shape=[1, self.style_dim])
                self.x_fake_list.append(tf.map_fn(
                    lambda c: return_g_images(self.generator,
                                              first_x_real,
                                              tf.gather(self.mapping_network(random_style_code), c)),
                    label_fix_list, dtype=tf.float32))

        elif self.phase == 'refer_test':
            """ Test """

            def return_g_images(generator, image, code):
                x = generator(image, code)
                return x

            self.custom_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image')
            self.refer_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='refer_image')


            label_fix_list = tf.constant([idx for idx in range(self.c_dim)])

            self.refer_fake_image = tf.map_fn(
                lambda c : return_g_images(self.generator,
                                           self.custom_image,
                                           tf.gather(self.style_encoder(self.refer_image), c)),
                label_fix_list, dtype=tf.float32)

        else :
            """ Test """

            def return_g_images(generator, image, code):
                x = generator(image, code)
                return x

            self.custom_image = tf.placeholder(tf.float32, [1, self.img_height, self.img_width, self.img_ch], name='custom_image')
            label_fix_list = tf.constant([idx for idx in range(self.c_dim)])

            random_style_code = tf.truncated_normal(shape=[1, self.style_dim])
            self.custom_fake_image = tf.map_fn(
                lambda c : return_g_images(self.generator,
                                           self.custom_image,
                                           tf.gather(self.mapping_network(random_style_code), c)),
                label_fix_list, dtype=tf.float32)
    def build_model(self):
        self.lr = tf.placeholder(tf.float32, name='learning_rate')


        """ Input Image"""
        Image_Data_Class = ImageData(self.img_size, self.img_ch, self.augment_flag)

        trainA = tf.data.Dataset.from_tensor_slices(self.trainA_dataset)
        trainB = tf.data.Dataset.from_tensor_slices(self.trainB_dataset)
        trainB_smooth = tf.data.Dataset.from_tensor_slices(self.trainB_smooth_dataset)

        gpu_device = '/gpu:0'
        trainA = trainA.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB = trainB.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
        trainB_smooth = trainB_smooth.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))

        trainA_iterator = trainA.make_one_shot_iterator()
        trainB_iterator = trainB.make_one_shot_iterator()
        trainB_smooth_iterator = trainB_smooth.make_one_shot_iterator()


        self.real_A = trainA_iterator.get_next()
        self.real_B = trainB_iterator.get_next()
        self.real_B_smooth = trainB_smooth_iterator.get_next()

        self.test_real_A = tf.placeholder(tf.float32, [1, self.img_size, self.img_size, self.img_ch], name='test_real_A')


        """ Define Generator, Discriminator """
        self.fake_B = self.generator(self.real_A)

        real_B_logit = self.discriminator(self.real_B)
        fake_B_logit = self.discriminator(self.fake_B, reuse=True)
        real_B_smooth_logit = self.discriminator(self.real_B_smooth, reuse=True)


        """ Define Loss """
        if self.gan_type.__contains__('gp') or self.gan_type.__contains__('lp') or self.gan_type.__contains__('dragan') :
            GP = self.gradient_panalty(real=self.real_B, fake=self.fake_B) + self.gradient_panalty(self.real_B, fake=self.real_B_smooth)
        else :
            GP = 0.0

        v_loss = self.vgg_weight * vgg_loss(self.real_A, self.fake_B)
        g_loss = self.adv_weight * generator_loss(self.gan_type, fake_B_logit)
        d_loss = self.adv_weight * discriminator_loss(self.gan_type, real_B_logit, fake_B_logit, real_B_smooth_logit) + GP

        self.Vgg_loss = v_loss
        self.Generator_loss = g_loss + v_loss
        self.Discriminator_loss = d_loss


        """ Result Image """
        self.test_fake_B = self.generator(self.test_real_A, reuse=True)

        """ Training """
        t_vars = tf.trainable_variables()
        G_vars = [var for var in t_vars if 'generator' in var.name]
        D_vars = [var for var in t_vars if 'discriminator' in var.name]

        self.init_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Vgg_loss, var_list=G_vars)
        self.G_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Generator_loss, var_list=G_vars)
        self.D_optim = tf.train.AdamOptimizer(self.lr, beta1=0.5, beta2=0.999).minimize(self.Discriminator_loss, var_list=D_vars)


        """" Summary """
        self.G_loss = tf.summary.scalar("Generator_loss", self.Generator_loss)
        self.D_loss = tf.summary.scalar("Discriminator_loss", self.Discriminator_loss)

        self.G_gan = tf.summary.scalar("G_gan", g_loss)
        self.G_vgg = tf.summary.scalar("G_vgg", v_loss)

        self.V_loss_merge = tf.summary.merge([self.G_vgg])
        self.G_loss_merge = tf.summary.merge([self.G_loss, self.G_gan, self.G_vgg])
        self.D_loss_merge = tf.summary.merge([self.D_loss])
    def build_model(self):
        """ Graph Input """
        # images
        if self.custom_dataset:
            Image_Data_Class = ImageData(self.img_size, self.c_dim)
            inputs = tf.data.Dataset.from_tensor_slices(self.data)

            gpu_device = '/gpu:0'
            inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(
                map_and_batch(Image_Data_Class.image_processing,
                              self.batch_size,
                              num_parallel_batches=16,
                              drop_remainder=True)).apply(
                                  prefetch_to_device(gpu_device,
                                                     self.batch_size))

            inputs_iterator = inputs.make_one_shot_iterator()

            self.inputs = inputs_iterator.get_next()

        else:
            self.inputs = tf.placeholder(
                tf.float32,
                [self.batch_size, self.img_size, self.img_size, self.c_dim],
                name='real_images')

        # noises
        self.z = tf.placeholder(tf.float32,
                                [self.batch_size, 1, 1, self.z_dim],
                                name='z')
        """ Loss Function """
        # output of D for real images
        real_logits = self.discriminator(self.inputs)

        # output of D for fake images
        fake_images = self.generator(self.z)
        fake_logits = self.discriminator(fake_images, reuse=True)

        if self.gan_type.__contains__('gp') or self.gan_type.__contains__(
                'lp') or self.gan_type.__contains__('dragan'):
            GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
        else:
            GP = 0

        # get loss for discriminator
        self.d_loss = discriminator_loss(
            self.Ra, self.gan_type, real=real_logits, fake=fake_logits) + GP

        # get loss for generator
        self.g_loss = generator_loss(self.Ra,
                                     self.gan_type,
                                     real=real_logits,
                                     fake=fake_logits)
        """ Training """
        # divide trainable variables into a group for D and a group for G
        t_vars = tf.trainable_variables()
        d_vars = [var for var in t_vars if 'discriminator' in var.name]
        g_vars = [var for var in t_vars if 'generator' in var.name]

        # optimizers
        self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                  self.d_loss, var_list=d_vars)
        self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate,
                                              beta1=self.beta1,
                                              beta2=self.beta2).minimize(
                                                  self.g_loss, var_list=g_vars)
        """" Testing """
        # for test
        self.fake_images = self.generator(self.z,
                                          is_training=False,
                                          reuse=True)
        """ Summary """
        self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
        self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
示例#33
0
    def build_model(self):

        if self.phase == 'train':
            self.lr = tf.placeholder(tf.float32, name='learning_rate')
            """ Input Image"""
            img_data_class = Image_data(self.img_height, self.img_width,
                                        self.img_ch, self.dataset_path,
                                        self.augment_flag)
            img_data_class.preprocess()

            self.dataset_num = len(img_data_class.image_list)

            img_and_embedding = tf.data.Dataset.from_tensor_slices(
                (img_data_class.image_list, img_data_class.embedding))

            gpu_device = '/gpu:0'
            img_and_embedding = img_and_embedding.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                    map_and_batch(img_data_class.image_processing,
                                  batch_size=self.batch_size,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            img_and_embedding_iterator = img_and_embedding.make_one_shot_iterator(
            )

            self.real_img_256, self.embedding = img_and_embedding_iterator.get_next(
            )
            sentence_index = tf.random.uniform(shape=[],
                                               minval=0,
                                               maxval=10,
                                               dtype=tf.int32)
            self.embedding = tf.gather(self.embedding,
                                       indices=sentence_index,
                                       axis=1)  #[bs, 1024]

            noise = tf.random_normal(shape=[self.batch_size, self.z_dim])
            self.fake_img_64, mu_64, logvar_64 = self.generator_1(
                self.embedding, noise)
            self.fake_img_256, mu_256, logvar_256 = self.generator_2(
                self.fake_img_64, self.embedding)
            self.real_img_64 = tf.image.resize_bilinear(self.real_img_256,
                                                        size=[64, 64])

            self.real_img = [self.real_img_64, self.real_img_256]
            self.fake_img = [self.fake_img_64, self.fake_img_256]

            real_logit_64 = self.discriminator_1(self.real_img_64, mu_64)
            fake_logit_64 = self.discriminator_1(self.fake_img_64, mu_64)

            real_logit_256 = self.discriminator_2(self.real_img_256, mu_256)
            fake_logit_256 = self.discriminator_2(self.fake_img_256, mu_256)

            g_adv_loss_64 = generator_loss(self.gan_type,
                                           fake_logit_64) * self.adv_weight
            g_kl_loss_64 = kl_loss(mu_64, logvar_64) * self.kl_weight

            d_adv_loss_64 = discriminator_loss(self.gan_type, real_logit_64,
                                               fake_logit_64) * self.adv_weight

            g_loss_64 = g_adv_loss_64 + g_kl_loss_64
            d_loss_64 = d_adv_loss_64

            g_adv_loss_256 = generator_loss(self.gan_type,
                                            fake_logit_256) * self.adv_weight
            g_kl_loss_256 = kl_loss(mu_256, logvar_256) * self.kl_weight

            d_adv_loss_256 = discriminator_loss(
                self.gan_type, real_logit_256,
                fake_logit_256) * self.adv_weight

            g_loss_256 = g_adv_loss_256 + g_kl_loss_256
            d_loss_256 = d_adv_loss_256

            self.g_loss = [g_loss_64, g_loss_256]
            self.d_loss = [d_loss_64, d_loss_256]
            """ Training """
            t_vars = tf.trainable_variables()
            G1_vars = [var for var in t_vars if 'generator_1' in var.name]
            G2_vars = [var for var in t_vars if 'generator_2' in var.name]
            D1_vars = [var for var in t_vars if 'discriminator_1' in var.name]
            D2_vars = [var for var in t_vars if 'discriminator_2' in var.name]

            g1_optim = tf.train.AdamOptimizer(
                self.lr, beta1=0.5, beta2=0.999).minimize(g_loss_64,
                                                          var_list=G1_vars)
            g2_optim = tf.train.AdamOptimizer(
                self.lr, beta1=0.5, beta2=0.999).minimize(g_loss_256,
                                                          var_list=G2_vars)

            d1_optim = tf.train.AdamOptimizer(
                self.lr, beta1=0.5, beta2=0.999).minimize(d_loss_64,
                                                          var_list=D1_vars)
            d2_optim = tf.train.AdamOptimizer(
                self.lr, beta1=0.5, beta2=0.999).minimize(d_loss_256,
                                                          var_list=D2_vars)

            self.g_optim = [g1_optim, g2_optim]
            self.d_optim = [d1_optim, d2_optim]
            """" Summary """
            self.summary_g_loss_64 = tf.summary.scalar("g_loss_64", g_loss_64)
            self.summary_g_loss_256 = tf.summary.scalar(
                "g_loss_256", g_loss_256)
            self.summary_d_loss_64 = tf.summary.scalar("d_loss_64", d_loss_64)
            self.summary_d_loss_256 = tf.summary.scalar(
                "d_loss_256", d_loss_256)

            self.summary_g_adv_loss_64 = tf.summary.scalar(
                "g_adv_loss_64", g_adv_loss_64)
            self.summary_g_adv_loss_256 = tf.summary.scalar(
                "g_adv_loss_256", g_adv_loss_256)
            self.summary_g_kl_loss_64 = tf.summary.scalar(
                "g_kl_loss_64", g_kl_loss_64)
            self.summary_g_kl_loss_256 = tf.summary.scalar(
                "g_kl_loss_256", g_kl_loss_256)

            self.summary_d_adv_loss_64 = tf.summary.scalar(
                "d_adv_loss_64", d_adv_loss_64)
            self.summary_d_adv_loss_256 = tf.summary.scalar(
                "d_adv_loss_256", d_adv_loss_256)

            g_summary_list = [
                self.summary_g_loss_64, self.summary_g_loss_256,
                self.summary_g_adv_loss_64, self.summary_g_adv_loss_256,
                self.summary_g_kl_loss_64, self.summary_g_kl_loss_256
            ]

            d_summary_list = [
                self.summary_d_loss_64, self.summary_d_loss_256,
                self.summary_d_adv_loss_64, self.summary_d_adv_loss_256
            ]

            self.summary_merge_g_loss = tf.summary.merge(g_summary_list)
            self.summary_merge_d_loss = tf.summary.merge(d_summary_list)

        else:
            """ Test """
            """ Input Image"""
            img_data_class = Image_data(self.img_height,
                                        self.img_width,
                                        self.img_ch,
                                        self.dataset_path,
                                        augment_flag=False)
            img_data_class.preprocess()

            self.dataset_num = len(img_data_class.image_list)

            img_and_embedding = tf.data.Dataset.from_tensor_slices(
                (img_data_class.image_list, img_data_class.embedding))

            gpu_device = '/gpu:0'
            img_and_embedding = img_and_embedding.apply(
                shuffle_and_repeat(self.dataset_num)).apply(
                    map_and_batch(img_data_class.image_processing,
                                  batch_size=32,
                                  num_parallel_batches=16,
                                  drop_remainder=True)).apply(
                                      prefetch_to_device(gpu_device, None))

            img_and_embedding_iterator = img_and_embedding.make_one_shot_iterator(
            )

            self.real_img_256, self.embedding = img_and_embedding_iterator.get_next(
            )
            sentence_index = tf.random.uniform(shape=[],
                                               minval=0,
                                               maxval=10,
                                               dtype=tf.int32)
            self.embedding = tf.gather(self.embedding,
                                       indices=sentence_index,
                                       axis=1)  # [bs, 1024]

            noise = tf.random_normal(shape=[self.batch_size, self.z_dim])
            self.fake_img_64, mu_64, logvar_64 = self.generator_1(
                self.embedding, noise, is_training=False)
            self.fake_img_256, mu_256, logvar_256 = self.generator_2(
                self.fake_img_64, self.embedding, is_training=False)

            self.test_fake_img = self.fake_img_256
            self.test_real_img = self.real_img_256
            self.test_fake_img_64 = self.fake_img_64