Пример #1
0
    def compute_loss(self, embs, steps, seq_lens, global_step, training,
                     frame_labels, seq_labels):

        if training:
            num_steps = CONFIG.TRAIN.NUM_FRAMES
        else:
            num_steps = CONFIG.EVAL.NUM_FRAMES

        embs = tf.squeeze(tf.concat(tf.split(embs, num_steps, axis=1), axis=0),
                          axis=1)
        logits = self.model['classifier'](embs)
        num_frames_per_step = CONFIG.DATA.NUM_STEPS
        labels = frame_labels[:, num_frames_per_step - 1::num_frames_per_step]

        labels = tf.squeeze(tf.concat(tf.split(labels, num_steps, axis=1),
                                      axis=0),
                            axis=1)
        labels = tf.one_hot(labels, self._num_classes)

        loss = tf.reduce_mean(
            tf.keras.losses.categorical_crossentropy(
                y_true=labels,
                y_pred=logits,
                from_logits=True,
                label_smoothing=CONFIG.CLASSIFICATION.LABEL_SMOOTHING))
        return loss
Пример #2
0
 def _compute_carry_and_output(self, x, h_tm1, c_tm1):
   """Computes carry and output using split kernels."""
   if not isinstance(self.recurrent_kernel, random_variable.RandomVariable):
     return super(LSTMCellFlipout, self)._compute_carry_and_output(x,
                                                                   h_tm1,
                                                                   c_tm1)
   x_i, x_f, x_c, x_o = x
   h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
   kernel_mean = self.recurrent_kernel.distribution.mean()
   perturbation = self.recurrent_kernel - kernel_mean
   k_i, k_f, k_c, k_o = tf.split(kernel_mean, num_or_size_splits=4, axis=1)
   p_i, p_f, p_c, p_o = tf.split(perturbation, num_or_size_splits=4, axis=1)
   si_i, si_f, si_c, si_o = tf.split(self.recurrent_sign_input,
                                     num_or_size_splits=4, axis=1)
   so_i, so_f, so_c, so_o = tf.split(self.recurrent_sign_output,
                                     num_or_size_splits=4, axis=1)
   z0 = (x_i + tf.keras.backend.dot(h_tm1_i, k_i) +
         tf.keras.backend.dot(h_tm1_i * si_i, p_i) * so_i)
   z1 = (x_f + tf.keras.backend.dot(h_tm1_f, k_f) +
         tf.keras.backend.dot(h_tm1_f * si_f, p_f) * so_f)
   z2 = (x_c + tf.keras.backend.dot(h_tm1_c, k_c) +
         tf.keras.backend.dot(h_tm1_c * si_c, p_c) * so_c)
   z3 = (x_o + tf.keras.backend.dot(h_tm1_o, k_o) +
         tf.keras.backend.dot(h_tm1_o * si_o, p_o) * so_o)
   i = self.recurrent_activation(z0)
   f = self.recurrent_activation(z1)
   c = f * c_tm1 + i * self.activation(z2)
   o = self.recurrent_activation(z3)
   return c, o
def intersection(boxlist1, boxlist2, scope=None):
    """Compute pairwise intersection areas between boxes.

  Args:
    boxlist1: BoxList holding N boxes
    boxlist2: BoxList holding M boxes
    scope: name scope.

  Returns:
    a tensor with shape [N, M] representing pairwise intersections
  """
    if not scope:
        scope = 'Intersection'
    with tf.name_scope(scope):
        y_min1, x_min1, y_max1, x_max1 = tf.split(value=boxlist1.get(),
                                                  num_or_size_splits=4,
                                                  axis=1)
        y_min2, x_min2, y_max2, x_max2 = tf.split(value=boxlist2.get(),
                                                  num_or_size_splits=4,
                                                  axis=1)
        all_pairs_min_ymax = tf.minimum(y_max1, tf.transpose(a=y_max2))
        all_pairs_max_ymin = tf.maximum(y_min1, tf.transpose(a=y_min2))
        intersect_heights = tf.maximum(0.0,
                                       all_pairs_min_ymax - all_pairs_max_ymin)
        all_pairs_min_xmax = tf.minimum(x_max1, tf.transpose(a=x_max2))
        all_pairs_max_xmin = tf.maximum(x_min1, tf.transpose(a=x_min2))
        intersect_widths = tf.maximum(0.0,
                                      all_pairs_min_xmax - all_pairs_max_xmin)
        return intersect_heights * intersect_widths
Пример #4
0
def denormalize_boxes(boxes, image_shape):
  """Converts boxes normalized by [height, width] to pixel coordinates.

  Args:
    boxes: a tensor whose last dimension is 4 representing the coordinates
      of boxes in ymin, xmin, ymax, xmax order.
    image_shape: a list of two integers, a two-element vector or a tensor such
      that all but the last dimensions are `broadcastable` to `boxes`. The last
      dimension is 2, which represents [height, width].

  Returns:
    denormalized_boxes: a tensor whose shape is the same as `boxes` representing
      the denormalized boxes.

  Raises:
    ValueError: If the last dimension of boxes is not 4.
  """
  with tf.name_scope('denormalize_boxes'):
    if isinstance(image_shape, list) or isinstance(image_shape, tuple):
      height, width = image_shape
    else:
      image_shape = tf.cast(image_shape, dtype=boxes.dtype)
      height, width = tf.split(image_shape, 2, axis=-1)

    ymin, xmin, ymax, xmax = tf.split(boxes, 4, axis=-1)
    ymin = ymin * height
    xmin = xmin * width
    ymax = ymax * height
    xmax = xmax * width

    denormalized_boxes = tf.concat([ymin, xmin, ymax, xmax], axis=-1)
    return denormalized_boxes
