def discriminate(self, conv, reuse=False, pg=1, t=False, alpha_trans=0.01): #dis_as_v = [] with tf.variable_scope("discriminator") as scope: if reuse == True: scope.reuse_variables() if t: conv_iden = downscale2d(conv) #from RGB conv_iden = lrelu(conv2d(conv_iden, output_dim= self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, use_wscale=self.use_wscale, name='dis_y_rgb_conv_{}'.format(conv_iden.shape[1]))) # fromRGB conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, use_wscale=self.use_wscale, name='dis_y_rgb_conv_{}'.format(conv.shape[1]))) for i in range(pg - 1): conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 1 - i), d_h=1, d_w=1, use_wscale=self.use_wscale, name='dis_n_conv_1_{}'.format(conv.shape[1]))) conv = lrelu(conv2d(conv, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, use_wscale=self.use_wscale, name='dis_n_conv_2_{}'.format(conv.shape[1]))) conv = downscale2d(conv) if i == 0 and t: conv = alpha_trans * conv + (1 - alpha_trans) * conv_iden conv = MinibatchstateConcat(conv) conv = lrelu( conv2d(conv, output_dim=self.get_nf(1), k_w=3, k_h=3, d_h=1, d_w=1, use_wscale=self.use_wscale, name='dis_n_conv_1_{}'.format(conv.shape[1]))) conv = lrelu( conv2d(conv, output_dim=self.get_nf(1), k_w=4, k_h=4, d_h=1, d_w=1, use_wscale=self.use_wscale, padding='VALID', name='dis_n_conv_2_{}'.format(conv.shape[1]))) conv = tf.reshape(conv, [self.batch_size, -1]) #for D output = fully_connect(conv, output_size=1, use_wscale=self.use_wscale, gain=1, name='dis_n_fully') return tf.nn.sigmoid(output), output
def discriminator(self, incom_x, local_x, pg=1, is_trans=False, alpha_trans=0.01, reuse=False): with tf.variable_scope("discriminator") as scope: if reuse == True: scope.reuse_variables() #global discriminator x = incom_x if is_trans: x_trans = downscale2d(x) #from rgb x_trans = lrelu(conv2d(x_trans, output_dim=self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, use_sp=self.use_sp, name='dis_rgb_g_{}'.format(x_trans.shape[1]))) x = lrelu(conv2d(x, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, use_sp=self.use_sp, name='dis_rgb_g_{}'.format(x.shape[1]))) for i in range(pg - 1): x = lrelu(conv2d(x, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, use_sp=self.use_sp, name='dis_conv_g_{}'.format(x.shape[1]))) x = downscale2d(x) if i == 0 and is_trans: x = alpha_trans * x + (1 - alpha_trans) * x_trans x = lrelu(conv2d(x, output_dim=self.get_nf(1), k_h=3, k_w=3, d_h=1, d_w=1, use_sp=self.use_sp, name='dis_conv_g_1_{}'.format(x.shape[1]))) x = tf.reshape(x, [self.batch_size, -1]) x_g = fully_connect(x, output_size=256, use_sp=self.use_sp, name='dis_conv_g_fully') #local discriminator x = local_x if is_trans: x_trans = downscale2d(x) #from rgb x_trans = lrelu(conv2d(x_trans, output_dim=self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, use_sp=self.use_sp, name='dis_rgb_l_{}'.format(x_trans.shape[1]))) x = lrelu(conv2d(x, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_w=1, d_h=1, use_sp=self.use_sp, name='dis_rgb_l_{}'.format(x.shape[1]))) for i in range(pg - 1): x = lrelu(conv2d(x, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, use_sp=self.use_sp, name='dis_conv_l_{}'.format(x.shape[1]))) x= downscale2d(x) if i == 0 and is_trans: x = alpha_trans * x + (1 - alpha_trans) * x_trans x = lrelu(conv2d(x, output_dim=self.get_nf(1), k_h=3, k_w=3, d_h=1, d_w=1, use_sp=self.use_sp, name='dis_conv_l_1_{}'.format(x.shape[1]))) x = tf.reshape(x, [self.batch_size, -1]) x_l = fully_connect(x, output_size=256, use_sp=self.use_sp, name='dis_conv_l_fully') logits = fully_connect(tf.concat([x_g, x_l], axis=1), output_size=1, use_sp=self.use_sp, name='dis_conv_fully') return logits
def encode_decode2(self, x, img_mask, pg=1, is_trans=False, alpha_trans=0.01, reuse=False): with tf.variable_scope("ed") as scope: if reuse == True: scope.reuse_variables() x = tf.concat([x, img_mask], axis=3) if is_trans: x_trans = downscale2d(x) #fromrgb x_trans = tf.nn.relu(instance_norm(conv2d(x_trans, output_dim=self.get_nf(pg - 2), k_w=1, k_h=1, d_h=1, d_w=1, name='gen_rgb_e_{}'.format(x_trans.shape[1])), scope='gen_rgb_e_in_{}'.format(x_trans.shape[1]))) #fromrgb x = tf.nn.relu(instance_norm(conv2d(x, output_dim=self.get_nf(pg - 1), k_w=1, k_h=1, d_h=1, d_w=1, name='gen_rgb_e_{}'.format(x.shape[1])), scope='gen_rgb_e_in_{}'.format(x.shape[1]))) for i in range(pg - 1): print "encode", x.shape x = tf.nn.relu(instance_norm(conv2d(x, output_dim=self.get_nf(pg - 2 - i), d_h=1, d_w=1, name='gen_conv_e_{}'.format(x.shape[1])), scope='gen_conv_e_in_{}'.format(x.shape[1]))) x = downscale2d(x) if i == 0 and is_trans: x = alpha_trans * x + (1 - alpha_trans) * x_trans up_x = tf.nn.relu( instance_norm(dilated_conv2d(x, output_dim=512, k_w=3, k_h=3, rate=4, name='gen_conv_dilated'), scope='gen_conv_in')) up_x = tf.nn.relu(instance_norm(conv2d(up_x, output_dim=self.get_nf(1), d_w=1, d_h=1, name='gen_conv_d'), scope='gen_conv_d_in_{}'.format(x.shape[1]))) for i in range(pg - 1): print "decode", up_x.shape if i == pg - 2 and is_trans: #torgb up_x_trans = conv2d(up_x, output_dim=self.channel, k_w=1, k_h=1, d_w=1, d_h=1, name='gen_rgb_d_{}'.format(up_x.shape[1])) up_x_trans = upscale(up_x_trans, 2) up_x = upscale(up_x, 2) up_x = tf.nn.relu(instance_norm(conv2d(up_x, output_dim=self.get_nf(i + 1), d_w=1, d_h=1, name='gen_conv_d_{}'.format(up_x.shape[1])), scope='gen_conv_d_in_{}'.format(up_x.shape[1]))) #torgb up_x = conv2d(up_x, output_dim=self.channel, k_w=1, k_h=1, d_w=1, d_h=1, name='gen_rgb_d_{}'.format(up_x.shape[1])) if pg == 1: up_x = up_x else: if is_trans: up_x = (1 - alpha_trans) * up_x_trans + alpha_trans * up_x else: up_x = up_x return up_x
def enc(x, start_res, end_res, scope='Encoder'): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): res = end_res if res > start_res: x1 = ops.downscale2d(x, 'NHWC') x1 = ops.from_rgb('rgb_' + rname(res // 2), x1, fn(res // 2), 'NHWC') x2 = ops.from_rgb('rgb_' + rname(res), x, fn(res // 2), 'NHWC') t = tf.get_variable( rname(res) + '_t', shape=[], dtype=tf.float32, collections=[tf.GraphKeys.GLOBAL_VARIABLES, "lerp"], initializer=tf.zeros_initializer(), trainable=False) x2 = block_dn(x2, fn(res), fn(res // 2), 3, rname(res)) x = ops.lerp_clip(x1, x2, t) res = res // 2 else: x = ops.from_rgb('rgb_' + rname(res), x, fn(res), 'NHWC') while res >= 4: x = block_dn(x, fn(res), fn(res // 2), 3, rname(res)) res = res // 2 x = tf.layers.flatten(x) x = ops.dense('fc1', x, 512, 'NHWC') mean, std = tf.split(x, 2, 1) return mean, std
def discriminator(x, resolution, cfg, is_training=True, scope='Discriminator'): assert (cfg.data_format == 'NCHW' or cfg.data_format == 'NHWC') def rname(resolution): return str(resolution) + 'x' + str(resolution) def fmap(resolution): return cfg.resolution_to_filt_num[resolution] x_shape = utils.int_shape(x) assert (resolution == x_shape[1 if cfg.data_format == 'NHWC' else 3]) assert (resolution == x_shape[2]) with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if resolution > cfg.starting_resolution: x1 = ops.downscale2d(x, cfg.data_format) x1 = ops.from_rgb('from_rgb_' + rname(resolution // 2), x1, fmap(resolution // 2), cfg.data_format) x2 = ops.from_rgb('from_rgb_' + rname(resolution), x, fmap(resolution // 2), cfg.data_format) t = tf.get_variable( rname(resolution) + '_t', shape=[], dtype=tf.float32, collections=[tf.GraphKeys.GLOBAL_VARIABLES, "lerp"], initializer=tf.zeros_initializer(), trainable=False) num_filters = [fmap(resolution), fmap(resolution // 2)] x2 = dblock(rname(resolution), x2, num_filters, cfg.data_format) x = ops.lerp_clip(x1, x2, t) resolution = resolution // 2 else: x = ops.from_rgb('from_rgb_' + rname(resolution), x, fmap(resolution), cfg.data_format) while resolution >= 4: if resolution == 4: x = ops.minibatch_stddev_layer(x, cfg.data_format) num_filters = [fmap(resolution), fmap(resolution // 2)] x = dblock(rname(resolution), x, num_filters, cfg.data_format) resolution = resolution // 2 x = ops.dense('2x2', x, fmap(resolution), cfg.data_format) x = ops.leaky_relu(x) x = ops.dense('output', x, 1, cfg.data_format) return x
def main(unused_argv): logging.info("************ Parameters summary ************") logging.info("Number of epochs : " + str(params.epochs)) logging.info("Batch size : " + str(params.batchsize)) logging.info("Adam learning rate : " + str(params.adam_lr)) logging.info("Adam beta1 : " + str(params.adam_b1)) logging.info("L1 loss weight : " + str(params.l1weight)) logging.info("L2 loss weight : " + str(params.l2weight)) logging.info("GAN loss weight : " + str(params.ganweight)) if params.vggfile is not None: logging.info("VGG file : " + str(params.vggfile)) logging.info("VGG loss weight : " + str(params.vggweight)) logging.info("VGG features : " + str(params.vggfeatures)) logging.info("Base depth : " + str(params.depth)) logging.info("Number of ResBlocks : " + str(params.nresblocks)) logging.info("Low-res image scale : " + str(params.lr_scale)) logging.info("Hi-res image scale : " + str(params.hr_scale)) logging.info("********************************************") # Preview lr_image_for_prev = None if params.preview is not None: lr_image_for_prev = otbtf.read_as_np_arr(otbtf.gdal_open(params.preview), False) with tf.Graph().as_default(): # dataset and iterator ds = otbtf.DatasetFromPatchesImages(filenames_dict={constants.hr_key: params.hr_patches, constants.lr_key: params.lr_patches}, use_streaming=params.streaming) tf_ds = ds.get_tf_dataset(batch_size=params.batchsize) iterator = tf.compat.v1.data.Iterator.from_structure(ds.output_types) iterator_init = iterator.make_initializer(tf_ds) dataset_inputs = iterator.get_next() # model inputs def _get_input(key, name): default_input = dataset_inputs[key] shape = (None, None, None, ds.output_shapes[key][-1]) return tf.compat.v1.placeholder_with_default(default_input, shape=shape, name=name) lr_image = _get_input(constants.lr_key, constants.lr_input_name) hr_image = _get_input(constants.hr_key, constants.hr_input_name) # model hr_nch = ds.output_shapes[constants.hr_key][-1] generator = partial(network.generator, scope=constants.gen_scope, nchannels=hr_nch, nresblocks=params.nresblocks, dim=params.depth) discriminator = partial(network.discriminator, scope=constants.dis_scope, dim=params.depth) hr_images_real = {factor: params.hr_scale * downscale2d(hr_image, factor=factor) for factor in constants.factors} hr_images_fake = generator(params.lr_scale * lr_image) # model outputs gen = {factor: (1.0 / params.hr_scale) * hr_images_fake[factor] for factor in constants.factors} for pad in constants.pads: tf.identity(gen[1][:, pad:-pad, pad:-pad, :], name="{}{}".format(constants.outputs_prefix, pad)) if lr_image_for_prev is not None: for factor in constants.factors: prev = network.nice_preview(gen[factor]) tf.compat.v1.summary.image("preview_factor{}".format(factor), prev, collections=[constants.epoch_key]) # discriminator dis_real = discriminator(hr_images=hr_images_real) dis_fake = discriminator(hr_images=hr_images_fake) # l1 loss gen_loss_l1 = tf.add_n([tf.reduce_mean(tf.abs(hr_images_fake[factor] - hr_images_real[factor])) for factor in constants.factors]) # l2 loss gen_loss_l2 = tf.add_n([tf.reduce_mean(tf.square(hr_images_fake[factor] - hr_images_real[factor])) for factor in constants.factors]) # VGG loss gen_loss_vgg = 0.0 if params.vggfile is not None: gen_loss_vgg = tf.add_n([compute_vgg_loss(hr_images_real[factor], hr_images_fake[factor], params.vggfeatures, params.vggfile) for factor in constants.factors]) # GAN Losses if params.losstype == "LSGAN": dis_loss = tf.reduce_mean(tf.square(dis_real - 1) + tf.square(dis_fake)) gen_loss_gan = tf.reduce_mean(tf.square(dis_fake - 1)) elif params.losstype == "WGAN-GP": dis_loss = dis_fake - dis_real alpha = tf.random_uniform(shape=[params.batchsize, 1, 1, 1], minval=0., maxval=1.) differences = {factor: hr_images_fake[factor] - hr_images_real[factor] for factor in constants.factors} interpolates_scales = {factor: hr_images_real[factor] + alpha * differences[factor] for factor in constants.factors} mixed_loss = tf.reduce_sum(discriminator(interpolates_scales)) mixed_grads = tf.gradients(mixed_loss, list(interpolates_scales.values())) mixed_norms = [tf.sqrt(tf.reduce_sum(tf.square(gradient), reduction_indices=[1, 2, 3])) for gradient in mixed_grads] gradient_penalties = [tf.reduce_mean(tf.square(slope - 1.0)) for slope in mixed_norms] gradient_penalty = tf.reduce_mean(gradient_penalties) dis_loss += 10 * gradient_penalty epsilon_penalty = tf.reduce_mean(tf.square(dis_real)) dis_loss += 0.001 * epsilon_penalty gen_loss_gan = -1.0 * tf.reduce_mean(dis_fake) dis_loss = tf.reduce_mean(dis_loss) else: raise Exception("Please select an available cost function") # Total losses def _new_loss(value, name, collections=None): tf.compat.v1.summary.scalar(name, value, collections) return value train_collections = [constants.train_key] all_collections = [constants.pretrain_key, constants.train_key] gen_loss = _new_loss(params.ganweight * gen_loss_gan, "gen_loss_gan", train_collections) gen_loss += _new_loss(params.vggweight * gen_loss_vgg, "gen_loss_vgg", train_collections) pretrain_loss = _new_loss(params.l1weight * gen_loss_l1, "gen_loss_l1", all_collections) pretrain_loss += _new_loss(params.l2weight * gen_loss_l2, "gen_loss_l2", all_collections) gen_loss += pretrain_loss dis_loss = _new_loss(dis_loss, "dis_loss", train_collections) # discriminator optimizer dis_optim = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr, beta1=params.adam_b1) dis_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, constants.dis_scope) dis_grads_and_vars = dis_optim.compute_gradients(dis_loss, var_list=dis_tvars) with tf.compat.v1.variable_scope("apply_dis_gradients", reuse=tf.compat.v1.AUTO_REUSE): dis_train = dis_optim.apply_gradients(dis_grads_and_vars) # generator optimizer with tf.control_dependencies([dis_train]): gen_optim = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr, beta1=params.adam_b1) gen_tvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, constants.gen_scope) gen_grads_and_vars = gen_optim.compute_gradients(gen_loss, var_list=gen_tvars) with tf.compat.v1.variable_scope("apply_gen_gradients", reuse=tf.compat.v1.AUTO_REUSE): gen_train = gen_optim.apply_gradients(gen_grads_and_vars) pretrain_op = tf.compat.v1.train.AdamOptimizer(learning_rate=params.adam_lr).minimize(pretrain_loss) train_nodes = [gen_train] if params.losstype == "LSGAN": ema = tf.train.ExponentialMovingAverage(decay=0.995) update_losses = ema.apply([dis_loss, gen_loss]) train_nodes.append(update_losses) train_op = tf.group(train_nodes, name="optimizer") merged_losses_summaries = tf.compat.v1.summary.merge_all(key=constants.train_key) merged_pretrain_summaries = tf.compat.v1.summary.merge_all(key=constants.pretrain_key) merged_preview_summaries = tf.compat.v1.summary.merge_all(key=constants.epoch_key) init = tf.global_variables_initializer() saver = tf.compat.v1.train.Saver(max_to_keep=5) sess = tf.Session() # Writer def _append_desc(key, value): if value == 0: return "" return "_{}{}".format(key, value) now = datetime.datetime.now() summaries_fn = "SR4RS" summaries_fn += _append_desc("E", params.epochs) summaries_fn += _append_desc("B", params.batchsize) summaries_fn += _append_desc("LR", params.adam_lr) summaries_fn += _append_desc("Gan", params.ganweight) summaries_fn += _append_desc("L1-", params.l1weight) summaries_fn += _append_desc("L2-", params.l2weight) summaries_fn += _append_desc("VGG", params.vggweight) summaries_fn += _append_desc("VGGFeat", params.vggfeatures) summaries_fn += _append_desc("Loss", params.losstype) summaries_fn += _append_desc("D", params.depth) summaries_fn += _append_desc("RB", params.nresblocks) summaries_fn += _append_desc("LRSC", params.lr_scale) summaries_fn += _append_desc("HRSC", params.hr_scale) if params.pretrain: summaries_fn += "pretrained" summaries_fn += "_{}{}_{}h{}min".format(now.day, now.strftime("%b"), now.hour, now.minute) train_writer = None if params.logdir is not None: train_writer = tf.summary.FileWriter(params.logdir + summaries_fn, sess.graph) def _add_summary(summarized, _step): if train_writer is not None: train_writer.add_summary(summarized, _step) sess.run(init) if params.load_ckpt is not None: saver.restore(sess, params.load_ckpt) # preview def _preview(_step): if lr_image_for_prev is not None and step % params.previews_step == 0: summary_pe = sess.run(merged_preview_summaries, {lr_image: lr_image_for_prev}) _add_summary(summary_pe, _step) def _do(_train_op, _summary_op, name): global step for curr_epoch in range(params.epochs): logging.info("{} Epoch #{}".format(name, curr_epoch)) sess.run(iterator_init) try: while True: _, _summary = sess.run([_train_op, _summary_op]) _add_summary(_summary, step) _preview(curr_epoch) step += 1 except tf.errors.OutOfRangeError: fs_stall_duration = ds.get_total_wait_in_seconds() logging.info("{}: one epoch done. Total FS stall: {:.2f}s".format(name, fs_stall_duration)) pass saver.save(sess, params.save_ckpt + summaries_fn, global_step=curr_epoch) # pre training if params.pretrain: _do(pretrain_op, merged_pretrain_summaries, "pre-training") # training _do(train_op, merged_losses_summaries, "training") # cleaning if train_writer is not None: train_writer.close() # Export SavedModel if params.savedmodel is not None: logging.info("Export SavedModel in {}".format(params.savedmodel)) outputs = ["{}{}:0".format(constants.outputs_prefix, pad) for pad in constants.pads] inputs = ["{}:0".format(constants.lr_input_name)] graph = tf.get_default_graph() tf.saved_model.simple_save(sess, params.savedmodel, inputs={i: graph.get_tensor_by_name(i) for i in inputs}, outputs={o: graph.get_tensor_by_name(o) for o in outputs}) quit()