def _loss_function(self, a, b, train_config): # Use train_config to implement more advance losses. with tf.variable_scope("loss_function"): return L1_loss(a, b) * train_config[ "l1_loss_weight"] # train_config["l1_loss_weight"]=1
def _build_training_graph(self, train_config): self.global_step = tf.Variable(0, trainable=False) filename_queue = tf.train.string_input_producer( [os.path.join(train_config["dataset_dir"], 'train.tfrecords')], num_epochs=train_config["num_epochs"]) frameA, frameB, frameC, frameAmp, amplification_factor = \ read_and_decode_3frames(filename_queue, (train_config["image_height"], train_config["image_width"], self.n_channels)) min_after_dequeue = 1000 num_threads = 16 capacity = min_after_dequeue + \ (num_threads + 2) * train_config["batch_size"] frameA, frameB, frameC, frameAmp, amplification_factor = \ tf.train.shuffle_batch([frameA, frameB, frameC, frameAmp, amplification_factor], batch_size=train_config["batch_size"], capacity=capacity, num_threads=num_threads, min_after_dequeue=min_after_dequeue) frameA = preprocess_image(frameA, train_config) frameB = preprocess_image(frameB, train_config) frameC = preprocess_image(frameC, train_config) self.loss_function = partial(self._loss_function, train_config=train_config) self.output = self.image_transformer( frameA, frameB, amplification_factor, [train_config["image_height"], train_config["image_width"]], self.arch_config, True, False) self.reg_loss = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if self.reg_loss and train_config["weight_decay"] > 0.0: print("Adding Regularization Weights.") self.loss = self.loss_function(self.output, frameAmp) + \ train_config["weight_decay"] * tf.add_n(self.reg_loss) else: print("No Regularization Weights.") self.loss = self.loss_function(self.output, frameAmp) # Add regularization more # TODO: Hardcoding the network name scope here. with tf.variable_scope('ynet_3frames/encoder', reuse=True): texture_c, shape_c = self._encoder(frameC) self.loss = self.loss + \ train_config["texture_loss_weight"] * L1_loss(texture_c, self.texture_a) + \ train_config["shape_loss_weight"] * L1_loss(shape_c, self.shape_b) self.loss_sum = tf.summary.scalar('train_loss', self.loss) self.image_sum = tf.summary.image('train_B_OUT', tf.concat([frameB, self.output], axis=2), max_outputs=2) if self.n_channels == 3: self.image_comp_sum = tf.summary.image('train_GT_OUT', frameAmp - self.output, max_outputs=2) self.image_orig_comp_sum = tf.summary.image('train_ORIG_OUT', frameA - self.output, max_outputs=2) else: self.image_comp_sum = tf.summary.image( 'train_GT_OUT', tf.concat([frameAmp, self.output, frameAmp], axis=3), max_outputs=2) self.image_orig_comp_sum = tf.summary.image( 'train_ORIG_OUT', tf.concat([frameA, self.output, frameA], axis=3), max_outputs=2) self.saver = tf.train.Saver(max_to_keep=train_config["ckpt_to_keep"])