def matched_intersection(boxlist1, boxlist2, scope=None):
    """Compute intersection areas between corresponding boxes in two boxlists.

  Args:
    boxlist1: BoxList holding N boxes
    boxlist2: BoxList holding N boxes
    scope: name scope.

  Returns:
    a tensor with shape [N] representing pairwise intersections
  """
    with tf.name_scope(scope, 'MatchedIntersection'):
        y_min1, x_min1, y_max1, x_max1 = tf.split(value=boxlist1.get(),
                                                  num_or_size_splits=4,
                                                  axis=1)
        y_min2, x_min2, y_max2, x_max2 = tf.split(value=boxlist2.get(),
                                                  num_or_size_splits=4,
                                                  axis=1)
        min_ymax = tf.minimum(y_max1, y_max2)
        max_ymin = tf.maximum(y_min1, y_min2)
        intersect_heights = tf.maximum(0.0, min_ymax - max_ymin)
        min_xmax = tf.minimum(x_max1, x_max2)
        max_xmin = tf.maximum(x_min1, x_min2)
        intersect_widths = tf.maximum(0.0, min_xmax - max_xmin)
        return tf.reshape(intersect_heights * intersect_widths, [-1])
Пример #6
0
 def _compute_carry_and_output(self, x, h_tm1, c_tm1):
     """Computes carry and output using split kernels."""
     x_i, x_f, x_c, x_o = x
     h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
     rec_k_i, rec_k_f, rec_k_c, rec_k_o = tf.split(self.recurrent_kernel,
                                                   num_or_size_splits=4,
                                                   axis=1)
     rec_alpha = self.recurrent_alpha_sample
     rec_gamma_i, rec_gamma_f, rec_gamma_c, rec_gamma_o = tf.split(
         self.recurrent_gamma_sample, num_or_size_splits=4, axis=1)
     if self.use_additive_perturbation:
         rec_i = tf.linalg.matmul(h_tm1_i + rec_alpha,
                                  rec_k_i) + rec_gamma_i
         rec_f = tf.linalg.matmul(h_tm1_f + rec_alpha,
                                  rec_k_f) + rec_gamma_f
         rec_c = tf.linalg.matmul(h_tm1_c + rec_alpha,
                                  rec_k_c) + rec_gamma_c
         rec_o = tf.linalg.matmul(h_tm1_o + rec_alpha,
                                  rec_k_o) + rec_gamma_o
     else:
         rec_i = tf.linalg.matmul(h_tm1_i * rec_alpha,
                                  rec_k_i) * rec_gamma_i
         rec_f = tf.linalg.matmul(h_tm1_f * rec_alpha,
                                  rec_k_f) * rec_gamma_f
         rec_c = tf.linalg.matmul(h_tm1_c * rec_alpha,
                                  rec_k_c) * rec_gamma_c
         rec_o = tf.linalg.matmul(h_tm1_o * rec_alpha,
                                  rec_k_o) * rec_gamma_o
     i = self.recurrent_activation(x_i + rec_i)
     f = self.recurrent_activation(x_f + rec_f)
     c = f * c_tm1 + i * self.activation(x_c + rec_c)
     o = self.recurrent_activation(x_o + rec_o)
     return c, o
Пример #7
0
    def forward(self, x, name='forward'):
        """Returns the forward `Bijector` evaluation, i.e., X = g(Y).

    Equivalent to `tf.split(x)`.

    Args:
      x: `Tensor`. The input to the 'forward' evaluation.
      name: The name to give this op.

    Returns:
      List of `Tensor`s.

    Raises:
      TypeError: if `self.dtype` is specified and `x.dtype` is not
        `self.dtype`.
      ValueError: if the sum of `split_sizes` does not equal the size of the
        `axis` dimension of `x`.
    """

        with self._name_and_control_scope(name):
            x = tf.convert_to_tensor(x, dtype_hint=self.dtype, name='x')
            self._maybe_assert_dtype(x)

            # Validate `x` statically if possible and get assertions.
            is_validated = self._validate_input_shape(x.shape)
            if is_validated or not self.validate_args:
                assertions = []
            else:
                assertions = self._validate_input_shape_tensor(
                    prefer_static.shape(x))

            with tf.control_dependencies(assertions):
                if self.split_sizes is None:
                    return tf.split(x, self.num_splits, axis=self.axis)
                return tf.split(x, self.split_sizes, axis=self.axis)
Пример #8
0
    def _forward(self, x):
        """Returns the forward `Bijector` evaluation, i.e., X = g(Y).

    Equivalent to `tf.split(x)`.

    Args:
      x: `Tensor`. The input to the 'forward' evaluation.

    Returns:
      List of `Tensor`s.

    Raises:
      ValueError: if the sum of `split_sizes` does not equal the size of the
        `axis` dimension of `x`.
    """
        # Validate `x` statically if possible and get assertions.
        is_validated = self._validate_input_shape(x.shape)
        if is_validated or not self.validate_args:
            assertions = []
        else:
            assertions = self._validate_input_shape_tensor(
                prefer_static.shape(x))

        with tf.control_dependencies(assertions):
            if self.split_sizes is None:
                return tf.split(x, self.num_splits, axis=self.axis)
            return tf.split(x, self.split_sizes, axis=self.axis)
