Beispiel #1
0
def create_model(inputs, targets, max_steps):
    model = Pix2Pix()

    out_channels = int(targets.get_shape()[-1])
    outputs = model.get_generator(inputs, out_channels, ngf=args.ngf,
                                  conv_type=args.conv_type,
                                  channel_multiplier=args.channel_multiplier,
                                  padding='SAME',
                                  net_type=args.net_type, reuse=False,
                                  upsampe_method=args.upsampe_method)

    # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
    predict_real = model.get_discriminator(inputs, targets, ndf=args.ndf,
                                           spectral_normed=True,
                                           update_collection=None,
                                           conv_type=args.conv_type,
                                           channel_multiplier=args.channel_multiplier,
                                           padding='VALID',
                                           net_type=args.net_type, reuse=False)

    # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
    predict_fake = model.get_discriminator(inputs, outputs, ndf=args.ndf,
                                           spectral_normed=True,
                                           update_collection='NO_OPS',
                                           conv_type=args.conv_type,
                                           channel_multiplier=args.channel_multiplier,
                                           padding='VALID',
                                           net_type=args.net_type, reuse=True)

    with tf.name_scope("d_loss"):
        # discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
        discrim_loss, _ = lib.misc.get_loss(predict_real, predict_fake, loss_type=args.loss_type)

        if args.loss_type == 'WGAN-GP':
            # Gradient Penalty
            alpha = tf.random_uniform(shape=[args.batch_size, 1, 1, 1], minval=0., maxval=1.)
            differences = outputs - targets
            interpolates = targets + (alpha * differences)
            # with tf.variable_scope("discriminator", reuse=True):
            gradients = tf.gradients(
                model.get_discriminator(inputs, interpolates, ndf=args.ndf,
                                        spectral_normed=True,
                                        update_collection=None,
                                        conv_type=args.conv_type,
                                        channel_multiplier=args.channel_multiplier,
                                        padding='VALID',
                                        net_type=args.net_type, reuse=True), [interpolates])[0]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]) + 1e-10)
            gradient_penalty = 10 * tf.reduce_mean(tf.square((slopes - 1.)))
            discrim_loss += gradient_penalty

    with tf.name_scope("g_loss"):
        # gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
        _, gen_loss_GAN = lib.misc.get_loss(predict_real, predict_fake, loss_type=args.loss_type)

        if args.g_bce:
            outputs_ = deprocess(outputs)
            targets_ = deprocess(targets)
            gen_loss_content = -tf.reduce_mean(
                targets_ * tf.log(tf.clip_by_value(outputs_, 1e-10, 1.0 - 1e-10)) +
                (1.0 - targets_) * tf.log(tf.clip_by_value(1.0 - outputs_, 1e-10, 1.0 - 1e-10)))
            # gen_loss_content = -tf.reduce_mean(
            #     targets * tf.log(tf.clip_by_value(outputs, 1e-10, 1.0)) +
            #     (1.0 - targets) * tf.log(tf.clip_by_value(1.0 - outputs, 1e-10, 1.0)))
        else:
            gen_loss_content = tf.reduce_mean(tf.abs(targets - outputs))

        gen_loss = gen_loss_GAN * args.gan_weight + gen_loss_content * args.l1_weight

    with tf.name_scope('global_step'):
        global_step = tf.train.get_or_create_global_step()
    # with tf.name_scope("global_step_summary"):
    #     tf.summary.scalar("global_step", global_step)

    with tf.name_scope('lr_decay'):
        # learning_rate = tf.train.polynomial_decay(
        #     learning_rate=args.initial_lr,
        #     global_step=global_step,
        #     decay_steps=max_steps,
        #     end_learning_rate=args.end_lr
        # )
        decay = 1.
        # decay = tf.where(
        #     tf.less(global_step, 23600), tf.maximum(0., 1. - (tf.cast(global_step, tf.float32) / 47200)), 0.5)
        # decay = tf.where(
        #     tf.less(global_step, int(max_steps * 0.5)),
        #     1.,
        #     tf.maximum(0., 1. - ((tf.cast(global_step, tf.float32) - int(max_steps * 0.5)) / max_steps)))
        if args.TTUR:
            print('\nUsing TTUR!\n')
            LR_D = tf.constant(0.0004)  # 2e-4  # Initial learning rate
            LR_G = tf.constant(0.0001)  # 2e-4  # Initial learning rate
            lr_d = LR_D * decay
            lr_g = LR_G * decay
        else:
            print('\nNot using TTUR!\n')
            LR_D = tf.constant(0.0002)  # 2e-4  # Initial learning rate
            LR_G = tf.constant(0.0002)  # 2e-4  # Initial learning rate
            lr_d = LR_D * decay
            lr_g = LR_G * decay

    # with tf.name_scope("lr_summary"):
    #     tf.summary.scalar("lr", learning_rate)

    with tf.name_scope("d_train"):
        discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("d_net")]
        discrim_optim = tf.train.AdamOptimizer(lr_d, beta1=args.beta1, beta2=args.beta2)
        # discrim_optim = tf.train.AdamOptimizer(learning_rate, beta1=args.beta1, beta2=args.beta2)
        discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
        discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)

    with tf.name_scope("g_train"):
        gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("g_net")]
        gen_optim = tf.train.AdamOptimizer(lr_g, beta1=args.beta1, beta2=args.beta2)
        # gen_optim = tf.train.AdamOptimizer(learning_rate, beta1=args.beta1, beta2=args.beta2)
        gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
        gen_train = gen_optim.apply_gradients(gen_grads_and_vars, global_step=global_step)

    ema = tf.train.ExponentialMovingAverage(decay=0.99)
    update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_content])

    # global_step = tf.train.get_or_create_global_step()
    # incr_global_step = tf.assign(global_step, global_step + 1)

    return Model(
        lr=lr_d + lr_g,
        outputs=outputs,
        predict_real=predict_real,
        predict_fake=predict_fake,
        discrim_loss=ema.average(discrim_loss),
        discrim_grads_and_vars=discrim_grads_and_vars,
        gen_loss_GAN=ema.average(gen_loss_GAN),
        gen_loss_content=ema.average(gen_loss_content),
        gen_grads_and_vars=gen_grads_and_vars,
        d_train=discrim_train,
        g_train=gen_train,
        losses=update_losses,
        global_step=global_step
    )
