Beispiel #1
0
    def __init__(self, opt):
        self.opt = opt

        self.genA2B = Generator(opt)
        self.genB2A = Generator(opt)

        if opt.training:
            self.discA = Discriminator(opt)
            self.discB = Discriminator(opt)
            self.learning_rate = tf.contrib.eager.Variable(
                opt.lr, dtype=tf.float32, name='learning_rate')
            self.disc_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                     beta1=opt.beta1)
            self.gen_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                    beta1=opt.beta1)
            self.global_step = tf.train.get_or_create_global_step()
            # Initialize history buffers:
            self.discA_buffer = ImageHistoryBuffer(opt)
            self.discB_buffer = ImageHistoryBuffer(opt)
        # Restore latest checkpoint:
        self.initialize_checkpoint()
        if not opt.training or opt.load_checkpoint:
            self.restore_checkpoint()
Beispiel #2
0
def adversarial_training(synthetic_generator,
                         real_generator,
                         refiner_model_path=None,
                         discriminator_model_path=None):
    """Adversarial training of refiner network Rθ and discriminator network Dφ."""
    #
    # define model input and output tensors
    #

    synthetic_image_tensor = layers.Input(shape=(img_height, img_width,
                                                 img_channels))
    refined_image_tensor = refiner_network(synthetic_image_tensor)

    refined_or_real_image_tensor = layers.Input(shape=(img_height, img_width,
                                                       img_channels))
    discriminator_output = discriminator_network(refined_or_real_image_tensor)

    #
    # define models
    #

    refiner_model = models.Model(input=synthetic_image_tensor,
                                 output=refined_image_tensor,
                                 name='refiner')
    discriminator_model = models.Model(input=refined_or_real_image_tensor,
                                       output=discriminator_output,
                                       name='discriminator')

    # combined must output the refined image along w/ the disc's classification of it for the refiner's self-reg loss
    refiner_model_output = refiner_model(synthetic_image_tensor)
    combined_output = discriminator_model(refiner_model_output)
    combined_model = models.Model(
        input=synthetic_image_tensor,
        output=[refiner_model_output, combined_output],
        name='combined')

    discriminator_model_output_shape = discriminator_model.output_shape

    print(refiner_model.summary())
    print(discriminator_model.summary())
    print(combined_model.summary())

    #
    # define custom l1 loss function for the refiner
    #

    def self_regularization_loss(y_true, y_pred):
        delta = 0.0001  # FIXME: need to figure out an appropriate value for this
        return tf.multiply(delta, tf.reduce_sum(tf.abs(y_pred - y_true)))

    #
    # define custom local adversarial loss (softmax for each image section) for the discriminator
    # the adversarial loss function is the sum of the cross-entropy losses over the local patches
    #

    def local_adversarial_loss(y_true, y_pred):
        # y_true and y_pred have shape (batch_size, # of local patches, 2), but really we just want to average over
        # the local patches and batch size so we can reshape to (batch_size * # of local patches, 2)
        y_true = tf.reshape(y_true, (-1, 2))
        y_pred = tf.reshape(y_pred, (-1, 2))
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true,
                                                       logits=y_pred)

        return tf.reduce_mean(loss)

    #
    # compile models
    #

    sgd = optimizers.SGD(lr=0.001)

    refiner_model.compile(optimizer=sgd, loss=self_regularization_loss)
    discriminator_model.compile(optimizer=sgd, loss=local_adversarial_loss)
    discriminator_model.trainable = False
    combined_model.compile(
        optimizer=sgd, loss=[self_regularization_loss, local_adversarial_loss])

    def get_image_batch(generator):
        """keras generators may generate an incomplete batch for the last batch"""
        img_batch = generator.next()
        if len(img_batch) != batch_size:
            img_batch = generator.next()

        assert len(img_batch) == batch_size

        return img_batch

    # the target labels for the cross-entropy loss layer are 0 for every yj (real) and 1 for every xi (refined)
    y_real = np.array([[[1.0, 0.0]] * discriminator_model_output_shape[1]] *
                      batch_size)
    y_refined = np.array([[[0.0, 1.0]] * discriminator_model_output_shape[1]] *
                         batch_size)
    assert y_real.shape == (batch_size, discriminator_model_output_shape[1], 2)

    if not refiner_model_path:
        # we first train the Rθ network with just self-regularization loss for 1,000 steps
        print('pre-training the refiner network...')
        gen_loss = np.zeros(shape=len(refiner_model.metrics_names))

        for i in range(1000):
            synthetic_image_batch = get_image_batch(synthetic_generator)
            gen_loss = np.add(
                refiner_model.train_on_batch(synthetic_image_batch,
                                             synthetic_image_batch), gen_loss)

            # log every `log_interval` steps
            if not i % log_interval:
                figure_name = 'refined_image_batch_pre_train_step_{}.png'.format(
                    i)
                print(
                    'Saving batch of refined images during pre-training at step: {}.'
                    .format(i))

                synthetic_image_batch = get_image_batch(synthetic_generator)
                plot_image_batch_w_labels.plot_batch(
                    np.concatenate((synthetic_image_batch[:, :, :, :3],
                                    refiner_model.predict_on_batch(
                                        synthetic_image_batch)[:, :, :, :3])),
                    os.path.join(cache_dir, figure_name),
                    label_batch=['Synthetic'] * batch_size +
                    ['Refined'] * batch_size)

                print('Refiner model self regularization loss: {}.'.format(
                    gen_loss / log_interval))
                gen_loss = np.zeros(shape=len(refiner_model.metrics_names))

        refiner_model.save(
            os.path.join(cache_dir, 'refiner_model_pre_trained.h5'))
    else:
        refiner_model.load_weights(refiner_model_path)

    if not discriminator_model_path:
        # and Dφ for 200 steps (one mini-batch for refined images, another for real)
        print('pre-training the discriminator network...')
        disc_loss = np.zeros(shape=len(discriminator_model.metrics_names))

        for _ in range(100):
            real_image_batch = get_image_batch(real_generator)
            disc_loss = np.add(
                discriminator_model.train_on_batch(real_image_batch, y_real),
                disc_loss)

            synthetic_image_batch = get_image_batch(synthetic_generator)
            refined_image_batch = refiner_model.predict_on_batch(
                synthetic_image_batch)
            disc_loss = np.add(
                discriminator_model.train_on_batch(refined_image_batch,
                                                   y_refined), disc_loss)

        discriminator_model.save(
            os.path.join(cache_dir, 'discriminator_model_pre_trained.h5'))

        # hard-coded for now
        print('Discriminator model loss: {}.'.format(disc_loss / (100 * 2)))
    else:
        discriminator_model.load_weights(discriminator_model_path)

    # TODO: what is an appropriate size for the image history buffer?
    image_history_buffer = ImageHistoryBuffer(
        (0, img_height, img_width, img_channels), batch_size * 100, batch_size)

    combined_loss = np.zeros(shape=len(combined_model.metrics_names))
    disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
    disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

    # see Algorithm 1 in https://arxiv.org/pdf/1612.07828v1.pdf
    for i in range(nb_steps):
        print('Step: {} of {}.'.format(i, nb_steps))

        # train the refiner
        for _ in range(k_g * 2):
            # sample a mini-batch of synthetic images
            synthetic_image_batch = get_image_batch(synthetic_generator)

            # update θ by taking an SGD step on mini-batch loss LR(θ)
            combined_loss = np.add(
                combined_model.train_on_batch(synthetic_image_batch,
                                              [synthetic_image_batch, y_real]),
                combined_loss)

        for _ in range(k_d):
            # sample a mini-batch of synthetic and real images
            synthetic_image_batch = get_image_batch(synthetic_generator)
            real_image_batch = get_image_batch(real_generator)

            # refine the synthetic images w/ the current refiner
            refined_image_batch = refiner_model.predict_on_batch(
                synthetic_image_batch)

            # use a history of refined images
            half_batch_from_image_history = image_history_buffer.get_from_image_history_buffer(
            )
            image_history_buffer.add_to_image_history_buffer(
                refined_image_batch)

            if len(half_batch_from_image_history):
                refined_image_batch[:batch_size //
                                    2] = half_batch_from_image_history

            # update φ by taking an SGD step on mini-batch loss LD(φ)
            disc_loss_real = np.add(
                discriminator_model.train_on_batch(real_image_batch, y_real),
                disc_loss_real)
            disc_loss_refined = np.add(
                discriminator_model.train_on_batch(refined_image_batch,
                                                   y_refined),
                disc_loss_refined)

        if not i % log_interval:
            # plot batch of refined images w/ current refiner
            figure_name = 'refined_image_batch_step_{}.png'.format(i)
            print('Saving batch of refined images at adversarial step: {}.'.
                  format(i))

            synthetic_image_batch = get_image_batch(synthetic_generator)
            plot_image_batch_w_labels.plot_batch(
                np.concatenate((synthetic_image_batch[:, :, :, :3],
                                refiner_model.predict_on_batch(
                                    synthetic_image_batch)[:, :, :, :3])),
                os.path.join(cache_dir, figure_name),
                label_batch=['Synthetic'] * batch_size +
                ['Refined'] * batch_size)

            # log loss summary
            print('Refiner model loss: {}.'.format(combined_loss /
                                                   (log_interval * k_g * 2)))
            print('Discriminator model loss real: {}.'.format(
                disc_loss_real / (log_interval * k_d * 2)))
            print('Discriminator model loss refined: {}.'.format(
                disc_loss_refined / (log_interval * k_d * 2)))

            combined_loss = np.zeros(shape=len(combined_model.metrics_names))
            disc_loss_real = np.zeros(
                shape=len(discriminator_model.metrics_names))
            disc_loss_refined = np.zeros(
                shape=len(discriminator_model.metrics_names))

            # save model checkpoints
            model_checkpoint_base_name = os.path.join(cache_dir,
                                                      '{}_model_step_{}.h5')
            refiner_model.save(model_checkpoint_base_name.format('refiner', i))
            discriminator_model.save(
                model_checkpoint_base_name.format('discriminator', i))