Пример #9
0
def train_mnist_model_batch_sharded(model, optimizer, mesh, num_epochs,
                                    steps_per_epoch, global_batch_size):

    dataset, _ = get_mnist_datasets(NUM_CLASS, global_batch_size)

    input_image_layout = dtensor.Layout.batch_sharded(mesh, "batch", rank=4)
    input_label_layout = dtensor.Layout.batch_sharded(mesh, "batch", rank=2)
    loss_obj = losses.CategoricalCrossentropy()

    num_local_devices = mesh.num_local_devices()
    iterator = iter(dataset)
    train_losses = []
    for epoch in range(num_epochs):
        total_loss = 0.00
        for _ in range(steps_per_epoch):
            images, labels = next(iterator)
            images = tf.split(images, num_local_devices)
            labels = tf.split(labels, num_local_devices)
            d_images = dtensor.pack(images, input_image_layout)
            d_labels = dtensor.pack(labels, input_label_layout)
            total_loss += train_step(model, d_images, d_labels, loss_obj,
                                     optimizer)

        train_loss = tf.reduce_mean(total_loss / steps_per_epoch)

        logging.info("Epoch %d, Loss: %f", epoch, train_loss)
        train_losses.append(train_loss)
    return train_losses
Пример #10
0
    def call(self, inputs, states, training=None):
        h_tm1 = states[0]  # previous memory state
        c_tm1 = states[1]  # previous carry state

        # dropout matrices for input units
        dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
        # dropout matrices for recurrent units
        rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(h_tm1,
                                                               training,
                                                               count=4)

        if 0 < self.dropout < 1.:
            inputs_i = inputs * dp_mask[0]
            inputs_f = inputs * dp_mask[1]
            inputs_c = inputs * dp_mask[2]
            inputs_o = inputs * dp_mask[3]
        else:
            inputs_i = inputs
            inputs_f = inputs
            inputs_c = inputs
            inputs_o = inputs

        if 0 < self.recurrent_dropout < 1.:
            h_tm1_i = h_tm1 * rec_dp_mask[0]
            h_tm1_f = h_tm1 * rec_dp_mask[1]
            h_tm1_c = h_tm1 * rec_dp_mask[2]
            h_tm1_o = h_tm1 * rec_dp_mask[3]
        else:
            h_tm1_i = h_tm1
            h_tm1_f = h_tm1
            h_tm1_c = h_tm1
            h_tm1_o = h_tm1

        (kernel_i, kernel_f, kernel_c, kernel_o) = tf.split(self.kernel,
                                                            4,
                                                            axis=3)
        (recurrent_kernel_i, recurrent_kernel_f, recurrent_kernel_c,
         recurrent_kernel_o) = tf.split(self.recurrent_kernel, 4, axis=3)

        if self.use_bias:
            bias_i, bias_f, bias_c, bias_o = tf.split(self.bias, 4)
        else:
            bias_i, bias_f, bias_c, bias_o = None, None, None, None

        x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
        x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
        x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
        x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
        h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
        h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
        h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
        h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)

        i = self.recurrent_activation(x_i + h_i)
        f = self.recurrent_activation(x_f + h_f)
        c = f * c_tm1 + i * self.activation(x_c + h_c)
        o = self.recurrent_activation(x_o + h_o)
        h = o * self.activation(c)
        return h, [h, c]
Пример #11
0
  def call(self, inputs, states, training=None):
    h_tm1 = states[0]  # previous memory state
    c_tm1 = states[1]  # previous carry state

    dp_mask = self.get_dropout_mask_for_cell(inputs, training, count=4)
    rec_dp_mask = self.get_recurrent_dropout_mask_for_cell(
        h_tm1, training, count=4)

    if self.implementation == 1:
      if 0 < self.dropout < 1.:
        inputs_i = inputs * dp_mask[0]
        inputs_f = inputs * dp_mask[1]
        inputs_c = inputs * dp_mask[2]
        inputs_o = inputs * dp_mask[3]
      else:
        inputs_i = inputs
        inputs_f = inputs
        inputs_c = inputs
        inputs_o = inputs
      k_i, k_f, k_c, k_o = tf.split(
          self.kernel, num_or_size_splits=4, axis=1)
      x_i = backend.dot(inputs_i, k_i)
      x_f = backend.dot(inputs_f, k_f)
      x_c = backend.dot(inputs_c, k_c)
      x_o = backend.dot(inputs_o, k_o)
      if self.use_bias:
        b_i, b_f, b_c, b_o = tf.split(
            self.bias, num_or_size_splits=4, axis=0)
        x_i = backend.bias_add(x_i, b_i)
        x_f = backend.bias_add(x_f, b_f)
        x_c = backend.bias_add(x_c, b_c)
        x_o = backend.bias_add(x_o, b_o)

      if 0 < self.recurrent_dropout < 1.:
        h_tm1_i = h_tm1 * rec_dp_mask[0]
        h_tm1_f = h_tm1 * rec_dp_mask[1]
        h_tm1_c = h_tm1 * rec_dp_mask[2]
        h_tm1_o = h_tm1 * rec_dp_mask[3]
      else:
        h_tm1_i = h_tm1
        h_tm1_f = h_tm1
        h_tm1_c = h_tm1
        h_tm1_o = h_tm1
      x = (x_i, x_f, x_c, x_o)
      h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
      c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)
    else:
      if 0. < self.dropout < 1.:
        inputs = inputs * dp_mask[0]
      z = backend.dot(inputs, self.kernel)
      z += backend.dot(h_tm1, self.recurrent_kernel)
      if self.use_bias:
        z = backend.bias_add(z, self.bias)

      z = tf.split(z, num_or_size_splits=4, axis=1)
      c, o = self._compute_carry_and_output_fused(z, c_tm1)

    h = o * self.activation(c)
    return h, [h, c]