Beispiel #2
0
def create_model(inputs, targets, max_steps):
    model = Pix2Pix()

    out_channels = int(targets.get_shape()[-1])
    outputs = model.get_generator(inputs, out_channels, ngf=args.ngf,
                                  conv_type=args.conv_type,
                                  channel_multiplier=args.channel_multiplier,
                                  padding='SAME',
                                  net_type=args.net_type, reuse=False,
                                  upsampe_method=args.upsampe_method)

    # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
    predict_real = model.get_discriminator(inputs, targets, ndf=args.ndf,
                                           spectral_normed=True,
                                           update_collection=None,
                                           conv_type=args.conv_type,
                                           channel_multiplier=args.channel_multiplier,
                                           padding='VALID',
                                           net_type=args.net_type, reuse=False)

    # 2x [batch, height, width, channels] => [batch, 30, 30, 1]
    predict_fake = model.get_discriminator(inputs, outputs, ndf=args.ndf,
                                           spectral_normed=True,
                                           update_collection=None,
                                           conv_type=args.conv_type,
                                           channel_multiplier=args.channel_multiplier,
                                           padding='VALID',
                                           net_type=args.net_type, reuse=True)

    with tf.name_scope("d_loss"):
        # discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS)))
        discrim_loss, _ = lib.misc.get_loss(predict_real, predict_fake, loss_type=args.loss_type)

        if args.loss_type == 'WGAN-GP':
            # Gradient Penalty
            alpha = tf.random_uniform(shape=[args.batch_size, 1, 1, 1], minval=0., maxval=1.)
            differences = outputs - targets
            interpolates = targets + (alpha * differences)
            # with tf.variable_scope("discriminator", reuse=True):
            gradients = tf.gradients(
                model.get_discriminator(inputs, interpolates, ndf=args.ndf,
                                        spectral_normed=True,
                                        update_collection=None,
                                        conv_type=args.conv_type,
                                        channel_multiplier=args.channel_multiplier,
                                        padding='VALID',
                                        net_type=args.net_type, reuse=True), [interpolates])[0]
            slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1, 2, 3]) + 1e-10)
            gradient_penalty = 10 * tf.reduce_mean(tf.square((slopes - 1.)))
            discrim_loss += gradient_penalty

    with tf.name_scope("g_loss"):
        # gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS))
        _, gen_loss_GAN = lib.misc.get_loss(predict_real, predict_fake, loss_type=args.loss_type)
        gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs))
        gen_loss = gen_loss_GAN * args.gan_weight + gen_loss_L1 * args.l1_weight

    with tf.name_scope('global_step'):
        global_step = tf.train.get_or_create_global_step()
    # with tf.name_scope("global_step_summary"):
    #     tf.summary.scalar("global_step", global_step)

    with tf.name_scope('lr_decay'):
        learning_rate = tf.train.polynomial_decay(
            learning_rate=args.initial_lr,
            global_step=global_step,
            decay_steps=max_steps,
            end_learning_rate=args.end_lr
        )

    # with tf.name_scope("lr_summary"):
    #     tf.summary.scalar("lr", learning_rate)

    with tf.name_scope("d_train"):
        discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("d_net")]
        discrim_optim = tf.train.AdamOptimizer(0.0004, beta1=args.beta1, beta2=args.beta2)
        # discrim_optim = tf.train.AdamOptimizer(learning_rate, beta1=args.beta1, beta2=args.beta2)
        discrim_grads_and_vars = discrim_optim.compute_gradients(discrim_loss, var_list=discrim_tvars)
        discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)

    with tf.name_scope("g_train"):
        gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("g_net")]
        gen_optim = tf.train.AdamOptimizer(0.0001, beta1=args.beta1, beta2=args.beta2)
        # gen_optim = tf.train.AdamOptimizer(learning_rate, beta1=args.beta1, beta2=args.beta2)
        gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
        gen_train = gen_optim.apply_gradients(gen_grads_and_vars, global_step=global_step)

    ema = tf.train.ExponentialMovingAverage(decay=0.99)
    update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1])

    # global_step = tf.train.get_or_create_global_step()
    # incr_global_step = tf.assign(global_step, global_step + 1)

    return Model(
        outputs=outputs,
        predict_real=predict_real,
        predict_fake=predict_fake,
        discrim_loss=ema.average(discrim_loss),
        discrim_grads_and_vars=discrim_grads_and_vars,
        gen_loss_GAN=ema.average(gen_loss_GAN),
        gen_loss_L1=ema.average(gen_loss_L1),
        gen_grads_and_vars=gen_grads_and_vars,
        d_train=discrim_train,
        g_train=gen_train,
        losses=update_losses,
        global_step=global_step
    )