Beispiel #3
0
def adversarial_training(synthesis_eyes_dir, mpii_gaze_dir, refiner_model_path=None, discriminator_model_path=None):
 
    synthetic_image_tensor = layers.Input(shape=(img_height, img_width,img_channels))
    refined_image_tensor = encoder_decoder_network(synthetic_image_tensor)

    refined_or_real_image_tensor = layers.Input(shape=(img_height, img_width,img_channels))
    discriminator_output = discriminator_network(refined_or_real_image_tensor)

    #
    # define models
    #

    refiner_model = models.Model(input=synthetic_image_tensor, output=refined_image_tensor, name='refiner')
    discriminator_model = models.Model(input=refined_or_real_image_tensor, output=discriminator_output,
                                       name='discriminator')

    # combined must output the refined image along w/ the disc's classification of it for the refiner's self-reg loss
    refiner_model_output = refiner_model(synthetic_image_tensor)
    print np.shape(refiner_model_output)
    combined_output = discriminator_model(refiner_model_output)
    combined_model = models.Model(input=synthetic_image_tensor, output=[refiner_model_output, combined_output],
                                  name='combined')

    discriminator_model_output_shape = discriminator_model.output_shape

    print(refiner_model.summary())
    print(discriminator_model.summary())
    print(combined_model.summary())

    #
    # define custom l1 loss function for the refiner
    #

    def self_regularization_loss(y_true, y_pred):
        delta = 0.0001  # FIXME: need to figure out an appropriate value for this
        return tf.multiply(delta, tf.reduce_sum(tf.abs(y_pred - y_true)))

    #
    # define custom local adversarial loss (softmax for each image section) for the discriminator
    # the adversarial loss function is the sum of the cross-entropy losses over the local patches
    #

    def local_adversarial_loss(y_true, y_pred):
        # y_true and y_pred have shape (batch_size, # of local patches, 2), but really we just want to average over
        # the local patches and batch size so we can reshape to (batch_size * # of local patches, 2)
        y_true = tf.reshape(y_true, (-1, 2))
        y_pred = tf.reshape(y_pred, (-1, 2))
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)

        return tf.reduce_mean(loss)

    def binary_crossentropy(target, output, from_logits=False):
        """Binary crossentropy between an output tensor and a target tensor.
          Arguments:
              target: A tensor with the same shape as `output`.
              output: A tensor.
              from_logits: Whether `output` is expected to be a logits tensor.
                  By default, we consider that `output`
                  encodes a probability distribution.
          Returns:
              A tensor.
          """
        #target  = crop1(target)
       # output = crop1(output)
        #print np.shape(target), np.shape(output)

          # Note: nn.softmax_cross_entropy_with_logits
          # expects logits, Keras expects probabilities.
        _EPSILON = 10e-8
        if not from_logits:
            # transform back to logits
            epsilon_ =  tf.convert_to_tensor(_EPSILON, output.dtype.base_dtype)
            output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
            output = math_ops.log(output / (1 - output))
        return nn.weighted_cross_entropy_with_logits(target, output, .15, name=None)

    #
    # compile models
    #

    sgd = optimizers.SGD(lr=0.001)

    
    refiner.compile(optimizer = optimizers.Adam(lr = 1e-4), loss =  binary_crossentropy, metrics = ['accuracy'])
    discriminator_model.compile(optimizer=sgd, loss=local_adversarial_loss)
    discriminator_model.trainable = False
    combined_model.compile(optimizer=sgd, loss=[binary_crossentropy, local_adversarial_loss])

    
    #data generators
    
    #normalize to 1??????

    # datagen = image.ImageDataGenerator(
    #     #preprocessing_function=applications.xception.preprocess_input,
    #     rescale = 1/3500.,
    #     data_format="channels_last")

    # flow_from_directory_params = {'target_size': (img_height, img_width),
    #                               'color_mode': 'grayscale',
    #                               'class_mode': None,
    #                               'batch_size': batch_size}

    # synthetic_generator = datagen.flow_from_directory(
    #     directory=synthesis_eyes_dir,
    #     **flow_from_directory_params
    # )

    # real_generator = datagen.flow_from_directory(
    #     directory=mpii_gaze_dir,
    #     **flow_from_directory_params
    # )
    samples = gen_samples1("/media/drc/DATA/chris_labelfusion/RGBDCNN/")#directories=['2017-06-16-30']
    train = generate_data_custom_depth(samples,img_height=480,img_width=640)
  

    print np.shape(get_image_batch(synthetic_generator))
    # the target labels for the cross-entropy loss layer are 0 for every yj (real) and 1 for every xi (refined)
    y_real = np.array([[[1.0, 0.0]] * discriminator_model_output_shape[1]] * batch_size)
    print(np.shape(y_real))
    print y_real
    y_refined = np.array([[[0.0, 1.0]] * discriminator_model_output_shape[1]] * batch_size)
    assert y_real.shape == (batch_size, discriminator_model_output_shape[1], 2)
    if not refiner_model_path:
        print('pre-training the refiner network...')
        gen_loss = np.zeros(shape=len(refiner_model.metrics_names))

        for i in range(500):
            synthetic_image_batch = get_image_batch(synthetic_generator)
            gen_loss = np.add(refiner_model.train_on_batch(synthetic_image_batch, synthetic_image_batch), gen_loss)
            # log every `log_interval` steps
            print i
            if not i % log_interval:
                figure_name = 'refined_image_batch_pre_train_step_{}.png'.format(i)
                print('Saving batch of refined images during pre-training at step: {}.'.format(i))

                synthetic_image_batch = get_image_batch(synthetic_generator)
                plot_image_batch_w_labels.plot_batch(
                    np.concatenate((synthetic_image_batch, refiner_model.predict_on_batch(synthetic_image_batch))),
                    os.path.join(cache_dir, figure_name),
                    label_batch=['Synthetic'] * batch_size + ['Refined'] * batch_size)

                print('Refiner model self regularization loss: {}.'.format(gen_loss / log_interval))
                gen_loss = np.zeros(shape=len(refiner_model.metrics_names))

        refiner_model.save(os.path.join(cache_dir, 'refiner_model_pre_trained.h5'))
    else:
        refiner_model.load_weights(refiner_model_path)

    print("pretrained refiner network")
    
    if not discriminator_model_path:
        print('pre-training the discriminator network...')
        disc_loss = np.zeros(shape=len(discriminator_model.metrics_names))

        for _ in range(200):
            real_image_batch = get_image_batch(real_generator)
            disc_loss = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss)
            synthetic_image_batch = get_image_batch(synthetic_generator)
            refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)
            disc_loss = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined), disc_loss)

        discriminator_model.save(os.path.join(cache_dir, 'discriminator_model_pre_trained.h5'))

        # hard-coded for now
        print('Discriminator model loss: {}.'.format(disc_loss / (100 * 2)))
    else:
        discriminator_model.load_weights(discriminator_model_path)

    # TODO: what is an appropriate size for the image history buffer?
    image_history_buffer = ImageHistoryBuffer((0, img_height, img_width, img_channels), batch_size * 100, batch_size)

    combined_loss = np.zeros(shape=len(combined_model.metrics_names))
    disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
    disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

    # see Algorithm 1 in https://arxiv.org/pdf/1612.07828v1.pdf
    for i in range(nb_steps):
        print('Step: {} of {}.'.format(i, nb_steps))

        # train the refiner
        for _ in range(k_g * 2):
            # sample a mini-batch of synthetic images
            synthetic_image_batch = get_image_batch(synthetic_generator)


            combined_loss = np.add(combined_model.train_on_batch(synthetic_image_batch,[synthetic_image_batch, y_real]), combined_loss)

        for _ in range(k_d):
            # sample a mini-batch of synthetic and real images
            synthetic_image_batch = get_image_batch(synthetic_generator)
            real_image_batch = get_image_batch(real_generator)
            
            # refine the synthetic images w/ the current refiner
            refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)

            # use a history of refined images
            half_batch_from_image_history = image_history_buffer.get_from_image_history_buffer()
            image_history_buffer.add_to_image_history_buffer(refined_image_batch)

            if len(half_batch_from_image_history):
                refined_image_batch[:batch_size // 2] = half_batch_from_image_history

            disc_loss_real = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss_real)
            disc_loss_refined = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined),
                                       disc_loss_refined)

        if not i % log_interval:
            # plot batch of refined images w/ current refiner
            figure_name = 'refined_image_batch_step_{}.png'.format(i)
            print('Saving batch of refined images at adversarial step: {}.'.format(i))

            synthetic_image_batch = get_image_batch(synthetic_generator)
            plot_image_batch_w_labels.plot_batch(
                np.concatenate((synthetic_image_batch, refiner_model.predict_on_batch(synthetic_image_batch))),
                os.path.join(cache_dir, figure_name),
                label_batch=['Synthetic'] * batch_size + ['Refined'] * batch_size)

            #plt.imshow(np.reshape(refiner_model.predict_on_batch(synthetic_image_batch)[0],(224,224)))
            #plt.show()
            # log loss summary
            print('Refiner model loss: {}.'.format(combined_loss / (log_interval * k_g * 2)))
            print('Discriminator model loss real: {}.'.format(disc_loss_real / (log_interval * k_d * 2)))
            print('Discriminator model loss refined: {}.'.format(disc_loss_refined / (log_interval * k_d * 2)))

            combined_loss = np.zeros(shape=len(combined_model.metrics_names))
            disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
            disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

            # save model checkpoints
            model_checkpoint_base_name = os.path.join(cache_dir, '{}_model_step_{}.h5')
            refiner_model.save(model_checkpoint_base_name.format('refiner', i))
            discriminator_model.save(model_checkpoint_base_name.format('discriminator', i))