Пример #12
0
    def _sample_channels(self,
                         component_logits,
                         locs,
                         scales,
                         coeffs=None,
                         seed=None):
        """Sample a single pixel-iteration and apply channel conditioning.
        Args:
          component_logits: 4D `Tensor` of logits for the Categorical distribution
            over Quantized Logistic mixture components. Dimensions are `[batch_size,
            height, width, num_logistic_mix]`.
          locs: 4D `Tensor` of location parameters for the Quantized Logistic
            mixture components. Dimensions are `[batch_size, height, width,
            num_logistic_mix, num_channels]`.
          scales: 4D `Tensor` of location parameters for the Quantized Logistic
            mixture components. Dimensions are `[batch_size, height, width,
            num_logistic_mix, num_channels]`.
          coeffs: 4D `Tensor` of coefficients for the linear dependence among color
            channels, or `None` if there is only one channel. Dimensions are
            `[batch_size, height, width, num_logistic_mix, num_coeffs]`, where
            `num_coeffs = num_channels * (num_channels - 1) // 2`.
          seed: `int`, random seed.
        Returns:
          samples: 4D `Tensor` of sampled image data with autoregression among
            channels. Dimensions are `[batch_size, height, width, num_channels]`.
        """
        num_channels = self.event_shape[-1]

        # sample mixture components once for the entire pixel
        component_dist = categorical.Categorical(logits=component_logits)
        mask = tf.one_hot(indices=component_dist.sample(seed=seed),
                          depth=self._num_logistic_mix)
        mask = tf.cast(mask[..., tf.newaxis], self.dtype)

        # apply mixture component mask and separate out RGB parameters
        masked_locs = tf.reduce_sum(locs * mask, axis=-2)
        loc_tensors = tf.split(masked_locs, num_channels, axis=-1)
        masked_scales = tf.reduce_sum(scales * mask, axis=-2)
        scale_tensors = tf.split(masked_scales, num_channels, axis=-1)

        if coeffs is not None:
            num_coeffs = num_channels * (num_channels - 1) // 2
            masked_coeffs = tf.reduce_sum(coeffs * mask, axis=-2)
            coef_tensors = tf.split(masked_coeffs, num_coeffs, axis=-1)

        channel_samples = []
        coef_count = 0
        for i in range(num_channels):
            loc = loc_tensors[i]
            for c in channel_samples:
                loc += c * coef_tensors[coef_count]
                coef_count += 1

            logistic_samp = logistic.Logistic(
                loc=loc, scale=scale_tensors[i]).sample(seed=seed)
            logistic_samp = tf.clip_by_value(logistic_samp, -1., 1.)
            channel_samples.append(logistic_samp)

        return tf.concat(channel_samples, axis=-1)
def bbox_overlap(boxes, gt_boxes):
    """Calculates the overlap between proposal and ground truth boxes.

  Some `gt_boxes` may have been padded.  The returned `iou` tensor for these
  boxes will be -1.

  Args:
    boxes: a tensor with a shape of [batch_size, N, 4]. N is the number of
      proposals before groundtruth assignment (e.g., rpn_post_nms_topn). The
      last dimension is the pixel coordinates in [ymin, xmin, ymax, xmax] form.
    gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4]. This
      tensor might have paddings with a negative value.

  Returns:
    iou: a tensor with as a shape of [batch_size, N, MAX_NUM_INSTANCES].
  """
    with tf.name_scope('bbox_overlap'):
        bb_y_min, bb_x_min, bb_y_max, bb_x_max = tf.split(value=boxes,
                                                          num_or_size_splits=4,
                                                          axis=2)
        gt_y_min, gt_x_min, gt_y_max, gt_x_max = tf.split(value=gt_boxes,
                                                          num_or_size_splits=4,
                                                          axis=2)

        # Calculates the intersection area.
        i_xmin = tf.math.maximum(bb_x_min, tf.transpose(gt_x_min, [0, 2, 1]))
        i_xmax = tf.math.minimum(bb_x_max, tf.transpose(gt_x_max, [0, 2, 1]))
        i_ymin = tf.math.maximum(bb_y_min, tf.transpose(gt_y_min, [0, 2, 1]))
        i_ymax = tf.math.minimum(bb_y_max, tf.transpose(gt_y_max, [0, 2, 1]))
        i_area = tf.math.maximum((i_xmax - i_xmin), 0) * tf.math.maximum(
            (i_ymax - i_ymin), 0)

        # Calculates the union area.
        bb_area = (bb_y_max - bb_y_min) * (bb_x_max - bb_x_min)
        gt_area = (gt_y_max - gt_y_min) * (gt_x_max - gt_x_min)
        # Adds a small epsilon to avoid divide-by-zero.
        u_area = bb_area + tf.transpose(gt_area, [0, 2, 1]) - i_area + 1e-8

        # Calculates IoU.
        iou = i_area / u_area

        # Fills -1 for IoU entries between the padded ground truth boxes.
        gt_invalid_mask = tf.less(
            tf.reduce_max(gt_boxes, axis=-1, keepdims=True), 0.0)
        padding_mask = tf.logical_or(tf.zeros_like(bb_x_min, dtype=tf.bool),
                                     tf.transpose(gt_invalid_mask, [0, 2, 1]))
        iou = tf.where(padding_mask, -tf.ones_like(iou), iou)

        return iou
