Example #1
0
    def discriminate(self, conv, reuse=False, pg=1, t=False, alpha_trans=0.01):
        #dis_as_v = []
        with tf.variable_scope("discriminator") as scope:

            if reuse == True:
                scope.reuse_variables()
            if t:
                conv_iden = downscale2d(conv)
                #from RGB
                conv_iden = lrelu(conv2d(conv_iden, output_dim= self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, use_wscale=self.use_wscale,
                           name='dis_y_rgb_conv_{}'.format(conv_iden.shape[1])))
            # fromRGB
            conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, use_wscale=self.use_wscale, name='dis_y_rgb_conv_{}'.format(conv.shape[1])))

            for i in range(pg - 1):
                conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1 - i), d_h=1, d_w=1, use_wscale=self.use_wscale,
                                    name='dis_n_conv_1_{}'.format(conv.shape[1])))
                conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, use_wscale=self.use_wscale,
                                                      name='dis_n_conv_2_{}'.format(conv.shape[1])))
                conv = downscale2d(conv)
                if i == 0 and t:
                    conv = alpha_trans * conv + (1 - alpha_trans) * conv_iden

            conv = MinibatchstateConcat(conv)
            conv = lrelu(
                conv2d(conv, output_dim=self.get_nf(1), k_w=3, k_h=3, d_h=1, d_w=1, use_wscale=self.use_wscale, name='dis_n_conv_1_{}'.format(conv.shape[1])))
            conv = lrelu(
                conv2d(conv, output_dim=self.get_nf(1), k_w=4, k_h=4, d_h=1, d_w=1, use_wscale=self.use_wscale, padding='VALID', name='dis_n_conv_2_{}'.format(conv.shape[1])))
            conv = tf.reshape(conv, [self.batch_size, -1])

            #for D
            output = fully_connect(conv, output_size=1, use_wscale=self.use_wscale, gain=1, name='dis_n_fully')

            return tf.nn.sigmoid(output), output
Example #2
0
    def discriminator(self, incom_x, local_x, pg=1, is_trans=False, alpha_trans=0.01, reuse=False):

        with tf.variable_scope("discriminator") as scope:
            if reuse == True:
                scope.reuse_variables()

            #global discriminator
            x = incom_x
            if is_trans:
                x_trans = downscale2d(x)
                #from rgb
                x_trans = lrelu(conv2d(x_trans, output_dim=self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, use_sp=self.use_sp,
                                       name='dis_rgb_g_{}'.format(x_trans.shape[1])))
            x = lrelu(conv2d(x, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, use_sp=self.use_sp,
                             name='dis_rgb_g_{}'.format(x.shape[1])))
            for i in range(pg - 1):
                x = lrelu(conv2d(x, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, use_sp=self.use_sp,
                                 name='dis_conv_g_{}'.format(x.shape[1])))
                x = downscale2d(x)
                if i == 0 and is_trans:
                    x = alpha_trans * x + (1 - alpha_trans) * x_trans
            x = lrelu(conv2d(x, output_dim=self.get_nf(1), k_h=3, k_w=3, d_h=1, d_w=1, use_sp=self.use_sp,
                             name='dis_conv_g_1_{}'.format(x.shape[1])))
            x = tf.reshape(x, [self.batch_size, -1])
            x_g = fully_connect(x, output_size=256, use_sp=self.use_sp, name='dis_conv_g_fully')

            #local discriminator
            x = local_x
            if is_trans:
                x_trans = downscale2d(x)
                #from rgb
                x_trans = lrelu(conv2d(x_trans, output_dim=self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, use_sp=self.use_sp,
                                       name='dis_rgb_l_{}'.format(x_trans.shape[1])))
            x = lrelu(conv2d(x, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, use_sp=self.use_sp,
                             name='dis_rgb_l_{}'.format(x.shape[1])))

            for i in range(pg - 1):
                x = lrelu(conv2d(x, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, use_sp=self.use_sp,
                                 name='dis_conv_l_{}'.format(x.shape[1])))
                x= downscale2d(x)
                if i == 0 and is_trans:
                    x = alpha_trans * x + (1 - alpha_trans) * x_trans

            x = lrelu(conv2d(x, output_dim=self.get_nf(1), k_h=3, k_w=3, d_h=1, d_w=1, use_sp=self.use_sp,
                             name='dis_conv_l_1_{}'.format(x.shape[1])))
            x = tf.reshape(x, [self.batch_size, -1])
            x_l = fully_connect(x, output_size=256, use_sp=self.use_sp, name='dis_conv_l_fully')

            logits = fully_connect(tf.concat([x_g, x_l], axis=1), output_size=1, use_sp=self.use_sp, name='dis_conv_fully')

            return logits