Beispiel #3
0
def train():
    if args.seed is None:
        args.seed = random.randint(0, 2**31 - 1)

    tf.set_random_seed(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.mode == "test" or args.mode == "export":
        if args.checkpoint_dir is None:
            raise Exception("checkpoint required for test mode")

        # load some options from the checkpoint
        # options = {"which_direction", "ngf", "ndf", "lab_colorization"}
        # with open(os.path.join(args.checkpoint_dir, "options.json"), 'r') as f:
        #     for key, val in json.loads(f.read()).items():
        #         if key in options:
        #             print("loaded", key, "=", val)
        #             setattr(args, key, val)
        # disable these features in test mode
        args.scale_size = CROP_SIZE
        args.flip = False

    for k, v in args._get_kwargs():
        print(k, "=", v)

    with open(os.path.join(args.output_dir, "options.json"), "w") as f:
        f.write(json.dumps(vars(args), sort_keys=True, indent=4))

    if args.mode == "export":
        # export the generator to a meta graph that can be imported later for standalone generation
        if args.lab_colorization:
            raise Exception("export not supported for lab_colorization")

        inputs = tf.placeholder(tf.string, shape=[1])
        input_data = tf.decode_base64(inputs[0])
        input_image = tf.image.decode_png(input_data)

        # remove alpha channel if present
        input_image = tf.cond(tf.equal(tf.shape(input_image)[2],
                                       4), lambda: input_image[:, :, :3],
                              lambda: input_image)
        # convert grayscale to RGB
        input_image = tf.cond(tf.equal(tf.shape(input_image)[2], 1),
                              lambda: tf.image.grayscale_to_rgb(input_image),
                              lambda: input_image)

        input_image = tf.image.convert_image_dtype(input_image,
                                                   dtype=tf.float32)
        input_image.set_shape([CROP_SIZE, CROP_SIZE, 3])
        batch_input = tf.expand_dims(input_image, axis=0)

        model_ = Pix2Pix()
        batch_output = deprocess(
            model_.get_generator(preprocess(batch_input),
                                 3,
                                 ngf=args.ngf,
                                 conv_type=args.conv_type,
                                 channel_multiplier=args.channel_multiplier,
                                 padding='SAME'))
        # with tf.variable_scope("generator"):
        #     batch_output = deprocess(model.get_generator(preprocess(batch_input), 3))

        output_image = tf.image.convert_image_dtype(batch_output,
                                                    dtype=tf.uint8)[0]
        if args.output_filetype == "png":
            output_data = tf.image.encode_png(output_image)
        elif args.output_filetype == "jpeg":
            output_data = tf.image.encode_jpeg(output_image, quality=80)
        else:
            raise Exception("invalid filetype")
        output = tf.convert_to_tensor([tf.encode_base64(output_data)])

        key = tf.placeholder(tf.string, shape=[1])
        inputs = {"key": key.name, "input": inputs.name}
        tf.add_to_collection("inputs", json.dumps(inputs))
        outputs = {
            "key": tf.identity(key).name,
            "output": output.name,
        }
        tf.add_to_collection("outputs", json.dumps(outputs))

        init_op = tf.global_variables_initializer()
        restore_saver = tf.train.Saver()
        export_saver = tf.train.Saver()

        with tf.Session() as sess:
            sess.run(init_op)
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(args.checkpoint_dir)
            restore_saver.restore(sess, checkpoint)
            print("exporting model")
            export_saver.export_meta_graph(
                filename=os.path.join(args.output_dir, "export.meta"))
            export_saver.save(sess,
                              os.path.join(args.output_dir, "export"),
                              write_meta_graph=False)

        return

    examples = load_examples()
    print("examples count = %d" % examples.count)

    max_steps = 2**32
    if args.max_epochs is not None:
        max_steps = examples.steps_per_epoch * args.max_epochs
    if args.max_steps is not None:
        max_steps = args.max_steps
    # inputs and targets are [batch_size, height, width, channels]
    modelNamedtuple = create_model(examples.inputs, examples.targets,
                                   max_steps)

    # undo colorization splitting on images that we use for display/output
    if args.lab_colorization:
        if args.which_direction == "AtoB":
            # inputs is brightness, this will be handled fine as a grayscale image
            # need to augment targets and outputs with brightness
            targets = augment(examples.targets, examples.inputs)
            outputs = augment(modelNamedtuple.outputs, examples.inputs)
            # inputs can be deprocessed normally and handled as if they are single channel
            # grayscale images
            inputs = deprocess(examples.inputs)
        elif args.which_direction == "BtoA":
            # inputs will be color channels only, get brightness from targets
            inputs = augment(examples.inputs, examples.targets)
            targets = deprocess(examples.targets)
            outputs = deprocess(modelNamedtuple.outputs)
        else:
            raise Exception("invalid direction")
    else:
        inputs = deprocess(examples.inputs)
        targets = deprocess(examples.targets)
        outputs = deprocess(modelNamedtuple.outputs)

    def convert(image):
        if args.aspect_ratio != 1.0:
            # upscale to correct aspect ratio
            size = [CROP_SIZE, int(round(CROP_SIZE * args.aspect_ratio))]
            image = tf.image.resize_images(
                image, size=size, method=tf.image.ResizeMethod.BICUBIC)

        return tf.image.convert_image_dtype(image,
                                            dtype=tf.uint8,
                                            saturate=True)

    # reverse any processing on images so they can be written to disk or displayed to user
    with tf.name_scope("convert_inputs"):
        converted_inputs = convert(inputs)

    with tf.name_scope("convert_targets"):
        converted_targets = convert(targets)

    with tf.name_scope("convert_outputs"):
        converted_outputs = convert(outputs)

    with tf.name_scope("encode_images"):
        if args.multiple_A:
            # channels = converted_inputs.shape.as_list()[3]
            converted_inputs = tf.split(converted_inputs, 2, 3)[1]
            print('\n----642----: {}\n'.format(
                converted_inputs.shape.as_list()))

        display_fetches = {
            "paths":
            examples.paths,
            "inputs":
            tf.map_fn(tf.image.encode_png,
                      converted_inputs,
                      dtype=tf.string,
                      name="input_pngs"),
            "targets":
            tf.map_fn(tf.image.encode_png,
                      converted_targets,
                      dtype=tf.string,
                      name="target_pngs"),
            "outputs":
            tf.map_fn(tf.image.encode_png,
                      converted_outputs,
                      dtype=tf.string,
                      name="output_pngs"),
        }

    # summaries
    # with tf.name_scope("inputs_summary"):
    #     tf.summary.image("inputs", converted_inputs)

    with tf.name_scope("targets_summary"):
        tf.summary.image("targets", converted_targets)

    with tf.name_scope("outputs_summary"):
        tf.summary.image("outputs", converted_outputs)

    tf.summary.scalar("discriminator_loss", modelNamedtuple.discrim_loss)
    tf.summary.scalar("generator_loss_GAN", modelNamedtuple.gen_loss_GAN)
    tf.summary.scalar("generator_loss_L1", modelNamedtuple.gen_loss_L1)

    # for var in tf.trainable_variables():
    #     tf.summary.histogram(var.op.name + "/values", var)

    for grad, var in modelNamedtuple.discrim_grads_and_vars + modelNamedtuple.gen_grads_and_vars:
        tf.summary.histogram(var.op.name + "/gradients", grad)

    with tf.name_scope("parameter_count"):
        parameter_count = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    summary_op = tf.summary.merge_all()
    saver = tf.train.Saver(max_to_keep=5)

    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        summary_writer = tf.summary.FileWriter(args.output_dir, sess.graph)
        sess.run(tf.global_variables_initializer())
        print("parameter_count =", sess.run(parameter_count))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        if args.checkpoint_dir is not None:
            print("loading model from checkpoint")
            checkpoint = tf.train.latest_checkpoint(args.checkpoint_dir)
            saver.restore(sess, checkpoint)

        # max_steps = 2 ** 32
        # if args.max_epochs is not None:
        #     max_steps = examples.steps_per_epoch * args.max_epochs
        # if args.max_steps is not None:
        #     max_steps = args.max_steps

        if args.mode == "test":
            # testing
            # at most, process the test data once
            start = time.time()
            max_steps = min(examples.steps_per_epoch, max_steps)
            for step in range(max_steps):
                results = sess.run(display_fetches)
                filesets = save_images(results)
                for i, f in enumerate(filesets):
                    print("evaluated image", f["name"])
                index_path = append_index(filesets)
            print("wrote index at", index_path)
            print("rate", (time.time() - start) / max_steps)
        else:
            # training
            start = time.time()

            for step in range(max_steps):

                def should(freq):
                    return freq > 0 and ((step + 1) % freq == 0
                                         or step == max_steps - 1)

                for i in range(args.n_dis):
                    sess.run(modelNamedtuple.d_train)

                fetches = {
                    "g_train": modelNamedtuple.g_train,
                    "losses": modelNamedtuple.losses,
                    "global_step": modelNamedtuple.global_step,
                }

                if should(args.progress_freq):
                    fetches["discrim_loss"] = modelNamedtuple.discrim_loss
                    fetches["gen_loss_GAN"] = modelNamedtuple.gen_loss_GAN
                    fetches["gen_loss_L1"] = modelNamedtuple.gen_loss_L1

                if should(args.summary_freq):
                    fetches["summary"] = summary_op

                if should(args.display_freq):
                    fetches["display"] = display_fetches

                # results = sess.run(fetches, options=options, run_metadata=run_metadata)
                results = sess.run(fetches)

                if should(args.summary_freq):
                    # print("recording summary")
                    summary_writer.add_summary(results["summary"],
                                               results["global_step"])

                if should(args.display_freq):
                    # print("saving display images")
                    filesets = save_images(results["display"],
                                           step=results["global_step"])
                    append_index(filesets, step=True)

                if should(args.progress_freq):
                    # global_step will have the correct step count if we resume from a checkpoint
                    train_epoch = math.ceil(results["global_step"] /
                                            examples.steps_per_epoch)
                    train_step = (results["global_step"] -
                                  1) % examples.steps_per_epoch + 1
                    rate = (step + 1) * args.batch_size / (time.time() - start)
                    remaining = (max_steps - step) * args.batch_size / rate

                    print(
                        "progress  epoch %d  step %d  image/sec %0.1f  remaining %dm"
                        % (train_epoch, train_step, rate, remaining / 60))
                    print("discrim_loss", results["discrim_loss"])
                    print("gen_loss_GAN", results["gen_loss_GAN"])
                    print("gen_loss_L1", results["gen_loss_L1"])

                if should(args.save_freq):
                    print("saving model...")
                    saver.save(sess,
                               os.path.join(args.output_dir, "model"),
                               global_step=modelNamedtuple.global_step)

        coord.request_stop()
        coord.join(threads)