Пример #14
0
def _lu_reconstruct_assertions(lower_upper, perm, validate_args):
    """Returns list of assertions related to `lu_reconstruct` assumptions."""
    assertions = []

    message = 'Input `lower_upper` must have at least 2 dimensions.'
    if tensorshape_util.rank(lower_upper.shape) is not None:
        if tensorshape_util.rank(lower_upper.shape) < 2:
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank_at_least(lower_upper,
                                             rank=2,
                                             message=message))

    message = '`rank(lower_upper)` must equal `rank(perm) + 1`'
    if (tensorshape_util.rank(lower_upper.shape) is not None
            and tensorshape_util.rank(perm.shape) is not None):
        if (tensorshape_util.rank(lower_upper.shape) !=
                tensorshape_util.rank(perm.shape) + 1):
            raise ValueError(message)
    elif validate_args:
        assertions.append(
            assert_util.assert_rank(lower_upper,
                                    rank=tf.rank(perm) + 1,
                                    message=message))

    message = '`lower_upper` must be square.'
    if tensorshape_util.is_fully_defined(lower_upper.shape[:-2]):
        if lower_upper.shape[-2] != lower_upper.shape[-1]:
            raise ValueError(message)
    elif validate_args:
        m, n = tf.split(tf.shape(lower_upper)[-2:], num_or_size_splits=2)
        assertions.append(assert_util.assert_equal(m, n, message=message))

    return assertions
    def compute_logits(self, inputs: tf.Tensor, training: bool) -> tf.Tensor:
        """
        Implements a language model, where each output is conditional on the current
        input and inputs processed so far.

        Args:
            inputs: int32 tensor of shape [B, T, 2], storing integer IDs of the Nodes and the Actions as a stacked tensor.
            training: Flag indicating if we are currently training (used to toggle dropout)

        Returns:
            tf.float32 tensor of shape [B, T, V], storing the distribution over output symbols
            for each timestep for each batch element.
        """

        # The input has shape (B, T, 3) because I stacked the node_tokes, action_tokens and fathers_ids

        nodes_ids, actions_ids, _ = tf.split(inputs, 3,
                                             axis=2)  # (None, 50, 3)
        nodes_ids = tf.squeeze(nodes_ids, axis=2)  # (None, 50)
        actions_ids = tf.squeeze(actions_ids, axis=2)

        # Get embeddings
        nodes_emb = self.nodes_embedding(nodes_ids)
        actions_emb = self.actions_embedding(actions_ids)

        # concat embeddings
        concat_input = tf.concat([nodes_emb, actions_emb], axis=2)

        cell_output = self.gru1(concat_input, training=training)
        rnn_output_logits = self.dense(cell_output)

        return rnn_output_logits
Пример #16
0
        def step_fn(inputs):
            """Per-Replica StepFn."""
            images, labels = inputs
            images = tf.tile(images, [FLAGS.num_models, 1, 1, 1])
            logits = model(images, training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)

            per_probs = tf.split(probs,
                                 num_or_size_splits=FLAGS.num_models,
                                 axis=0)
            for i in range(FLAGS.num_models):
                member_probs = per_probs[i]
                member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                    labels, member_probs)
                metrics['test/nll_member_{}'.format(i)].update_state(
                    member_loss)
                metrics['test/accuracy_member_{}'.format(i)].update_state(
                    labels, member_probs)

            probs = tf.reduce_mean(per_probs, axis=0)
            negative_log_likelihood = tf.reduce_mean(
                tf.keras.losses.sparse_categorical_crossentropy(labels, probs))
            metrics['test/negative_log_likelihood'].update_state(
                negative_log_likelihood)
            metrics['test/accuracy'].update_state(labels, probs)
Пример #17
0
def ensemble_batchnorm(x, num_models=1, use_tpu=True, **kwargs):
    """A modified batch norm layer for Batch Ensemble model.

  Args:
    x: input tensor.
    num_models: number of ensemble members.
    use_tpu: whether the model is running on TPU.
    **kwargs: Keyword arguments to batch normalization layers.

  Returns:
    Output tensor for the block.
  """
    # In BatchEnsemble inference stage, the input to the model is tiled which
    # leads to dynamic shape because of the tf.split function. Such operation
    # is not supported in tf2.0 on TPU. For current workaround, we use single
    # BatchNormalization layer for all ensemble member. This is not correct in
    # math but works in practice.
    if num_models == 1 or use_tpu:
        return tf.keras.layers.BatchNormalization(**kwargs)(x)
    name = kwargs.get('name')
    split_inputs = tf.split(x, num_models, axis=0)
    for i in range(num_models):
        if name is not None:
            kwargs['name'] = name + '_{}'.format(i)
        split_inputs[i] = tf.keras.layers.BatchNormalization(**kwargs)(
            split_inputs[i])
    return tf.concat(split_inputs, axis=0)
def keypoint_flip_horizontal(keypoints,
                             flip_point,
                             flip_permutation,
                             scope=None):
    """Flips the keypoints horizontally around the flip_point.

  This operation flips the x coordinate for each keypoint around the flip_point
  and also permutes the keypoints in a manner specified by flip_permutation.

  Args:
    keypoints: a tensor of shape [num_instances, num_keypoints, 2]
    flip_point:  (float) scalar tensor representing the x coordinate to flip the
      keypoints around.
    flip_permutation: rank 1 int32 tensor containing the keypoint flip
      permutation. This specifies the mapping from original keypoint indices
      to the flipped keypoint indices. This is used primarily for keypoints
      that are not reflection invariant. E.g. Suppose there are 3 keypoints
      representing ['head', 'right_eye', 'left_eye'], then a logical choice for
      flip_permutation might be [0, 2, 1] since we want to swap the 'left_eye'
      and 'right_eye' after a horizontal flip.
    scope: name scope.

  Returns:
    new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
  """
    with tf.name_scope(scope, 'FlipHorizontal'):
        keypoints = tf.transpose(a=keypoints, perm=[1, 0, 2])
        keypoints = tf.gather(keypoints, flip_permutation)
        v, u = tf.split(value=keypoints, num_or_size_splits=2, axis=2)
        u = flip_point * 2.0 - u
        new_keypoints = tf.concat([v, u], 2)
        new_keypoints = tf.transpose(a=new_keypoints, perm=[1, 0, 2])
        return new_keypoints
