def get_embeddings(images, model_options, embedding_dimension): """Extracts embedding vectors for images. Should only be used for inference. Args: images: A tensor of shape [batch, height, width, channels]. model_options: A ModelOptions instance to configure models. embedding_dimension: Integer, the dimension of the embedding. Returns: embeddings: A tensor of shape [batch, height, width, embedding_dimension]. """ features, end_points = model.extract_features(images, model_options, is_training=False) if model_options.decoder_output_stride is not None: if model_options.crop_size is None: height = tf.shape(images)[1] width = tf.shape(images)[2] else: height, width = model_options.crop_size decoder_height = model.scale_dimension( height, 1.0 / model_options.decoder_output_stride) decoder_width = model.scale_dimension( width, 1.0 / model_options.decoder_output_stride) features = model.refine_by_decoder( features, end_points, decoder_height=decoder_height, decoder_width=decoder_width, decoder_use_separable_conv=model_options. decoder_use_separable_conv, model_variant=model_options.model_variant, is_training=False) with tf.variable_scope('embedding'): embeddings = split_separable_conv2d_with_identity_initializer( features, embedding_dimension, scope='split_separable_conv2d') return embeddings
def get_embeddings(images, model_options, embedding_dimension): """Extracts embedding vectors for images. Should only be used for inference. Args: images: A tensor of shape [batch, height, width, channels]. model_options: A ModelOptions instance to configure models. embedding_dimension: Integer, the dimension of the embedding. Returns: embeddings: A tensor of shape [batch, height, width, embedding_dimension]. """ features, end_points = model.extract_features( images, model_options, is_training=False) if model_options.decoder_output_stride is not None: decoder_output_stride = min(model_options.decoder_output_stride) if model_options.crop_size is None: height = tf.shape(images)[1] width = tf.shape(images)[2] else: height, width = model_options.crop_size features = model.refine_by_decoder( features, end_points, crop_size=[height, width], decoder_output_stride=[decoder_output_stride], decoder_use_separable_conv=model_options.decoder_use_separable_conv, model_variant=model_options.model_variant, is_training=False) with tf.variable_scope('embedding'): embeddings = split_separable_conv2d_with_identity_initializer( features, embedding_dimension, scope='split_separable_conv2d') return embeddings
def train(l_args): """Trains the model.""" if l_args.verbose: tf.logging.set_verbosity(tf.logging.INFO) # Create input data pipeline. x_train_files = sorted( glob.glob('/datatmp/Datasets/Cityscapes/leftImg8bit/train/*/*.png')) x_label_files = sorted( glob.glob( '/datatmp/Datasets/Cityscapes/gtFine/train/*/*_labelIds.png')) y_train_files = sorted( glob.glob( '/datatmp/Experiments/semantic_compression/{}/lambda_{}/leftImg8bit/train/*/*.png' .format(l_args.images_dir, l_args.lmbda))) print(len(x_train_files), len(y_train_files)) assert (len(x_train_files) == len(y_train_files)) assert (x_train_files[0].split("/")[-1] == y_train_files[0].split("/")[-1]) assert ( x_train_files[-1].split("/")[-1] == y_train_files[-1].split("/")[-1]) print(x_train_files[0].split("/")[-1][:-16], x_label_files[0].split("/")[-1].split("_gtFine_labelIds.png")[0]) print(x_train_files[-1].split("/")[-1][:-16], x_label_files[-1].split("/")[-1].split("_gtFine_labelIds.png")[0]) assert (len(x_label_files) == len(x_train_files)) assert (x_train_files[0].split("/")[-1][:-16] == x_label_files[0].split( "/")[-1].split("_gtFine_labelIds.png")[0]) assert (x_train_files[-1].split("/")[-1][:-16] == x_label_files[-1].split( "/")[-1].split("_gtFine_labelIds.png")[0]) train_dataset = tf.data.Dataset.from_tensor_slices( (x_train_files, x_label_files, y_train_files)) train_dataset = train_dataset.shuffle( buffer_size=len(x_train_files)).repeat() train_dataset = train_dataset.map( read_pngs, num_parallel_calls=l_args.preprocess_threads) train_dataset = train_dataset.map(lambda x: tf.random_crop( x, [int(z) for z in l_args.patchsize.split(",")] + [7])) train_dataset = train_dataset.batch(l_args.batchsize) train_dataset = train_dataset.prefetch(l_args.batchsize) train_batch = train_dataset.make_one_shot_iterator().get_next() train_x, _, train_y = train_batch[:, :, :, : 3], train_batch[:, :, :, 3: 4], train_batch[:, :, :, 4:] scaled_train_x, scaled_train_y = train_x / 255., train_y / 255. model_options = common.ModelOptions( outputs_to_num_classes={common.OUTPUT_TYPE: 19}, crop_size=[int(z) for z in l_args.patchsize.split(",")], atrous_rates=None, output_stride=16) x_features, _ = model.extract_features(train_x, model_options) exclude_list = ['global_step'] variables_to_restore = tf.contrib.framework.get_variables_to_restore( exclude=exclude_list) seg_saver = tf.train.Saver(variables_to_restore) print(variables_to_restore) layers = RDN() scaled_x_tilde_hat = layers(scaled_train_y) x_tilde_hat = 255.0 * scaled_x_tilde_hat with tf.variable_scope(tf.get_variable_scope(), reuse=True): x_tilde_hat_features, _ = model.extract_features( x_tilde_hat, model_options) mse = tf.reduce_mean( tf.squared_difference(scaled_train_x, scaled_x_tilde_hat)) * 255**2 ssim = tf.reduce_mean( 1 - tf.image.ssim_multiscale(scaled_x_tilde_hat, scaled_train_x, 1)) l1 = tf.reduce_mean(tf.math.abs(scaled_train_x - scaled_x_tilde_hat)) distortion = {"mse": mse, "l1": l1, "msssim": ssim}[l_args.loss_type] distillation = tf.reduce_mean( tf.squared_difference(x_features, x_tilde_hat_features)) train_loss = distortion + l_args.mu * distillation var_list = [ var for var in tf.trainable_variables() if var not in variables_to_restore ] print() print() print(var_list) step = tf.train.get_or_create_global_step() main_optimizer = tf.train.AdamOptimizer(learning_rate=l_args.lr) train_op = main_optimizer.minimize(train_loss, var_list=var_list, global_step=step) log_all_summaries(train_x, scaled_x_tilde_hat, scaled_train_y, None, None, train_loss, None, mse, None, ssim, "train") #log_all_summaries(val_x, valid_x_tilde, None, None, valid_loss, None, valid_mse, None, valid_ssim, "val") hooks = [ tf.train.StopAtStepHook(last_step=l_args.last_step), tf.train.NanTensorHook(train_loss), ] def load_pretrain(scaffold, sess): seg_saver.restore(sess, save_path=PATH_TO_TRAINED_MODEL) hooks = [ tf.train.StopAtStepHook(last_step=l_args.last_step), tf.train.NanTensorHook(train_loss), ] with tf.train.MonitoredTrainingSession( hooks=hooks, checkpoint_dir=l_args.checkpoint_dir, save_checkpoint_secs=1200, save_summaries_secs=60, scaffold=tf.train.Scaffold(init_fn=load_pretrain)) as sess: while not sess.should_stop(): sess.run(train_op)
def get_logits_with_matching(images, model_options, weight_decay=0.0001, reuse=None, is_training=False, fine_tune_batch_norm=False, reference_labels=None, batch_size=None, num_frames_per_video=None, embedding_dimension=None, max_neighbors_per_object=0, k_nearest_neighbors=1, use_softmax_feedback=True, initial_softmax_feedback=None, embedding_seg_feature_dimension=256, embedding_seg_n_layers=4, embedding_seg_kernel_size=7, embedding_seg_atrous_rates=None, normalize_nearest_neighbor_distances=True, also_attend_to_previous_frame=True, damage_initial_previous_frame_mask=False, use_local_previous_frame_attention=True, previous_frame_attention_window_size=15, use_first_frame_matching=True, also_return_embeddings=False, ref_embeddings=None): """Gets the logits by atrous/image spatial pyramid pooling using attention. Args: images: A tensor of size [batch, height, width, channels]. model_options: A ModelOptions instance to configure models. weight_decay: The weight decay for model variables. reuse: Reuse the model variables or not. is_training: Is training or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not. reference_labels: The segmentation labels of the reference frame on which attention is applied. batch_size: Integer, the number of videos on a batch num_frames_per_video: Integer, the number of frames per video embedding_dimension: Integer, the dimension of the embedding max_neighbors_per_object: Integer, the maximum number of candidates for the nearest neighbor query per object after subsampling. Can be 0 for no subsampling. k_nearest_neighbors: Integer, the number of nearest neighbors to use. use_softmax_feedback: Boolean, whether to give the softmax predictions of the last frame as additional input to the segmentation head. initial_softmax_feedback: List of Float32 tensors, or None. Can be used to initialize the softmax predictions used for the feedback loop. Only has an effect if use_softmax_feedback is True. embedding_seg_feature_dimension: Integer, the dimensionality used in the segmentation head layers. embedding_seg_n_layers: Integer, the number of layers in the segmentation head. embedding_seg_kernel_size: Integer, the kernel size used in the segmentation head. embedding_seg_atrous_rates: List of integers of length embedding_seg_n_layers, the atrous rates to use for the segmentation head. normalize_nearest_neighbor_distances: Boolean, whether to normalize the nearest neighbor distances to [0,1] using sigmoid, scale and shift. also_attend_to_previous_frame: Boolean, whether to also use nearest neighbor attention with respect to the previous frame. damage_initial_previous_frame_mask: Boolean, whether to artificially damage the initial previous frame mask. Only has an effect if also_attend_to_previous_frame is True. use_local_previous_frame_attention: Boolean, whether to restrict the previous frame attention to a local search window. Only has an effect, if also_attend_to_previous_frame is True. previous_frame_attention_window_size: Integer, the window size used for local previous frame attention, if use_local_previous_frame_attention is True. use_first_frame_matching: Boolean, whether to extract features by matching to the reference frame. This should always be true except for ablation experiments. also_return_embeddings: Boolean, whether to return the embeddings as well. ref_embeddings: Tuple of (first_frame_embeddings, previous_frame_embeddings), each of shape [batch, height, width, embedding_dimension], or None. Returns: outputs_to_logits: A map from output_type to logits. If also_return_embeddings is True, it will also return an embeddings tensor of shape [batch, height, width, embedding_dimension]. """ features, end_points = model.extract_features( images, model_options, weight_decay=weight_decay, reuse=reuse, is_training=is_training, fine_tune_batch_norm=fine_tune_batch_norm) if model_options.decoder_output_stride: decoder_output_stride = min(model_options.decoder_output_stride) if model_options.crop_size is None: height = tf.shape(images)[1] width = tf.shape(images)[2] else: height, width = model_options.crop_size decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride) decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride) features = model.refine_by_decoder( features, end_points, crop_size=[height, width], decoder_output_stride=[decoder_output_stride], decoder_use_separable_conv=model_options.decoder_use_separable_conv, model_variant=model_options.model_variant, weight_decay=weight_decay, reuse=reuse, is_training=is_training, fine_tune_batch_norm=fine_tune_batch_norm) with tf.variable_scope('embedding', reuse=reuse): embeddings = split_separable_conv2d_with_identity_initializer( features, embedding_dimension, scope='split_separable_conv2d') embeddings = tf.identity(embeddings, name='embeddings') scaled_reference_labels = tf.image.resize_nearest_neighbor( reference_labels, resolve_shape(embeddings, 4)[1:3], align_corners=True) h, w = decoder_height, decoder_width if num_frames_per_video is None: num_frames_per_video = tf.size(embeddings) // ( batch_size * h * w * embedding_dimension) new_labels_shape = tf.stack([batch_size, -1, h, w, 1]) reshaped_reference_labels = tf.reshape(scaled_reference_labels, new_labels_shape) new_embeddings_shape = tf.stack([batch_size, num_frames_per_video, h, w, embedding_dimension]) reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape) all_nn_features = [] all_ref_obj_ids = [] # To keep things simple, we do all this separate for each sequence for now. for n in range(batch_size): embedding = reshaped_embeddings[n] if ref_embeddings is None: n_chunks = 100 reference_embedding = embedding[0] if also_attend_to_previous_frame or use_softmax_feedback: queries_embedding = embedding[2:] else: queries_embedding = embedding[1:] else: if USE_CORRELATION_COST: n_chunks = 20 else: n_chunks = 500 reference_embedding = ref_embeddings[0][n] queries_embedding = embedding reference_labels = reshaped_reference_labels[n][0] nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object( reference_embedding, queries_embedding, reference_labels, max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks) if normalize_nearest_neighbor_distances: nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2 all_nn_features.append(nn_features_n) all_ref_obj_ids.append(ref_obj_ids) feat_dim = resolve_shape(features)[-1] features = tf.reshape(features, tf.stack( [batch_size, num_frames_per_video, h, w, feat_dim])) if ref_embeddings is None: # Strip the features for the reference frame. if also_attend_to_previous_frame or use_softmax_feedback: features = features[:, 2:] else: features = features[:, 1:] # To keep things simple, we do all this separate for each sequence for now. outputs_to_logits = {output: [] for output in model_options.outputs_to_num_classes} for n in range(batch_size): features_n = features[n] nn_features_n = all_nn_features[n] nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4]) n_objs = tf.shape(nn_features_n_tr)[0] # Repeat features for every object. features_n_tiled = tf.tile(features_n[tf.newaxis], multiples=[n_objs, 1, 1, 1, 1]) prev_frame_labels = None if also_attend_to_previous_frame: prev_frame_labels = reshaped_reference_labels[n, 1] if is_training and damage_initial_previous_frame_mask: # Damage the previous frame masks. prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels, dilate=False) tf.summary.image('prev_frame_labels', tf.cast(prev_frame_labels[tf.newaxis], tf.uint8) * 32) initial_softmax_feedback_n = create_initial_softmax_from_labels( prev_frame_labels, reshaped_reference_labels[n][0], decoder_output_stride=None, reduce_labels=True) elif initial_softmax_feedback is not None: initial_softmax_feedback_n = initial_softmax_feedback[n] else: initial_softmax_feedback_n = None if initial_softmax_feedback_n is None: last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32) else: last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[ ..., tf.newaxis] assert len(model_options.outputs_to_num_classes) == 1 output = list(model_options.outputs_to_num_classes.keys())[0] logits = [] n_ref_frames = 1 prev_frame_nn_features_n = None if also_attend_to_previous_frame or use_softmax_feedback: n_ref_frames += 1 if ref_embeddings is not None: n_ref_frames = 0 for t in range(num_frames_per_video - n_ref_frames): to_concat = [features_n_tiled[:, t]] if use_first_frame_matching: to_concat.append(nn_features_n_tr[:, t]) if use_softmax_feedback: to_concat.append(last_softmax) if also_attend_to_previous_frame: assert normalize_nearest_neighbor_distances, ( 'previous frame attention currently only works when normalized ' 'distances are used') embedding = reshaped_embeddings[n] if ref_embeddings is None: last_frame_embedding = embedding[t + 1] query_embeddings = embedding[t + 2, tf.newaxis] else: last_frame_embedding = ref_embeddings[1][0] query_embeddings = embedding if use_local_previous_frame_attention: assert query_embeddings.shape[0] == 1 prev_frame_nn_features_n = ( local_previous_frame_nearest_neighbor_features_per_object( last_frame_embedding, query_embeddings[0], prev_frame_labels, all_ref_obj_ids[n], max_distance=previous_frame_attention_window_size) ) else: prev_frame_nn_features_n, _ = ( nearest_neighbor_features_per_object( last_frame_embedding, query_embeddings, prev_frame_labels, max_neighbors_per_object, k_nearest_neighbors, gt_ids=all_ref_obj_ids[n])) prev_frame_nn_features_n = (tf.nn.sigmoid( prev_frame_nn_features_n) - 0.5) * 2 prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n, axis=0) prev_frame_nn_features_n_tr = tf.transpose( prev_frame_nn_features_n_sq, [2, 0, 1, 3]) to_concat.append(prev_frame_nn_features_n_tr) features_n_concat_t = tf.concat(to_concat, axis=-1) embedding_seg_features_n_t = ( create_embedding_segmentation_features( features_n_concat_t, embedding_seg_feature_dimension, embedding_seg_n_layers, embedding_seg_kernel_size, reuse or n > 0, atrous_rates=embedding_seg_atrous_rates)) logits_t = model.get_branch_logits( embedding_seg_features_n_t, 1, model_options.atrous_rates, aspp_with_batch_norm=model_options.aspp_with_batch_norm, kernel_size=model_options.logits_kernel_size, weight_decay=weight_decay, reuse=reuse or n > 0 or t > 0, scope_suffix=output ) logits.append(logits_t) prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0), [2, 0, 1]) last_softmax = tf.nn.softmax(logits_t, axis=0) logits = tf.stack(logits, axis=1) logits_shape = tf.stack( [n_objs, num_frames_per_video - n_ref_frames] + resolve_shape(logits)[2:-1]) logits_reshaped = tf.reshape(logits, logits_shape) logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0]) outputs_to_logits[output].append(logits_transposed) add_image_summaries( images[n * num_frames_per_video: (n+1) * num_frames_per_video], nn_features_n, logits_transposed, batch_size=1, prev_frame_nn_features=prev_frame_nn_features_n) if also_return_embeddings: return outputs_to_logits, embeddings else: return outputs_to_logits
def get_logits_with_matching(images, model_options, weight_decay=0.0001, reuse=None, is_training=False, fine_tune_batch_norm=False, reference_labels=None, batch_size=None, num_frames_per_video=None, embedding_dimension=None, max_neighbors_per_object=0, k_nearest_neighbors=1, use_softmax_feedback=True, initial_softmax_feedback=None, embedding_seg_feature_dimension=256, embedding_seg_n_layers=4, embedding_seg_kernel_size=7, embedding_seg_atrous_rates=None, normalize_nearest_neighbor_distances=True, also_attend_to_previous_frame=True, damage_initial_previous_frame_mask=False, use_local_previous_frame_attention=True, previous_frame_attention_window_size=15, use_first_frame_matching=True, also_return_embeddings=False, ref_embeddings=None): """Gets the logits by atrous/image spatial pyramid pooling using attention. Args: images: A tensor of size [batch, height, width, channels]. model_options: A ModelOptions instance to configure models. weight_decay: The weight decay for model variables. reuse: Reuse the model variables or not. is_training: Is training or not. fine_tune_batch_norm: Fine-tune the batch norm parameters or not. reference_labels: The segmentation labels of the reference frame on which attention is applied. batch_size: Integer, the number of videos on a batch num_frames_per_video: Integer, the number of frames per video embedding_dimension: Integer, the dimension of the embedding max_neighbors_per_object: Integer, the maximum number of candidates for the nearest neighbor query per object after subsampling. Can be 0 for no subsampling. k_nearest_neighbors: Integer, the number of nearest neighbors to use. use_softmax_feedback: Boolean, whether to give the softmax predictions of the last frame as additional input to the segmentation head. initial_softmax_feedback: List of Float32 tensors, or None. Can be used to initialize the softmax predictions used for the feedback loop. Only has an effect if use_softmax_feedback is True. embedding_seg_feature_dimension: Integer, the dimensionality used in the segmentation head layers. embedding_seg_n_layers: Integer, the number of layers in the segmentation head. embedding_seg_kernel_size: Integer, the kernel size used in the segmentation head. embedding_seg_atrous_rates: List of integers of length embedding_seg_n_layers, the atrous rates to use for the segmentation head. normalize_nearest_neighbor_distances: Boolean, whether to normalize the nearest neighbor distances to [0,1] using sigmoid, scale and shift. also_attend_to_previous_frame: Boolean, whether to also use nearest neighbor attention with respect to the previous frame. damage_initial_previous_frame_mask: Boolean, whether to artificially damage the initial previous frame mask. Only has an effect if also_attend_to_previous_frame is True. use_local_previous_frame_attention: Boolean, whether to restrict the previous frame attention to a local search window. Only has an effect, if also_attend_to_previous_frame is True. previous_frame_attention_window_size: Integer, the window size used for local previous frame attention, if use_local_previous_frame_attention is True. use_first_frame_matching: Boolean, whether to extract features by matching to the reference frame. This should always be true except for ablation experiments. also_return_embeddings: Boolean, whether to return the embeddings as well. ref_embeddings: Tuple of (first_frame_embeddings, previous_frame_embeddings), each of shape [batch, height, width, embedding_dimension], or None. Returns: outputs_to_logits: A map from output_type to logits. If also_return_embeddings is True, it will also return an embeddings tensor of shape [batch, height, width, embedding_dimension]. """ features, end_points = model.extract_features( images, model_options, weight_decay=weight_decay, reuse=reuse, is_training=is_training, fine_tune_batch_norm=fine_tune_batch_norm) if model_options.decoder_output_stride: decoder_output_stride = min(model_options.decoder_output_stride) if model_options.crop_size is None: height = tf.shape(images)[1] width = tf.shape(images)[2] else: height, width = model_options.crop_size decoder_height = model.scale_dimension(height, 1.0 / decoder_output_stride) decoder_width = model.scale_dimension(width, 1.0 / decoder_output_stride) features = model.refine_by_decoder( features, end_points, crop_size=[height, width], decoder_output_stride=[decoder_output_stride], decoder_use_separable_conv=model_options.decoder_use_separable_conv, model_variant=model_options.model_variant, weight_decay=weight_decay, reuse=reuse, is_training=is_training, fine_tune_batch_norm=fine_tune_batch_norm) with tf.variable_scope('embedding', reuse=reuse): embeddings = split_separable_conv2d_with_identity_initializer( features, embedding_dimension, scope='split_separable_conv2d') embeddings = tf.identity(embeddings, name='embeddings') scaled_reference_labels = tf.image.resize_nearest_neighbor( reference_labels, resolve_shape(embeddings, 4)[1:3], align_corners=True) h, w = decoder_height, decoder_width if num_frames_per_video is None: num_frames_per_video = tf.size(embeddings) // ( batch_size * h * w * embedding_dimension) new_labels_shape = tf.stack([batch_size, -1, h, w, 1]) reshaped_reference_labels = tf.reshape(scaled_reference_labels, new_labels_shape) new_embeddings_shape = tf.stack([batch_size, num_frames_per_video, h, w, embedding_dimension]) reshaped_embeddings = tf.reshape(embeddings, new_embeddings_shape) all_nn_features = [] all_ref_obj_ids = [] # To keep things simple, we do all this separate for each sequence for now. for n in range(batch_size): embedding = reshaped_embeddings[n] if ref_embeddings is None: n_chunks = 100 reference_embedding = embedding[0] if also_attend_to_previous_frame or use_softmax_feedback: queries_embedding = embedding[2:] else: queries_embedding = embedding[1:] else: if USE_CORRELATION_COST: n_chunks = 20 else: n_chunks = 500 reference_embedding = ref_embeddings[0][n] queries_embedding = embedding reference_labels = reshaped_reference_labels[n][0] nn_features_n, ref_obj_ids = nearest_neighbor_features_per_object( reference_embedding, queries_embedding, reference_labels, max_neighbors_per_object, k_nearest_neighbors, n_chunks=n_chunks) if normalize_nearest_neighbor_distances: nn_features_n = (tf.nn.sigmoid(nn_features_n) - 0.5) * 2 all_nn_features.append(nn_features_n) all_ref_obj_ids.append(ref_obj_ids) feat_dim = resolve_shape(features)[-1] features = tf.reshape(features, tf.stack( [batch_size, num_frames_per_video, h, w, feat_dim])) if ref_embeddings is None: # Strip the features for the reference frame. if also_attend_to_previous_frame or use_softmax_feedback: features = features[:, 2:] else: features = features[:, 1:] # To keep things simple, we do all this separate for each sequence for now. outputs_to_logits = {output: [] for output in model_options.outputs_to_num_classes} for n in range(batch_size): features_n = features[n] nn_features_n = all_nn_features[n] nn_features_n_tr = tf.transpose(nn_features_n, [3, 0, 1, 2, 4]) n_objs = tf.shape(nn_features_n_tr)[0] # Repeat features for every object. features_n_tiled = tf.tile(features_n[tf.newaxis], multiples=[n_objs, 1, 1, 1, 1]) prev_frame_labels = None if also_attend_to_previous_frame: prev_frame_labels = reshaped_reference_labels[n, 1] if is_training and damage_initial_previous_frame_mask: # Damage the previous frame masks. prev_frame_labels = mask_damaging.damage_masks(prev_frame_labels, dilate=False) tf.summary.image('prev_frame_labels', tf.cast(prev_frame_labels[tf.newaxis], tf.uint8) * 32) initial_softmax_feedback_n = create_initial_softmax_from_labels( prev_frame_labels, reshaped_reference_labels[n][0], decoder_output_stride=None, reduce_labels=True) elif initial_softmax_feedback is not None: initial_softmax_feedback_n = initial_softmax_feedback[n] else: initial_softmax_feedback_n = None if initial_softmax_feedback_n is None: last_softmax = tf.zeros((n_objs, h, w, 1), dtype=tf.float32) else: last_softmax = tf.transpose(initial_softmax_feedback_n, [2, 0, 1])[ ..., tf.newaxis] assert len(model_options.outputs_to_num_classes) == 1 output = model_options.outputs_to_num_classes.keys()[0] logits = [] n_ref_frames = 1 prev_frame_nn_features_n = None if also_attend_to_previous_frame or use_softmax_feedback: n_ref_frames += 1 if ref_embeddings is not None: n_ref_frames = 0 for t in range(num_frames_per_video - n_ref_frames): to_concat = [features_n_tiled[:, t]] if use_first_frame_matching: to_concat.append(nn_features_n_tr[:, t]) if use_softmax_feedback: to_concat.append(last_softmax) if also_attend_to_previous_frame: assert normalize_nearest_neighbor_distances, ( 'previous frame attention currently only works when normalized ' 'distances are used') embedding = reshaped_embeddings[n] if ref_embeddings is None: last_frame_embedding = embedding[t + 1] query_embeddings = embedding[t + 2, tf.newaxis] else: last_frame_embedding = ref_embeddings[1][0] query_embeddings = embedding if use_local_previous_frame_attention: assert query_embeddings.shape[0] == 1 prev_frame_nn_features_n = ( local_previous_frame_nearest_neighbor_features_per_object( last_frame_embedding, query_embeddings[0], prev_frame_labels, all_ref_obj_ids[n], max_distance=previous_frame_attention_window_size) ) else: prev_frame_nn_features_n, _ = ( nearest_neighbor_features_per_object( last_frame_embedding, query_embeddings, prev_frame_labels, max_neighbors_per_object, k_nearest_neighbors, gt_ids=all_ref_obj_ids[n])) prev_frame_nn_features_n = (tf.nn.sigmoid( prev_frame_nn_features_n) - 0.5) * 2 prev_frame_nn_features_n_sq = tf.squeeze(prev_frame_nn_features_n, axis=0) prev_frame_nn_features_n_tr = tf.transpose( prev_frame_nn_features_n_sq, [2, 0, 1, 3]) to_concat.append(prev_frame_nn_features_n_tr) features_n_concat_t = tf.concat(to_concat, axis=-1) embedding_seg_features_n_t = ( create_embedding_segmentation_features( features_n_concat_t, embedding_seg_feature_dimension, embedding_seg_n_layers, embedding_seg_kernel_size, reuse or n > 0, atrous_rates=embedding_seg_atrous_rates)) logits_t = model.get_branch_logits( embedding_seg_features_n_t, 1, model_options.atrous_rates, aspp_with_batch_norm=model_options.aspp_with_batch_norm, kernel_size=model_options.logits_kernel_size, weight_decay=weight_decay, reuse=reuse or n > 0 or t > 0, scope_suffix=output ) logits.append(logits_t) prev_frame_labels = tf.transpose(tf.argmax(logits_t, axis=0), [2, 0, 1]) last_softmax = tf.nn.softmax(logits_t, axis=0) logits = tf.stack(logits, axis=1) logits_shape = tf.stack( [n_objs, num_frames_per_video - n_ref_frames] + resolve_shape(logits)[2:-1]) logits_reshaped = tf.reshape(logits, logits_shape) logits_transposed = tf.transpose(logits_reshaped, [1, 2, 3, 0]) outputs_to_logits[output].append(logits_transposed) add_image_summaries( images[n * num_frames_per_video: (n+1) * num_frames_per_video], nn_features_n, logits_transposed, batch_size=1, prev_frame_nn_features=prev_frame_nn_features_n) if also_return_embeddings: return outputs_to_logits, embeddings else: return outputs_to_logits
def train(l_args): """Trains the model.""" if l_args.verbose: tf.logging.set_verbosity(tf.logging.INFO) # Create input data pipeline. x_train_files = sorted( glob.glob('/datatmp/Datasets/Cityscapes/leftImg8bit/train/*/*.png')) x_label_files = sorted( glob.glob( '/datatmp/Datasets/Cityscapes/gtFine/train/*/*_labelIds.png')) y_train_files = sorted( glob.glob( '/datatmp/Experiments/semantic_compression/{}/lambda_{}/leftImg8bit/train/*/*.png' .format(l_args.images_dir, l_args.lmbda))) print(len(x_train_files), len(y_train_files)) assert (len(x_train_files) == len(y_train_files)) assert (x_train_files[0].split("/")[-1] == y_train_files[0].split("/")[-1]) assert ( x_train_files[-1].split("/")[-1] == y_train_files[-1].split("/")[-1]) print(x_train_files[0].split("/")[-1][:-16], x_label_files[0].split("/")[-1].split("_gtFine_labelIds.png")[0]) print(x_train_files[-1].split("/")[-1][:-16], x_label_files[-1].split("/")[-1].split("_gtFine_labelIds.png")[0]) assert (len(x_label_files) == len(x_train_files)) assert (x_train_files[0].split("/")[-1][:-16] == x_label_files[0].split( "/")[-1].split("_gtFine_labelIds.png")[0]) assert (x_train_files[-1].split("/")[-1][:-16] == x_label_files[-1].split( "/")[-1].split("_gtFine_labelIds.png")[0]) train_dataset = tf.data.Dataset.from_tensor_slices( (x_train_files, x_label_files, y_train_files)) train_dataset = train_dataset.shuffle( buffer_size=len(x_train_files)).repeat() train_dataset = train_dataset.map(read_pngs, num_parallel_calls=min( l_args.preprocess_threads, l_args.batchsize)) if l_args.resize_images: train_dataset = train_dataset.map( lambda x: tf.image.resize_images(x, [512, 1024])) train_dataset = train_dataset.map(lambda x: tf.random_crop( x, [int(z) for z in l_args.patchsize.split(",")] + [7])) train_dataset = train_dataset.batch(l_args.batchsize) train_dataset = train_dataset.prefetch(l_args.batchsize) train_batch = train_dataset.make_one_shot_iterator().get_next() train_x, _, train_y = train_batch[:, :, :, : 3], train_batch[:, :, :, 3: 4], train_batch[:, :, :, 4:] scaled_train_x, scaled_train_y = train_x / 255., train_y / 255. x_val_files = sorted( glob.glob('/datatmp/Datasets/Cityscapes/leftImg8bit/val/*/*.png')) x_label_files = sorted( glob.glob('/datatmp/Datasets/Cityscapes/gtFine/val/*/*_labelIds.png')) y_val_files = sorted( glob.glob( '/datatmp/Experiments/semantic_compression/{}/lambda_{}/leftImg8bit/val/*/*.png' .format(l_args.images_dir, l_args.lmbda))) print(len(x_val_files), len(y_val_files)) assert (len(x_val_files) == len(y_val_files)) assert (x_val_files[0].split("/")[-1] == y_val_files[0].split("/")[-1]) assert (x_val_files[-1].split("/")[-1] == y_val_files[-1].split("/")[-1]) print(x_val_files[0].split("/")[-1][:-16], x_label_files[0].split("/")[-1].split("_gtFine_labelIds.png")[0]) print(x_val_files[-1].split("/")[-1][:-16], x_label_files[-1].split("/")[-1].split("_gtFine_labelIds.png")[0]) assert (len(x_label_files) == len(x_val_files)) assert (x_val_files[0].split("/")[-1][:-16] == x_label_files[0].split("/") [-1].split("_gtFine_labelIds.png")[0]) assert (x_val_files[-1].split("/")[-1][:-16] == x_label_files[-1].split( "/")[-1].split("_gtFine_labelIds.png")[0]) def set_shape(x): x.set_shape([1024, 2048, 7]) return x val_dataset = tf.data.Dataset.from_tensor_slices( (x_val_files, x_label_files, y_val_files)) val_dataset = val_dataset.map(read_pngs, num_parallel_calls=1) val_dataset = val_dataset.map(set_shape, num_parallel_calls=1) val_dataset = val_dataset.batch(1) val_dataset = val_dataset.prefetch(1) val_batch = val_dataset.make_one_shot_iterator().get_next() val_x, _, val_y = val_batch[:, :, :, :3], val_batch[:, :, :, 3: 4], val_batch[:, :, :, 4:] scaled_val_x, scaled_val_y = val_x / 255., val_y / 255. model_options = common.ModelOptions( outputs_to_num_classes={common.OUTPUT_TYPE: 19}, crop_size=[int(z) for z in l_args.patchsize.split(",")], atrous_rates=None, output_stride=16) x_features, _ = model.extract_features(train_x, model_options) exclude_list = ['global_step'] variables_to_restore = tf.contrib.framework.get_variables_to_restore( exclude=exclude_list) seg_saver = tf.train.Saver(variables_to_restore) print(variables_to_restore) rdn = RDN() scaled_x_tilde_hat = rdn(scaled_train_y) x_tilde_hat = 255.0 * scaled_x_tilde_hat with tf.variable_scope(tf.get_variable_scope(), reuse=True): x_tilde_hat_features, _ = model.extract_features( x_tilde_hat, model_options) var_list = [ var for var in tf.trainable_variables() if var not in variables_to_restore ] print() print() print(var_list) discriminator = PatchDiscriminator(l_args.disc_patchsize) fake = tf.reduce_mean(discriminator(scaled_x_tilde_hat, scaled_train_y)) real = tf.reduce_mean(discriminator(scaled_train_x, scaled_train_y)) generator_loss = -1.0 * fake wasserstein_distance = real - fake discriminator_loss = -1.0 * wasserstein_distance mse = tf.reduce_mean( tf.squared_difference(scaled_train_x, scaled_x_tilde_hat)) * 255**2 ssim = tf.reduce_mean( 1 - tf.image.ssim_multiscale(scaled_x_tilde_hat, scaled_train_x, 1)) l1 = tf.reduce_mean(tf.math.abs(scaled_train_x - scaled_x_tilde_hat)) distortion = { "mse": mse, "l1": l1, "msssim": ssim, "msssim_l1": 2 * l1 + ssim }[l_args.loss_type] distillation = tf.reduce_mean( tf.squared_difference(x_features, x_tilde_hat_features)) train_loss = l_args.rho * distortion + l_args.mu * distillation + generator_loss rdn_weights = var_list discriminator_weights = discriminator.weights print() print() print(discriminator_weights) print() print() print([ var for var in tf.trainable_variables() if var not in variables_to_restore + rdn_weights + discriminator_weights ]) step = tf.train.get_or_create_global_step() generator_optimizer = tf.train.AdamOptimizer(learning_rate=l_args.lr, beta1=0, beta2=0.9) generator_op = generator_optimizer.minimize(train_loss, var_list=rdn_weights, global_step=step) discriminator_optimizer = tf.train.AdamOptimizer( learning_rate=l_args.disc_lr, beta1=0, beta2=0.9) discriminator_op = discriminator_optimizer.minimize( discriminator_loss, var_list=discriminator_weights) train_summary = log_all_summaries(train_x, scaled_x_tilde_hat, scaled_train_y, None, None, train_loss, None, mse, None, ssim, distillation, wasserstein_distance, l1, "train") scaled_x_val_hat = rdn(scaled_val_y) val_fake = tf.reduce_mean(discriminator(scaled_x_val_hat, scaled_val_y)) val_real = tf.reduce_mean(discriminator(scaled_val_x, scaled_val_y)) val_generator_loss = -1.0 * val_fake val_wasserstein = val_real - val_fake with tf.variable_scope(tf.get_variable_scope(), reuse=True): val_x_features, _ = model.extract_features(val_x, model_options) x_val_hat_features, _ = model.extract_features( 255.0 * scaled_x_val_hat, model_options) val_distillation = tf.reduce_mean( tf.squared_difference(val_x_features, x_val_hat_features)) val_mse = tf.reduce_mean( tf.squared_difference(scaled_val_x, scaled_x_val_hat)) * 255**2 val_ssim = tf.reduce_mean( 1 - tf.image.ssim_multiscale(scaled_x_val_hat, scaled_val_x, 1)) val_l1 = tf.reduce_mean(tf.math.abs(scaled_val_x - scaled_x_val_hat)) val_distortion = { "mse": val_mse, "l1": val_l1, "msssim": val_ssim, "msssim_l1": 2 * val_l1 + val_ssim }[l_args.loss_type] val_loss = l_args.rho * val_distortion + l_args.mu * val_distillation + val_generator_loss #valid_summary = log_all_summaries(val_x, scaled_x_val_hat, scaled_val_y, # None, None, val_loss, None, val_mse, None, val_ssim, val_distillation, val_wasserstein, val_l1, "val") def load_pretrain(scaffold, sess): seg_saver.restore(sess, save_path=PATH_TO_TRAINED_MODEL) hooks = [ tf.train.StopAtStepHook(last_step=l_args.last_step), tf.train.NanTensorHook(train_loss), #tf.train.SummarySaverHook(save_secs=120, output_dir=l_args.checkpoint_dir,summary_op=valid_summary), tf.train.SummarySaverHook(save_secs=60, output_dir=l_args.checkpoint_dir, summary_op=train_summary), ] with tf.train.MonitoredTrainingSession( hooks=hooks, checkpoint_dir=l_args.checkpoint_dir, save_checkpoint_secs=1200, save_summaries_steps=None, save_summaries_secs=None, scaffold=tf.train.Scaffold(init_fn=load_pretrain)) as sess: while not sess.should_stop(): sess.run(discriminator_op) sess.run(generator_op)
def experiment(l_args): x_val_files = sorted( glob.glob('/datatmp/Datasets/Cityscapes/leftImg8bit/val/*/*.png')) x_label_files = sorted( glob.glob('/datatmp/Datasets/Cityscapes/gtFine/val/*/*_labelIds.png')) y_val_files = sorted( glob.glob( '/datatmp/Experiments/semantic_compression/{}/lambda_{}/leftImg8bit/val/*/*.png' .format(l_args.images_dir, l_args.lmbda))) print(len(x_val_files), len(y_val_files)) assert (len(x_val_files) == len(y_val_files)) assert (x_val_files[0].split("/")[-1] == y_val_files[0].split("/")[-1]) assert (x_val_files[-1].split("/")[-1] == y_val_files[-1].split("/")[-1]) print(x_val_files[0].split("/")[-1][:-16], x_label_files[0].split("/")[-1].split("_gtFine_labelIds.png")[0]) print(x_val_files[-1].split("/")[-1][:-16], x_label_files[-1].split("/")[-1].split("_gtFine_labelIds.png")[0]) assert (len(x_label_files) == len(x_val_files)) assert (x_val_files[0].split("/")[-1][:-16] == x_label_files[0].split("/") [-1].split("_gtFine_labelIds.png")[0]) assert (x_val_files[-1].split("/")[-1][:-16] == x_label_files[-1].split( "/")[-1].split("_gtFine_labelIds.png")[0]) def set_shape(x): x.set_shape([1024, 2048, 7]) return x val_dataset = tf.data.Dataset.from_tensor_slices( (x_val_files, x_label_files, y_val_files)) val_dataset = val_dataset.map(read_pngs, num_parallel_calls=l_args.preprocess_threads) val_dataset = val_dataset.map(set_shape, num_parallel_calls=l_args.preprocess_threads) val_dataset = val_dataset.batch(1) val_dataset = val_dataset.prefetch(1) val_batch = val_dataset.make_one_shot_iterator().get_next() val_x, _, val_y = val_batch[:, :, :, :3], val_batch[:, :, :, 3: 4], val_batch[:, :, :, 4:] scaled_val_x, scaled_val_y = val_x / 255., val_y / 255. model_options = common.ModelOptions( outputs_to_num_classes={common.OUTPUT_TYPE: 19}, crop_size=[int(z) for z in l_args.patchsize.split(",")], atrous_rates=None, output_stride=16) x_features, _ = model.extract_features(val_x, model_options) with tf.variable_scope(tf.get_variable_scope(), reuse=True): y_features, _ = model.extract_features(val_y, model_options) exclude_list = ['global_step'] variables_to_restore = tf.contrib.framework.get_variables_to_restore( exclude=exclude_list) seg_saver = tf.train.Saver(variables_to_restore) diff_features = x_features - y_features sess = tf.Session() seg_saver.restore(sess, save_path=PATH_TO_TRAINED_MODEL) #while not sess.should_stop(): feature_index = 256 for i in range(500): x, y, fx, fy, dxy = sess.run( [val_x, val_y, x_features, y_features, diff_features]) print(i) rxy = np.maximum(fx, 1e-6) / np.maximum(fy, 1e-6) #for i in range(fx.shape[-1]): fig = plt.figure(figsize=(30, 20)) ax1 = fig.add_subplot(3, 2, 1) ax1.imshow(x[0].astype(np.uint8)) ax2 = fig.add_subplot(3, 2, 2) ax2.imshow(y[0].astype(np.uint8)) ax3 = fig.add_subplot(3, 2, 3) ax3.imshow(fx[0, :, :, feature_index] / np.max(fx[0, :, :, feature_index])) ax4 = fig.add_subplot(3, 2, 4) ax4.imshow(fy[0, :, :, feature_index] / np.max(fy[0, :, :, feature_index])) ax5 = fig.add_subplot(3, 2, 5) ax5.imshow( np.abs(dxy[0, :, :, feature_index]) / np.max(np.abs(dxy[0, :, :, feature_index]))) ax6 = fig.add_subplot(3, 2, 6) ax6.imshow(rxy[0, :, :, feature_index] / np.max(rxy[0, :, :, feature_index])) plt.savefig("feats/{}.png".format(i))