Example #3
0
    def encode_decode2(self, x, img_mask, pg=1, is_trans=False, alpha_trans=0.01, reuse=False):

        with tf.variable_scope("ed") as scope:

            if reuse == True:
                scope.reuse_variables()

            x = tf.concat([x, img_mask], axis=3)
            if is_trans:
                x_trans = downscale2d(x)
                #fromrgb
                x_trans = tf.nn.relu(instance_norm(conv2d(x_trans, output_dim=self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1,
                                       name='gen_rgb_e_{}'.format(x_trans.shape[1])), scope='gen_rgb_e_in_{}'.format(x_trans.shape[1])))
            #fromrgb
            x = tf.nn.relu(instance_norm(conv2d(x, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_h=1, d_w=1,
                             name='gen_rgb_e_{}'.format(x.shape[1])), scope='gen_rgb_e_in_{}'.format(x.shape[1])))
            for i in range(pg - 1):
                print "encode", x.shape
                x = tf.nn.relu(instance_norm(conv2d(x, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1,
                                 name='gen_conv_e_{}'.format(x.shape[1])), scope='gen_conv_e_in_{}'.format(x.shape[1])))
                x = downscale2d(x)
                if i == 0 and is_trans:
                    x = alpha_trans * x + (1 - alpha_trans) * x_trans
            up_x = tf.nn.relu(
                instance_norm(dilated_conv2d(x, output_dim=512, k_w=3, k_h=3, rate=4, name='gen_conv_dilated'),
                              scope='gen_conv_in'))
            up_x = tf.nn.relu(instance_norm(conv2d(up_x, output_dim=self.get_nf(1), d_w=1, d_h=1, name='gen_conv_d'),
                                       scope='gen_conv_d_in_{}'.format(x.shape[1])))
            for i in range(pg - 1):

                print "decode", up_x.shape
                if i == pg - 2 and is_trans:
                    #torgb
                    up_x_trans = conv2d(up_x, output_dim=self.channel, k_w=1, k_h=1, d_w=1, d_h=1,
                                        name='gen_rgb_d_{}'.format(up_x.shape[1]))
                    up_x_trans = upscale(up_x_trans, 2)

                up_x = upscale(up_x, 2)
                up_x = tf.nn.relu(instance_norm(conv2d(up_x, output_dim=self.get_nf(i + 1), d_w=1, d_h=1,
                                name='gen_conv_d_{}'.format(up_x.shape[1])), scope='gen_conv_d_in_{}'.format(up_x.shape[1])))
            #torgb
            up_x = conv2d(up_x, output_dim=self.channel, k_w=1, k_h=1, d_w=1, d_h=1,
                        name='gen_rgb_d_{}'.format(up_x.shape[1]))
            if pg == 1: up_x = up_x
            else:
                if is_trans: up_x = (1 - alpha_trans) * up_x_trans + alpha_trans * up_x
                else:
                    up_x = up_x
            return up_x