Пример #19
0
  def _split_and_reshape_event(self, x):
    event_tensors = self._distribution.event_shape_tensor()
    splits = [
        ps.maximum(1, ps.reduce_prod(s))
        for s in tf.nest.flatten(event_tensors)
    ]
    x = tf.nest.pack_sequence_as(event_tensors, tf.split(x, splits, axis=-1))

    def _reshape_part(part, dtype, event_shape):
      part = tf.cast(part, dtype)
      static_rank = tf.get_static_value(ps.rank_from_shape(event_shape))
      if static_rank == 1:
        return part
      new_shape = ps.concat([ps.shape(part)[:-1], event_shape], axis=-1)
      return tf.reshape(part, ps.cast(new_shape, tf.int32))

    if all(
        tensorshape_util.is_fully_defined(s)
        for s in tf.nest.flatten(self._distribution.event_shape)):
      x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype,
                                self._distribution.event_shape)
    else:
      x = tf.nest.map_structure(_reshape_part, x, self._distribution.dtype,
                                self._distribution.event_shape_tensor())
    return x
Пример #20
0
        def _fn(x, output_units, **condition_kwargs):
            """Fully connected MLP parameterized via `real_nvp_template`."""
            if condition_kwargs:
                raise NotImplementedError(
                    'Conditioning not implemented in the default template.')

            if tensorshape_util.rank(x.shape) == 1:
                x = x[tf.newaxis, ...]
                reshape_output = lambda x: x[0]
            else:
                reshape_output = lambda x: x
            for units in hidden_layers:
                x = tf1.layers.dense(
                    inputs=x,
                    units=units,
                    activation=activation,
                    *args,  # pylint: disable=keyword-arg-before-vararg
                    **kwargs)
            x = tf1.layers.dense(
                inputs=x,
                units=(1 if shift_only else 2) * output_units,
                activation=None,
                *args,  # pylint: disable=keyword-arg-before-vararg
                **kwargs)
            if shift_only:
                return reshape_output(x), None
            shift, log_scale = tf.split(x, 2, axis=-1)
            return reshape_output(shift), reshape_output(log_scale)
Пример #21
0
    def testReverseMask(self, num_masked, fraction_masked, batch_shape):
        input_depth = 8
        x_ = np.random.normal(0., 1.,
                              batch_shape + (input_depth, )).astype(np.float32)
        flip_nvp = tfb.RealNVP(
            num_masked=num_masked,
            fraction_masked=fraction_masked,
            validate_args=True,
            **self._real_nvp_kwargs,
        )
        x = tf.constant(x_)

        forward_x = flip_nvp.forward(x)

        expected_num_masked = (num_masked if num_masked is not None else
                               np.floor(input_depth * fraction_masked))

        self.assertEqual(flip_nvp._masked_size, expected_num_masked)

        _, x2_ = np.split(x_, [input_depth - abs(flip_nvp._masked_size)],
                          axis=-1)  # pylint: disable=unbalanced-tuple-unpacking

        # Check latter half is the same after passing thru reversed mask RealNVP.
        _, forward_x2 = tf.split(forward_x, [
            input_depth - abs(flip_nvp._masked_size),
            abs(flip_nvp._masked_size)
        ],
                                 axis=-1)
        self.evaluate(tf1.global_variables_initializer())
        forward_x2_ = self.evaluate(forward_x2)

        self.assertAllClose(forward_x2_, x2_, rtol=1e-4, atol=0.)
Пример #22
0
def keypoint_prune_outside_window(keypoints, window, scope=None):
    """Prunes keypoints that fall outside a given window.

  This function replaces keypoints that fall outside the given window with nan.
  See also clip_to_window which clips any keypoints that fall outside the given
  window.

  Args:
    keypoints: a tensor of shape [num_instances, num_keypoints, 2]
    window: a tensor of shape [4] representing the [y_min, x_min, y_max, x_max]
      window outside of which the op should prune the keypoints.
    scope: name scope.

  Returns:
    new_keypoints: a tensor of shape [num_instances, num_keypoints, 2]
  """
    if not scope:
        scope = 'PruneOutsideWindow'
    with tf.name_scope(scope):
        y, x = tf.split(value=keypoints, num_or_size_splits=2, axis=2)
        win_y_min, win_x_min, win_y_max, win_x_max = tf.unstack(window)

        valid_indices = tf.logical_and(
            tf.logical_and(y >= win_y_min, y <= win_y_max),
            tf.logical_and(x >= win_x_min, x <= win_x_max))

        new_y = tf.where(valid_indices, y, np.nan * tf.ones_like(y))
        new_x = tf.where(valid_indices, x, np.nan * tf.ones_like(x))
        new_keypoints = tf.concat([new_y, new_x], 2)

        return new_keypoints
