def build_loss(self): summaries = [] self.global_step = tf.Variable(0, trainable=False) if self.conf['learning_rate'] == 'scheduled' and not self.visualize: print('using scheduled learning rate') self.lr = tf.train.piecewise_constant(self.global_step, self.conf['lr_boundaries'], self.conf['lr_values']) else: self.lr = tf.placeholder_with_default(self.conf['learning_rate'], ()) if not self.trafo_pix: # L2 loss, PSNR for eval. true_fft_list, pred_fft_list = [], [] loss, psnr_all = 0.0, 0.0 total_recon_cost = 0 for i, x, gx in zip( list(range(len(self.gen_images))), self.images[self.conf['context_frames']:], self.gen_images[self.conf['context_frames'] - 1:]): recon_cost_mse = self.mean_squared_error(x, gx) summaries.append(tf.summary.scalar('recon_cost' + str(i), recon_cost_mse)) total_recon_cost += recon_cost_mse summaries.append(tf.summary.scalar('total reconst cost', total_recon_cost)) loss += total_recon_cost if ('ignore_state_action' not in self.conf) and ('ignore_state' not in self.conf): for i, state, gen_state in zip( list(range(len(self.gen_states))), self.states[self.conf['context_frames']:], self.gen_states[self.conf['context_frames'] - 1:]): state_cost = self.mean_squared_error(state, gen_state) * 1e-4 * self.conf['use_state'] summaries.append(tf.summary.scalar('state_cost' + str(i), state_cost)) loss += state_cost #tracking frame matching cost: total_frame_match_cost = 0 for i, im, gen_im in zip_equal(list(range(len(self.tracking_gen_images))), self.images[1:], self.tracking_gen_images): cost = self.mean_squared_error(im, gen_im) * self.conf['track_agg_fact'] total_frame_match_cost += cost summaries.append(tf.summary.scalar('total_frame_match_cost', total_frame_match_cost)) loss += total_frame_match_cost #adding transformation aggreement cost: total_trans_agg_cost = 0 for i, k1, k2 in zip_equal(list(range(len(self.tracking_kerns))), self.tracking_kerns, self.pred_kerns): cost = self.mean_squared_error(k1, k2) * self.conf['track_agg_fact'] total_trans_agg_cost += cost summaries.append(tf.summary.scalar('total_trans_agg_cost', total_trans_agg_cost)) loss += total_trans_agg_cost self.loss = loss = loss / np.float32(len(self.images) - self.conf['context_frames']) summaries.append(tf.summary.scalar('loss', loss)) self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss, self.global_step) self.summ_op = tf.summary.merge(summaries)
def fuse_trafos(self, enc6, background_image, transformed, scope, extra_masks): masks = slim.layers.conv2d_transpose( enc6, (self.conf['num_masks'] + extra_masks), 1, stride=1, scope=scope) img_height = 64 img_width = 64 num_masks = self.conf['num_masks'] if self.conf['model'] == 'DNA': if num_masks != 1: raise ValueError('Only one mask is supported for DNA model.') # the total number of masks is num_masks +extra_masks because of background and generated pixels! masks = tf.reshape( tf.nn.softmax(tf.reshape(masks, [-1, num_masks + extra_masks])), [ int(self.batch_size), int(img_height), int(img_width), num_masks + extra_masks ]) mask_list = tf.split(axis=3, num_or_size_splits=num_masks + extra_masks, value=masks) output = mask_list[0] * background_image for layer, mask in zip_equal(transformed, mask_list[1:]): output += layer * mask return output, mask_list
def mix_datasets(dataset0, dataset1, batch_size, ratio_01): """Sample batch with specified mix of ground truth and generated data_files points. Args: ground_truth_x: tensor of ground-truth data_files points. generated_x: tensor of generated data_files points. batch_size: batch size num_set0: number of ground-truth examples to include in batch. Returns: New batch with num_ground_truth sampled from ground_truth_x and the rest from generated_x. """ num_set0 = tf.cast(int(batch_size) * ratio_01, tf.int64) idx = tf.random_shuffle(tf.range(int(batch_size))) set0_idx = tf.gather(idx, tf.range(num_set0)) set1_idx = tf.gather(idx, tf.range(num_set0, int(batch_size))) output = [] for set0, set1 in zip_equal(dataset0, dataset1): dataset0_examps = tf.gather(set0, set0_idx) dataset1_examps = tf.gather(set1, set1_idx) output.append( tf.dynamic_stitch([set0_idx, set1_idx], [dataset0_examps, dataset1_examps])) return output
def fuse_pix_movebckgd(self, mask_list, transf_pix, transf_backgd_pix): pix_distrib = transf_backgd_pix + transf_pix pix_distrib_output = 0 for pix, mask in zip_equal(pix_distrib, mask_list): pix_distrib_output += pix * mask return pix_distrib_output
def fuse_trafos_movbckgd(self, enc6, moved_background_image, transformed, scope, extra_masks, reuse): print('moving backgd') num_masks = self.conf['num_masks'] img_height = 64 img_width = 64 ## moving the background masks = slim.layers.conv2d_transpose( enc6, (self.conf['num_masks'] + extra_masks), 1, stride=1, scope=scope) masks = tf.reshape( tf.nn.softmax(tf.reshape(masks, [-1, num_masks + extra_masks])), [ int(self.batch_size), int(img_height), int(img_width), num_masks + extra_masks ]) mask_list = tf.split(axis=3, num_or_size_splits=num_masks + extra_masks, value=masks) complete_transformed = moved_background_image + transformed output = 0 moved_parts = [] for layer, mask in zip_equal(complete_transformed, mask_list): moved_parts.append(layer * mask) output += layer * mask return output, mask_list, moved_parts
def fuse_trafos(self, enc6, transf_history, scope, total_masks): masks = slim.layers.conv2d_transpose( enc6, (total_masks), 1, stride=1, scope=scope) # the total number of masks is num_masks +extra_masks because of background and generated pixels! masks = tf.reshape( tf.nn.softmax(tf.reshape(masks, [-1, total_masks])), [int(self.batch_size), 64, 64, total_masks]) mask_list = tf.split(axis=3, num_or_size_splits=total_masks, value=masks) output = 0 for layer, mask in zip_equal(transf_history, mask_list): output += layer * mask return output, mask_list