def compute_loss(self, embs, steps, seq_lens, global_step, training, frame_labels, seq_labels): if training: num_steps = CONFIG.TRAIN.NUM_FRAMES else: num_steps = CONFIG.EVAL.NUM_FRAMES embs = tf.squeeze(tf.concat(tf.split(embs, num_steps, axis=1), axis=0), axis=1) logits = self.model['classifier'](embs) num_frames_per_step = CONFIG.DATA.NUM_STEPS labels = frame_labels[:, num_frames_per_step - 1::num_frames_per_step] labels = tf.squeeze(tf.concat(tf.split(labels, num_steps, axis=1), axis=0), axis=1) labels = tf.one_hot(labels, self._num_classes) loss = tf.reduce_mean( tf.keras.losses.categorical_crossentropy( y_true=labels, y_pred=logits, from_logits=True, label_smoothing=CONFIG.CLASSIFICATION.LABEL_SMOOTHING)) return loss
def _compute_carry_and_output(self, x, h_tm1, c_tm1): """Computes carry and output using split kernels.""" if not isinstance(self.recurrent_kernel, random_variable.RandomVariable): return super(LSTMCellFlipout, self)._compute_carry_and_output(x, h_tm1, c_tm1) x_i, x_f, x_c, x_o = x h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 kernel_mean = self.recurrent_kernel.distribution.mean() perturbation = self.recurrent_kernel - kernel_mean k_i, k_f, k_c, k_o = tf.split(kernel_mean, num_or_size_splits=4, axis=1) p_i, p_f, p_c, p_o = tf.split(perturbation, num_or_size_splits=4, axis=1) si_i, si_f, si_c, si_o = tf.split(self.recurrent_sign_input, num_or_size_splits=4, axis=1) so_i, so_f, so_c, so_o = tf.split(self.recurrent_sign_output, num_or_size_splits=4, axis=1) z0 = (x_i + tf.keras.backend.dot(h_tm1_i, k_i) + tf.keras.backend.dot(h_tm1_i * si_i, p_i) * so_i) z1 = (x_f + tf.keras.backend.dot(h_tm1_f, k_f) + tf.keras.backend.dot(h_tm1_f * si_f, p_f) * so_f) z2 = (x_c + tf.keras.backend.dot(h_tm1_c, k_c) + tf.keras.backend.dot(h_tm1_c * si_c, p_c) * so_c) z3 = (x_o + tf.keras.backend.dot(h_tm1_o, k_o) + tf.keras.backend.dot(h_tm1_o * si_o, p_o) * so_o) i = self.recurrent_activation(z0) f = self.recurrent_activation(z1) c = f * c_tm1 + i * self.activation(z2) o = self.recurrent_activation(z3) return c, o
def intersection(boxlist1, boxlist2, scope=None): """Compute pairwise intersection areas between boxes. Args: boxlist1: BoxList holding N boxes boxlist2: BoxList holding M boxes scope: name scope. Returns: a tensor with shape [N, M] representing pairwise intersections """ if not scope: scope = 'Intersection' with tf.name_scope(scope): y_min1, x_min1, y_max1, x_max1 = tf.split(value=boxlist1.get(), num_or_size_splits=4, axis=1) y_min2, x_min2, y_max2, x_max2 = tf.split(value=boxlist2.get(), num_or_size_splits=4, axis=1) all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(a=y_max2)) all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(a=y_min2)) intersect_heights = tf.maximum(0.0, all_pairs_min_ymax - all_pairs_max_ymin) all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(a=x_max2)) all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(a=x_min2)) intersect_widths = tf.maximum(0.0, all_pairs_min_xmax - all_pairs_max_xmin) return intersect_heights * intersect_widths
def denormalize_boxes(boxes, image_shape): """Converts boxes normalized by [height, width] to pixel coordinates. Args: boxes: a tensor whose last dimension is 4 representing the coordinates of boxes in ymin, xmin, ymax, xmax order. image_shape: a list of two integers, a two-element vector or a tensor such that all but the last dimensions are `broadcastable` to `boxes`. The last dimension is 2, which represents [height, width]. Returns: denormalized_boxes: a tensor whose shape is the same as `boxes` representing the denormalized boxes. Raises: ValueError: If the last dimension of boxes is not 4. """ with tf.name_scope('denormalize_boxes'): if isinstance(image_shape, list) or isinstance(image_shape, tuple): height, width = image_shape else: image_shape = tf.cast(image_shape, dtype=boxes.dtype) height, width = tf.split(image_shape, 2, axis=-1) ymin, xmin, ymax, xmax = tf.split(boxes, 4, axis=-1) ymin = ymin * height xmin = xmin * width ymax = ymax * height xmax = xmax * width denormalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1) return denormalized_boxes
def matched_intersection(boxlist1, boxlist2, scope=None): """Compute intersection areas between corresponding boxes in two boxlists. Args: boxlist1: BoxList holding N boxes boxlist2: BoxList holding N boxes scope: name scope. Returns: a tensor with shape [N] representing pairwise intersections """ with tf.name_scope(scope, 'MatchedIntersection'): y_min1, x_min1, y_max1, x_max1 = tf.split(value=boxlist1.get(), num_or_size_splits=4, axis=1) y_min2, x_min2, y_max2, x_max2 = tf.split(value=boxlist2.get(), num_or_size_splits=4, axis=1) min_ymax = tf.minimum(y_max1, y_max2) max_ymin = tf.maximum(y_min1, y_min2) intersect_heights = tf.maximum(0.0, min_ymax - max_ymin) min_xmax = tf.minimum(x_max1, x_max2) max_xmin = tf.maximum(x_min1, x_min2) intersect_widths = tf.maximum(0.0, min_xmax - max_xmin) return tf.reshape(intersect_heights * intersect_widths, [-1])
def _compute_carry_and_output(self, x, h_tm1, c_tm1): """Computes carry and output using split kernels.""" x_i, x_f, x_c, x_o = x h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 rec_k_i, rec_k_f, rec_k_c, rec_k_o = tf.split(self.recurrent_kernel, num_or_size_splits=4, axis=1) rec_alpha = self.recurrent_alpha_sample rec_gamma_i, rec_gamma_f, rec_gamma_c, rec_gamma_o = tf.split( self.recurrent_gamma_sample, num_or_size_splits=4, axis=1) if self.use_additive_perturbation: rec_i = tf.linalg.matmul(h_tm1_i + rec_alpha, rec_k_i) + rec_gamma_i rec_f = tf.linalg.matmul(h_tm1_f + rec_alpha, rec_k_f) + rec_gamma_f rec_c = tf.linalg.matmul(h_tm1_c + rec_alpha, rec_k_c) + rec_gamma_c rec_o = tf.linalg.matmul(h_tm1_o + rec_alpha, rec_k_o) + rec_gamma_o else: rec_i = tf.linalg.matmul(h_tm1_i * rec_alpha, rec_k_i) * rec_gamma_i rec_f = tf.linalg.matmul(h_tm1_f * rec_alpha, rec_k_f) * rec_gamma_f rec_c = tf.linalg.matmul(h_tm1_c * rec_alpha, rec_k_c) * rec_gamma_c rec_o = tf.linalg.matmul(h_tm1_o * rec_alpha, rec_k_o) * rec_gamma_o i = self.recurrent_activation(x_i + rec_i) f = self.recurrent_activation(x_f + rec_f) c = f * c_tm1 + i * self.activation(x_c + rec_c) o = self.recurrent_activation(x_o + rec_o) return c, o
def forward(self, x, name='forward'): """Returns the forward `Bijector` evaluation, i.e., X = g(Y). Equivalent to `tf.split(x)`. Args: x: `Tensor`. The input to the 'forward' evaluation. name: The name to give this op. Returns: List of `Tensor`s. Raises: TypeError: if `self.dtype` is specified and `x.dtype` is not `self.dtype`. ValueError: if the sum of `split_sizes` does not equal the size of the `axis` dimension of `x`. """ with self._name_and_control_scope(name): x = tf.convert_to_tensor(x, dtype_hint=self.dtype, name='x') self._maybe_assert_dtype(x) # Validate `x` statically if possible and get assertions. is_validated = self._validate_input_shape(x.shape) if is_validated or not self.validate_args: assertions = [] else: assertions = self._validate_input_shape_tensor( prefer_static.shape(x)) with tf.control_dependencies(assertions): if self.split_sizes is None: return tf.split(x, self.num_splits, axis=self.axis) return tf.split(x, self.split_sizes, axis=self.axis)
def _forward(self, x): """Returns the forward `Bijector` evaluation, i.e., X = g(Y). Equivalent to `tf.split(x)`. Args: x: `Tensor`. The input to the 'forward' evaluation. Returns: List of `Tensor`s. Raises: ValueError: if the sum of `split_sizes` does not equal the size of the `axis` dimension of `x`. """ # Validate `x` statically if possible and get assertions. is_validated = self._validate_input_shape(x.shape) if is_validated or not self.validate_args: assertions = [] else: assertions = self._validate_input_shape_tensor( prefer_static.shape(x)) with tf.control_dependencies(assertions): if self.split_sizes is None: return tf.split(x, self.num_splits, axis=self.axis) return tf.split(x, self.split_sizes, axis=self.axis)
def train_mnist_model_batch_sharded(model, optimizer, mesh, num_epochs, steps_per_epoch, global_batch_size): dataset, _ = get_mnist_datasets(NUM_CLASS, global_batch_size) input_image_layout = dtensor.Layout.batch_sharded(mesh, "batch", rank=4) input_label_layout = dtensor.Layout.batch_sharded(mesh, "batch", rank=2) loss_obj = losses.CategoricalCrossentropy() num_local_devices = mesh.num_local_devices() iterator = iter(dataset) train_losses = [] for epoch in range(num_epochs): total_loss = 0.00 for _ in range(steps_per_epoch): images, labels = next(iterator) images = tf.split(images, num_local_devices) labels = tf.split(labels, num_local_devices) d_images = dtensor.pack(images, input_image_layout) d_labels = dtensor.pack(labels, input_label_layout) total_loss += train_step(model, d_images, d_labels, loss_obj, optimizer) train_loss = tf.reduce_mean(total_loss / steps_per_epoch) logging.info("Epoch %d, Loss: %f", epoch, train_loss) train_losses.append(train_loss) return train_losses
def call(self, inputs, states, training=None): h_tm1 = states[0] # previous memory state c_tm1 = states[1] # previous carry state # dropout matrices for input units dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) # dropout matrices for recurrent units rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(h_tm1, training, count=4) if 0 < self.dropout < 1.: inputs_i = inputs * dp_mask[0] inputs_f = inputs * dp_mask[1] inputs_c = inputs * dp_mask[2] inputs_o = inputs * dp_mask[3] else: inputs_i = inputs inputs_f = inputs inputs_c = inputs inputs_o = inputs if 0 < self.recurrent_dropout < 1.: h_tm1_i = h_tm1 * rec_dp_mask[0] h_tm1_f = h_tm1 * rec_dp_mask[1] h_tm1_c = h_tm1 * rec_dp_mask[2] h_tm1_o = h_tm1 * rec_dp_mask[3] else: h_tm1_i = h_tm1 h_tm1_f = h_tm1 h_tm1_c = h_tm1 h_tm1_o = h_tm1 (kernel_i, kernel_f, kernel_c, kernel_o) = tf.split(self.kernel, 4, axis=3) (recurrent_kernel_i, recurrent_kernel_f, recurrent_kernel_c, recurrent_kernel_o) = tf.split(self.recurrent_kernel, 4, axis=3) if self.use_bias: bias_i, bias_f, bias_c, bias_o = tf.split(self.bias, 4) else: bias_i, bias_f, bias_c, bias_o = None, None, None, None x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding) x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding) x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding) x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding) h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i) h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f) h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c) h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o) i = self.recurrent_activation(x_i + h_i) f = self.recurrent_activation(x_f + h_f) c = f * c_tm1 + i * self.activation(x_c + h_c) o = self.recurrent_activation(x_o + h_o) h = o * self.activation(c) return h, [h, c]
def call(self, inputs, states, training=None): h_tm1 = states[0] # previous memory state c_tm1 = states[1] # previous carry state dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) rec_dp_mask = self.get_recurrent_dropout_mask_for_cell( h_tm1, training, count=4) if self.implementation == 1: if 0 < self.dropout < 1.: inputs_i = inputs * dp_mask[0] inputs_f = inputs * dp_mask[1] inputs_c = inputs * dp_mask[2] inputs_o = inputs * dp_mask[3] else: inputs_i = inputs inputs_f = inputs inputs_c = inputs inputs_o = inputs k_i, k_f, k_c, k_o = tf.split( self.kernel, num_or_size_splits=4, axis=1) x_i = backend.dot(inputs_i, k_i) x_f = backend.dot(inputs_f, k_f) x_c = backend.dot(inputs_c, k_c) x_o = backend.dot(inputs_o, k_o) if self.use_bias: b_i, b_f, b_c, b_o = tf.split( self.bias, num_or_size_splits=4, axis=0) x_i = backend.bias_add(x_i, b_i) x_f = backend.bias_add(x_f, b_f) x_c = backend.bias_add(x_c, b_c) x_o = backend.bias_add(x_o, b_o) if 0 < self.recurrent_dropout < 1.: h_tm1_i = h_tm1 * rec_dp_mask[0] h_tm1_f = h_tm1 * rec_dp_mask[1] h_tm1_c = h_tm1 * rec_dp_mask[2] h_tm1_o = h_tm1 * rec_dp_mask[3] else: h_tm1_i = h_tm1 h_tm1_f = h_tm1 h_tm1_c = h_tm1 h_tm1_o = h_tm1 x = (x_i, x_f, x_c, x_o) h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o) c, o = self._compute_carry_and_output(x, h_tm1, c_tm1) else: if 0. < self.dropout < 1.: inputs = inputs * dp_mask[0] z = backend.dot(inputs, self.kernel) z += backend.dot(h_tm1, self.recurrent_kernel) if self.use_bias: z = backend.bias_add(z, self.bias) z = tf.split(z, num_or_size_splits=4, axis=1) c, o = self._compute_carry_and_output_fused(z, c_tm1) h = o * self.activation(c) return h, [h, c]
def _sample_channels(self, component_logits, locs, scales, coeffs=None, seed=None): """Sample a single pixel-iteration and apply channel conditioning. Args: component_logits: 4D `Tensor` of logits for the Categorical distribution over Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix]`. locs: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. scales: 4D `Tensor` of location parameters for the Quantized Logistic mixture components. Dimensions are `[batch_size, height, width, num_logistic_mix, num_channels]`. coeffs: 4D `Tensor` of coefficients for the linear dependence among color channels, or `None` if there is only one channel. Dimensions are `[batch_size, height, width, num_logistic_mix, num_coeffs]`, where `num_coeffs = num_channels * (num_channels - 1) // 2`. seed: `int`, random seed. Returns: samples: 4D `Tensor` of sampled image data with autoregression among channels. Dimensions are `[batch_size, height, width, num_channels]`. """ num_channels = self.event_shape[-1] # sample mixture components once for the entire pixel component_dist = categorical.Categorical(logits=component_logits) mask = tf.one_hot(indices=component_dist.sample(seed=seed), depth=self._num_logistic_mix) mask = tf.cast(mask[..., tf.newaxis], self.dtype) # apply mixture component mask and separate out RGB parameters masked_locs = tf.reduce_sum(locs * mask, axis=-2) loc_tensors = tf.split(masked_locs, num_channels, axis=-1) masked_scales = tf.reduce_sum(scales * mask, axis=-2) scale_tensors = tf.split(masked_scales, num_channels, axis=-1) if coeffs is not None: num_coeffs = num_channels * (num_channels - 1) // 2 masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2) coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1) channel_samples = [] coef_count = 0 for i in range(num_channels): loc = loc_tensors[i] for c in channel_samples: loc += c * coef_tensors[coef_count] coef_count += 1 logistic_samp = logistic.Logistic( loc=loc, scale=scale_tensors[i]).sample(seed=seed) logistic_samp = tf.clip_by_value(logistic_samp, -1., 1.) channel_samples.append(logistic_samp) return tf.concat(channel_samples, axis=-1)
def bbox_overlap(boxes, gt_boxes): """Calculates the overlap between proposal and ground truth boxes. Some `gt_boxes` may have been padded. The returned `iou` tensor for these boxes will be -1. Args: boxes: a tensor with a shape of [batch_size, N, 4]. N is the number of proposals before groundtruth assignment (e.g., rpn_post_nms_topn). The last dimension is the pixel coordinates in [ymin, xmin, ymax, xmax] form. gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4]. This tensor might have paddings with a negative value. Returns: iou: a tensor with as a shape of [batch_size, N, MAX_NUM_INSTANCES]. """ with tf.name_scope('bbox_overlap'): bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(value=boxes, num_or_size_splits=4, axis=2) gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(value=gt_boxes, num_or_size_splits=4, axis=2) # Calculates the intersection area. i_xmin = tf.math.maximum(bb_x_min, tf.transpose(gt_x_min, [0, 2, 1])) i_xmax = tf.math.minimum(bb_x_max, tf.transpose(gt_x_max, [0, 2, 1])) i_ymin = tf.math.maximum(bb_y_min, tf.transpose(gt_y_min, [0, 2, 1])) i_ymax = tf.math.minimum(bb_y_max, tf.transpose(gt_y_max, [0, 2, 1])) i_area = tf.math.maximum((i_xmax - i_xmin), 0) * tf.math.maximum( (i_ymax - i_ymin), 0) # Calculates the union area. bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min) gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min) # Adds a small epsilon to avoid divide-by-zero. u_area = bb_area + tf.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8 # Calculates IoU. iou = i_area / u_area # Fills -1 for IoU entries between the padded ground truth boxes. gt_invalid_mask = tf.less( tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0) padding_mask = tf.logical_or(tf.zeros_like(bb_x_min, dtype=tf.bool), tf.transpose(gt_invalid_mask, [0, 2, 1])) iou = tf.where(padding_mask, -tf.ones_like(iou), iou) return iou
def _lu_reconstruct_assertions(lower_upper, perm, validate_args): """Returns list of assertions related to `lu_reconstruct` assumptions.""" assertions = [] message = 'Input `lower_upper` must have at least 2 dimensions.' if tensorshape_util.rank(lower_upper.shape) is not None: if tensorshape_util.rank(lower_upper.shape) < 2: raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank_at_least(lower_upper, rank=2, message=message)) message = '`rank(lower_upper)` must equal `rank(perm) + 1`' if (tensorshape_util.rank(lower_upper.shape) is not None and tensorshape_util.rank(perm.shape) is not None): if (tensorshape_util.rank(lower_upper.shape) != tensorshape_util.rank(perm.shape) + 1): raise ValueError(message) elif validate_args: assertions.append( assert_util.assert_rank(lower_upper, rank=tf.rank(perm) + 1, message=message)) message = '`lower_upper` must be square.' if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]): if lower_upper.shape[-2] != lower_upper.shape[-1]: raise ValueError(message) elif validate_args: m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2) assertions.append(assert_util.assert_equal(m, n, message=message)) return assertions
def compute_logits(self, inputs: tf.Tensor, training: bool) -> tf.Tensor: """ Implements a language model, where each output is conditional on the current input and inputs processed so far. Args: inputs: int32 tensor of shape [B, T, 2], storing integer IDs of the Nodes and the Actions as a stacked tensor. training: Flag indicating if we are currently training (used to toggle dropout) Returns: tf.float32 tensor of shape [B, T, V], storing the distribution over output symbols for each timestep for each batch element. """ # The input has shape (B, T, 3) because I stacked the node_tokes, action_tokens and fathers_ids nodes_ids, actions_ids, _ = tf.split(inputs, 3, axis=2) # (None, 50, 3) nodes_ids = tf.squeeze(nodes_ids, axis=2) # (None, 50) actions_ids = tf.squeeze(actions_ids, axis=2) # Get embeddings nodes_emb = self.nodes_embedding(nodes_ids) actions_emb = self.actions_embedding(actions_ids) # concat embeddings concat_input = tf.concat([nodes_emb, actions_emb], axis=2) cell_output = self.gru1(concat_input, training=training) rnn_output_logits = self.dense(cell_output) return rnn_output_logits
def step_fn(inputs): """Per-Replica StepFn.""" images, labels = inputs images = tf.tile(images, [FLAGS.num_models, 1, 1, 1]) logits = model(images, training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=FLAGS.num_models, axis=0) for i in range(FLAGS.num_models): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) probs = tf.reduce_mean(per_probs, axis=0) negative_log_likelihood = tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy(labels, probs)) metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs)
def ensemble_batchnorm(x, num_models=1, use_tpu=True, **kwargs): """A modified batch norm layer for Batch Ensemble model. Args: x: input tensor. num_models: number of ensemble members. use_tpu: whether the model is running on TPU. **kwargs: Keyword arguments to batch normalization layers. Returns: Output tensor for the block. """ # In BatchEnsemble inference stage, the input to the model is tiled which # leads to dynamic shape because of the tf.split function. Such operation # is not supported in tf2.0 on TPU. For current workaround, we use single # BatchNormalization layer for all ensemble member. This is not correct in # math but works in practice. if num_models == 1 or use_tpu: return tf.keras.layers.BatchNormalization(**kwargs)(x) name = kwargs.get('name') split_inputs = tf.split(x, num_models, axis=0) for i in range(num_models): if name is not None: kwargs['name'] = name + '_{}'.format(i) split_inputs[i] = tf.keras.layers.BatchNormalization(**kwargs)( split_inputs[i]) return tf.concat(split_inputs, axis=0)
def keypoint_flip_horizontal(keypoints, flip_point, flip_permutation, scope=None): """Flips the keypoints horizontally around the flip_point. This operation flips the x coordinate for each keypoint around the flip_point and also permutes the keypoints in a manner specified by flip_permutation. Args: keypoints: a tensor of shape [num_instances, num_keypoints, 2] flip_point: (float) scalar tensor representing the x coordinate to flip the keypoints around. flip_permutation: rank 1 int32 tensor containing the keypoint flip permutation. This specifies the mapping from original keypoint indices to the flipped keypoint indices. This is used primarily for keypoints that are not reflection invariant. E.g. Suppose there are 3 keypoints representing ['head', 'right_eye', 'left_eye'], then a logical choice for flip_permutation might be [0, 2, 1] since we want to swap the 'left_eye' and 'right_eye' after a horizontal flip. scope: name scope. Returns: new_keypoints: a tensor of shape [num_instances, num_keypoints, 2] """ with tf.name_scope(scope, 'FlipHorizontal'): keypoints = tf.transpose(a=keypoints, perm=[1, 0, 2]) keypoints = tf.gather(keypoints, flip_permutation) v, u = tf.split(value=keypoints, num_or_size_splits=2, axis=2) u = flip_point * 2.0 - u new_keypoints = tf.concat([v, u], 2) new_keypoints = tf.transpose(a=new_keypoints, perm=[1, 0, 2]) return new_keypoints
def _split_and_reshape_event(self, x): event_tensors = self._distribution.event_shape_tensor() splits = [ ps.maximum(1, ps.reduce_prod(s)) for s in tf.nest.flatten(event_tensors) ] x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1)) def _reshape_part(part, dtype, event_shape): part = tf.cast(part, dtype) static_rank = tf.get_static_value(ps.rank_from_shape(event_shape)) if static_rank == 1: return part new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1) return tf.reshape(part, ps.cast(new_shape, tf.int32)) if all( tensorshape_util.is_fully_defined(s) for s in tf.nest.flatten(self._distribution.event_shape)): x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape) else: x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype, self._distribution.event_shape_tensor()) return x
def _fn(x, output_units, **condition_kwargs): """Fully connected MLP parameterized via `real_nvp_template`.""" if condition_kwargs: raise NotImplementedError( 'Conditioning not implemented in the default template.') if tensorshape_util.rank(x.shape) == 1: x = x[tf.newaxis, ...] reshape_output = lambda x: x[0] else: reshape_output = lambda x: x for units in hidden_layers: x = tf1.layers.dense( inputs=x, units=units, activation=activation, *args, # pylint: disable=keyword-arg-before-vararg **kwargs) x = tf1.layers.dense( inputs=x, units=(1 if shift_only else 2) * output_units, activation=None, *args, # pylint: disable=keyword-arg-before-vararg **kwargs) if shift_only: return reshape_output(x), None shift, log_scale = tf.split(x, 2, axis=-1) return reshape_output(shift), reshape_output(log_scale)
def testReverseMask(self, num_masked, fraction_masked, batch_shape): input_depth = 8 x_ = np.random.normal(0., 1., batch_shape + (input_depth, )).astype(np.float32) flip_nvp = tfb.RealNVP( num_masked=num_masked, fraction_masked=fraction_masked, validate_args=True, **self._real_nvp_kwargs, ) x = tf.constant(x_) forward_x = flip_nvp.forward(x) expected_num_masked = (num_masked if num_masked is not None else np.floor(input_depth * fraction_masked)) self.assertEqual(flip_nvp._masked_size, expected_num_masked) _, x2_ = np.split(x_, [input_depth - abs(flip_nvp._masked_size)], axis=-1) # pylint: disable=unbalanced-tuple-unpacking # Check latter half is the same after passing thru reversed mask RealNVP. _, forward_x2 = tf.split(forward_x, [ input_depth - abs(flip_nvp._masked_size), abs(flip_nvp._masked_size) ], axis=-1) self.evaluate(tf1.global_variables_initializer()) forward_x2_ = self.evaluate(forward_x2) self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
def keypoint_prune_outside_window(keypoints, window, scope=None): """Prunes keypoints that fall outside a given window. This function replaces keypoints that fall outside the given window with nan. See also clip_to_window which clips any keypoints that fall outside the given window. Args: keypoints: a tensor of shape [num_instances, num_keypoints, 2] window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max] window outside of which the op should prune the keypoints. scope: name scope. Returns: new_keypoints: a tensor of shape [num_instances, num_keypoints, 2] """ if not scope: scope = 'PruneOutsideWindow' with tf.name_scope(scope): y, x = tf.split(value=keypoints, num_or_size_splits=2, axis=2) win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window) valid_indices = tf.logical_and( tf.logical_and(y >= win_y_min, y <= win_y_max), tf.logical_and(x >= win_x_min, x <= win_x_max)) new_y = tf.where(valid_indices, y, np.nan * tf.ones_like(y)) new_x = tf.where(valid_indices, x, np.nan * tf.ones_like(x)) new_keypoints = tf.concat([new_y, new_x], 2) return new_keypoints
def load_weights(self, base_fn): """Find the latest checkpoint matching base_fn, and load the weights.""" matcher = base_fn + "_*.npy" filenames = sorted(gfile.glob(matcher), reverse=True) assert len(filenames) > 0, "No files matching {}".format(matcher) filename = filenames[0] # load array with gfile.GFile(filename, "rb") as fin: serialized_weights = np.load(fin) print(serialized_weights.shape, self.all_weights_flat_sizes) all_weights_flat_split = tf.split(serialized_weights, self.all_weights_flat_sizes) all_weights_flat = [ tf.reshape(t, s) for t, s in zip(all_weights_flat_split, self.all_weights_flat_shapes) ] all_weights = tf.nest.pack_sequence_as(self.all_weights_structure, all_weights_flat) all_layers = self.layers + [self.loss_layer] if self.shared_params is not None: all_layers += list(self.shared_params.values()) for l, lw in zip(all_layers, all_weights): l.load_weights(lw)
def step_fn(inputs): """Per-Replica StepFn.""" # Note that we don't use tf.tile for labels here images, labels = inputs images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1]) # get lambdas lambdas = log_uniform_mean(lambda_parameters) rep_lambdas = tf.repeat(lambdas, per_core_batch_size, axis=0) # eval on testsets logits = model([images, rep_lambdas], training=False) if FLAGS.use_bfloat16: logits = tf.cast(logits, tf.float32) probs = tf.nn.softmax(logits) per_probs = tf.split(probs, num_or_size_splits=FLAGS.ensemble_size, axis=0) # per member performance and gibbs performance (average per member perf) if dataset_name == 'clean': for i in range(FLAGS.ensemble_size): member_probs = per_probs[i] member_loss = tf.keras.losses.sparse_categorical_crossentropy( labels, member_probs) metrics['test/nll_member_{}'.format(i)].update_state( member_loss) metrics['test/accuracy_member_{}'.format(i)].update_state( labels, member_probs) labels_tile = tf.tile(labels, [FLAGS.ensemble_size]) metrics['test/gibbs_nll'].update_state( tf.reduce_mean( tf.keras.losses.sparse_categorical_crossentropy( labels_tile, logits, from_logits=True))) metrics['test/gibbs_accuracy'].update_state(labels_tile, probs) # ensemble performance negative_log_likelihood = ensemble_crossentropy( labels, logits, FLAGS.ensemble_size) probs = tf.reduce_mean(per_probs, axis=0) if dataset_name == 'clean': metrics['test/negative_log_likelihood'].update_state( negative_log_likelihood) metrics['test/accuracy'].update_state(labels, probs) metrics['test/ece'].update_state(labels, probs) else: corrupt_metrics['test/nll_{}'.format( dataset_name)].update_state(negative_log_likelihood) corrupt_metrics['test/accuracy_{}'.format( dataset_name)].update_state(labels, probs) corrupt_metrics['test/ece_{}'.format( dataset_name)].update_state(labels, probs) if dataset_name == 'clean': per_probs_stacked = tf.stack(per_probs, axis=0) diversity_results = um.average_pairwise_diversity( per_probs_stacked, FLAGS.ensemble_size) for k, v in diversity_results.items(): metrics['test/' + k].update_state(v)
def call(self, inputs): dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx()) inputs = tf.cast(inputs, dtype, name='inputs') q = self._posterior(inputs) r = self._prior(inputs) self.add_loss(self._kl_divergence_fn(q, r)) w = tf.convert_to_tensor(value=q) prev_units = self.input_spec.axes[-1] if self.use_bias: split_sizes = [prev_units * self.units, self.units] kernel, bias = tf.split(w, split_sizes, axis=-1) else: kernel, bias = w, None kernel = tf.reshape(kernel, shape=tf.concat([ tf.shape(input=kernel)[:-1], [prev_units, self.units], ], axis=0)) outputs = tf.matmul(inputs, kernel) if self.use_bias: outputs = tf.nn.bias_add(outputs, bias) if self.activation is not None: outputs = self.activation(outputs) # pylint: disable=not-callable return outputs
def _inverse(self, y): # To derive the inverse mapping note that: # y[i] = exp(x[i]) / normalization # and # y[end] = 1 / normalization. # Thus: # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) # = log(exp(x[i])/normalization) - log(y[end]) # = log(y[i]) - log(y[end]) # Do this first to make sure CSE catches that it'll happen again in # _inverse_log_det_jacobian. assertions = [] if self.validate_args: assertions.append(assert_util.assert_near( tf.reduce_sum(y, axis=-1), tf.ones([], y.dtype), 2. * np.finfo(dtype_util.as_numpy_dtype(y.dtype)).eps, message='Last dimension of `y` must sum to `1`.')) assertions.append(assert_util.assert_less_equal( y, tf.ones([], y.dtype), message='Elements of `y` must be less than or equal to `1`.')) assertions.append(assert_util.assert_non_negative( y, message='Elements of `y` must be non-negative.')) with tf.control_dependencies(assertions): x = tf.math.log(y) x, log_normalization = tf.split(x, num_or_size_splits=[-1, 1], axis=-1) return x - log_normalization
def get_dist_and_mode(self, states): """Returns a tf.Distribution for given states. Args: states: A batch of states. """ out = self.trunk(states) mu, log_std = tf.split(out, num_or_size_splits=2, axis=1) mode = tf.nn.tanh(mu) log_std = tf.nn.tanh(log_std) assert LOG_STD_MAX > LOG_STD_MIN log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) std = tf.exp(log_std) dist = ds.TransformedDistribution(distribution=ds.Normal(loc=0., scale=1.), bijector=tfp.bijectors.Chain([ tfp.bijectors.Tanh(), tfp.bijectors.Affine( shift=mu, scale_diag=std), ]), event_shape=[mu.shape[-1]], batch_shape=[mu.shape[0]]) return dist, mode
def box_list_scale(boxlist, y_scale, x_scale, scope=None): """scale box coordinates in x and y dimensions. Args: boxlist: BoxList holding N boxes y_scale: (float) scalar tensor x_scale: (float) scalar tensor scope: name scope. Returns: boxlist: BoxList holding N boxes """ with tf.name_scope(scope, 'Scale'): y_scale = tf.cast(y_scale, tf.float32) x_scale = tf.cast(x_scale, tf.float32) y_min, x_min, y_max, x_max = tf.split(value=boxlist.get(), num_or_size_splits=4, axis=1) y_min = y_scale * y_min y_max = y_scale * y_max x_min = x_scale * x_min x_max = x_scale * x_max scaled_boxlist = box_list.BoxList( tf.concat([y_min, x_min, y_max, x_max], 1)) return _copy_extra_fields(scaled_boxlist, boxlist)
def compress(self, bottleneck): """Compresses a floating-point tensor. Compresses the tensor to bit strings. `bottleneck` is first quantized as in `quantize()`, and then compressed using the probability tables derived from `self.prior`. The quantized tensor can later be recovered by calling `decompress()`. The innermost `self.coding_rank` dimensions are treated as one coding unit, i.e. are compressed into one string each. Any additional dimensions to the left are treated as batch dimensions. Arguments: bottleneck: `tf.Tensor` containing the data to be compressed. Must have at least `self.coding_rank` dimensions, and the innermost dimensions must be broadcastable to `self.prior_shape`. Returns: A `tf.Tensor` having the same shape as `bottleneck` without the `self.coding_rank` innermost dimensions, containing a string for each coding unit. """ input_shape = tf.shape(bottleneck) input_rank = tf.shape(input_shape)[0] batch_shape, coding_shape = tf.split( input_shape, [input_rank - self.coding_rank, self.coding_rank]) broadcast_shape = coding_shape[:self.coding_rank - len(self.prior_shape)] indexes, offset = self._compute_indexes_and_offset(broadcast_shape) if offset is not None: bottleneck -= offset symbols = tf.cast(tf.round(bottleneck), tf.int32) symbols = tf.reshape(symbols, tf.concat([[-1], coding_shape], 0)) # Prevent tensors from bouncing back and forth between host and GPU. with tf.device("/cpu:0"): cdf = self.cdf cdf_length = self.cdf_length cdf_offset = self.cdf_offset def loop_body(symbols): return range_coding_ops.unbounded_index_range_encode( symbols, indexes, cdf, cdf_length, cdf_offset, precision=self.range_coder_precision, overflow_width=4, debug_level=1) # TODO(jonycgn,ssjhv): Consider switching to Python control flow. strings = tf.map_fn(loop_body, symbols, dtype=tf.string, name="compress") strings = tf.reshape(strings, batch_shape) return strings
def testStddev(self): base_stddev = 2. shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32) expected_stddev = tf.abs(base_stddev * scale) normal = self._cls()( distribution=tfd.Normal(loc=tf.zeros_like(shift), scale=base_stddev * tf.ones_like(scale), validate_args=True), bijector=tfb.Chain( [tfb.Shift(shift=shift), tfb.Scale(scale=scale)], validate_args=True), validate_args=True) self.assertAllClose(expected_stddev, normal.stddev()) self.assertAllClose(expected_stddev**2, normal.variance()) split_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.Split(3), validate_args=True) self.assertAllCloseNested( tf.split(expected_stddev, num_or_size_splits=3, axis=-1), split_normal.stddev()) scaled_normal = self._cls()(distribution=tfd.Independent( normal, reinterpreted_batch_ndims=1), bijector=tfb.ScaleMatvecTriL([[1., 0.], [-1., 2.]]), validate_args=True) with self.assertRaisesRegex(NotImplementedError, 'is a multivariate transformation'): scaled_normal.stddev()