Пример #23
0
    def load_weights(self, base_fn):
        """Find the latest checkpoint matching base_fn, and load the weights."""

        matcher = base_fn + "_*.npy"
        filenames = sorted(gfile.glob(matcher), reverse=True)
        assert len(filenames) > 0, "No files matching {}".format(matcher)
        filename = filenames[0]

        # load array
        with gfile.GFile(filename, "rb") as fin:
            serialized_weights = np.load(fin)

        print(serialized_weights.shape, self.all_weights_flat_sizes)
        all_weights_flat_split = tf.split(serialized_weights,
                                          self.all_weights_flat_sizes)
        all_weights_flat = [
            tf.reshape(t, s) for t, s in zip(all_weights_flat_split,
                                             self.all_weights_flat_shapes)
        ]

        all_weights = tf.nest.pack_sequence_as(self.all_weights_structure,
                                               all_weights_flat)

        all_layers = self.layers + [self.loss_layer]
        if self.shared_params is not None:
            all_layers += list(self.shared_params.values())
        for l, lw in zip(all_layers, all_weights):
            l.load_weights(lw)
        def step_fn(inputs):
            """Per-Replica StepFn."""
            # Note that we don't use tf.tile for labels here
            images, labels = inputs
            images = tf.tile(images, [FLAGS.ensemble_size, 1, 1, 1])

            # get lambdas
            lambdas = log_uniform_mean(lambda_parameters)
            rep_lambdas = tf.repeat(lambdas, per_core_batch_size, axis=0)

            # eval on testsets
            logits = model([images, rep_lambdas], training=False)
            if FLAGS.use_bfloat16:
                logits = tf.cast(logits, tf.float32)
            probs = tf.nn.softmax(logits)
            per_probs = tf.split(probs,
                                 num_or_size_splits=FLAGS.ensemble_size,
                                 axis=0)

            # per member performance and gibbs performance (average per member perf)
            if dataset_name == 'clean':
                for i in range(FLAGS.ensemble_size):
                    member_probs = per_probs[i]
                    member_loss = tf.keras.losses.sparse_categorical_crossentropy(
                        labels, member_probs)
                    metrics['test/nll_member_{}'.format(i)].update_state(
                        member_loss)
                    metrics['test/accuracy_member_{}'.format(i)].update_state(
                        labels, member_probs)

                labels_tile = tf.tile(labels, [FLAGS.ensemble_size])
                metrics['test/gibbs_nll'].update_state(
                    tf.reduce_mean(
                        tf.keras.losses.sparse_categorical_crossentropy(
                            labels_tile, logits, from_logits=True)))
                metrics['test/gibbs_accuracy'].update_state(labels_tile, probs)

            # ensemble performance
            negative_log_likelihood = ensemble_crossentropy(
                labels, logits, FLAGS.ensemble_size)
            probs = tf.reduce_mean(per_probs, axis=0)
            if dataset_name == 'clean':
                metrics['test/negative_log_likelihood'].update_state(
                    negative_log_likelihood)
                metrics['test/accuracy'].update_state(labels, probs)
                metrics['test/ece'].update_state(labels, probs)
            else:
                corrupt_metrics['test/nll_{}'.format(
                    dataset_name)].update_state(negative_log_likelihood)
                corrupt_metrics['test/accuracy_{}'.format(
                    dataset_name)].update_state(labels, probs)
                corrupt_metrics['test/ece_{}'.format(
                    dataset_name)].update_state(labels, probs)

            if dataset_name == 'clean':
                per_probs_stacked = tf.stack(per_probs, axis=0)
                diversity_results = um.average_pairwise_diversity(
                    per_probs_stacked, FLAGS.ensemble_size)
                for k, v in diversity_results.items():
                    metrics['test/' + k].update_state(v)
Пример #25
0
    def call(self, inputs):
        dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx())
        inputs = tf.cast(inputs, dtype, name='inputs')

        q = self._posterior(inputs)
        r = self._prior(inputs)
        self.add_loss(self._kl_divergence_fn(q, r))

        w = tf.convert_to_tensor(value=q)
        prev_units = self.input_spec.axes[-1]
        if self.use_bias:
            split_sizes = [prev_units * self.units, self.units]
            kernel, bias = tf.split(w, split_sizes, axis=-1)
        else:
            kernel, bias = w, None

        kernel = tf.reshape(kernel,
                            shape=tf.concat([
                                tf.shape(input=kernel)[:-1],
                                [prev_units, self.units],
                            ],
                                            axis=0))
        outputs = tf.matmul(inputs, kernel)

        if self.use_bias:
            outputs = tf.nn.bias_add(outputs, bias)

        if self.activation is not None:
            outputs = self.activation(outputs)  # pylint: disable=not-callable

        return outputs
  def _inverse(self, y):
    # To derive the inverse mapping note that:
    #   y[i] = exp(x[i]) / normalization
    # and
    #   y[end] = 1 / normalization.
    # Thus:
    # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization)
    #      = log(exp(x[i])/normalization) - log(y[end])
    #      = log(y[i]) - log(y[end])

    # Do this first to make sure CSE catches that it'll happen again in
    # _inverse_log_det_jacobian.

    assertions = []
    if self.validate_args:
      assertions.append(assert_util.assert_near(
          tf.reduce_sum(y, axis=-1),
          tf.ones([], y.dtype),
          2. * np.finfo(dtype_util.as_numpy_dtype(y.dtype)).eps,
          message='Last dimension of `y` must sum to `1`.'))
      assertions.append(assert_util.assert_less_equal(
          y, tf.ones([], y.dtype),
          message='Elements of `y` must be less than or equal to `1`.'))
      assertions.append(assert_util.assert_non_negative(
          y, message='Elements of `y` must be non-negative.'))

    with tf.control_dependencies(assertions):
      x = tf.math.log(y)
      x, log_normalization = tf.split(x, num_or_size_splits=[-1, 1], axis=-1)
    return x - log_normalization
