def build_model(self): data_path = self.args.train_data_path imgs = load_data.get_loader(data_path, self.batch_size, self.img_size) style_imgs = load_style_img(self.args.style_data_path) with slim.arg_scope(model.arg_scope()): gen_img, variables = model.gen_net(imgs, reuse=False, name='transform') with slim.arg_scope(vgg.vgg_arg_scope()): gen_img_processed = [ load_data.img_process(image, True) for image in tf.unstack( gen_img, axis=0, num=self.batch_size) ] f1, f2, f3, f4, exclude = vgg.vgg_16( tf.concat([gen_img_processed, imgs, style_imgs], axis=0)) gen_f, img_f, _ = tf.split(f3, 3, 0) content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float( tf.size(gen_f)) style_loss = model.styleloss(f1, f2, f3, f4) # load vgg model vgg_model_path = self.args.vgg_model vgg_vars = slim.get_variables_to_restore(include=['vgg_16'], exclude=exclude) # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6']) init_fn = slim.assign_from_checkpoint_fn( vgg_model_path, vgg_vars) init_fn(self.sess) # tf.initialize_variables(var_list=vgg_init_var) print('vgg s weights load done') self.gen_img = gen_img self.global_step = tf.Variable(0, name="global_step", trainable=False) self.content_loss = content_loss self.style_loss = style_loss * self.args.style_w self.loss = self.content_loss + self.style_loss self.opt = tf.train.AdamOptimizer(0.0001).minimize( self.loss, global_step=self.global_step, var_list=variables) all_var = tf.global_variables() # init_var = [v for v in all_var if 'beta' in v.name or 'global_step' in v.name or 'Adam' in v.name] init_var = [v for v in all_var if 'vgg_16' not in v.name] init = tf.variables_initializer(var_list=init_var) self.sess.run(init) self.save = tf.train.Saver(var_list=variables)
def test(self): print ('test model') test_img = tools.load_test_img(TEST_IMAGE_PATH) test_img = self.sess.run(test_img) with slim.arg_scope(model.arg_scope()): gen_img, _ = model.inference(test_img, reuse=False, name='transform') vars = slim.get_variables_to_restore(include=['transform']) init_fn = slim.assign_from_checkpoint_fn('tmp/model/model.ckpt-50', vars) init_fn(self.sess) gen_img = self.sess.run(gen_img) tools.save_images(gen_img, 'images/test.jpg')
def _build_model(self): config, params = self.config, self.params is_training = params["is_training"] learning_rate = params["learning_rate"] global_step = params["global_step"] artifact_im = params["artifact_im"] reference_im = params["reference_im"] # TODO: more elegant way? (such as factory pattern) if config.model == "base": model_fn = model.base elif config.model == "residual": model_fn = model.residual elif config.model == "base-skip": model_fn = model.base_skip elif config.model == "residual-skip": model_fn = model.residual_skip else: raise NotImplementedError("There is no such {} model".format( config.model)) with slim.arg_scope(model.arg_scope(is_training)): G_dn, G_residual, end_pts = model_fn( artifact_im, scope="generator", num_channels=config.num_channels) num_params = 0 for var in tf.trainable_variables(): print(var.name, var.get_shape()) num_params += reduce(mul, var.get_shape().as_list(), 1) print("Total parameter: {}".format(num_params)) with tf.variable_scope("Loss"): R_residual = artifact_im - reference_im # multiple 1000 to compare L2 loss in GAN L2_loss = 1000 * tf.losses.mean_squared_error( labels=R_residual, predictions=G_residual) with tf.variable_scope("Optimizer"): optimizer = tf.train.AdamOptimizer( learning_rate, beta1=config.beta1, epsilon=config.epsilon).minimize(L2_loss, global_step) params["denoised"] = G_dn params["residual"] = G_residual params["L2_loss"] = L2_loss params["optimizer"] = optimizer
def test(self): print('test model') test_img_path = self.args.test_data_path test_img = load_test_img(test_img_path) # test_img = tf.random_uniform(shape=(1, 500, 800, 3), minval=0, maxval=1.) test_img = self.sess.run(test_img) with slim.arg_scope(model.arg_scope()): gen_img, _ = model.gen_net(test_img, reuse=False, name='transform') # load model model_path = self.args.transfer_model vars = slim.get_variables_to_restore(include=['transform']) # vgg_init_var = slim.get_variables_to_restore(include=['vgg_16/fc6']) init_fn = slim.assign_from_checkpoint_fn(model_path, vars) init_fn(self.sess) # tf.initialize_variables(var_list=vgg_init_var) print('vgg s weights load done') gen_img = self.sess.run(gen_img) save_img.save_images(gen_img, self.args.new_img_name)
def build_model(self): train_imgs = tools.load_train_img(TRAIN_DATA_DIR, self.batch_size, self.img_size) style_imgs = tools.load_style_img(STYLE_IMAGE_PATH) with slim.arg_scope(model.arg_scope()): gen_img, variables = model.inference(train_imgs, reuse=False, name='transform') with slim.arg_scope(vgg.vgg_arg_scope()): gen_img_processed = [tf.image.per_image_standardization(image) for image in tf.unstack(gen_img, axis=0, num=self.batch_size)] f1, f2, f3, f4, exclude = vgg.vgg_16(tf.concat([gen_img_processed, train_imgs, style_imgs], axis=0)) gen_f, img_f, _ = tf.split(f4, 3, 0) content_loss = tf.nn.l2_loss(gen_f - img_f) / tf.to_float(tf.size(gen_f)) style_loss = model.styleloss(f1, f2, f3, f4) vgg_model_path = VGG_MODEL_PATH vgg_vars = slim.get_variables_to_restore(include=['vgg_16'], exclude=exclude) init_fn = slim.assign_from_checkpoint_fn(vgg_model_path, vgg_vars) init_fn(self.sess) print("vgg's weights load done") self.gen_img = gen_img self.global_step = tf.Variable(0, name="global_step", trainable=False) self.content_loss = content_loss self.style_loss = style_loss * self.style_w self.loss = self.content_loss + self.style_loss self.learn_rate = tf.train.exponential_decay(self.learn_rate_base, self.global_step, 1, self.learn_rate_decay, staircase=True) self.opt = tf.train.AdamOptimizer(self.learn_rate).minimize(self.loss, global_step=self.global_step, var_list=variables) all_var = tf.global_variables() init_var = [v for v in all_var if 'vgg_16' not in v.name] init = tf.variables_initializer(var_list=init_var) self.sess.run(init) self.save = tf.train.Saver(var_list=variables)
def main(_): import setproctitle setproctitle.setproctitle('python_train') is_training = tf.placeholder(tf.bool, shape=()) left_in = tf.placeholder(tf.float32, [None, FLAGS.patch_size, FLAGS.patch_size, 1], name='left') righ_in = tf.placeholder(tf.float32, [None, FLAGS.patch_size, FLAGS.patch_size, 1], name='right') with tf.name_scope('similarity'): label = tf.placeholder(tf.int32, [ None, ], name='label') arg_scope = model.arg_scope(weight_decay=FLAGS.weight_decay, data_format=DATA_FORMAT) global_steps = tf.Variable(0, trainable=False) with tf.device(gpus): with tf.name_scope('model'): with slim.arg_scope(arg_scope): left_out = model.siamese_net(left_in, reuse=False) righ_out = model.siamese_net(righ_in, reuse=True) cls_out = model.cls_net(left_out, righ_out) loss = model.net_loss(cls_out, label) tf.summary.scalar('loss', loss) lr = FLAGS.learning_rate mm = FLAGS.momentum optimizer = tf.train.MomentumOptimizer(learning_rate=lr, momentum=mm, use_nesterov=True) train_op = optimizer.minimize(loss=loss, global_step=global_steps) # 损失函数优化器的minimize()中global_step=global_steps能够提供global_step自动+1的操作 # global_steps是所有epoch的steps全都累加起来 variables_to_train = tf.trainable_variables() for var in variables_to_train: tf.summary.histogram(var.op.name, var) saver = tf.train.Saver(max_to_keep=500) # Create a session config = tf.ConfigProto() config.gpu_options.allow_growth = False config.gpu_options.per_process_gpu_memory_fraction = 0.6 config.log_device_placement = False config.allow_soft_placement = True sess = tf.Session(config=config) # Add summary writers merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter('log/train', sess.graph) eval_writer = tf.summary.FileWriter('log/eval') # Init variables if not os.listdir(ckpt_dir): # 文件夹为空 init = tf.global_variables_initializer() # Run the init operation sess.run(init, {is_training: True}) else: saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir)) # This is just a dictionary including important parameters transferred to another function # These parameters are actually all parts, namely operations, of the graph, without being ops = { 'is_training_pl': is_training, 'left_in_pl': left_in, 'righ_in_pl': righ_in, 'label_pl': label, 'loss': loss, 'train_op': train_op, 'merged': merged, 'global_steps': global_steps } # train iter for i in range(FLAGS.train_iter): np.random.shuffle(DATA_PATHS) np.random.shuffle(EVAL_DATA_PATHS) for b in range(n_batches): step, train_loss = train_one_step(b, sess, ops, train_writer, is_training=True) eval_loss = eval_one_step(b, sess, ops, eval_writer, is_training=False) print "global step %d: train loss = %f; eval loss = %f" % ( step, train_loss, eval_loss) saver.save(sess, "model_with_eval/model.ckpt", global_steps)