Example #4
0
def enc(x, start_res, end_res, scope='Encoder'):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        res = end_res
        if res > start_res:
            x1 = ops.downscale2d(x, 'NHWC')
            x1 = ops.from_rgb('rgb_' + rname(res // 2), x1, fn(res // 2),
                              'NHWC')
            x2 = ops.from_rgb('rgb_' + rname(res), x, fn(res // 2), 'NHWC')
            t = tf.get_variable(
                rname(res) + '_t',
                shape=[],
                dtype=tf.float32,
                collections=[tf.GraphKeys.GLOBAL_VARIABLES, "lerp"],
                initializer=tf.zeros_initializer(),
                trainable=False)
            x2 = block_dn(x2, fn(res), fn(res // 2), 3, rname(res))
            x = ops.lerp_clip(x1, x2, t)
            res = res // 2
        else:
            x = ops.from_rgb('rgb_' + rname(res), x, fn(res), 'NHWC')

        while res >= 4:
            x = block_dn(x, fn(res), fn(res // 2), 3, rname(res))
            res = res // 2

        x = tf.layers.flatten(x)
        x = ops.dense('fc1', x, 512, 'NHWC')
        mean, std = tf.split(x, 2, 1)
        return mean, std
Example #5
0
def discriminator(x, resolution, cfg, is_training=True, scope='Discriminator'):
    assert (cfg.data_format == 'NCHW' or cfg.data_format == 'NHWC')

    def rname(resolution):
        return str(resolution) + 'x' + str(resolution)

    def fmap(resolution):
        return cfg.resolution_to_filt_num[resolution]

    x_shape = utils.int_shape(x)
    assert (resolution == x_shape[1 if cfg.data_format == 'NHWC' else 3])
    assert (resolution == x_shape[2])
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        if resolution > cfg.starting_resolution:
            x1 = ops.downscale2d(x, cfg.data_format)
            x1 = ops.from_rgb('from_rgb_' + rname(resolution // 2), x1,
                              fmap(resolution // 2), cfg.data_format)
            x2 = ops.from_rgb('from_rgb_' + rname(resolution), x,
                              fmap(resolution // 2), cfg.data_format)
            t = tf.get_variable(
                rname(resolution) + '_t',
                shape=[],
                dtype=tf.float32,
                collections=[tf.GraphKeys.GLOBAL_VARIABLES, "lerp"],
                initializer=tf.zeros_initializer(),
                trainable=False)
            num_filters = [fmap(resolution), fmap(resolution // 2)]
            x2 = dblock(rname(resolution), x2, num_filters, cfg.data_format)
            x = ops.lerp_clip(x1, x2, t)
            resolution = resolution // 2
        else:
            x = ops.from_rgb('from_rgb_' + rname(resolution), x,
                             fmap(resolution), cfg.data_format)
        while resolution >= 4:
            if resolution == 4:
                x = ops.minibatch_stddev_layer(x, cfg.data_format)
            num_filters = [fmap(resolution), fmap(resolution // 2)]
            x = dblock(rname(resolution), x, num_filters, cfg.data_format)
            resolution = resolution // 2

        x = ops.dense('2x2', x, fmap(resolution), cfg.data_format)
        x = ops.leaky_relu(x)

        x = ops.dense('output', x, 1, cfg.data_format)

        return x
Example #6
0
def main(unused_argv):
    logging.info("************ Parameters summary ************")
    logging.info("Number of epochs     : " + str(params.epochs))
    logging.info("Batch size           : " + str(params.batchsize))
    logging.info("Adam learning rate   : " + str(params.adam_lr))
    logging.info("Adam beta1           : " + str(params.adam_b1))
    logging.info("L1 loss weight       : " + str(params.l1weight))
    logging.info("L2 loss weight       : " + str(params.l2weight))
    logging.info("GAN loss weight      : " + str(params.ganweight))
    if params.vggfile is not None:
        logging.info("VGG file             : " + str(params.vggfile))
        logging.info("VGG loss weight      : " + str(params.vggweight))
        logging.info("VGG features         : " + str(params.vggfeatures))
    logging.info("Base depth           : " + str(params.depth))
    logging.info("Number of ResBlocks  : " + str(params.nresblocks))
    logging.info("Low-res image scale  : " + str(params.lr_scale))
    logging.info("Hi-res image scale   : " + str(params.hr_scale))
    logging.info("********************************************")

    # Preview
    lr_image_for_prev = None
    if params.preview is not None:
        lr_image_for_prev = otbtf.read_as_np_arr(otbtf.gdal_open(params.preview), False)

    with tf.Graph().as_default():
        # dataset and iterator
        ds = otbtf.DatasetFromPatchesImages(filenames_dict={constants.hr_key: params.hr_patches,
                                                            constants.lr_key: params.lr_patches},
                                            use_streaming=params.streaming)
        tf_ds = ds.get_tf_dataset(batch_size=params.batchsize)
        iterator = tf.compat.v1.data.Iterator.from_structure(ds.output_types)
        iterator_init = iterator.make_initializer(tf_ds)
        dataset_inputs = iterator.get_next()

        # model inputs
        def _get_input(key, name):
            default_input = dataset_inputs[key]
            shape = (None, None, None, ds.output_shapes[key][-1])
            return tf.compat.v1.placeholder_with_default(default_input, shape=shape, name=name)

        lr_image = _get_input(constants.lr_key, constants.lr_input_name)
        hr_image = _get_input(constants.hr_key, constants.hr_input_name)

        # model
        hr_nch = ds.output_shapes[constants.hr_key][-1]
        generator = partial(network.generator, scope=constants.gen_scope, nchannels=hr_nch,
                            nresblocks=params.nresblocks, dim=params.depth)
        discriminator = partial(network.discriminator, scope=constants.dis_scope, dim=params.depth)

        hr_images_real = {factor: params.hr_scale * downscale2d(hr_image, factor=factor) for factor in constants.factors}
        hr_images_fake = generator(params.lr_scale * lr_image)

        # model outputs
        gen = {factor: (1.0 / params.hr_scale) * hr_images_fake[factor] for factor in constants.factors}
        for pad in constants.pads:
            tf.identity(gen[1][:, pad:-pad, pad:-pad, :], name="{}{}".format(constants.outputs_prefix, pad))
        if lr_image_for_prev is not None:
            for factor in constants.factors:
                prev = network.nice_preview(gen[factor])
                tf.compat.v1.summary.image("preview_factor{}".format(factor), prev, collections=[constants.epoch_key])

        # discriminator
        dis_real = discriminator(hr_images=hr_images_real)
        dis_fake = discriminator(hr_images=hr_images_fake)

        # l1 loss
        gen_loss_l1 = tf.add_n([tf.reduce_mean(tf.abs(hr_images_fake[factor] -
                                                      hr_images_real[factor])) for factor in constants.factors])

        # l2 loss
        gen_loss_l2 = tf.add_n([tf.reduce_mean(tf.square(hr_images_fake[factor] -
                                                         hr_images_real[factor])) for factor in constants.factors])

        # VGG loss
        gen_loss_vgg = 0.0
        if params.vggfile is not None:
            gen_loss_vgg = tf.add_n([compute_vgg_loss(hr_images_real[factor],
                                                      hr_images_fake[factor],
                                                      params.vggfeatures,
                                                      params.vggfile) for factor in constants.factors])

        # GAN Losses
        if params.losstype == "LSGAN":
            dis_loss = tf.reduce_mean(tf.square(dis_real - 1) + tf.square(dis_fake))
            gen_loss_gan = tf.reduce_mean(tf.square(dis_fake - 1))
        elif params.losstype == "WGAN-GP":
            dis_loss = dis_fake - dis_real
            alpha = tf.random_uniform(shape=[params.batchsize, 1, 1, 1], minval=0., maxval=1.)
            differences = {factor: hr_images_fake[factor] - hr_images_real[factor] for factor in constants.factors}
            interpolates_scales = {factor: hr_images_real[factor] +
                                           alpha * differences[factor] for factor in constants.factors}
            mixed_loss = tf.reduce_sum(discriminator(interpolates_scales))
            mixed_grads = tf.gradients(mixed_loss, list(interpolates_scales.values()))
            mixed_norms = [tf.sqrt(tf.reduce_sum(tf.square(gradient), reduction_indices=[1, 2, 3])) for gradient in
                           mixed_grads]
            gradient_penalties = [tf.reduce_mean(tf.square(slope - 1.0)) for slope in mixed_norms]
            gradient_penalty = tf.reduce_mean(gradient_penalties)
            dis_loss += 10 * gradient_penalty
            epsilon_penalty = tf.reduce_mean(tf.square(dis_real))
            dis_loss += 0.001 * epsilon_penalty
            gen_loss_gan = -1.0 * tf.reduce_mean(dis_fake)
            dis_loss = tf.reduce_mean(dis_loss)
        else:
            raise Exception("Please select an available cost function")

        # Total losses
        def _new_loss(value, name, collections=None):
            tf.compat.v1.summary.scalar(name, value, collections)
            return value

        train_collections = [constants.train_key]
        all_collections = [constants.pretrain_key, constants.train_key]
        gen_loss = _new_loss(params.ganweight * gen_loss_gan, "gen_loss_gan", train_collections)
        gen_loss += _new_loss(params.vggweight * gen_loss_vgg, "gen_loss_vgg", train_collections)
        pretrain_loss = _new_loss(params.l1weight * gen_loss_l1, "gen_loss_l1", all_collections)
        pretrain_loss += _new_loss(params.l2weight * gen_loss_l2, "gen_loss_l2", all_collections)
        gen_loss += pretrain_loss
        dis_loss = _new_loss(dis_loss, "dis_loss", train_collections)

        # discriminator optimizer
        dis_optim = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr, beta1=params.adam_b1)
        dis_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, constants.dis_scope)
        dis_grads_and_vars = dis_optim.compute_gradients(dis_loss, var_list=dis_tvars)
        with tf.compat.v1.variable_scope("apply_dis_gradients", reuse=tf.compat.v1.AUTO_REUSE):
            dis_train = dis_optim.apply_gradients(dis_grads_and_vars)

        # generator optimizer
        with tf.control_dependencies([dis_train]):
            gen_optim = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr, beta1=params.adam_b1)
            gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, constants.gen_scope)
            gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars)
            with tf.compat.v1.variable_scope("apply_gen_gradients", reuse=tf.compat.v1.AUTO_REUSE):
                gen_train = gen_optim.apply_gradients(gen_grads_and_vars)

        pretrain_op = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr).minimize(pretrain_loss)
        train_nodes = [gen_train]
        if params.losstype == "LSGAN":
            ema = tf.train.ExponentialMovingAverage(decay=0.995)
            update_losses = ema.apply([dis_loss, gen_loss])
            train_nodes.append(update_losses)
        train_op = tf.group(train_nodes, name="optimizer")

        merged_losses_summaries = tf.compat.v1.summary.merge_all(key=constants.train_key)
        merged_pretrain_summaries = tf.compat.v1.summary.merge_all(key=constants.pretrain_key)
        merged_preview_summaries = tf.compat.v1.summary.merge_all(key=constants.epoch_key)

        init = tf.global_variables_initializer()
        saver = tf.compat.v1.train.Saver(max_to_keep=5)

        sess = tf.Session()

        # Writer
        def _append_desc(key, value):
            if value == 0:
                return ""
            return "_{}{}".format(key, value)

        now = datetime.datetime.now()
        summaries_fn = "SR4RS"
        summaries_fn += _append_desc("E", params.epochs)
        summaries_fn += _append_desc("B", params.batchsize)
        summaries_fn += _append_desc("LR", params.adam_lr)
        summaries_fn += _append_desc("Gan", params.ganweight)
        summaries_fn += _append_desc("L1-", params.l1weight)
        summaries_fn += _append_desc("L2-", params.l2weight)
        summaries_fn += _append_desc("VGG", params.vggweight)
        summaries_fn += _append_desc("VGGFeat", params.vggfeatures)
        summaries_fn += _append_desc("Loss", params.losstype)
        summaries_fn += _append_desc("D", params.depth)
        summaries_fn += _append_desc("RB", params.nresblocks)
        summaries_fn += _append_desc("LRSC", params.lr_scale)
        summaries_fn += _append_desc("HRSC", params.hr_scale)
        if params.pretrain:
            summaries_fn += "pretrained"
        summaries_fn += "_{}{}_{}h{}min".format(now.day, now.strftime("%b"), now.hour, now.minute)

        train_writer = None
        if params.logdir is not None:
            train_writer = tf.summary.FileWriter(params.logdir + summaries_fn, sess.graph)

        def _add_summary(summarized, _step):
            if train_writer is not None:
                train_writer.add_summary(summarized, _step)

        sess.run(init)
        if params.load_ckpt is not None:
            saver.restore(sess, params.load_ckpt)

        # preview
        def _preview(_step):
            if lr_image_for_prev is not None and step % params.previews_step == 0:
                summary_pe = sess.run(merged_preview_summaries, {lr_image: lr_image_for_prev})
                _add_summary(summary_pe, _step)


        def _do(_train_op, _summary_op, name):
            global step
            for curr_epoch in range(params.epochs):
                logging.info("{} Epoch #{}".format(name, curr_epoch))
                sess.run(iterator_init)
                try:
                    while True:
                        _, _summary = sess.run([_train_op, _summary_op])
                        _add_summary(_summary, step)
                        _preview(curr_epoch)
                        step += 1
                except tf.errors.OutOfRangeError:
                    fs_stall_duration = ds.get_total_wait_in_seconds()
                    logging.info("{}: one epoch done. Total FS stall: {:.2f}s".format(name, fs_stall_duration))
                    pass
                saver.save(sess, params.save_ckpt + summaries_fn, global_step=curr_epoch)

        # pre training
        if params.pretrain:
            _do(pretrain_op, merged_pretrain_summaries, "pre-training")

        # training
        _do(train_op, merged_losses_summaries, "training")

        # cleaning
        if train_writer is not None:
            train_writer.close()

        # Export SavedModel
        if params.savedmodel is not None:
            logging.info("Export SavedModel in {}".format(params.savedmodel))
            outputs = ["{}{}:0".format(constants.outputs_prefix, pad) for pad in constants.pads]
            inputs = ["{}:0".format(constants.lr_input_name)]
            graph = tf.get_default_graph()
            tf.saved_model.simple_save(sess,
                                       params.savedmodel,
                                       inputs={i: graph.get_tensor_by_name(i) for i in inputs},
                                       outputs={o: graph.get_tensor_by_name(o) for o in outputs})

    quit()