Пример #27
0
    def get_dist_and_mode(self, states):
        """Returns a tf.Distribution for given states.

    Args:
      states: A batch of states.
    """
        out = self.trunk(states)
        mu, log_std = tf.split(out, num_or_size_splits=2, axis=1)
        mode = tf.nn.tanh(mu)

        log_std = tf.nn.tanh(log_std)
        assert LOG_STD_MAX > LOG_STD_MIN
        log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std +
                                                                     1)

        std = tf.exp(log_std)

        dist = ds.TransformedDistribution(distribution=ds.Normal(loc=0.,
                                                                 scale=1.),
                                          bijector=tfp.bijectors.Chain([
                                              tfp.bijectors.Tanh(),
                                              tfp.bijectors.Affine(
                                                  shift=mu, scale_diag=std),
                                          ]),
                                          event_shape=[mu.shape[-1]],
                                          batch_shape=[mu.shape[0]])
        return dist, mode
def box_list_scale(boxlist, y_scale, x_scale, scope=None):
    """scale box coordinates in x and y dimensions.

  Args:
    boxlist: BoxList holding N boxes
    y_scale: (float) scalar tensor
    x_scale: (float) scalar tensor
    scope: name scope.

  Returns:
    boxlist: BoxList holding N boxes
  """
    with tf.name_scope(scope, 'Scale'):
        y_scale = tf.cast(y_scale, tf.float32)
        x_scale = tf.cast(x_scale, tf.float32)
        y_min, x_min, y_max, x_max = tf.split(value=boxlist.get(),
                                              num_or_size_splits=4,
                                              axis=1)
        y_min = y_scale * y_min
        y_max = y_scale * y_max
        x_min = x_scale * x_min
        x_max = x_scale * x_max
        scaled_boxlist = box_list.BoxList(
            tf.concat([y_min, x_min, y_max, x_max], 1))
        return _copy_extra_fields(scaled_boxlist, boxlist)
Пример #29
0
    def compress(self, bottleneck):
        """Compresses a floating-point tensor.

    Compresses the tensor to bit strings. `bottleneck` is first quantized
    as in `quantize()`, and then compressed using the probability tables derived
    from `self.prior`. The quantized tensor can later be recovered by
    calling `decompress()`.

    The innermost `self.coding_rank` dimensions are treated as one coding unit,
    i.e. are compressed into one string each. Any additional dimensions to the
    left are treated as batch dimensions.

    Arguments:
      bottleneck: `tf.Tensor` containing the data to be compressed. Must have at
        least `self.coding_rank` dimensions, and the innermost dimensions must
        be broadcastable to `self.prior_shape`.

    Returns:
      A `tf.Tensor` having the same shape as `bottleneck` without the
      `self.coding_rank` innermost dimensions, containing a string for each
      coding unit.
    """
        input_shape = tf.shape(bottleneck)
        input_rank = tf.shape(input_shape)[0]
        batch_shape, coding_shape = tf.split(
            input_shape, [input_rank - self.coding_rank, self.coding_rank])
        broadcast_shape = coding_shape[:self.coding_rank -
                                       len(self.prior_shape)]

        indexes, offset = self._compute_indexes_and_offset(broadcast_shape)
        if offset is not None:
            bottleneck -= offset
        symbols = tf.cast(tf.round(bottleneck), tf.int32)
        symbols = tf.reshape(symbols, tf.concat([[-1], coding_shape], 0))

        # Prevent tensors from bouncing back and forth between host and GPU.
        with tf.device("/cpu:0"):
            cdf = self.cdf
            cdf_length = self.cdf_length
            cdf_offset = self.cdf_offset

            def loop_body(symbols):
                return range_coding_ops.unbounded_index_range_encode(
                    symbols,
                    indexes,
                    cdf,
                    cdf_length,
                    cdf_offset,
                    precision=self.range_coder_precision,
                    overflow_width=4,
                    debug_level=1)

            # TODO(jonycgn,ssjhv): Consider switching to Python control flow.
            strings = tf.map_fn(loop_body,
                                symbols,
                                dtype=tf.string,
                                name="compress")

        strings = tf.reshape(strings, batch_shape)
        return strings
    def testStddev(self):
        base_stddev = 2.
        shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32)
        scale = np.array([[1, -2, 3], [2, -3, 2]], dtype=np.float32)
        expected_stddev = tf.abs(base_stddev * scale)
        normal = self._cls()(
            distribution=tfd.Normal(loc=tf.zeros_like(shift),
                                    scale=base_stddev * tf.ones_like(scale),
                                    validate_args=True),
            bijector=tfb.Chain(
                [tfb.Shift(shift=shift),
                 tfb.Scale(scale=scale)],
                validate_args=True),
            validate_args=True)
        self.assertAllClose(expected_stddev, normal.stddev())
        self.assertAllClose(expected_stddev**2, normal.variance())

        split_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                   bijector=tfb.Split(3),
                                   validate_args=True)
        self.assertAllCloseNested(
            tf.split(expected_stddev, num_or_size_splits=3, axis=-1),
            split_normal.stddev())

        scaled_normal = self._cls()(distribution=tfd.Independent(
            normal, reinterpreted_batch_ndims=1),
                                    bijector=tfb.ScaleMatvecTriL([[1., 0.],
                                                                  [-1., 2.]]),
                                    validate_args=True)
        with self.assertRaisesRegex(NotImplementedError,
                                    'is a multivariate transformation'):
            scaled_normal.stddev()