Beispiel #4
0
def adversarial_training(synthetic_data_dir, real_data_dir, refiner_model_path=None, discriminator_model_path=None):
    """Adversarial training of refiner network Rθ and discriminator network Dφ."""

    # Define model input and output tensors
    synthetic_image_tensor = layers.Input(shape=(img_height, img_width, img_channels))
    refined_image_tensor = refiner_network(synthetic_image_tensor)

    refined_or_real_image_tensor = layers.Input(shape=(img_height, img_width, img_channels))
    discriminator_output = discriminator_network(refined_or_real_image_tensor)

    # Define models
    refiner_model = models.Model(
        input=synthetic_image_tensor, output=refined_image_tensor,
        name='refiner'
    )
    discriminator_model = models.Model(
        input=refined_or_real_image_tensor, output=discriminator_output,
        name='discriminator'
    )

    # Combined must output the refined image along w/ the disc's classification of it for the refiner's self-reg loss
    refiner_model_output = refiner_model(synthetic_image_tensor)
    combined_output = discriminator_model(refiner_model_output)
    combined_model = models.Model(
        input=synthetic_image_tensor,
        output=[refiner_model_output, combined_output],
        name='combined'
    )

    discriminator_model_output_shape = discriminator_model.output_shape

    """
    print refiner_model.summary()
    print discriminator_model.summary()
    print combined_model.summary()
    """

    # Define custom l1 loss function for the refiner
    def self_regularization_loss(y_true, y_pred):
        delta = 0.0001  # FIXME: need to figure out an appropriate value for this
        return tf.multiply(delta, tf.reduce_sum(tf.abs(y_pred - y_true)))

    # Define custom local adversarial loss (softmax for each image section) for the discriminator
    # the adversarial loss function is the sum of the cross-entropy losses over the local patches
    def local_adversarial_loss(y_true, y_pred):
        # y_true and y_pred have shape (batch_size, # of local patches, 2), but really we just want to average over
        # the local patches and batch size so we can reshape to (batch_size * # of local patches, 2)
        y_true = tf.reshape(y_true, (-1, 2))
        y_pred = tf.reshape(y_pred, (-1, 2))
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)

        return tf.reduce_mean(loss)

    # Compile models
    sgd = optimizers.SGD(lr=FLAGS.learning_rate)

    refiner_model.compile(optimizer=sgd, loss=self_regularization_loss)
    discriminator_model.compile(optimizer=sgd, loss=local_adversarial_loss)
    discriminator_model.trainable = False
    combined_model.compile(optimizer=sgd, loss=[self_regularization_loss, local_adversarial_loss])

    synthetic_generator, real_generator = get_data_generator(synthetic_data_dir, real_data_dir)

    # TODO Should use one-sided smoothing suggested by Goodfellow
    # the target labels for the cross-entropy loss layer are 0 for every yj (real) and 1 for every xi (refined)
    y_real = np.array([[[1.0, 0.0]] * discriminator_model_output_shape[1]] * batch_size)
    y_refined = np.array([[[0.0, 1.0]] * discriminator_model_output_shape[1]] * batch_size)
    assert y_real.shape == (batch_size, discriminator_model_output_shape[1], 2)

    if refiner_model_path is None:
        pretrain_refiner(refiner_model, synthetic_generator)
        refiner_model.save(FLAGS.model_dir + 'refiner_model_pre_trained.h5')
    else:
        tf.logging.info("Loading refiner model from {}".format(FLAGS.model_dir + refiner_model_path))
        refiner_model.load_weights(FLAGS.model_dir + refiner_model_path)

    if discriminator_model_path is None:
        pretrain_discriminator(discriminator_model, real_generator, synthetic_generator)
        discriminator_model.save(FLAGS.model_dir + 'discriminator_model_pre_trained.h5')
    else:
        tf.logging.info("Loading discriminator model from {}".format(FLAGS.model_dir + discriminator_model_path))
        discriminator_model.load_weights(FLAGS.model_dir + discriminator_model_path)

    # TODO: what is an appropriate size for the image history buffer?
    image_history_buffer = ImageHistoryBuffer((0, img_height, img_width, img_channels), batch_size * 100, batch_size)

    combined_loss = np.zeros(shape=len(combined_model.metrics_names))
    disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
    disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

    # see Algorithm 1 in https://arxiv.org/pdf/1612.07828v1.pdf
    for i in range(FLAGS.start_iter, FLAGS.start_iter + FLAGS.max_iter):
        print 'Step: {} of {}.'.format(i, FLAGS.start_iter + FLAGS.max_iter)

        # train the refiner
        for _ in range(FLAGS.kg * 2):
            # sample a mini-batch of synthetic images
            synthetic_image_batch = get_image_batch(synthetic_generator)

            # update θ by taking an SGD step on mini-batch loss LR(θ)
            combined_loss = np.add(combined_model.train_on_batch(synthetic_image_batch,
                                                                 [synthetic_image_batch, y_real]), combined_loss)

        for _ in range(FLAGS.kd):
            # sample a mini-batch of synthetic and real images
            synthetic_image_batch = get_image_batch(synthetic_generator)
            real_image_batch = get_image_batch(real_generator)

            # refine the synthetic images w/ the current refiner
            refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)

            # use a history of refined images
            half_batch_from_image_history = image_history_buffer.get_from_image_history_buffer()
            image_history_buffer.add_to_image_history_buffer(refined_image_batch)

            if len(half_batch_from_image_history):
                refined_image_batch[:batch_size // 2] = half_batch_from_image_history

            # update φ by taking an SGD step on mini-batch loss LD(φ)
            disc_loss_real = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss_real)
            disc_loss_refined = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined),
                                       disc_loss_refined)

        if i % log_interval == 0:
            # plot batch of refined images w/ current refiner
            if FLAGS.display:
                fig_name = 'refined_image_batch_step_{}.png'.format(i)
                print 'Saving batch of refined images at adversarial step: {}.'.format(i)
                visualize_training_procedure(fig_name, synthetic_generator, refiner_model)

            # log loss summary
            print 'Refiner model loss: {}.'.format(combined_loss / (log_interval * FLAGS.kg * 2))
            print 'Discriminator model loss real: {}.'.format(disc_loss_real / (log_interval * FLAGS.kd * 2))
            print 'Discriminator model loss refined: {}.'.format(disc_loss_refined / (log_interval * FLAGS.kd * 2))

            combined_loss = np.zeros(shape=len(combined_model.metrics_names))
            disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
            disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

            # save model checkpoints
            """
Beispiel #5
0
def adversarial_training(WLoss_AdverLoss = "W", refiner_model_path=None, discriminator_model_path=None):
    refiner_cp = 0
    discriminator_cp = 0
    if WLoss_AdverLoss == "W":
        print("--------Running with Wasserstein Loss------")
        """Adversarial training of refiner network Rθ and discriminator network Dφ."""
        # define model input and output tensors
        real_image_tensor = layers.Input(shape=(img_height, img_width, img_channels))

        synthetic_image_tensor = layers.Input(shape=(img_height, img_width, img_channels))
        refined_image_tensor = refiner_network(synthetic_image_tensor)

        refined_or_real_image_tensor = layers.Input(shape=(img_height, img_width, img_channels))
        discriminator_output = discriminator_network(refined_or_real_image_tensor)
        # define models
        # Construct weighted average between real synthetic and synthetic images
        interpolated_img = RandomWeightedAverage()([real_image_tensor, refined_or_real_image_tensor])
        validity_interp = discriminator_network(interpolated_img)

        refiner_model = models.Model(input=synthetic_image_tensor, output=refined_image_tensor, name='refiner')
        discriminator_model = models.Model(input=[real_image_tensor,refined_or_real_image_tensor], outputs=[discriminator_output,validity_interp],
                                           name='discriminator')

        # combined must output the refined image along w/ the disc's classification of it for the refiner's self-reg loss
        refiner_model_output = refiner_model(synthetic_image_tensor)
        [combined_output1,combined_output2] = discriminator_model([real_image_tensor,refiner_model_output])
        combined_model = models.Model(input=[real_image_tensor,synthetic_image_tensor], outputs=[refiner_model_output, combined_output1],name='combined')

        discriminator_model_output_shape = discriminator_model.output_shape

        print(refiner_model.summary())
        print(discriminator_model.summary())
        print(combined_model.summary())
    else:
        print("--------Running with adverserial Loss------")
        synthetic_image_tensor = layers.Input(shape=(img_height, img_width, img_channels))
        refined_image_tensor = refiner_network(synthetic_image_tensor)

        refined_or_real_image_tensor = layers.Input(shape=(img_height, img_width, img_channels))
        discriminator_output = discriminator_network(refined_or_real_image_tensor)
        # define modelS
        refiner_model = models.Model(input=synthetic_image_tensor, output=refined_image_tensor, name='refiner')
        discriminator_model = models.Model(input=refined_or_real_image_tensor, output=discriminator_output,
                                           name='discriminator')
        # combined must output the refined image along w/ the disc's classification of it for the refiner's self-reg loss
        refiner_model_output = refiner_model(synthetic_image_tensor)
        combined_output = discriminator_model(refiner_model_output)
        combined_model = models.Model(input=synthetic_image_tensor, output=[refiner_model_output, combined_output],
                                      name='combined')
        discriminator_model_output_shape = discriminator_model.output_shape
        print(refiner_model.summary())
        print(discriminator_model.summary())
        print(combined_model.summary())

    def gradient_penalty_loss(y_true, y_pred, averaged_samples):
        """
        Computes gradient penalty based on prediction and weighted real / fake samples
        """
        gradients = K.gradients(y_pred, averaged_samples)[0]
        # compute the euclidean norm by squaring ...
        gradients_sqr = K.square(gradients)
        #   ... summing over the rows ...
        gradients_sqr_sum = K.sum(gradients_sqr,
                                  axis=np.arange(1, len(gradients_sqr.shape)))
        #   ... and sqrt
        gradient_l2_norm = K.sqrt(gradients_sqr_sum)
        # compute lambda * (1 - ||grad||)^2 still for each single sample
        gradient_penalty = K.square(1 - gradient_l2_norm)
        # return the mean as loss over all the batch samples
        return K.mean(gradient_penalty)

    def wasserstein_loss( y_true, y_pred):
        y_true = tf.reshape(y_true, (-1, 2))
        y_pred = tf.reshape(y_pred, (-1, 2))
        return K.mean(y_true * y_pred)

    # define custom l1 loss function for the refiner
    def self_regularization_loss(y_true, y_pred):
        delta = 0.0001  # FIXME: need to figure out an appropriate value for this
        return tf.multiply(delta, tf.reduce_sum(tf.abs(y_pred - y_true)))

    # define custom local adversarial loss (softmax for each image section) for the discriminator
    # the adversarial loss function is the sum of the cross-entropy losses over the local patches
    def local_adversarial_loss(y_true, y_pred):
        # y_true and y_pred have shape (batch_size, # of local patches, 2), but really we just want to average over
        # the local patches and batch size so we can reshape to (batch_size * # of local patches, 2)
        y_true = tf.reshape(y_true, (-1, 2))
        y_pred = tf.reshape(y_pred, (-1, 2))
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
        return tf.reduce_mean(loss)

    # compile models
    if WLoss_AdverLoss == "W":
        optimizer = RMSprop(lr=0.00005)
        # Use Python partial to provide loss function with additional
        # 'averaged_samples' argument
        partial_gp_loss = partial(gradient_penalty_loss,
                                  averaged_samples=interpolated_img)
        partial_gp_loss.__name__ = 'gradient_penalty'  # Keras requires function names

        refiner_model.compile(optimizer=optimizer, loss=self_regularization_loss)
        discriminator_model.compile(optimizer=optimizer, loss=[wasserstein_loss,
                                                         partial_gp_loss],loss_weights=[1, 10])
        discriminator_model.trainable = False
        combined_model.compile(optimizer=optimizer, loss=[self_regularization_loss, wasserstein_loss])
    else:
        sgd = optimizers.SGD(lr=0.001)

        refiner_model.compile(optimizer=sgd, loss=self_regularization_loss)
        discriminator_model.compile(optimizer=sgd, loss=local_adversarial_loss)
        discriminator_model.trainable = False
        combined_model.compile(optimizer=sgd, loss=[self_regularization_loss, local_adversarial_loss])

    # data generators
    datagen = image.ImageDataGenerator(
        preprocessing_function=applications.xception.preprocess_input,
        data_format='channels_last')

    flow_from_directory_params = {'target_size': (img_height, img_width),
                                  'color_mode': 'grayscale' if img_channels == 1 else 'rgb',
                                  'class_mode': None,
                                  'batch_size': batch_size}
    flow_params = {'batch_size': batch_size}

    synthetic_generator = datagen.flow(
        x=syn_image_stack,
        **flow_params
    )
    sample_generator = datagen.flow(
        x=sample_images,
        **flow_params
    )
    samples = sample_generator.next()

    real_generator = datagen.flow(
        x=real_image_stack,
        **flow_params
    )

    def get_image_batch(generator):
        """keras generators may generate an incomplete batch for the last batch"""
        img_batch = generator.next()
        if len(img_batch) != batch_size:
            img_batch = generator.next()

        assert len(img_batch) == batch_size

        return img_batch

    # Adversarial ground truths
    # the target labels for the cross-entropy loss layer are 0 for every yj (real) and 1 for every xi (refined)
    if WLoss_AdverLoss =="W":
        y_real = np.array([[[1.0, 0.0]] * discriminator_model_output_shape[1][1]] * batch_size)
        y_refined = np.array([[[0.0, 1.0]] * discriminator_model_output_shape[1][1]] * batch_size)
        y_dummy = np.array([[[0.0, 0.0]] * discriminator_model_output_shape[1][1]] * batch_size)
        assert y_real.shape == (batch_size, discriminator_model_output_shape[1][1], 2)
    else:
        # the target labels for the cross-entropy loss layer are 0 for every yj (real) and 1 for every xi (refined)
        y_real = np.array([[[1.0, 0.0]] * discriminator_model_output_shape[1]] * batch_size)
        y_refined = np.array([[[0.0, 1.0]] * discriminator_model_output_shape[1]] * batch_size)
        assert y_real.shape == (batch_size, discriminator_model_output_shape[1], 2)

    if not refiner_model_path:
        # we first train the Rθ network with just self-regularization loss for 1,000 steps
        print('pre-training the refiner network...')
        gen_loss = np.zeros(shape=len(refiner_model.metrics_names))
        gen_loss_pre_training_vec = list()
        if WLoss_AdverLoss == "W":
            folder_name = 'pre training refiner loss wasserstein'
            plot_name = 'pre-training refiner loss wasserstein.png'

        else:
            folder_name = 'pre training refiner loss original'
            plot_name = 'pre-training refiner loss.png'

        if not os.path.exists(folder_name):
            os.makedirs(folder_name)
        for i in range(1000):
            synthetic_image_batch = get_image_batch(synthetic_generator)
            gen_loss = np.add(refiner_model.train_on_batch(synthetic_image_batch, synthetic_image_batch), gen_loss)

            if not i % log_interval and i != 0:
                sub_folder_name = 'compare_images_batch_train_step_{}'.format(i)
                sub_folder_name2 = 'compare_same_images'

                print('Saving batch of refined images during pre-training at step: {}'.format(i))

                synthetic_image_batch = get_image_batch(synthetic_generator)
                refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)
                refined_images_compare = refiner_model.predict_on_batch(samples)
                plot_images_compare_2.plot_compare2(samples, refined_images_compare,
                                                    os.path.join(folder_name, sub_folder_name2), i)
                plot_images_compare.plot_compare(synthetic_image_batch, refined_image_batch,
                                                 os.path.join(folder_name, sub_folder_name), number_of_plots=2)
                gen_loss_pre_training_vec.append(gen_loss)
                plot_loss.plot_loss_vec(gen_loss_pre_training_vec, plot_name)

                print('Refiner model self regularization loss: {}.'.format(gen_loss / log_interval))
                gen_loss = np.zeros(shape=len(refiner_model.metrics_names))
        refiner_model.save(os.path.join(cache_dir, 'refiner_model_pre_trained.h5'))
    else:
        refiner_model.load_weights(refiner_model_path)
        if "pre" not in refiner_model_path:
            refiner_cp = int(((refiner_model_path.split('_step_'))[1].split('.h5'))[0])
        else:
            refiner_cp = 0
        print(refiner_cp)

    if not discriminator_model_path:
        # and Dφ for 200 steps (one mini-batch for refined images, another for real)
        print('pre-training the discriminator network...')
        disc_loss = np.zeros(shape=len(discriminator_model.metrics_names))
        disc_loss_pre_training_vec = list()
        for _ in range(100):
            real_image_batch = get_image_batch(real_generator)
            if WLoss_AdverLoss == "W":
                plot_name = 'pre-training discriminator_model_path loss wasserstein.png'
                disc_loss_pre = discriminator_model.train_on_batch([real_image_batch,real_image_batch], [y_real,y_dummy])
                disc_loss = np.add(disc_loss_pre[0] , disc_loss)
                
                synthetic_image_batch = get_image_batch(synthetic_generator)
                refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)
                disc_loss_pre = discriminator_model.train_on_batch([real_image_batch,refined_image_batch], [y_refined,y_dummy])
                disc_loss =  np.add(disc_loss_pre[0], disc_loss)
            else:
                plot_name = 'pre-training discriminator_model_path loss.png'
                disc_loss = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss)

                synthetic_image_batch = get_image_batch(synthetic_generator)
                refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)
                disc_loss = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined), disc_loss)
            
            disc_loss_pre_training_vec.append(disc_loss)
            plot_loss.plot_loss_vec(disc_loss_pre_training_vec, plot_name)
            
        discriminator_model.save(os.path.join(cache_dir, 'discriminator_model_pre_trained.h5'))
        print('Discriminator model loss: {}.'.format(disc_loss / (100 * 2)))
    else:
        discriminator_model.load_weights(discriminator_model_path)

    image_history_buffer = ImageHistoryBuffer((0, img_height, img_width, img_channels), batch_size * 100, batch_size)

    combined_loss = np.zeros(shape=len(combined_model.metrics_names))
    disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
    disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))
    if WLoss_AdverLoss == "W":
        folder_name = 'train_loss_wasserstein'
        plot_name1 = 'training refiner loss_w.png'
        plot_name2 = 'real training discriminator loss_w.png'
        plot_name3 = 'refined training discriminator loss_w.png'
    else:
        folder_name = 'train_loss_original'
        plot_name1 = 'training refiner loss_a.png'
        plot_name2 = 'real training discriminator loss_a.png'
        plot_name3 = 'refined training discriminator loss_a.png'
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    # see Algorithm 1 in https://arxiv.org/pdf/1612.07828v1.pdf
    refiner_loss_training_vec = list()
    disc_loss_real_training_vec = list()
    disc_loss_syn_training_vec = list()

    for i in range(refiner_cp , nb_steps):
        # train the refiner
        for _ in range(k_g * 2):
            # sample a mini-batch of synthetic images
            synthetic_image_batch = get_image_batch(synthetic_generator)

            # update θ by taking an SGD step on mini-batch loss LR(θ)
            if WLoss_AdverLoss == "W":
                combined_loss = np.add(combined_model.train_on_batch([synthetic_image_batch, synthetic_image_batch],
                                                                 [synthetic_image_batch, y_real]), combined_loss)
            else:
                combined_loss = np.add(combined_model.train_on_batch(synthetic_image_batch,
                                                                 [synthetic_image_batch, y_real]), combined_loss)

        for _ in range(k_d):
            if WLoss_AdverLoss == "W":
                # sample a mini-batch of synthetic and real images
                synthetic_image_batch = get_image_batch(synthetic_generator)
                real_image_batch = get_image_batch(real_generator)

                # refine the synthetic images w/ the current refiner
                refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)

                # use a history of refined images
                half_batch_from_image_history = image_history_buffer.get_from_image_history_buffer()
                image_history_buffer.add_to_image_history_buffer(refined_image_batch)

                if len(half_batch_from_image_history):
                    refined_image_batch[:batch_size // 2] = half_batch_from_image_history

                # update φ by taking an SGD step on mini-batch loss LD(φ)
                disc_loss_pre = discriminator_model.train_on_batch([real_image_batch,real_image_batch], [y_real,y_dummy])
                disc_loss_real = np.add(disc_loss_pre[0], disc_loss_real)
                disc_loss_pre = discriminator_model.train_on_batch([real_image_batch, refined_image_batch], [y_refined,y_dummy])
                disc_loss_refined = np.add(disc_loss_pre[0],disc_loss_refined)
            else:
                synthetic_image_batch = get_image_batch(synthetic_generator)
                real_image_batch = get_image_batch(real_generator)

                # refine the synthetic images w/ the current refiner
                refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)

                # use a history of refined images
                half_batch_from_image_history = image_history_buffer.get_from_image_history_buffer()
                image_history_buffer.add_to_image_history_buffer(refined_image_batch)

                if len(half_batch_from_image_history):
                    refined_image_batch[:batch_size // 2] = half_batch_from_image_history

                # update φ by taking an SGD step on mini-batch loss LD(φ)
                disc_loss_real = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss_real)
                disc_loss_refined = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined),
                                           disc_loss_refined)
        
        if not i % log_interval and i != 0:
            # plot batch of refined images w/ current refiner
            sub_folder_name = 'refined_image_batch_step_{}'.format(i)
            sub_folder_name2 = 'compare_same_images'
            print('Saving batch of refined images at adversarial step: {}.'.format(i))

            synthetic_image_batch = get_image_batch(synthetic_generator)
            refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)

            refined_images_compare = refiner_model.predict_on_batch(samples)
            plot_images_compare_2.plot_compare2(samples, refined_images_compare,
                                                os.path.join(folder_name, sub_folder_name2), i)
            plot_images_compare.plot_compare(synthetic_image_batch, refined_image_batch,
                                             os.path.join(folder_name, sub_folder_name), number_of_plots=5)

            # log loss summary
            print('Refiner model loss: {}.'.format(combined_loss / (log_interval * k_g * 2)))
            print('Discriminator model loss real: {}.'.format(disc_loss_real / (log_interval * k_d * 2)))
            print('Discriminator model loss refined: {}.'.format(disc_loss_refined / (log_interval * k_d * 2)))
            refiner_loss_training_vec.append(combined_loss[1])
            disc_loss_real_training_vec.append(disc_loss_real[0])
            disc_loss_syn_training_vec.append(disc_loss_refined[0])
            plot_loss.plot_loss_vec(refiner_loss_training_vec, plot_name1)
            plot_loss.plot_loss_vec(disc_loss_real_training_vec, plot_name2)
            plot_loss.plot_loss_vec(disc_loss_syn_training_vec, plot_name3)

            combined_loss = np.zeros(shape=len(combined_model.metrics_names))
            disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
            disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

            # save model checkpoints
            if WLoss_AdverLoss == "W":
                model_checkpoint_base_name = os.path.join(cache_dir, '{}_model_w_step_{}.h5')
            else:
                model_checkpoint_base_name = os.path.join(cache_dir, '{}_model_a_step_{}.h5')
            refiner_model.save(model_checkpoint_base_name.format('refiner', i))
            discriminator_model.save(model_checkpoint_base_name.format('discriminator', i))
Beispiel #6
0
def adversarial_training(synthesis_eyes_dir, mpii_gaze_dir, refiner_model_path=None, discriminator_model_path=None):
 
    synthetic_image_tensor = layers.Input(shape=(img_height, img_width,img_channels))
    refined_image_tensor = refiner_network(synthetic_image_tensor)

    refined_or_real_image_tensor = layers.Input(shape=(img_height, img_width,img_channels))
    discriminator_output = discriminator_network(refined_or_real_image_tensor)

    #
    # define models
    #

    refiner_model = models.Model(input=synthetic_image_tensor, output=refined_image_tensor, name='refiner')
    discriminator_model = models.Model(input=refined_or_real_image_tensor, output=discriminator_output,
                                       name='discriminator')

    # combined must output the refined image along w/ the disc's classification of it for the refiner's self-reg loss
    refiner_model_output = refiner_model(synthetic_image_tensor)
    print np.shape(refiner_model_output)
    combined_output = discriminator_model(refiner_model_output)
    combined_model = models.Model(input=synthetic_image_tensor, output=[refiner_model_output, combined_output],
                                  name='combined')

    discriminator_model_output_shape = discriminator_model.output_shape

    print(refiner_model.summary())
    print(discriminator_model.summary())
    print(combined_model.summary())

    #
    # define custom l1 loss function for the refiner
    #

    def self_regularization_loss(y_true, y_pred):
        delta = 0.0001  # FIXME: need to figure out an appropriate value for this
        return tf.multiply(delta, tf.reduce_sum(tf.abs(y_pred - y_true)))

    #
    # define custom local adversarial loss (softmax for each image section) for the discriminator
    # the adversarial loss function is the sum of the cross-entropy losses over the local patches
    #

    def local_adversarial_loss(y_true, y_pred):
        # y_true and y_pred have shape (batch_size, # of local patches, 2), but really we just want to average over
        # the local patches and batch size so we can reshape to (batch_size * # of local patches, 2)
        y_true = tf.reshape(y_true, (-1, 2))
        y_pred = tf.reshape(y_pred, (-1, 2))
        loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)

        return tf.reduce_mean(loss)

    #
    # compile models
    #

    sgd = optimizers.SGD(lr=0.001)

    refiner_model.compile(optimizer=sgd, loss=self_regularization_loss)
    discriminator_model.compile(optimizer=sgd, loss=local_adversarial_loss)
    discriminator_model.trainable = False
    combined_model.compile(optimizer=sgd, loss=[self_regularization_loss, local_adversarial_loss])

    
    #data generators
    
    #normalize to 1??????

    datagen = image.ImageDataGenerator(
        #preprocessing_function=applications.xception.preprocess_input,
        rescale = 1/3500.,
        data_format="channels_last")

    flow_from_directory_params = {'target_size': (img_height, img_width),
                                  'color_mode': 'grayscale',
                                  'class_mode': None,
                                  'batch_size': batch_size}

    synthetic_generator = datagen.flow_from_directory(
        directory=synthesis_eyes_dir,
        **flow_from_directory_params
    )

    real_generator = datagen.flow_from_directory(
        directory=mpii_gaze_dir,
        **flow_from_directory_params
    )
    def hot_vectorize(x,value = 0):
        zero_mask = x==value
        non_zero_mask = x!=value
        x[zero_mask] = 1
        x[non_zero_mask] = 0
        return x.astype(float)

    def stack_frames(frames,img_height,img_width,channels):
        stack  = np.zeros((1,img_height,img_width,channels))
        index = 0
        for i in frames:
            num_chan = 1 if len(np.shape(i)) ==2 else np.shape(i)[2]
            stack[0,:,:,index:index+num_chan] = np.reshape(i,(img_height,img_width,num_chan))
            index += num_chan
        return stack

    def grab_frame(files,i,path,func=None):
        img = misc.imread(path+files[i])
        if func:
            return func(img)
        return img

    def grab_frame1(path,func=None):
        img = misc.imread(path)
        if func:
            return func(img)
        return img

    def normalize(x):
        x=x.astype(float)/3500.
        x*=2
        x-=1
        return x

    def normalize_depth(x):
        return x.astype(float)/3000.

    def convert_rgb_normal(img):
        return (img/255.*2)-1

    def bounding_box(img,size = 100):
        h,w = np.shape(img)
        non_zeros = np.nonzero(img)
        x_min = np.min(non_zeros[0])
        x_max = np.max(non_zeros[0])
        y_min = np.min(non_zeros[1])
        y_max = np.max(non_zeros[1])
        out = (x_min,x_min+size,y_min,y_min+size)            
        rgb = sorted(glob.glob(path+"r-*"))
        #if size else (x_min,x_max,y_min,y_max)#minuce or plus coordinates
        if x_min< 0 or x_min+size > h or y_min<0 or y_min+size>w:
            return None
        return out

    def crop(x,x1 = 100,x2 = 500,y1 = 50, y2 = 450):
        x = x[y1:y2,x1:x2]
        return x
  
    def gen_samples(directory,key,shuffle = True):
            samples = []
            dirs = os.listdir(directory)
            for i in dirs:
                path = os.path.join(directory, i)+"/"
                if os.access(path, os.R_OK):
                    depth = sorted(glob.glob(path+"*"+key))
                    samples.extend(depth)
            if shuffle:
                random.shuffle(samples)
            return samples

    def generate_data_custom_depth(samples,img_height=img_height,img_width=img_width,batch_size=4,func=[]):
        i = 0
        while True:
            stack1 = np.zeros((batch_size,img_height,img_width,1))
            j=0
            while j<batch_size:
                try: 
                    depth = samples[i]
                    depth_img = grab_frame1(depth,normalize)
                    depth_img = crop(depth_img)

                    stack1[j] = np.reshape(depth_img,(img_height,img_width,1))
                    j+=1
                    i= (i+1)%len(samples)
                except Exception:
                    i=(i+1)%len(samples)
            yield stack1 

    # samples = gen_samples("/media/drc/DATA/chris_labelfusion/RGBDCNN/",key="_truth.png")
    # synthetic_generator = generate_data_custom_depth(samples,img_height=img_height,img_width=img_width,batch_size= batch_size)

    # samples = gen_samples("/media/drc/DATA/chris_labelfusion/RGBDCNN/",key ="_depth.png")
    # real_generator = generate_data_custom_depth(samples,img_height=img_height,img_width=img_width,batch_size= batch_size)
 

    def get_image_batch(generator):
        """keras generators may generate an incomplete batch for the last batch"""
        img_batch = generator.next()
        if len(img_batch) != batch_size:
            img_batch = generator.next()

        assert len(img_batch) == batch_size

        return img_batch
      # synthetic_generator = datagen.flow_from_directory(
    #     directory=synthesis_eyes_dir,
    #     **flow_from_directory_params
    # )

    # real_generator = datagen.flow_from_directory(
    #     directory=mpii_gaze_dir,
    #     **flow_from_directory_params
    # )
    print np.shape(get_image_batch(synthetic_generator))
    # the target labels for the cross-entropy loss layer are 0 for every yj (real) and 1 for every xi (refined)
    y_real = np.array([[[1.0, 0.0]] * discriminator_model_output_shape[1]] * batch_size)
    print(np.shape(y_real))
    print y_real
    y_refined = np.array([[[0.0, 1.0]] * discriminator_model_output_shape[1]] * batch_size)
    assert y_real.shape == (batch_size, discriminator_model_output_shape[1], 2)
    if not refiner_model_path:
        print('pre-training the refiner network...')
        gen_loss = np.zeros(shape=len(refiner_model.metrics_names))

        for i in range(500):
            synthetic_image_batch = get_image_batch(synthetic_generator)
            gen_loss = np.add(refiner_model.train_on_batch(synthetic_image_batch, synthetic_image_batch), gen_loss)
            # log every `log_interval` steps
            print i
            if not i % log_interval:
                figure_name = 'refined_image_batch_pre_train_step_{}.png'.format(i)
                print('Saving batch of refined images during pre-training at step: {}.'.format(i))

                synthetic_image_batch = get_image_batch(synthetic_generator)
                plot_image_batch_w_labels.plot_batch(
                    np.concatenate((synthetic_image_batch, refiner_model.predict_on_batch(synthetic_image_batch))),
                    os.path.join(cache_dir, figure_name),
                    label_batch=['Synthetic'] * batch_size + ['Refined'] * batch_size)

                print('Refiner model self regularization loss: {}.'.format(gen_loss / log_interval))
                gen_loss = np.zeros(shape=len(refiner_model.metrics_names))

        refiner_model.save(os.path.join(cache_dir, 'refiner_model_pre_trained.h5'))
    else:
        refiner_model.load_weights(refiner_model_path)

    print("pretrained refiner network")
    
    if not discriminator_model_path:
        print('pre-training the discriminator network...')
        disc_loss = np.zeros(shape=len(discriminator_model.metrics_names))

        for _ in range(200):
            real_image_batch = get_image_batch(real_generator)
            disc_loss = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss)
            synthetic_image_batch = get_image_batch(synthetic_generator)
            refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)
            disc_loss = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined), disc_loss)

        discriminator_model.save(os.path.join(cache_dir, 'discriminator_model_pre_trained.h5'))

        # hard-coded for now
        print('Discriminator model loss: {}.'.format(disc_loss / (100 * 2)))
    else:
        discriminator_model.load_weights(discriminator_model_path)

    # TODO: what is an appropriate size for the image history buffer?
    image_history_buffer = ImageHistoryBuffer((0, img_height, img_width, img_channels), batch_size * 100, batch_size)

    combined_loss = np.zeros(shape=len(combined_model.metrics_names))
    disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
    disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

    # see Algorithm 1 in https://arxiv.org/pdf/1612.07828v1.pdf
    for i in range(nb_steps):
        print('Step: {} of {}.'.format(i, nb_steps))

        # train the refiner
        for _ in range(k_g * 2):
            # sample a mini-batch of synthetic images
            synthetic_image_batch = get_image_batch(synthetic_generator)


            combined_loss = np.add(combined_model.train_on_batch(synthetic_image_batch,[synthetic_image_batch, y_real]), combined_loss)

        for _ in range(k_d):
            # sample a mini-batch of synthetic and real images
            synthetic_image_batch = get_image_batch(synthetic_generator)
            real_image_batch = get_image_batch(real_generator)
            
            # refine the synthetic images w/ the current refiner
            refined_image_batch = refiner_model.predict_on_batch(synthetic_image_batch)

            # use a history of refined images
            half_batch_from_image_history = image_history_buffer.get_from_image_history_buffer()
            image_history_buffer.add_to_image_history_buffer(refined_image_batch)

            if len(half_batch_from_image_history):
                refined_image_batch[:batch_size // 2] = half_batch_from_image_history

            disc_loss_real = np.add(discriminator_model.train_on_batch(real_image_batch, y_real), disc_loss_real)
            disc_loss_refined = np.add(discriminator_model.train_on_batch(refined_image_batch, y_refined),
                                       disc_loss_refined)

        if not i % log_interval:
            # plot batch of refined images w/ current refiner
            figure_name = 'refined_image_batch_step_{}.png'.format(i)
            print('Saving batch of refined images at adversarial step: {}.'.format(i))

            synthetic_image_batch = get_image_batch(synthetic_generator)
            plot_image_batch_w_labels.plot_batch(
                np.concatenate((synthetic_image_batch, refiner_model.predict_on_batch(synthetic_image_batch))),
                os.path.join(cache_dir, figure_name),
                label_batch=['Synthetic'] * batch_size + ['Refined'] * batch_size)

            #plt.imshow(np.reshape(refiner_model.predict_on_batch(synthetic_image_batch)[0],(224,224)))
            #plt.show()
            # log loss summary
            print('Refiner model loss: {}.'.format(combined_loss / (log_interval * k_g * 2)))
            print('Discriminator model loss real: {}.'.format(disc_loss_real / (log_interval * k_d * 2)))
            print('Discriminator model loss refined: {}.'.format(disc_loss_refined / (log_interval * k_d * 2)))

            combined_loss = np.zeros(shape=len(combined_model.metrics_names))
            disc_loss_real = np.zeros(shape=len(discriminator_model.metrics_names))
            disc_loss_refined = np.zeros(shape=len(discriminator_model.metrics_names))

            # save model checkpoints
            model_checkpoint_base_name = os.path.join(cache_dir, '{}_model_step_{}.h5')
            refiner_model.save(model_checkpoint_base_name.format('refiner', i))
            discriminator_model.save(model_checkpoint_base_name.format('discriminator', i))
Beispiel #7
0
class CycleGANModel(object):
    """
    CycleGAN model class, responsible for checkpointing and the forward and backward pass.
    Inspired by:
    https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/cycle_gan_model.py
    """
    def __init__(self, opt):
        self.opt = opt

        self.genA2B = Generator(opt)
        self.genB2A = Generator(opt)

        if opt.training:
            self.discA = Discriminator(opt)
            self.discB = Discriminator(opt)
            self.learning_rate = tf.contrib.eager.Variable(
                opt.lr, dtype=tf.float32, name='learning_rate')
            self.disc_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                     beta1=opt.beta1)
            self.gen_optim = tf.train.AdamOptimizer(self.learning_rate,
                                                    beta1=opt.beta1)
            self.global_step = tf.train.get_or_create_global_step()
            # Initialize history buffers:
            self.discA_buffer = ImageHistoryBuffer(opt)
            self.discB_buffer = ImageHistoryBuffer(opt)
        # Restore latest checkpoint:
        self.initialize_checkpoint()
        if not opt.training or opt.load_checkpoint:
            self.restore_checkpoint()

    def initialize_checkpoint(self):
        if self.opt.training:
            self.checkpoint = tf.train.Checkpoint(
                discA=self.discA,
                discB=self.discB,
                genA2B=self.genA2B,
                genB2A=self.genB2A,
                disc_optim=self.disc_optim,
                gen_optim=self.gen_optim,
                learning_rate=self.learning_rate,
                global_step=self.global_step)
        else:
            self.checkpoint = tf.train.Checkpoint(genA2B=self.genA2B,
                                                  genB2A=self.genB2A)

    def restore_checkpoint(self):
        checkpoint_dir = os.path.join(self.opt.save_dir, 'checkpoints')
        latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
        if (not self.opt.training
                or self.opt.load_checkpoint) and latest_checkpoint is not None:
            # Use assert_existing_objects_matched() instead of asset_consumed() here because
            # optimizers aren't initialized fully until first gradient update.
            # This will throw an exception if the checkpoint does not restore the model weights.
            self.checkpoint.restore(
                latest_checkpoint).assert_existing_objects_matched()
            print("Checkpoint restored from ", latest_checkpoint)
        else:
            print("Failed to restore checkpoint, initializing model.")

    def set_input(self, input):
        # Get next batches:
        self.dataA = input["A"].get_next()
        self.dataB = input["B"].get_next()

    def forward(self):
        # Gen output shape: (batch_size, img_size, img_size, 3)
        self.fakeB = self.genA2B(self.dataA)
        self.reconstructedA = self.genB2A(self.fakeB)

        self.fakeA = self.genB2A(self.dataB)
        self.reconstructedB = self.genA2B(self.fakeA)

    def backward_D(self, netD, real, fake, tape):
        # Disc output shape: (batch_size, img_size/8, img_size/8, 1)
        pred_real = netD(real)
        pred_fake = netD(tf.stop_gradient(fake))  # Detaches generator from D
        disc_loss = discriminator_loss(pred_real, pred_fake, self.opt.gan_mode)
        if self.opt.gan_mode == 'wgangp':  # GRADIENT PENALTY
            with tape.stop_recording():
                epsilon = tf.random_uniform(shape=[BATCH_SIZE, 1, 1, 1],
                                            minval=0.,
                                            maxval=1.)
                X_hat = real + epsilon * (fake - real)

                def gp_func(X_hat):
                    return netD(X_hat)

                gp_grad_func = tf.contrib.eager.gradients_function(gp_func)
                grad_critic_X_hat = gp_grad_func(X_hat)[0]
            slopes = tf.sqrt(
                tf.reduce_sum(tf.square(grad_critic_X_hat), axis=[1, 2, 3]))
            gradient_penalty = tf.reduce_mean((slopes - 1.)**2)
            disc_loss += 10 * gradient_penalty  # Lambda = 10 in gradient penalty
        return disc_loss

    def backward_discA(self, tape):
        # Sample from history buffer of 50 images:
        fake_A = self.discA_buffer.query(self.fakeA)
        self.discA_loss = self.backward_D(self.discA, self.dataA, fake_A, tape)
        return self.discA_loss

    def backward_discB(self, tape):
        # Sample from history buffer of 50 images:
        fake_B = self.discB_buffer.query(self.fakeB)
        self.discB_loss = self.backward_D(self.discB, self.dataB, fake_B, tape)
        return self.discB_loss

    def backward_G(self):
        if self.opt.identity_lambda > 0:
            identityA = self.genB2A(self.dataA)
            self.id_lossA = identity_loss(
                self.dataA,
                identityA) * self.opt.cyc_lambda * self.opt.identity_lambda

            identityB = self.genA2B(self.dataB)
            self.id_lossB = identity_loss(
                self.dataB,
                identityB) * self.opt.cyc_lambda * self.opt.identity_lambda
        else:
            id_lossA, id_lossB = 0, 0

        self.genA2B_loss = generator_loss(self.discB(self.dataB),
                                          self.discB(self.fakeB),
                                          self.opt.gan_mode)
        self.genB2A_loss = generator_loss(self.discA(self.dataA),
                                          self.discA(self.fakeA),
                                          self.opt.gan_mode)

        self.cyc_lossA = cycle_loss(self.dataA,
                                    self.reconstructedA) * self.opt.cyc_lambda
        self.cyc_lossB = cycle_loss(self.dataB,
                                    self.reconstructedB) * self.opt.cyc_lambda

        gen_loss = self.genA2B_loss + self.genB2A_loss + self.cyc_lossA + self.cyc_lossB + self.id_lossA + self.id_lossB
        return gen_loss

    def optimize_parameters(self):
        for net in (self.discA, self.discB):
            for layer in net.layers:
                layer.trainable = False

        with tf.GradientTape() as genTape:
            genTape.watch([self.genA2B.variables, self.genB2A.variables])

            self.forward()
            gen_loss = self.backward_G()

        gen_variables = [self.genA2B.variables, self.genB2A.variables]
        gen_gradients = genTape.gradient(gen_loss, gen_variables)
        self.gen_optim.apply_gradients(list(zip(gen_gradients[0], gen_variables[0])) \
                                     + list(zip(gen_gradients[1], gen_variables[1])),
                                     global_step=self.global_step)

        for net in (self.discA, self.discB):
            for layer in net.layers:
                layer.trainable = True

        with tf.GradientTape(persistent=True) as discTape:
            discTape.watch([self.discA.variables, self.discB.variables])
            self.forward()
            discA_loss = self.backward_discA(discTape)
            discB_loss = self.backward_discB(discTape)

        discA_gradients = discTape.gradient(discA_loss, self.discA.variables)
        discB_gradients = discTape.gradient(discB_loss, self.discB.variables)
        self.disc_optim.apply_gradients(zip(discA_gradients,
                                            self.discA.variables),
                                        global_step=self.global_step)
        self.disc_optim.apply_gradients(zip(discB_gradients,
                                            self.discB.variables),
                                        global_step=self.global_step)

    def save_model(self):
        checkpoint_prefix = os.path.join(self.opt.save_dir, 'checkpoints',
                                         'ckpt')
        checkpoint_path = self.checkpoint.save(file_prefix=checkpoint_prefix)
        print("Checkpoint saved at ", checkpoint_path)

    def test(self):
        self.fakeA = self.genB2A(self.dataB)
        self.fakeB = self.genA2B(self.dataA)
        return [self.dataA, self.fakeA, self.dataB, self.fakeB]

    def update_learning_rate(self, batches_per_epoch):
        new_lr = self._get_learning_rate(batches_per_epoch)
        self.learning_rate.assign(new_lr)

    def _get_learning_rate(self, batches_per_epoch):
        global_step = self.global_step.numpy(
        ) / 3  # /3 because there are 3 gradient updates per batch.
        total_epochs = global_step // batches_per_epoch
        learning_rate_lambda = 1.0 - max(
            0, total_epochs - self.opt.niter) / float(self.opt.niter_decay + 1)
        return self.opt.lr * max(0, learning_rate_lambda)