def call(self, inputs, states, training=None): count, state_in = states (p_zs, p_mus, p_sigmas, p_hs, p_flatten_memory, q_zs, q_mus, q_sigmas, q_hs, q_flatten_memory ) = array_ops.split( state_in, [self.units, self.units, self.units, self.units, self.units * self.num_memory_slots, self.units, self.units, self.units, self.units, self.units * self.num_memory_slots], axis=-1) # prior_inputs = q_hs prior_inputs = K.concatenate([q_hs, p_hs]) (p_next_zs, p_next_mus, p_next_sigmas, p_next_hs, p_next_flatten_memory) = self._call_one_layer( prior_inputs, p_flatten_memory, training, self.p_ws) p_next_zs = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_zs, p_next_zs) p_next_mus = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_mus, p_next_mus) p_next_sigmas = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_sigmas, p_next_sigmas) p_next_hs = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_hs, p_next_hs) p_next_flatten_memory = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_flatten_memory, p_next_flatten_memory) posterior_inputs = K.concatenate( [inputs, p_next_mus, p_next_sigmas, q_zs]) (q_next_zs, q_next_mus, q_next_sigmas, q_next_hs, q_next_flatten_memory) = self._call_one_layer( posterior_inputs, q_flatten_memory, training, self.q_ws) state_out = K.concatenate( [p_next_zs, p_next_mus, p_next_sigmas, p_next_hs, p_next_flatten_memory, q_next_zs, q_next_mus, q_next_sigmas, q_next_hs, q_next_flatten_memory]) return ({"p_zs": p_next_zs, "p_mus": p_next_mus, "p_sigmas": p_next_sigmas, "q_zs": q_next_zs, "q_mus": q_next_mus, "q_sigmas": q_next_sigmas}, [count + 1.0, state_out])
def nms_tf(dets, thresh): """Non-maximum suppression with tf graph mode.""" x1 = dets[:, 0] y1 = dets[:, 1] x2 = dets[:, 2] y2 = dets[:, 3] scores = dets[:, 4] areas = (x2 - x1 + 1) * (y2 - y1 + 1) order = tf.argsort(scores, direction='DESCENDING') keep = tf.TensorArray(tf.int32, size=0, dynamic_size=True) index = 0 while tf.shape(order)[0] > 0: i = order[0] keep = keep.write(index, i) xx1 = tf.maximum(x1[i], tf.gather(x1, order[1:])) yy1 = tf.maximum(y1[i], tf.gather(y1, order[1:])) xx2 = tf.minimum(x2[i], tf.gather(x2, order[1:])) yy2 = tf.minimum(y2[i], tf.gather(y2, order[1:])) w = tf.maximum(0.0, xx2 - xx1 + 1) h = tf.maximum(0.0, yy2 - yy1 + 1) intersection = w * h overlap = intersection / (areas[i] + tf.gather(areas, order[1:]) - intersection) inds = tf.where_v2(overlap <= thresh) order = tf.concat(tf.gather(order, inds + 1), axis=1) order = tf.squeeze(order, axis=-1) index += 1 return keep.stack()
def edge_parameters(graph, truth, triplets): nodes = graph.nodes edges = graph.edges senders = graph.senders receivers = graph.receivers n_nodes = tf.shape(nodes)[0] edgei = tf.cast(triplets[0], tf.int64) node = tf.cast(triplets[1], tf.int64) edgeo = tf.cast(triplets[2], tf.int64) radius = tf.cast(triplets[3], tf.float64) triplet_tensor = triplet_parameter_tensor(n_nodes, edgei, node, edgeo, tf.cast(truth, tf.float64), radius) return triplet_tensor.values triplet_sum = tf.sparse_reduce_sum(triplet_tensor, axis=[0, 2]) triplet_count = tf.dtypes.cast(tf.math.bincount(tf.cast(node, tf.int32), ), tf.float64) triplet_count_padded = tf.pad(triplet_count, [[0, n_nodes - tf.shape(triplet_count)[0]]]) triplet_count_padded_safe = tf.where_v2( tf.equal(triplet_count_padded, 0.0), tf.ones(tf.shape(triplet_count_padded), tf.float64), triplet_count_padded) triplet_average = tf.divide(triplet_sum, triplet_count_padded_safe) return triplet_average
def _binarize(actions, pivots): n_xy = (len(pivots[0]) + 1) * (len(pivots[1]) + 1) n_z = len(pivots[2]) + 1 n_theta = len(pivots[3]) + 1 B = actions.get_shape().as_list()[0] input_adim = actions.get_shape().as_list()[2] T = actions.get_shape().as_list()[1] assert input_adim == 4, "only supports [x,y,z,theta] action space for now!" assert len(pivots) == input_adim, "bad discretization pivots array!" binned_actions = [] for a in range(input_adim): binned_action = tf.zeros((B, T), dtype=tf.int32) for p in range(len(pivots[a])): pivot = pivots[a][p] binned_action = tf.where_v2(actions[:, :, a] > pivot, binned_action + 1, binned_action) binned_actions.append(binned_action) xy_act = binned_actions[0] + (len(pivots[0]) + 1) * binned_actions[1] z_act, theta_act = binned_actions[2], binned_actions[3] one_hot_actions = [ tf.one_hot(tensor, n_dim) for tensor, n_dim in zip((xy_act, z_act, theta_act), (n_xy, n_z, n_theta)) ] return one_hot_actions
def rgba_to_rgb(model_config, rgba): """Converts rgba to rgb images. Args: model_config: A ModelConfig object. rgba: A tensor with shape [..., 4]. Should be in the [0, 1] range and of type tf.float32. Returns: A tensor with shape [..., 3]. """ channel_count = rgba.get_shape().as_list()[-1] assert channel_count == 4 assert rgba.dtype == tf.float32 a = rgba[..., 3:4] bgcs = model_config.hparams.bg[0] assert bgcs in ['b', 'w'] bgc = 1.0 if bgcs == 'w' else 0.0 smooth = len( model_config.hparams.bg) > 1 and model_config.hparams.bg[1] == 's' if smooth: rgb = rgba[..., :3] * a + bgc * (1 - a) elif bgcs == 'b': rgb = rgba[..., :3] else: # White, not smooth: rgb = tf.where_v2(tf.cast(a, dtype=tf.bool), rgba[..., :3], 1.0) return rgb
def _compute_ndcg(labels, logits, features, topn=100): zero = tf.constant(0, dtype=tf.float32) neg_val = tf.constant(-1000, dtype=tf.float32) condition = tf.not_equal(features, zero) predictions = tf.where_v2(condition, neg_val, logits) valid_list, _ = Metrics.compute_ndcg(labels, predictions, topn) return tf.metrics.mean(valid_list)
def where(self, condition, x, y, v2=True): if not isinstance(condition, tf.Tensor): msg = "Don't know how to handle `condition` of type {}" raise TypeError(msg.format(type(condition))) if not v2: value = tf.where(condition, x.value, y.value) else: value = tf.where_v2(condition, x.value, y.value) return DenseTensor(value)
def add_noise_to_xyz(xyz, stddev): if stddev == 0.0: return xyz noise_shape = xyz.shape noise = tf.random.normal(shape=noise_shape, mean=0.0, stddev=stddev) noisy_pts = tf.where_v2( tf.reduce_all(tf.equal(xyz, 0.0), axis=-1, keepdims=True), xyz, xyz + noise) return noisy_pts
def calculate_adacos_logits(embds, labels, one_hot, embedding_size, class_num, is_dynamic=True): weights = tf.get_variable( name='final_dense', shape=[embedding_size, class_num], dtype=tf.float32, initializer=tf.contrib.layers.xavier_initializer(uniform=False), trainable=True) init_s = math.sqrt(2) * math.log(class_num - 1) adacos_s = tf.get_variable(name='adacos_s_value', dtype=tf.float32, initializer=tf.constant(init_s), trainable=False, aggregation=tf.VariableAggregation.MEAN) embds = tf.nn.l2_normalize(embds, axis=1, name='normed_embd') weights = tf.nn.l2_normalize(weights, axis=0) logits_before_s = tf.matmul(embds, weights, name='adacos_logits_before_s') if is_dynamic == False: output = tf.multiply(init_s, logits_before_s, name='adacos_fixed_logits') return output theta = tf.acos( tf.clip_by_value(logits_before_s, -1.0 + 1e-10, 1.0 - 1e-10)) B_avg = tf.where_v2(tf.less(one_hot, 1), tf.exp(adacos_s * logits_before_s), tf.zeros_like(logits_before_s)) B_avg = tf.reduce_mean(tf.reduce_sum(B_avg, axis=1), name='B_avg') #B_avg = tf.stop_gradient(B_avg) idxs = tf.squeeze(labels) theta_class = tf.gather_nd(theta, tf.stack( [tf.range(tf.shape(labels)[0]), labels], axis=1), name='theta_class') theta_med = tf.contrib.distributions.percentile(theta_class, q=50) #theta_med = tf.stop_gradient(theta_med) with tf.control_dependencies([theta_med, B_avg]): temp_s = tf.log(B_avg) / tf.cos(tf.minimum(math.pi / 4, theta_med)) adacos_s = tf.assign(adacos_s, temp_s) output = tf.multiply(adacos_s, logits_before_s, name='adacos_dynamic_logits') return output
def sample_bounding_box(self): """The bounding box of the uniform samples.""" if self._bounding_box is None: bbox_samples = self.full_uniform_samples[..., :3] # Use only the inside samples: is_inside = self.full_uniform_samples[..., 3:4] < 0.0 bbox_samples = tf.where_v2(is_inside, bbox_samples, 0.0) self._bounding_box = BoundingBox.from_samples( bbox_samples) # self._full_zero_set_points_and_normals[..., :3]) return self._finite_wrapper(self._bounding_box)
def _forward(x): """Forward. Args: x (tf.Variable): The input to be quantized, weights normally. Returns: tf.Variable: The quantized input. """ expectation = tf.reduce_mean(tf.abs(x)) return tf.where_v2((x >= 0), 1.0, -1.0) * expectation
def element_center_lowres_grid_inside_loss(model_config, training_example, structured_implicit): """Loss that element centers should lie within a voxel of the GT inside.""" element_centers = structured_implicit.element_centers gt_sdf_at_centers, _ = interpolate_util.interpolate( training_example.grid, element_centers, training_example.world2grid) gt_sdf_at_centers = tf.where_v2(gt_sdf_at_centers > model_config.hparams.igt, gt_sdf_at_centers, 0.0) mse = model_config.hparams.ig * tf.reduce_mean( tf.square(gt_sdf_at_centers + 1e-04)) + 1e-05 summarize.summarize_loss(model_config, mse, 'lowres_grid_inside_loss') return mse
def apply_noise_to_depth(depth, stddev): """Applies independent random gaussian noise to each valid depth pixel.""" if stddev == 0.0: return depth assert stddev > 0.0 noise = tf.random.normal( shape=depth.shape, mean=0.0, stddev=stddev) noise = tf.where(noise >= 0.0, noise + 1.0, 1.0 / (1.0 + tf.abs(noise))) noise = tf.where_v2(depth > 0.0, noise, 1.0) return depth * noise
def where(condition, x, y): """Return an array with elements from `x` or `y`, depending on condition. Args: condition: array_like, bool. Where True, yield `x`, otherwise yield `y`. x: see below. y: array_like, optional. Values from which to choose. `x`, `y` and `condition` need to be broadcastable to some shape. Returns: An array. """ condition = array_creation.asarray(condition, dtype=np.bool_) x, y = array_creation.promote_args_types(x, y) return utils.tensor_to_ndarray(tf.where_v2(condition.data, x.data, y.data))
def _forward(x): """Forward. Args: x (tf.Variable): The input to be quantized, weights normally. Returns: tf.Variable: The quantized input. """ # x kernel shape is [height, width, in_channels, out_channels] scaling_factor = tf.reduce_mean(tf.abs(x), axis=[0, 1, 2]) # TODO(wakisaka): tensorflow raise error. # tf.compat.v1.summary.histogram("scaling_factor", scaling_factor) quantized = tf.where_v2((x >= 0), 1.0, -1.0) * scaling_factor return quantized
def sample_loss(model_config, gt_sdf, structured_implicit, global_samples, name, apply_ucf): """Computes an l2 loss for predicted-vs-gt insidedness at samples.""" gt_class = sdf_util.apply_class_transfer(gt_sdf, model_config, soft_transfer=False, offset=0.0) print('[HERE: ldif.training.loss.sample_loss] structured_implicit type:', type(structured_implicit)) # StructuredImplicit if model_config.hparams.lrf == 'l': global_decisions, local_outputs = structured_implicit.class_at_samples( global_samples) local_decisions, local_weights = local_outputs predicted_class = local_decisions gt_class = tf.tile(tf.expand_dims(gt_class, axis=1), [1, model_config.hparams.sc, 1, 1]) weights = tf.stop_gradient(local_weights) elif model_config.hparams.lrf == 'g': # default setting for training global_decisions, local_outputs = structured_implicit.class_at_samples( global_samples) predicted_class = global_decisions weights = 1.0 elif model_config.hparams.lrf == 'x': # TODO(kgenova) Don't forget we need more samples if lrf='x' than otherwise. local_samples, _, local_gt = geom_util.local_views_of_shape( global_samples, structured_implicit.world2local, local_point_count=model_config.hparams.spc, global_features=gt_class) # This is an important distinction: With lrf='x', the implicit values are # required to be a classification decision *on their own*. predicted_class = structured_implicit.implicit_values(local_samples) gt_class = local_gt weights = 1.0 if apply_ucf: is_outside = gt_class > 0.5 is_outside_frac = tf.reduce_mean(tf.cast(is_outside, dtype=tf.float32)) if name is not None: tf.summary.scalar( '%s-%s-outside-frac' % (model_config.inputs['split'], name), is_outside_frac) weights *= tf.where_v2(is_outside, 1.0, model_config.hparams.ucf) loss = weighted_l2_loss(gt_class, predicted_class, weights) return tf.reduce_mean(loss)
def query_ball_point(radius, nsample, xyz, new_xyz): """ Input: radius: local region radius nsample: max sample number in local region xyz: all points, [B, N, C] new_xyz: query points, [B, S, C] Return: group_idx: grouped points index, [B, S, nsample] """ B, N, C = xyz.get_shape().as_list() #B = b.value #N = n.value #C = c.value S = new_xyz.get_shape().as_list()[1] group_idx = tf.tile((tf.reshape(tf.range(0, N), (1, 1, N))), [B, S, 1]) sqrdists = square_distance(new_xyz, xyz) mask = tf.greater(sqrdists, tf.square(radius)) indecies = tf.where(mask) Ns = tf.ones(tf.shape(group_idx)) * N Ns = tf.boolean_mask(Ns, mask) Ns = tf.dtypes.cast(Ns, tf.int32) un_group_idx = tf.tensor_scatter_nd_update(group_idx, indecies, Ns) #group_idx[sqrdists > radius ** 2] = N group_idx = tf.sort(un_group_idx)[:, :, :nsample] #group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] temp = group_idx[:, :, 0] group_first = tf.tile(tf.reshape(temp, (B, S, 1)), [1, 1, nsample]) #group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask = tf.math.equal(group_idx, 1) #mask = tf.equal(group_idx, tf.ones(group_idx.shape)) #mask = group_idx == N updates = tf.boolean_mask(group_first, mask) indecies = tf.where_v2(updates) print("##############################################") print(updates.shape) print(indecies.shape) group_idx = tf.tensor_scatter_nd_update(group_idx, indecies, updates) #group_idx[mask] = group_first[mask] return group_idx
def _add_losses(self, sigma_rpn=3.0): with tf.compat.v1.variable_scope("loss_" + self._tag): rpn_cls_score = tf.reshape( self._predictions["rpn_cls_score_reshape"], [-1, 2]) rpn_label = tf.reshape(self._anchor_targets["rpn_labels"], [-1]) rpn_select = tf.where_v2(tf.not_equal(rpn_label, -1)) rpn_cls_score = tf.reshape(tf.gather(rpn_cls_score, rpn_select), [-1, 2]) rpn_label = tf.reshape(tf.gather(rpn_label, rpn_select), [-1]) rpn_cross_entropy = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=rpn_cls_score, labels=rpn_label)) rpn_bbox_pred = self._predictions["rpn_bbox_pred"] rpn_bbox_targets = self._anchor_targets["rpn_bbox_targets"] rpn_bbox_inside_weights = self._anchor_targets[ "rpn_bbox_inside_weights"] rpn_bbox_outside_weights = self._anchor_targets[ "rpn_bbox_outside_weights"] rpn_loss_box = self._smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights, sigma=sigma_rpn, dim=[1, 2, 3]) cls_score = self._predictions["cls_score"] label = tf.reshape(self._proposal_targets["labels"], [-1]) cross_entropy = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=tf.reshape(cls_score, [-1, self._num_classes]), labels=label)) bbox_pred = self._predictions["bbox_pred"] bbox_targets = self._proposal_targets["bbox_targets"] bbox_inside_weights = self._proposal_targets["bbox_inside_weights"] bbox_outside_weights = self._proposal_targets[ "bbox_outside_weights"] loss_box = self._smooth_l1_loss(bbox_pred, bbox_targets, bbox_inside_weights, bbox_outside_weights) self._losses["cross_entropy"] = cross_entropy self._losses["loss_box"] = loss_box self._losses["rpn_cross_entropy"] = rpn_cross_entropy self._losses["rpn_loss_box"] = rpn_loss_box self._losses[ "total_loss"] = cross_entropy + loss_box + rpn_cross_entropy + rpn_loss_box return self._losses["total_loss"]
def zero_by_mask(mask, vals, replace_with=0.0): """"Sets the invalid part of vals to the value of replace_with. Args: mask: Boolean tensor with shape [..., 1]. vals: Tensor with shape [..., channel_count]. replace_with: Value to put in invalid locations, if not 0.0. Dtype should be compatible with that of vals. Can be a scalar tensor. Returns: Tensor with shape [..., channel_count]. """ assert mask.dtype == tf.as_dtype(np.bool) ms = mask.get_shape().as_list() vs = vals.get_shape().as_list() mask = tf.ensure_shape(mask, vs[:-1] + [1]) vals = tf.ensure_shape(vals, ms[:-1] + [vs[-1]]) vals = tf.where_v2(mask, vals, replace_with) return vals
def transform_depth_dodeca_to_xyz_dodeca(depth_dodeca): """Lifts a dodecahedron of depth images to world space.""" batch_size = depth_dodeca.get_shape().as_list()[0] cam2world = get_dodeca_camera_to_worlds() cam2world = np.reshape(cam2world, [1, 20, 4, 4]).astype(np.float32) world2cams = np.linalg.inv(cam2world) world2cams = np.tile(world2cams, [batch_size, 1, 1, 1]) world2cams = tf.unstack(tf.constant(world2cams, dtype=tf.float32), axis=1) depth_im_stack = tf.unstack(depth_dodeca, axis=1) assert len(depth_im_stack) == 20 assert len(world2cams) == 20 xyz_images = [] for i in range(20): world2cam = world2cams[i] depth_im = depth_im_stack[i] xyz_image = depth_image_to_xyz_image(depth_im, world2cam, xfov=0.5)[0] xyz_images.append(xyz_image) xyz_images = tf.stack(xyz_images, axis=1) xyz_images = tf.where_v2(depth_dodeca > 0.0, xyz_images, 0.0) return xyz_images
def iou_class(self, confusion_matrix): """ :param confusion_matrix: running confusion matrix :return: iou: calculated IOU per class. """ pred_areas = tf.cast(tf.reduce_sum(confusion_matrix, axis=0), tf.float32) labels_areas = tf.cast(tf.reduce_sum(confusion_matrix, axis=1), tf.float32) intersection = tf.cast(tf.linalg.diag_part(confusion_matrix), tf.float32) union = pred_areas + labels_areas - intersection union = tf.where_v2(tf.greater(union, 0), union, tf.ones_like(union)) iou = tf.div(intersection, union) return iou
def _load(self, dataset): """Defines how to yield batches from dataset.""" dataset = dataset.prefetch(1) iterator = dataset.make_one_shot_iterator() features, labels = iterator.get_next() if self.schema: for ff in self.schema.features_to_forward: features[ff + '_'] = tf.identity(features[ff]) if self.is_train: w = self._normalize_weight() features['weight'] = (tf.where_v2( tf.equal(tf.cast(labels, tf.int64), 0), w[1], w[0])) else: feature_columns = [ col for col in self.schema.names if col != self.schema.label ] features['weight'] = tf.ones_like(features[feature_columns[0]], dtype=tf.float32) return features, labels
def G_main( latents_in, # First input: Latent vectors (Z) [minibatch, latent_size]. labels_in, # Second input: Conditioning labels [minibatch]. truncation_psi=None, # Style strength multiplier for the truncation trick. None = disable. truncation_cutoff=None, # Number of layers for which to apply the truncation trick. None = disable. truncation_psi_val=None, # Value for truncation_psi to use during validation. truncation_cutoff_val=None, # Value for truncation_cutoff to use during validation. dlatent_avg_beta=0.995, # Decay for tracking the moving average of W during training. None = disable. style_mixing_prob=0.9, # Probability of mixing styles during training. None = disable. is_training=False, # Network is under training? Enables and disables specific features. is_validation=False, # Network is under validation? Chooses which value to use for truncation_psi. return_dlatents=False, # Return dlatents in addition to the images? is_template_graph=False, # True = template graph constructed by the Network class, False = actual evaluation. components=dnnlib.EasyDict(), # Container for sub-networks. Retained between calls. mapping_func='G_mapping', # Build func name for the mapping network. synthesis_func='G_synthesis_stylegan2', # Build func name for the synthesis network. **kwargs): # Arguments for sub-networks (mapping and synthesis). # Validate arguments. assert not is_training or not is_validation assert isinstance(components, dnnlib.EasyDict) if is_validation: truncation_psi = truncation_psi_val truncation_cutoff = truncation_cutoff_val if is_training or (truncation_psi is not None and not tflib.is_tf_expression(truncation_psi) and truncation_psi == 1): truncation_psi = None if is_training: truncation_cutoff = None if not is_training or (dlatent_avg_beta is not None and not tflib.is_tf_expression(dlatent_avg_beta) and dlatent_avg_beta == 1): dlatent_avg_beta = None if not is_training or (style_mixing_prob is not None and not tflib.is_tf_expression(style_mixing_prob) and style_mixing_prob <= 0): style_mixing_prob = None # Setup components. if 'synthesis' not in components: components.synthesis = tflib.Network('G_synthesis', func_name=globals()[synthesis_func], **kwargs) num_layers = components.synthesis.input_shape[1] dlatent_size = components.synthesis.input_shape[2] if 'mapping' not in components: components.mapping = tflib.Network('G_mapping', func_name=globals()[mapping_func], dlatent_broadcast=num_layers, **kwargs) # Setup variables. lod_in = tf.get_variable('lod', initializer=np.float32(0), trainable=False) dlatent_avg = tf.get_variable('dlatent_avg', shape=[dlatent_size], initializer=tf.initializers.zeros(), trainable=False) # Evaluate mapping network. dlatents = components.mapping.get_output_for(latents_in, labels_in, is_training=is_training, **kwargs) dlatents = tf.cast(dlatents, tf.float32) # Update moving average of W. if dlatent_avg_beta is not None: with tf.variable_scope('DlatentAvg'): batch_avg = tf.reduce_mean(dlatents[:, 0], axis=0) update_op = tf.assign(dlatent_avg, tflib.lerp(batch_avg, dlatent_avg, dlatent_avg_beta)) with tf.control_dependencies([update_op]): dlatents = tf.identity(dlatents) # Perform style mixing regularization. if style_mixing_prob is not None: with tf.variable_scope('StyleMix'): latents2 = tf.random_normal(tf.shape(latents_in)) dlatents2 = components.mapping.get_output_for(latents2, labels_in, is_training=is_training, **kwargs) dlatents2 = tf.cast(dlatents2, tf.float32) layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] cur_layers = num_layers - tf.cast(lod_in, tf.int32) * 2 mixing_cutoff = tf.where_v2( tf.random_uniform([tf.shape(dlatents)[0]], 0.0, 1.0) < style_mixing_prob, tf.random_uniform([tf.shape(dlatents)[0]], 1, cur_layers, dtype=tf.int32), cur_layers[np.newaxis])[:, np.newaxis, np.newaxis] dlatents = tf.where(tf.broadcast_to(layer_idx < mixing_cutoff, tf.shape(dlatents)), dlatents, dlatents2) # Apply truncation trick. if truncation_psi is not None: with tf.variable_scope('Truncation'): layer_idx = np.arange(num_layers)[np.newaxis, :, np.newaxis] layer_psi = np.ones(layer_idx.shape, dtype=np.float32) if truncation_cutoff is None: layer_psi *= truncation_psi else: layer_psi = tf.where(layer_idx < truncation_cutoff, layer_psi * truncation_psi, layer_psi) dlatents = tflib.lerp(dlatent_avg, dlatents, layer_psi) # Evaluate synthesis network. deps = [] if 'lod' in components.synthesis.vars: deps.append(tf.assign(components.synthesis.vars['lod'], lod_in)) with tf.control_dependencies(deps): images_out = components.synthesis.get_output_for(dlatents, is_training=is_training, force_clean_graph=is_template_graph, **kwargs) # Return requested outputs. images_out = tf.identity(images_out, name='images_out') if return_dlatents: return images_out, dlatents return images_out
def local_views_of_shape(global_points, world2local, local_point_count, global_normals=None, global_features=None, zeros_invalid=False, zero_threshold=1e-6, expand_region=True, threshold=4.0): """Computes a set of local point cloud observations from a global observation. It is assumed for optimization purposes that global_point_count >> local_point_count. Args: global_points: Tensor with shape [batch_size, global_point_count, 3]. The input observation point cloud in world space. world2local: Tensor with shape [batch_size, frame_count, 4, 4]. Each 4x4 matrix maps from points in world space to points in a local frame. local_point_count: Integer. The number of points to output in each local frame. Whatever this value, the local_point_count closest points to each local frame origin will be returned. global_normals: Tensor with shape [batch_size, global_point_count, 3]. The input observation point cloud's normals in world space. Optional. global_features: Tensor with shape [batch_size, global_point_count, feature_count]. The input observation point cloud features, in any space. Optional. zeros_invalid: Whether to consider the vector [0, 0, 0] to be invalid. zero_threshold: Values less than this in magnitude are considered to be 0. expand_region: Whether to expand outward from the threshold region. If false, fill with zeros. threshold: The distance threshold. Returns: local_points: Tensor with shape [batch_size, frame_count, local_point_count, 3]. local_normals: Tensor with shape [batch_size, frame_count, local_point_count, 3]. None if global_normals not provided. local_features: Tensor with shape [batch_size, frame_count, local_point_count, feature_count]. Unlike the local normals and points, these are not transformed because there may or may not be a good transformation to apply, depending on what the features are. But they will be the features associated with the local points that were chosen. None if global_features not provided. """ # Example use case: batch_size = 64, global_point_count = 100000 # local_point_count = 1000, frame_count = 25. Then: # global_points has size 64*100000*3*4 = 73mb # local_points has size 64*1000*25*3*4 = 18mb # If we made an intermediate tensor with shape [batch_size, frame_count, # global_point_count, 3] -> 64 * 25 * 100000 * 3 * 4 = 1.8 Gb -> bad. batch_size, _, _ = global_points.get_shape().as_list() if zeros_invalid: # If we just set the global points to be very far away, they won't be a # nearest neighbor abs_zero = False if abs_zero: is_zero = tf.reduce_all( tf.equal(global_points, 0.0), axis=-1, keepdims=True) else: is_zero = tf.reduce_all( tf.abs(global_points) < zero_threshold, axis=-1, keepdims=True) global_points = tf.where_v2(is_zero, 100.0, global_points) _, frame_count, _, _ = world2local.get_shape().as_list() local2world = tf.matrix_inverse(world2local) # *sigh* oh well, guess we have to do the transform: tiled_global = tf.tile( tf.expand_dims(to_homogeneous(global_points, is_point=True), axis=1), [1, frame_count, 1, 1]) all_local_points = tf.matmul(tiled_global, world2local, transpose_b=True) distances = tf.norm(all_local_points, axis=-1) # thresh = 4.0 # TODO(kgenova) This is potentially a problem because it could introduce # randomness into the pipeline at inference time. probabilities = tf.random.uniform(distances.get_shape().as_list()) is_valid = distances < threshold sample_order = tf.where(is_valid, probabilities, -distances) _, top_indices = tf.math.top_k( sample_order, k=local_point_count, sorted=False) local_points = tf.gather(all_local_points, top_indices, batch_dims=2, axis=-2) local_points = tf.ensure_shape( local_points[..., :3], [batch_size, frame_count, local_point_count, 3]) is_valid = tf.expand_dims(is_valid, axis=-1) # log.info('is_valid shape: ', is_valid.get_shape().as_list()) # log.info('top_indices shape: ', top_indices.get_shape().as_list()) # log.info('all_local_points shape: ', all_local_points.get_shape().as_list()) points_valid = tf.gather(is_valid, top_indices, batch_dims=2, axis=-2) # points_valid = tf.expand_dims(points_valid, axis=-1) points_valid = tf.ensure_shape( points_valid, [batch_size, frame_count, local_point_count, 1]) if not expand_region: local_points = tf.where_v2(points_valid, local_points, 0.0) # valid_feature = tf.cast(points_valid, dtype=tf.float32) if global_normals is not None: tiled_global_normals = tf.tile( tf.expand_dims(to_homogeneous(global_normals, is_point=False), axis=1), [1, frame_count, 1, 1]) # Normals get transformed by the inverse-transpose matrix: all_local_normals = tf.matmul( tiled_global_normals, local2world, transpose_b=False) local_normals = tf.gather( all_local_normals, top_indices, batch_dims=2, axis=-2) # Remove the homogeneous coordinate now. It isn't a bug to normalize with # it since it's zero, but it's confusing. local_normals = tf.math.l2_normalize(local_normals[..., :3], axis=-1) local_normals = tf.ensure_shape( local_normals, [batch_size, frame_count, local_point_count, 3]) else: local_normals = None if global_features is not None: feature_count = global_features.get_shape().as_list()[-1] local_features = tf.gather( global_features, top_indices, batch_dims=1, axis=-2) local_features = tf.ensure_shape( local_features, [batch_size, frame_count, local_point_count, feature_count]) else: local_features = None return local_points, local_normals, local_features, points_valid
def train_student_net(label1, label2, label3, label4): total_x_train = [] total_y_train = [] total_x_test = [] total_y_test = [] x_train, y_train, x_test, y_test = return_part_mnist([label1, label2]) total_x_train.extend(x_train) total_y_train.extend(y_train) total_x_test.extend(x_test) total_y_test.extend(y_test) x_train, y_train, x_test, y_test = return_part_mnist([label3, label4]) total_x_train.extend(x_train) total_y_train.extend(y_train) total_x_test.extend(x_test) total_y_test.extend(y_test) rnd_idxs = list(range(len(total_x_train))) np.random.shuffle(rnd_idxs) total_x_train = np.asarray(total_x_train, np.float) total_x_train = total_x_train[rnd_idxs] total_x_train = np.asarray(total_x_train, np.float) total_x_train = total_x_train[:-50] batch_size = 64 train_ds = tf.data.Dataset.from_tensor_slices(total_x_train) train_ds = train_ds.batch(batch_size).repeat() # batch 能给数据集增加批维度 train_it = train_ds.make_one_shot_iterator() x_train_it = train_it.get_next() org_input_tensor = tf.keras.layers.Input(tensor=x_train_it) teacher_net1, t1_output_blocks, t1_logits = build_base_model( 'vgg16', include_top=False, classes=2, name='0_1', restore_path='./0_1.h5', trainable=False, org_input_tensor=org_input_tensor) teacher_net2, t2_output_blocks, t2_logits = build_base_model( 'vgg16', include_top=False, classes=2, name='2_3', restore_path='./2_3.h5', trainable=False, org_input_tensor=org_input_tensor) student_net, s_output_blocks, s_logits = build_base_model( 'vgg16', include_top=False, classes=4, name='0_1_2_3', org_input_tensor=org_input_tensor) t1_entropy = cal_entropy(t1_logits) t2_entropy = cal_entropy(t2_logits) t1_probs = tf.nn.softmax(t1_logits) t1_pred_labels = tf.argmax(t1_probs, axis=1) t2_probs = tf.nn.softmax(t2_logits) t2_pred_labels = tf.argmax(t2_probs, axis=1) s_probs = tf.nn.softmax(s_logits) s_labels = tf.argmax(s_probs, axis=1) image_results = tf_put_text(org_input_tensor, t1_pred_labels, t2_pred_labels, s_labels, t1_entropy, t2_entropy) tf.summary.image('mnist/image', org_input_tensor, max_outputs=3) tf.summary.image('mnist/results', tf.cast(tf.expand_dims(image_results, axis=3), tf.uint8), max_outputs=3) t1_blocks = [ FeatureAlignmentModule(block.get_shape().as_list()[-1]) for block in t1_output_blocks ] t2_blocks = [ FeatureAlignmentModule(block.get_shape().as_list()[-1]) for block in t2_output_blocks ] t1_cb_outs = [ t1_blocks[i].call(t1_output_blocks[i], s_output_blocks[i]) for i in range(len(t1_output_blocks)) ] t2_cb_outs = [ t2_blocks[i].call(t2_output_blocks[i], s_output_blocks[i]) for i in range(len(t2_output_blocks)) ] t1_cb_losses = [ FeatureAlignmentModule.block_loss(*t1_cb_outs[i]) for i in range(len(t1_cb_outs)) ] t2_cb_losses = [ FeatureAlignmentModule.block_loss(*t2_cb_outs[i]) for i in range(len(t2_cb_outs)) ] for idx, t1_cb_loss in enumerate(t1_cb_losses): tf.summary.scalar('t1_block_loss/level' + str(idx), tf.reduce_mean(t1_cb_loss)) for idx, t2_cb_loss in enumerate(t2_cb_losses): tf.summary.scalar('t2_block_loss/level_' + str(idx), tf.reduce_mean(t2_cb_loss)) t1_cb_losses = tf.add_n(t1_cb_losses) / len(t1_cb_losses) t2_cb_losses = tf.add_n(t2_cb_losses) / len(t2_cb_losses) tf.summary.scalar('t1_block_loss/total', tf.reduce_mean(t1_cb_losses)) tf.summary.scalar('t2_block_loss/total', tf.reduce_mean(t2_cb_losses)) block_loss_layer = tf.keras.layers.Lambda( lambda paras: (lambda entropy1, entropy2, losses1, losses2: tf.where( tf.keras.backend.less_equal(entropy1, entropy2), losses1, losses2, name='cal_block_loss'))(*paras)) block_loss_tensor = block_loss_layer( (t1_entropy, t2_entropy, t1_cb_losses, t2_cb_losses)) # 开始计算soft loss # print('t1_logits is ', t1_logits) # [5,2] new_t1_logits = tf.keras.backend.concatenate( [t1_logits, tf.keras.backend.zeros_like(t1_logits)], axis=1) new_t2_logits = tf.keras.backend.concatenate( [tf.keras.backend.zeros_like(t2_logits), t2_logits], axis=1) soft_loss = tf.where_v2( tf.keras.backend.less_equal(t1_entropy, t2_entropy), cal_soft_loss(new_t1_logits, s_logits), cal_soft_loss(new_t2_logits, s_logits)) soft_loss = soft_loss * 200. print('soft loss is ', soft_loss) tf.summary.scalar('soft_loss', tf.reduce_mean(soft_loss)) total_loss_layer = tf.keras.layers.Lambda( lambda x: (lambda xx, yy: tf.reduce_mean(xx + yy))(*x)) total_loss_tensor = total_loss_layer((block_loss_tensor, soft_loss)) target_model = tf.keras.Model(inputs=org_input_tensor, outputs=[t1_logits, t2_logits, s_logits]) target_model.add_loss(total_loss_tensor) target_model.compile(optimizer=keras.optimizers.Adam(lr=0.0001)) tensorboard_callback = Tensorboard(summary_op=tf.summary.merge_all(), log_dir='./log/', batch_interval=5) target_model.fit(None, None, epochs=5, verbose=1, batch_size=batch_size, steps_per_epoch=24704 / batch_size, callbacks=[tensorboard_callback]) target_model.save_weights('./0_1_2_3.h5')
def get_relu_gate(name, x, alpha=0.): with tf.name_scope(name): return tf.where_v2(tf.greater(x, 0.), tf.constant(1., dtype=tf.float32), tf.constant(alpha, dtype=tf.float32))
def interpolate(grid, samples, world2grid): """Returns the trilinearly interpolated function values on the grid. Args: grid: Tensor with shape [batch_size, depth, height, width]. The function to interpolate. samples: Tensor with shape [batch_size, sample_count, 3]. The xyz triplets. world2grid: Tensor with shape [batch_size, 4, 4]. A rigid body transform mapping from sample coordinates to grid coordinates. Returns: sdf: Tensor with shape [batch_size, sample_count, 1]. The ground truth sdf at the sample locations. Differentiable w.r.t. samples. invalid: Tensor with shape [batch_size, sample_count, 1] and type tf.bool. True where the input samples map outside the supplied grid. The sdf tensor will be zero at these locations and there will be no gradient. """ xyzw_samples = tf.pad(samples, paddings=tf.constant([[0, 0], [0, 0], [0, 1]]), mode='CONSTANT', constant_values=1) ensure_shape(samples, [-1, -1, 3]) batch_size, sample_count = samples.get_shape().as_list()[:2] ensure_shape(grid, [batch_size, -1, -1, -1]) xyzw_samples = tf.ensure_shape(xyzw_samples, [batch_size, sample_count, 4]) grid_frame_samples = tf.matmul(xyzw_samples, world2grid, transpose_b=True)[..., :3] lower_coords = tf.floor(grid_frame_samples) alpha = grid_frame_samples - lower_coords min_alpha = 1e-05 max_alpha = 1 - 1e-05 alpha = tf.clip_by_value(alpha, min_alpha, max_alpha) lower_coords = tf.cast(lower_coords, tf.int32) upper_coords = tf.cast(tf.ceil(grid_frame_samples), tf.int32) depth, height, width = grid.get_shape().as_list()[1:] max_vals = np.array([[[width, height, depth]]], dtype=np.int32) - 1 max_vals = tf.constant(max_vals) is_invalid = tf.logical_or( tf.reduce_any(lower_coords < 0, axis=-1, keep_dims=True), tf.reduce_any(upper_coords > max_vals, axis=-1, keep_dims=True)) log.info('is_invalid vs lower_coords: %s vs %s' % (repr(is_invalid.get_shape().as_list()), repr(lower_coords.get_shape().as_list()))) lower_coords = tf.where_v2(is_invalid, 0, lower_coords) log.info('Post-where lower_coords: %s' % repr(lower_coords.get_shape().as_list())) upper_coords = tf.where_v2(is_invalid, 0, upper_coords) lca = tf.split(lower_coords, 3, axis=-1)[::-1] uca = tf.split(upper_coords, 3, axis=-1)[::-1] aca = tf.unstack(alpha, axis=-1)[::-1] lca[0] = tf.ensure_shape(lca[0], [batch_size, sample_count, 1]) lca[1] = tf.ensure_shape(lca[1], [batch_size, sample_count, 1]) lca[2] = tf.ensure_shape(lca[2], [batch_size, sample_count, 1]) batch_indices = np.arange(batch_size, dtype=np.int32) batch_indices = np.reshape(batch_indices, [batch_size, 1, 1]) batch_indices = np.tile(batch_indices, [1, sample_count, 1]) batch_indices = tf.constant(batch_indices, dtype=tf.int32) def batch_gather_nd(source, index_list): return tf.gather_nd(source, tf.concat([batch_indices] + index_list, axis=-1)) def lerp(lval, uval, alpha): return lval * (1 - alpha) + uval * (alpha) def lookup_and_lerp(lidx, uidx, alpha): return lerp(batch_gather_nd(grid, lidx), batch_gather_nd(grid, uidx), alpha) c00 = lookup_and_lerp([lca[0], lca[1], lca[2]], [uca[0], lca[1], lca[2]], aca[0]) c01 = lookup_and_lerp([lca[0], lca[1], uca[2]], [uca[0], lca[1], uca[2]], aca[0]) c10 = lookup_and_lerp([lca[0], uca[1], lca[2]], [uca[0], uca[1], lca[2]], aca[0]) c11 = lookup_and_lerp([lca[0], uca[1], uca[2]], [uca[0], uca[1], uca[2]], aca[0]) c0 = lerp(c00, c10, aca[1]) c1 = lerp(c01, c11, aca[1]) sdf = tf.expand_dims(lerp(c0, c1, aca[2]), axis=-1) log.info('is_invalid vs sdf coords: %s vs %s' % (repr( is_invalid.get_shape().as_list()), repr(sdf.get_shape().as_list()))) sdf = tf.where_v2(is_invalid, 1e-5, sdf) sdf = tf.ensure_shape(sdf, [batch_size, sample_count, 1]) return sdf, is_invalid
def call(self, inputs, states, training=None): count, state_in = states (p_zs, p_mus, p_sigmas, p_hs, p_flatten_memory, q_zs, q_mus, q_sigmas, q_hs, q_flatten_memory ) = array_ops.split( state_in, [self.units, self.units, self.units, self.units, self.units * self.prior_memory_slots, self.units, self.units, self.units, self.units, self.units * self.posterior_memory_slots], axis=-1) # prior_inputs = q_hs if self.prior_core == 'rmc': core = self._call_rmc_core elif self.prior_core == 'lstm': core = self._call_lstm_core else: raise ValueError("Cannot Support %s core" % self.prior_core) prior_inputs = K.concatenate([q_hs, p_hs]) (p_next_zs, p_next_mus, p_next_sigmas, p_next_hs, p_next_flatten_memory) = core( prior_inputs, p_flatten_memory, training, self.p_ws, self.prior_memory_slots) p_next_zs = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_zs, p_next_zs) p_next_mus = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_mus, p_next_mus) p_next_sigmas = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_sigmas, p_next_sigmas) p_next_hs = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_hs, p_next_hs) p_next_flatten_memory = tf.where_v2( tf.floormod(count, self.time_scale) > 0, p_flatten_memory, p_next_flatten_memory) if self.posterior_core == 'rmc': core = self._call_rmc_core elif self.posterior_core == 'lstm': core = self._call_lstm_core else: raise ValueError("Cannot Support %s core" % self.posterior_core) posterior_inputs = K.concatenate( [inputs, p_next_mus, p_next_sigmas, q_zs]) (q_next_zs, q_next_mus, q_next_sigmas, q_next_hs, q_next_flatten_memory) = core( posterior_inputs, q_flatten_memory, training, self.q_ws, self.posterior_memory_slots) state_out = K.concatenate( [p_next_zs, p_next_mus, p_next_sigmas, p_next_hs, p_next_flatten_memory, q_next_zs, q_next_mus, q_next_sigmas, q_next_hs, q_next_flatten_memory]) return ({"p_zs": p_next_zs, "p_mus": p_next_mus, "p_sigmas": p_next_sigmas, "q_zs": q_next_zs, "q_mus": q_next_mus, "q_sigmas": q_next_sigmas}, [count + 1.0, state_out])
def model_fn(features, labels, mode, params): tf.compat.v1.logging.info("*** Features ***") for name in sorted(features.keys()): tf.compat.v1.logging.info(" name = %s, shape = %s" % (name, features[name].shape)) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] span_mask = features["span_mask"] is_training = (mode == tf.estimator.ModeKeys.TRAIN) candidate_ner_scores = create_model(bert_config, is_training, input_ids, input_mask, segment_ids, span_mask, num_labels) tvars = tf.trainable_variables() initialized_variable_names = {} if init_checkpoint and (hvd is None or hvd.rank() == 0): (assignment_map, initialized_variable_names ) = modeling.get_assignment_map_from_checkpoint( tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) tf.compat.v1.logging.info("**** Trainable Variables ****") for var in tvars: init_string = "" if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" tf.compat.v1.logging.info(" name = %s, shape = %s%s", var.name, var.shape, init_string) output_spec = None if mode == tf.estimator.ModeKeys.TRAIN: gold_labels = features['gold_labels'] gold_labels = tf.boolean_mask(gold_labels, tf.not_equal(gold_labels, -1)) # 真实实体 true_labels = tf.boolean_mask(gold_labels, tf.not_equal(gold_labels, 0)) pred_labels = tf.boolean_mask(candidate_ner_scores, tf.not_equal(gold_labels, 0)) # 只统计真实实体的准确率,否则准确率虚高 accuracy = tf.metrics.accuracy( true_labels, tf.arg_max(pred_labels, dimension=-1)) negative_labels = tf.boolean_mask(gold_labels, tf.equal(gold_labels, 0)) negative_pred_labels = tf.boolean_mask(candidate_ner_scores, tf.equal(gold_labels, 0)) # 只统计真实实体的准确率,否则准确率虚高 negative_accuracy = tf.metrics.accuracy( negative_labels, tf.arg_max(negative_pred_labels, dimension=-1)) tensor_to_log = { "positive_accuracy": accuracy[1] * 100, "negative_accuracy": negative_accuracy[1] * 100 } if FLAGS.focal_loss: gold_labels = tf.one_hot(gold_labels, depth=num_labels, dtype=tf.float32) total_loss = focal_loss(candidate_ner_scores, gold_labels) elif FLAGS.dice_loss: total_loss = self_adjust_dice_loss(candidate_ner_scores, gold_labels) else: total_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=gold_labels, logits=candidate_ner_scores) if 0.0 < FLAGS.neg_sample < 1.0: # 对负样本进行采样 sample_vals = tf.random.uniform(shape=tf.shape(gold_labels)) masks = tf.where_v2( tf.logical_and(gold_labels <= 0, sample_vals >= FLAGS.neg_sample), 0.0, 1.0) total_loss = masks * total_loss batch_size = tf.shape(input_ids)[0] total_loss = tf.reduce_sum(total_loss) / tf.to_float(batch_size) train_op = optimization.create_optimizer(total_loss, learning_rate, num_train_steps, num_warmup_steps, hvd, amp) output_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=[ tf.train.LoggingTensorHook(tensor_to_log, every_n_iter=50) ]) elif mode == tf.estimator.ModeKeys.PREDICT: output_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={"score": tf.expand_dims(candidate_ner_scores, 0) } # 因为用了boolen_mask,导致原来的batch信息丢失 ) else: raise ValueError("Only TRAIN and PREDICT modes are supported: %s" % (mode)) return output_spec
def main(): """Create the model and start the evaluation process.""" args = get_arguments() if not args.save_original: IMG_PATH = "" MASK_PATH = "" else: MASK_PATH = args.data_path + "SegmentationClassAug/" IMG_PATH = args.data_path + "JPEGImages/" if args.auto: model = Path(args.model_weights).stem args.num_classes = 2 if "-bin" in model else 21 args.level = 1 if "-lvl1" in model else args.level model_parts = model.split("-") args.model = model_parts[0] if "sigma" in model: args.model += "-"+model_parts[1]+"-"+model_parts[2] print(args.model) # Prepare image so we can feed it to the compression module. image = tf.image.decode_jpeg(tf.io.read_file(IMG_PATH + args.img_path), channels=3) image = tf.expand_dims(image, 0) image = tf.cast(image, tf.uint8) # Get compression model. compressor = get_model_for_level(args.level, latent=True, sigma= "sigma" in args.model) # Extract latent space latent_batch = tf.cast(compressor(image), tf.float32) # Create network. if args.model == "cResNet": net = cResNet91({'data': latent_batch[0]}, num_classes=args.num_classes) elif args.model == "cResNet40": net = cResNet40({'data': latent_batch[0]}, num_classes=args.num_classes) elif args.model == "cResNet42": net = cResNet42({'data': latent_batch[0]}, num_classes=args.num_classes) elif args.model == "cResNet-sigma-conc": net = cResNet_sigma_conc({'y_hat': latent_batch[0], 'sigma_hat': latent_batch[1]}, num_classes=args.num_classes) elif args.model == "cResNet-sigma-add": net = cResNet_sigma_add({'y_hat': latent_batch[0], 'sigma_hat': latent_batch[1]}, num_classes=args.num_classes) elif args.model == "cResNet-sigma-resblock": net = cResNet_sigma_resblock({'y_hat': latent_batch[0], 'sigma_hat': latent_batch[1]}, num_classes=args.num_classes) else: raise Exception(f"Invalid model : {args.model}") # Which variables to load. restore_var = tf.global_variables() # Predictions. raw_output = net.layers['fc1_voc12'] raw_output_up = tf.image.resize_bilinear(raw_output, tf.shape(image)[1:3,]) raw_output_up = tf.argmax(raw_output_up, dimension=3) pred = tf.expand_dims(raw_output_up, dim=3) # Set up TF session and initialize variables. if args.no_gpu: config = tf.ConfigProto(device_count = {'GPU': 0}) else: config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) # Load weights. loader = tf.train.Saver(var_list=restore_var) load(loader, sess, args.model_weights) # Perform inference. preds = sess.run(pred) file_name = Path(args.img_path).stem msk = decode_labels(preds, num_classes=args.num_classes) im = Image.fromarray(msk[0]) if not os.path.exists(args.save_dir): os.makedirs(args.save_dir) output_file = args.save_dir + file_name + f'_mask_{args.model}_{args.level}.png' im.save(output_file) if args.save_original: im = Image.open(IMG_PATH + args.img_path) im.save(args.save_dir + args.img_path) original = tf.image.decode_jpeg(tf.io.read_file(MASK_PATH + file_name + '.png'), channels=1) if args.num_classes == 2: mask = tf.equal(original, 255) original = tf.cast(original, dtype=tf.float32) original = tf.clip_by_value(original, 0, 1) original = tf.cast(original, dtype=tf.int32) original = tf.where_v2(mask, 255, original) original = tf.expand_dims(original, 0) original = sess.run(original) original = decode_labels(original, num_classes=args.num_classes, include=True) im = Image.fromarray(original[0]) im.save(args.save_dir + file_name + '_gt.png') print('The output file has been saved to {}'.format(output_file))