Example #1
0
  def _build(self, x, state):
    prev_keep_mask = state
    shape = tf.shape(x)
    noise = tf.random_uniform(shape, dtype=x.dtype)
    other_mask = tf.floor(self._keep_prob + noise)
    choice_noise = tf.random_uniform(shape, dtype=x.dtype)
    choice = tf.less(choice_noise, self._flip_prob)
    # KLUDGE(melisgl): The client has to pass the last keep_mask from
    # a batch to the next so the mask may end up next to some
    # recurrent cell state. This state is often zero at the beginning
    # and may be periodically zeroed (per example) during training.
    # While zeroing LSTM state is okay, zeroing the dropout mask is
    # not. So instead of forcing every client to deal with this common
    # (?) case, if an all zero mask is detected, then regenerate a
    # fresh mask. This is of course a major hack and won't help with
    # learnt initial states, for example.
    sum_ = tf.reduce_sum(prev_keep_mask, 1, keepdims=True)
    is_initializing = tf.equal(sum_, 0.0)

    self._keep_mask = tf.where(tf.logical_or(choice, is_initializing),
                               other_mask,
                               prev_keep_mask)
    self._time_step += 1
    return x * self._keep_mask / self._keep_prob * self._scaler
Example #2
0
    def smooth_L1_loss(self, y_true, y_pred):
        '''
        Compute smooth L1 loss, see references.

        Arguments:
            y_true (nD tensor): A TensorFlow tensor of any shape containing the ground truth data.
                In this context, the expected tensor has shape `(batch_size, #boxes, 4)` and
                contains the ground truth bounding box coordinates, where the last dimension
                contains `(xmin, xmax, ymin, ymax)`.
            y_pred (nD tensor): A TensorFlow tensor of identical structure to `y_true` containing
                the predicted data, in this context the predicted bounding box coordinates.

        Returns:
            The smooth L1 loss, a nD-1 Tensorflow tensor. In this context a 2D tensor
            of shape (batch, n_boxes_total).

        References:
            https://arxiv.org/abs/1504.08083
        '''
        absolute_loss = tf.abs(y_true - y_pred)
        square_loss = 0.5 * (y_true - y_pred)**2
        l1_loss = tf.where(tf.less(absolute_loss, 1.0), square_loss,
                           absolute_loss - 0.5)
        return tf.reduce_sum(l1_loss, axis=-1)
  def eager_decay_rate():
    """Callable to compute the learning rate."""
    post_warmup_learning_rate = tf.train.exponential_decay(
        learning_rate_base,
        global_step - warmup_steps,
        learning_rate_decay_steps,
        learning_rate_decay_factor,
        staircase=staircase)
    if callable(post_warmup_learning_rate):
      post_warmup_learning_rate = post_warmup_learning_rate()

    if learning_rate_base < warmup_learning_rate:
      raise ValueError('learning_rate_base must be larger or equal to '
                       'warmup_learning_rate.')
    slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
    warmup_rate = slope * tf.cast(global_step,
                                  tf.float32) + warmup_learning_rate
    learning_rate = tf.where(
        tf.less(tf.cast(global_step, tf.int32), tf.constant(warmup_steps)),
        warmup_rate,
        tf.maximum(post_warmup_learning_rate, min_learning_rate),
        name='learning_rate')

    return learning_rate
Example #4
0
    def train(self, features, labels):
        """A wrapper for tf.cond."""

        with tf.variable_scope('', reuse=tf.AUTO_REUSE):
            global_step = tf.train.get_or_create_global_step()
            learning_rate = self.get_learning_rate(global_step)
            optimizer = self.get_optimizer(learning_rate)

            def branch_fn(image_size, optimizer):
                return self.train_op(features, labels, image_size, optimizer)

            grads = tf.cond(
                tf.less(features['image_info'][0][3],
                        features['image_info'][0][4]),
                lambda: branch_fn(self.params['image_size'], optimizer),
                lambda: branch_fn(self.params['image_size'][::-1], optimizer))

            variables = self.remove_variables(tf.trainable_variables(),
                                              self.params['resnet_depth'])

            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                return optimizer.apply_gradients(zip(grads, variables),
                                                 global_step=global_step)
Example #5
0
    def compute_gradients(self,
                          loss,
                          var_list,
                          gate_gradients=GATE_OP,
                          aggregation_method=None,
                          colocate_gradients_with_ops=False,
                          grad_loss=None,
                          gradient_tape=None):
      if callable(loss):
        # TF is running in Eager mode, check we received a vanilla tape.
        if not gradient_tape:
          raise ValueError('When in Eager mode, a tape needs to be passed.')

        vector_loss = loss()
        if self._num_microbatches is None:
          self._num_microbatches = tf.shape(input=vector_loss)[0]
        sample_state = self._dp_sum_query.initial_sample_state(var_list)
        microbatches_losses = tf.reshape(vector_loss,
                                         [self._num_microbatches, -1])
        sample_params = (
            self._dp_sum_query.derive_sample_params(self._global_state))

        def process_microbatch(i, sample_state):
          """Process one microbatch (record) with privacy helper."""
          microbatch_loss = tf.reduce_mean(
              input_tensor=tf.gather(microbatches_losses, [i]))
          grads = gradient_tape.gradient(microbatch_loss, var_list)
          sample_state = self._dp_sum_query.accumulate_record(
              sample_params, sample_state, grads)
          return sample_state

        for idx in range(self._num_microbatches):
          sample_state = process_microbatch(idx, sample_state)

        grad_sums, self._global_state = (
            self._dp_sum_query.get_noised_result(
                sample_state, self._global_state))

        def normalize(v):
          return v / tf.cast(self._num_microbatches, tf.float32)

        final_grads = tf.nest.map_structure(normalize, grad_sums)

        grads_and_vars = list(zip(final_grads, var_list))
        return grads_and_vars

      else:
        # TF is running in graph mode, check we did not receive a gradient tape.
        if gradient_tape:
          raise ValueError('When in graph mode, a tape should not be passed.')

        # Note: it would be closer to the correct i.i.d. sampling of records if
        # we sampled each microbatch from the appropriate binomial distribution,
        # although that still wouldn't be quite correct because it would be
        # sampling from the dataset without replacement.
        if self._num_microbatches is None:
          self._num_microbatches = tf.shape(input=loss)[0]

        microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1])
        sample_params = (
            self._dp_sum_query.derive_sample_params(self._global_state))

        def process_microbatch(i, sample_state):
          """Process one microbatch (record) with privacy helper."""
          grads, _ = zip(*super(cls, self).compute_gradients(
              tf.reduce_mean(input_tensor=tf.gather(
                  microbatches_losses, [i])), var_list, gate_gradients,
              aggregation_method, colocate_gradients_with_ops, grad_loss))
          grads_list = [
              g if g is not None else tf.zeros_like(v)
              for (g, v) in zip(list(grads), var_list)
          ]
          sample_state = self._dp_sum_query.accumulate_record(
              sample_params, sample_state, grads_list)
          return sample_state

        if var_list is None:
          var_list = (
              tf.compat.v1.trainable_variables() + tf.compat.v1.get_collection(
                  tf.compat.v1.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))

        sample_state = self._dp_sum_query.initial_sample_state(var_list)

        if self._unroll_microbatches:
          for idx in range(self._num_microbatches):
            sample_state = process_microbatch(idx, sample_state)
        else:
          # Use of while_loop here requires that sample_state be a nested
          # structure of tensors. In general, we would prefer to allow it to be
          # an arbitrary opaque type.
          cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
          body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)]  # pylint: disable=line-too-long
          idx = tf.constant(0)
          _, sample_state = tf.while_loop(
              cond=cond_fn, body=body_fn, loop_vars=[idx, sample_state])

        grad_sums, self._global_state = (
            self._dp_sum_query.get_noised_result(
                sample_state, self._global_state))

        def normalize(v):
          return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32))

        final_grads = tf.nest.map_structure(normalize, grad_sums)

        return list(zip(final_grads, var_list))
Example #6
0
def assign_and_sample_proposals(proposed_boxes,
                                gt_boxes,
                                gt_classes,
                                gt_attributes,
                                num_samples_per_image=512,
                                mix_gt_boxes=True,
                                fg_fraction=0.25,
                                fg_iou_thresh=0.5,
                                bg_iou_thresh_hi=0.5,
                                bg_iou_thresh_lo=0.0):
    """Assigns the proposals with groundtruth classes and performs subsmpling.

  Given `proposed_boxes`, `gt_boxes`, `gt_classes` and `gt_attributes`, the
  function uses the following algorithm to generate the final
  `num_samples_per_image` RoIs.
    1. Calculates the IoU between each proposal box and each gt_boxes.
    2. Assigns each proposed box with a groundtruth class and box by choosing
       the largest IoU overlap.
    3. Samples `num_samples_per_image` boxes from all proposed boxes, and
       returns box_targets, class_targets, and RoIs.

  Args:
    proposed_boxes: a tensor of shape of [batch_size, N, 4]. N is the number
      of proposals before groundtruth assignment. The last dimension is the
      box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
      format.
    gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4].
      The coordinates of gt_boxes are in the pixel coordinates of the scaled
      image. This tensor might have padding of values -1 indicating the invalid
      box coordinates.
    gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
      tensor might have paddings with values of -1 indicating the invalid
      classes.
    gt_attributes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES,
      num_attributes]. This tensor might have paddings with values of -1
      indicating the invalid attributes.
    num_samples_per_image: an integer represents RoI minibatch size per image.
    mix_gt_boxes: a bool indicating whether to mix the groundtruth boxes before
      sampling proposals.
    fg_fraction: a float represents the target fraction of RoI minibatch that
      is labeled foreground (i.e., class > 0).
    fg_iou_thresh: a float represents the IoU overlap threshold for an RoI to be
      considered foreground (if >= fg_iou_thresh).
    bg_iou_thresh_hi: a float represents the IoU overlap threshold for an RoI to
      be considered background (class = 0 if overlap in [LO, HI)).
    bg_iou_thresh_lo: a float represents the IoU overlap threshold for an RoI to
      be considered background (class = 0 if overlap in [LO, HI)).

  Returns:
    sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
      coordinates of the sampled RoIs, where K is the number of the sampled
      RoIs, i.e. K = num_samples_per_image.
    sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
      box coordinates of the matched groundtruth boxes of the samples RoIs.
    sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
      classes of the matched groundtruth boxes of the sampled RoIs.
    sampled_gt_attributes: a tensor of shape of [batch_size, K,
      num_attributes], storing the attributes of the matched groundtruth
      attributes of the sampled RoIs.
    sampled_gt_indices: a tensor of shape of [batch_size, K], storing the
      indices of the sampled groudntruth boxes in the original `gt_boxes`
      tensor, i.e. gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
  """

    with tf.name_scope('sample_proposals'):
        if mix_gt_boxes:
            boxes = tf.concat([proposed_boxes, gt_boxes], axis=1)
        else:
            boxes = proposed_boxes

        (matched_gt_boxes, matched_gt_classes, matched_gt_attributes,
         matched_gt_indices, matched_iou,
         _) = box_matching(boxes, gt_boxes, gt_classes, gt_attributes)

        positive_match = tf.greater(matched_iou, fg_iou_thresh)
        negative_match = tf.logical_and(
            tf.greater_equal(matched_iou, bg_iou_thresh_lo),
            tf.less(matched_iou, bg_iou_thresh_hi))
        ignored_match = tf.less(matched_iou, 0.0)

        # re-assign negatively matched boxes to the background class.
        matched_gt_classes = tf.where(negative_match,
                                      tf.zeros_like(matched_gt_classes),
                                      matched_gt_classes)
        matched_gt_indices = tf.where(negative_match,
                                      tf.zeros_like(matched_gt_indices),
                                      matched_gt_indices)

        sample_candidates = tf.logical_and(
            tf.logical_or(positive_match, negative_match),
            tf.logical_not(ignored_match))

        sampler = (
            balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
                positive_fraction=fg_fraction, is_static=True))

        batch_size, _ = sample_candidates.get_shape().as_list()
        sampled_indicators = []
        for i in range(batch_size):
            sampled_indicator = sampler.subsample(sample_candidates[i],
                                                  num_samples_per_image,
                                                  positive_match[i])
            sampled_indicators.append(sampled_indicator)
        sampled_indicators = tf.stack(sampled_indicators)
        _, sampled_indices = tf.nn.top_k(tf.cast(sampled_indicators,
                                                 dtype=tf.int32),
                                         k=num_samples_per_image,
                                         sorted=True)

        sampled_indices_shape = tf.shape(sampled_indices)
        batch_indices = (
            tf.expand_dims(tf.range(sampled_indices_shape[0]), axis=-1) *
            tf.ones([1, sampled_indices_shape[-1]], dtype=tf.int32))
        gather_nd_indices = tf.stack([batch_indices, sampled_indices], axis=-1)

        sampled_rois = tf.gather_nd(boxes, gather_nd_indices)
        sampled_gt_boxes = tf.gather_nd(matched_gt_boxes, gather_nd_indices)
        sampled_gt_classes = tf.gather_nd(matched_gt_classes,
                                          gather_nd_indices)
        sampled_gt_attributes = tf.gather_nd(matched_gt_attributes,
                                             gather_nd_indices)
        sampled_gt_indices = tf.gather_nd(matched_gt_indices,
                                          gather_nd_indices)

        return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
                sampled_gt_attributes, sampled_gt_indices)
Example #7
0
def main(unused_argv=None):
    tf.logging.set_verbosity(FLAGS.log)

    if FLAGS.config is None:
        raise RuntimeError("No config name specified.")

    config = utils.get_module("wavenet." + FLAGS.config).Config(
        FLAGS.train_path)

    logdir = FLAGS.logdir
    tf.logging.info("Saving to %s" % logdir)

    with tf.Graph().as_default():
        total_batch_size = FLAGS.total_batch_size
        assert total_batch_size % FLAGS.worker_replicas == 0
        worker_batch_size = total_batch_size / FLAGS.worker_replicas

        # Run the Reader on the CPU
        cpu_device = "/job:localhost/replica:0/task:0/cpu:0"
        if FLAGS.ps_tasks:
            cpu_device = "/job:worker/cpu:0"

        with tf.device(cpu_device):
            inputs_dict = config.get_batch(worker_batch_size)

        with tf.device(
                tf.train.replica_device_setter(ps_tasks=FLAGS.ps_tasks,
                                               merge_devices=True)):
            global_step = tf.get_variable(
                "global_step", [],
                tf.int32,
                initializer=tf.constant_initializer(0),
                trainable=False)

            # pylint: disable=cell-var-from-loop
            lr = tf.constant(config.learning_rate_schedule[0])
            for key, value in config.learning_rate_schedule.iteritems():
                lr = tf.cond(tf.less(global_step, key), lambda: lr,
                             lambda: tf.constant(value))
            # pylint: enable=cell-var-from-loop
            tf.summary.scalar("learning_rate", lr)

            # build the model graph
            outputs_dict = config.build(inputs_dict, is_training=True)
            loss = outputs_dict["loss"]
            tf.summary.scalar("train_loss", loss)

            worker_replicas = FLAGS.worker_replicas
            ema = tf.train.ExponentialMovingAverage(decay=0.9999,
                                                    num_updates=global_step)
            opt = tf.train.SyncReplicasOptimizer(
                tf.train.AdamOptimizer(lr, epsilon=1e-8),
                worker_replicas,
                total_num_replicas=worker_replicas,
                variable_averages=ema,
                variables_to_average=tf.trainable_variables())

            train_op = opt.minimize(loss,
                                    global_step=global_step,
                                    name="train",
                                    colocate_gradients_with_ops=True)

            session_config = tf.ConfigProto(allow_soft_placement=True)

            is_chief = (FLAGS.task == 0)
            local_init_op = opt.chief_init_op if is_chief else opt.local_step_init_op

            slim.learning.train(
                train_op=train_op,
                logdir=logdir,
                is_chief=is_chief,
                master=FLAGS.master,
                number_of_steps=config.num_iters,
                global_step=global_step,
                log_every_n_steps=250,
                local_init_op=local_init_op,
                save_interval_secs=300,
                sync_optimizer=opt,
                session_config=session_config,
            )
Example #8
0
    def inject_latent(self, layer, inputs, target, action):
        """Inject a deterministic latent based on the target frame."""
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        activation_fn = common_layers.belu
        if hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        def add_bits(layer, bits):
            z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if not self.is_training:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
                bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            else:
                bits, _ = discretization.predict_bits_with_lstm(
                    layer,
                    hparams.latent_predictor_state_size,
                    hparams.bottleneck_bits,
                    temperature=hparams.latent_predictor_temperature)
                bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
            return add_bits(layer, bits), 0.0

        # Embed.
        frames = tf.concat(inputs + [target], axis=-1)
        x = tfl.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Add embedded action if present.
        if action is not None:
            x = common_video.inject_additional_input(x, action,
                                                     "action_enc_latent",
                                                     hparams.action_injection)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tfl.conv2d(x,
                                   filters,
                                   kernel,
                                   activation=activation_fn,
                                   strides=(2, 2),
                                   padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

        bits, bits_clean = discretization.tanh_discrete_bottleneck(
            x, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        if not hparams.full_latent_tower:
            _, pred_loss = discretization.predict_bits_with_lstm(
                layer,
                hparams.latent_predictor_state_size,
                hparams.bottleneck_bits,
                target_bits=bits_clean)
            # Mix bits from latent with predicted bits on forward pass as a noise.
            if hparams.latent_rnn_max_sampling > 0.0:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    bits_pred, _ = discretization.predict_bits_with_lstm(
                        layer,
                        hparams.latent_predictor_state_size,
                        hparams.bottleneck_bits,
                        temperature=hparams.latent_predictor_temperature)
                    bits_pred = tf.expand_dims(tf.expand_dims(bits_pred,
                                                              axis=1),
                                               axis=2)
                # Be bits_pred on the forward pass but bits on the backward one.
                bits_pred = bits_clean + tf.stop_gradient(bits_pred -
                                                          bits_clean)
                # Select which bits to take from pred sampling with bit_p probability.
                which_bit = tf.random_uniform(common_layers.shape_list(bits))
                bit_p = common_layers.inverse_lin_decay(
                    hparams.latent_rnn_warmup_steps)
                bit_p *= hparams.latent_rnn_max_sampling
                bits = tf.where(which_bit < bit_p, bits_pred, bits)

        res = add_bits(layer, bits)
        # During training, sometimes skip the latent to help action-conditioning.
        res_p = common_layers.inverse_lin_decay(
            hparams.latent_rnn_warmup_steps / 2)
        res_p *= hparams.latent_use_max_probability
        res_rand = tf.random_uniform([layer_shape[0]])
        res = tf.where(res_rand < res_p, res, layer)
        return res, pred_loss
Example #9
0
 def clip_boxes(self, boxes):
   """Clip boxes to fit in an image."""
   boxes = tf.where(tf.less(boxes, 0), tf.zeros_like(boxes), boxes)
   boxes = tf.where(tf.greater(boxes, self._output_size[0] - 1),
                    (self._output_size[1] - 1) * tf.ones_like(boxes), boxes)
   return boxes
Example #10
0
def mask(config: configure_pretraining.PretrainingConfig,
         inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0,
         disallow_from_mask=None, already_masked=None):
  """Implementation of dynamic masking. The optional arguments aren't needed for
  BERT/ELECTRA and are from early experiments in "strategically" masking out
  tokens instead of uniformly at random.

  Args:
    config: configure_pretraining.PretrainingConfig
    inputs: pretrain_data.Inputs containing input input_ids/input_mask
    mask_prob: percent of tokens to mask
    proposal_distribution: for non-uniform masking can be a [B, L] tensor
                           of scores for masking each position.
    disallow_from_mask: a boolean tensor of [B, L] of positions that should
                        not be masked out
    already_masked: a boolean tensor of [B, N] of already masked-out tokens
                    for multiple rounds of masking
  Returns: a pretrain_data.Inputs with masking added
  """
  # Get the batch size, sequence length, and max masked-out tokens
  N = config.max_predictions_per_seq
  B, L = modeling.get_shape_list(inputs.input_ids)

  # Find indices where masking out a token is allowed
  vocab = tokenization.FullTokenizer(
      config.vocab_file, do_lower_case=config.do_lower_case).vocab
  candidates_mask = _get_candidates_mask(inputs, vocab, disallow_from_mask)

  # Set the number of tokens to mask out per example
  num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)
  num_to_predict = tf.maximum(1, tf.minimum(
      N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32)))
  masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32)
  if already_masked is not None:
    masked_lm_weights *= (1 - already_masked)

  # Get a probability of masking each position in the sequence
  candidate_mask_float = tf.cast(candidates_mask, tf.float32)
  sample_prob = (proposal_distribution * candidate_mask_float)
  sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True)

  # Sample the positions to mask out
  sample_prob = tf.stop_gradient(sample_prob)
  sample_logits = tf.log(sample_prob)
  masked_lm_positions = tf.random.categorical(
      sample_logits, N, dtype=tf.int32)
  masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32)

  # Get the ids of the masked-out tokens
  shift = tf.expand_dims(L * tf.range(B), -1)
  flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
  masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]),
                               flat_positions)
  masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1])
  masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32)

  # Update the input ids
  replace_with_mask_positions = masked_lm_positions * tf.cast(
      tf.less(tf.random.uniform([B, N]), 0.85), tf.int32)
  inputs_ids, _ = scatter_update(
      inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]),
      replace_with_mask_positions)

  return pretrain_data.get_updated_inputs(
      inputs,
      input_ids=tf.stop_gradient(inputs_ids),
      masked_lm_positions=masked_lm_positions,
      masked_lm_ids=masked_lm_ids,
      masked_lm_weights=masked_lm_weights
  )
Example #11
0
def resize_and_crop_image_v2(image,
                             short_side,
                             long_side,
                             padded_size,
                             aug_scale_min=1.0,
                             aug_scale_max=1.0,
                             seed=1,
                             method=tf.image.ResizeMethod.BILINEAR):
    """Resizes the input image to output size (Faster R-CNN style).

  Resize and pad images given the specified short / long side length and the
  stride size.

  Here are the preprocessing steps.
  1. For a given image, keep its aspect ratio and first try to rescale the short
     side of the original image to `short_side`.
  2. If the scaled image after 1 has a long side that exceeds `long_side`, keep
     the aspect ratio and rescal the long side of the image to `long_side`.
  2. Pad the rescaled image to the padded_size.

  Args:
    image: a `Tensor` of shape [height, width, 3] representing an image.
    short_side: a scalar `Tensor` or `int` representing the desired short side
      to be rescaled to.
    long_side: a scalar `Tensor` or `int` representing the desired long side to
      be rescaled to.
    padded_size: a `Tensor` or `int` list/tuple of two elements representing
      [height, width] of the padded output image size. Padding will be applied
      after scaling the image to the desired_size.
    aug_scale_min: a `float` with range between [0, 1.0] representing minimum
      random scale applied to desired_size for training scale jittering.
    aug_scale_max: a `float` with range between [1.0, inf] representing maximum
      random scale applied to desired_size for training scale jittering.
    seed: seed for random scale jittering.
    method: function to resize input image to scaled image.

  Returns:
    output_image: `Tensor` of shape [height, width, 3] where [height, width]
      equals to `output_size`.
    image_info: a 2D `Tensor` that encodes the information of the image and the
      applied preprocessing. It is in the format of
      [[original_height, original_width], [scaled_height, scaled_width],
       [y_scale, x_scale], [y_offset, x_offset]], where [scaled_height,
      scaled_width] is the actual scaled image size, and [y_scale, x_scale] is
      the scaling factor, which is the ratio of
      scaled dimension / original dimension.
  """
    with tf.name_scope('resize_and_crop_image_v2'):
        image_size = tf.cast(tf.shape(image)[0:2], tf.float32)

        scale_using_short_side = (short_side /
                                  tf.minimum(image_size[0], image_size[1]))
        scale_using_long_side = (long_side /
                                 tf.maximum(image_size[0], image_size[1]))

        scaled_size = tf.round(image_size * scale_using_short_side)
        scaled_size = tf.where(
            tf.greater(tf.maximum(scaled_size[0], scaled_size[1]), long_side),
            tf.round(image_size * scale_using_long_side), scaled_size)
        desired_size = scaled_size

        random_jittering = (aug_scale_min != 1.0 or aug_scale_max != 1.0)

        if random_jittering:
            random_scale = tf.random_uniform([],
                                             aug_scale_min,
                                             aug_scale_max,
                                             seed=seed)
            scaled_size = tf.round(random_scale * scaled_size)

        # Computes 2D image_scale.
        image_scale = scaled_size / image_size

        # Selects non-zero random offset (x, y) if scaled image is larger than
        # desired_size.
        if random_jittering:
            max_offset = scaled_size - desired_size
            max_offset = tf.where(tf.less(max_offset, 0),
                                  tf.zeros_like(max_offset), max_offset)
            offset = max_offset * tf.random_uniform([
                2,
            ], 0, 1, seed=seed)
            offset = tf.cast(offset, tf.int32)
        else:
            offset = tf.zeros((2, ), tf.int32)

        scaled_image = tf.image.resize_images(image,
                                              tf.cast(scaled_size, tf.int32),
                                              method=method)

        if random_jittering:
            scaled_image = scaled_image[offset[0]:offset[0] + desired_size[0],
                                        offset[1]:offset[1] +
                                        desired_size[1], :]

        output_image = tf.image.pad_to_bounding_box(scaled_image, 0, 0,
                                                    padded_size[0],
                                                    padded_size[1])

        image_info = tf.stack([
            image_size, scaled_size, image_scale,
            tf.cast(offset, tf.float32)
        ])
        return output_image, image_info
Example #12
0
def main():
    print("Local rank: ", hvd.local_rank(), hvd.size())

    logdir = osp.join(FLAGS.logdir, FLAGS.exp)
    if hvd.rank() == 0:
        if not osp.exists(logdir):
            os.makedirs(logdir)
        logger = TensorBoardOutputFormat(logdir)
    else:
        logger = None

    LABEL = None
    print("Loading data...")
    if FLAGS.dataset == 'cifar10':
        dataset = Cifar10(augment=FLAGS.augment, rescale=FLAGS.rescale)
        test_dataset = Cifar10(train=False, rescale=FLAGS.rescale)
        channel_num = 3

        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        if FLAGS.large_model:
            model = ResNet32Large(num_channels=channel_num,
                                  num_filters=128,
                                  train=True)
        elif FLAGS.larger_model:
            model = ResNet32Larger(num_channels=channel_num, num_filters=128)
        elif FLAGS.wider_model:
            model = ResNet32Wider(num_channels=channel_num, num_filters=192)
        else:
            model = ResNet32(num_channels=channel_num, num_filters=128)

    elif FLAGS.dataset == 'imagenet':
        dataset = Imagenet(train=True)
        test_dataset = Imagenet(train=False)
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet32Wider(num_channels=channel_num, num_filters=256)

    elif FLAGS.dataset == 'imagenetfull':
        channel_num = 3
        X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 1000), dtype=tf.float32)

        model = ResNet128(num_channels=channel_num, num_filters=64)

    elif FLAGS.dataset == 'mnist':
        dataset = Mnist(rescale=FLAGS.rescale)
        test_dataset = dataset
        channel_num = 1
        X_NOISE = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 28, 28), dtype=tf.float32)
        LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
        LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)

        model = MnistNet(num_channels=channel_num,
                         num_filters=FLAGS.num_filters)

    elif FLAGS.dataset == 'dsprites':
        dataset = DSprites(cond_shape=FLAGS.cond_shape,
                           cond_size=FLAGS.cond_size,
                           cond_pos=FLAGS.cond_pos,
                           cond_rot=FLAGS.cond_rot)
        test_dataset = dataset
        channel_num = 1

        X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
        X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)

        if FLAGS.dpos_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.dsize_only:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.drot_only:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_size:
            LABEL = tf.placeholder(shape=(None, 1), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 1), dtype=tf.float32)
        elif FLAGS.cond_shape:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)
        elif FLAGS.cond_pos:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        elif FLAGS.cond_rot:
            LABEL = tf.placeholder(shape=(None, 2), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
        else:
            LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
            LABEL_POS = tf.placeholder(shape=(None, 3), dtype=tf.float32)

        model = DspritesNet(num_channels=channel_num,
                            num_filters=FLAGS.num_filters,
                            cond_size=FLAGS.cond_size,
                            cond_shape=FLAGS.cond_shape,
                            cond_pos=FLAGS.cond_pos,
                            cond_rot=FLAGS.cond_rot)

    print("Done loading...")

    if FLAGS.dataset == "imagenetfull":
        # In the case of full imagenet, use custom_tensorflow dataloader
        data_loader = TFImagenetLoader('train',
                                       FLAGS.batch_size,
                                       hvd.rank(),
                                       hvd.size(),
                                       rescale=FLAGS.rescale)
    else:
        data_loader = DataLoader(dataset,
                                 batch_size=FLAGS.batch_size,
                                 num_workers=FLAGS.data_workers,
                                 drop_last=True,
                                 shuffle=True)

    batch_size = FLAGS.batch_size

    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)

    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):

        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(tf.convert_to_tensor(np.reshape(
                np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                (FLAGS.batch_size * 10, 10)),
                                                            dtype=tf.float32),
                                       trainable=False,
                                       dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)),
                (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(x_split,
                                       weights[0],
                                       label=label_tensor,
                                       stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(energy_pos_full,
                                                       axis=1,
                                                       keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) -
                                     energy_partition_est,
                                     axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(X_SPLIT[j],
                              weights[0],
                              label=LABEL_POS_SPLIT[j],
                              stop_at_grad=False)
            ]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True)
        ])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(
                tf.shape(x_mod),
                mean=0.0,
                stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat([
                model.forward(x_mod,
                              weights[0],
                              label=LABEL_SPLIT[j],
                              reuse=True,
                              stop_at_grad=False,
                              stop_batch=True)
            ],
                                                    axis=0)

            x_grad, label_grad = tf.gradients(FLAGS.temperature * energy_noise,
                                              [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm,
                                              FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod,
                                    weights[0],
                                    label=LABEL_SPLIT[j],
                                    stop_at_grad=False,
                                    reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(tf.stop_gradient(x_mod),
                          weights[0],
                          label=LABEL_SPLIT[j],
                          stop_at_grad=False,
                          reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(x_mod,
                                    weights[0],
                                    reuse=True,
                                    label=LABEL,
                                    stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(
                label_prob * tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)

    saver = loader = tf.train.Saver(max_to_keep=30,
                                    keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    print("Initializing variables...")

    print("Start broadcast")
    print("End broadcast")

    if FLAGS.train:
        print("Training phase")
        train(target_vars, saver, sess, logger, data_loader, resume_itr,
              logdir)
    print("Testing phase")
    test(target_vars, saver, sess, logger, data_loader)
Example #13
0
def random_image_rotation(image, masks, boxes, keypoints, max_angle=45, probability=0.9):
    """
    What this function does:
    1. It takes a random box and rotates everything around its center.
    2. Then it rescales the image so that the box not too small or not too big.
    3. Then it translates the image's center to be at the box's center.

    All coordinates are absolute:
    1. Boxes have coordinates in ranges [0, height] and [0, width].
    2. Keypoints have coordinates in ranges [0, height - 1] and [0, width - 1].

    Arguments:
        image: a float tensor with shape [height, width, 3].
        masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE, 2].
        boxes: a float tensor with shape [num_persons, 4].
        keypoints: an int tensor with shape [num_persons, 17, 3].
        max_angle: an integer.
        probability: a float number.
    Returns:
        image: a float tensor with shape [height, width, 3].
        masks: a float tensor with shape [height / DOWNSAMPLE, width / DOWNSAMPLE, 2].
        boxes: a float tensor with shape [num_remaining_boxes, 4],
            where num_remaining_boxes <= num_persons.
        keypoints: an int tensor with shape [num_remaining_boxes, 17, 3].
    """
    def rotate(image, masks, boxes, keypoints):

        # get the center of the image
        image_shape = tf.to_float(tf.shape(image))
        image_height = image_shape[0]
        image_width = image_shape[1]
        image_center = 0.5 * tf.stack([image_height, image_width])
        image_center = tf.reshape(image_center, [1, 2])

        box_center, box_width = get_random_box_center(boxes, image_height, image_width)
        rotation = get_random_rotation(max_angle, box_center, image_width)
        scaler = get_random_scaling(box_center, box_width, image_width)

        rotation *= scaler
        translation = image_center - tf.matmul(box_center, rotation)

        """
        Assume tensor `points` has shape [n, 2].
        1. points = points - box_center (translate center of the coordinate system to the box center)
        2. points = points * rotation (rotate and scale relative to the new center)
        3. points = points + box_center (translate back)
        4. points = points - center_translation (translate image center to the box center)

        So full transformation is:
        (points - box_center) * rotation + box_center - center_translation =
        = points * rotation + translation, where translation = image_center - rotation * box_center.
        """

        boxes = transform_boxes(boxes, rotation, translation)
        keypoints = transform_keypoints(keypoints, rotation, translation)
        # after this some boxes and keypoints could be out of the image

        boxes, keypoints = correct(boxes, keypoints, image_height, image_width)
        # now all boxes and keypoints are inside the image

        transform = get_inverse_transform(rotation, translation)
        image = contrib.image.transform(image, transform, interpolation='BILINEAR')

        # masks are smaller than the image
        scaler = tf.stack([1, 1, DOWNSAMPLE, 1, 1, DOWNSAMPLE, 1, 1])
        masks_transform = transform / tf.to_float(scaler)

        masks = contrib.image.transform(masks, masks_transform, interpolation='NEAREST')
        # masks are binary so we use the nearest neighbor interpolation

        return image, masks, boxes, keypoints

    do_it = tf.less(tf.random_uniform([]), probability)
    image, masks, boxes, keypoints = tf.cond(
        do_it,
        lambda: rotate(image, masks, boxes, keypoints),
        lambda: (image, masks, boxes, keypoints)
    )
    return image, masks, boxes, keypoints
Example #14
0
def trilerp_gather(vol, inds, bad_inds=None):
  """Trilinear interpolation dense gather from volume at query inds."""

  inds_b = inds[Ellipsis, 0]
  inds_x = inds[Ellipsis, 1]
  inds_y = inds[Ellipsis, 2]
  inds_z = inds[Ellipsis, 3]

  inds_x_0 = tf.floor(inds_x)
  inds_x_1 = inds_x_0 + 1
  inds_y_0 = tf.floor(inds_y)
  inds_y_1 = inds_y_0 + 1
  inds_z_0 = tf.floor(inds_z)
  inds_z_1 = inds_z_0 + 1

  # store invalid indices to implement correct out-of-bounds conditions
  invalid_x = tf.logical_or(
      tf.less(inds_x_0, 0.0),
      tf.greater(inds_x_1, tf.to_float(tf.shape(vol)[2] - 1)))
  invalid_y = tf.logical_or(
      tf.less(inds_y_0, 0.0),
      tf.greater(inds_y_1, tf.to_float(tf.shape(vol)[1] - 1)))
  invalid_z = tf.logical_or(
      tf.less(inds_z_0, 0.0),
      tf.greater(inds_z_1, tf.to_float(tf.shape(vol)[3] - 1)))
  if bad_inds is not None:
    invalid_inds = tf.logical_or(
        tf.logical_or(tf.logical_or(invalid_x, invalid_y), invalid_z), bad_inds)
  else:
    invalid_inds = tf.logical_or(tf.logical_or(invalid_x, invalid_y), invalid_z)

  inds_x_0 = tf.clip_by_value(inds_x_0, 0.0, tf.to_float(tf.shape(vol)[2] - 2))
  inds_x_1 = tf.clip_by_value(inds_x_1, 0.0, tf.to_float(tf.shape(vol)[2] - 1))
  inds_y_0 = tf.clip_by_value(inds_y_0, 0.0, tf.to_float(tf.shape(vol)[1] - 2))
  inds_y_1 = tf.clip_by_value(inds_y_1, 0.0, tf.to_float(tf.shape(vol)[1] - 1))
  inds_z_0 = tf.clip_by_value(inds_z_0, 0.0, tf.to_float(tf.shape(vol)[3] - 2))
  inds_z_1 = tf.clip_by_value(inds_z_1, 0.0, tf.to_float(tf.shape(vol)[3] - 1))

  # compute interp weights
  w_x_0 = 1.0 - (inds_x - inds_x_0)
  w_x_1 = 1.0 - w_x_0
  w_y_0 = 1.0 - (inds_y - inds_y_0)
  w_y_1 = 1.0 - w_y_0
  w_z_0 = 1.0 - (inds_z - inds_z_0)
  w_z_1 = 1.0 - w_z_0

  w_0_0_0 = w_y_0 * w_x_0 * w_z_0
  w_1_0_0 = w_y_1 * w_x_0 * w_z_0
  w_0_1_0 = w_y_0 * w_x_1 * w_z_0
  w_0_0_1 = w_y_0 * w_x_0 * w_z_1
  w_1_1_0 = w_y_1 * w_x_1 * w_z_0
  w_0_1_1 = w_y_0 * w_x_1 * w_z_1
  w_1_0_1 = w_y_1 * w_x_0 * w_z_1
  w_1_1_1 = w_y_1 * w_x_1 * w_z_1

  # gather for interp
  inds_0_0_0 = tf.to_int32(
      tf.stack([inds_b, inds_y_0, inds_x_0, inds_z_0], axis=-1))
  inds_1_0_0 = tf.to_int32(
      tf.stack([inds_b, inds_y_1, inds_x_0, inds_z_0], axis=-1))
  inds_0_1_0 = tf.to_int32(
      tf.stack([inds_b, inds_y_0, inds_x_1, inds_z_0], axis=-1))
  inds_0_0_1 = tf.to_int32(
      tf.stack([inds_b, inds_y_0, inds_x_0, inds_z_1], axis=-1))
  inds_1_1_0 = tf.to_int32(
      tf.stack([inds_b, inds_y_1, inds_x_1, inds_z_0], axis=-1))
  inds_0_1_1 = tf.to_int32(
      tf.stack([inds_b, inds_y_0, inds_x_1, inds_z_1], axis=-1))
  inds_1_0_1 = tf.to_int32(
      tf.stack([inds_b, inds_y_1, inds_x_0, inds_z_1], axis=-1))
  inds_1_1_1 = tf.to_int32(
      tf.stack([inds_b, inds_y_1, inds_x_1, inds_z_1], axis=-1))

  vol_0_0_0 = tf.gather_nd(vol, inds_0_0_0) * w_0_0_0[Ellipsis, tf.newaxis]
  vol_1_0_0 = tf.gather_nd(vol, inds_1_0_0) * w_1_0_0[Ellipsis, tf.newaxis]
  vol_0_1_0 = tf.gather_nd(vol, inds_0_1_0) * w_0_1_0[Ellipsis, tf.newaxis]
  vol_0_0_1 = tf.gather_nd(vol, inds_0_0_1) * w_0_0_1[Ellipsis, tf.newaxis]
  vol_1_1_0 = tf.gather_nd(vol, inds_1_1_0) * w_1_1_0[Ellipsis, tf.newaxis]
  vol_0_1_1 = tf.gather_nd(vol, inds_0_1_1) * w_0_1_1[Ellipsis, tf.newaxis]
  vol_1_0_1 = tf.gather_nd(vol, inds_1_0_1) * w_1_0_1[Ellipsis, tf.newaxis]
  vol_1_1_1 = tf.gather_nd(vol, inds_1_1_1) * w_1_1_1[Ellipsis, tf.newaxis]

  out_vol = vol_0_0_0 + vol_1_0_0 + vol_0_1_0 + vol_0_0_1 + \
            vol_1_1_0 + vol_0_1_1 + vol_1_0_1 + vol_1_1_1

  # boundary conditions for invalid indices
  invalid_inds = tf.tile(invalid_inds[:, :, :, :, tf.newaxis],
                         [1, 1, 1, 1, tf.shape(vol)[4]])
  out_vol = tf.where(invalid_inds, tf.zeros_like(out_vol), out_vol)

  return out_vol
Example #15
0
def bilerp_gather(img, inds):
  """Bilinear interpolation dense gather from image at query inds."""

  inds_b, _, _, = tf.meshgrid(
      tf.range(tf.shape(img)[0]),
      tf.range(tf.shape(img)[1]),
      tf.range(tf.shape(img)[2]),
      indexing='ij')

  inds_b = tf.to_float(inds_b)
  inds_x = inds[Ellipsis, 0]
  inds_y = inds[Ellipsis, 1]

  inds_x_0 = tf.floor(inds_x)
  inds_x_1 = inds_x_0 + 1
  inds_y_0 = tf.floor(inds_y)
  inds_y_1 = inds_y_0 + 1

  # store invalid indices to implement correct out-of-bounds conditions
  invalid_x = tf.logical_or(
      tf.less(inds_x_0, 0.0),
      tf.greater(inds_x_1, tf.to_float(tf.shape(img)[2] - 1)))
  invalid_y = tf.logical_or(
      tf.less(inds_y_0, 0.0),
      tf.greater(inds_y_1, tf.to_float(tf.shape(img)[1] - 1)))
  invalid_inds = tf.logical_or(invalid_x, invalid_y)

  inds_x_0 = tf.clip_by_value(inds_x_0, 0.0, tf.to_float(tf.shape(img)[2] - 2))
  inds_x_1 = tf.clip_by_value(inds_x_1, 0.0, tf.to_float(tf.shape(img)[2] - 1))
  inds_y_0 = tf.clip_by_value(inds_y_0, 0.0, tf.to_float(tf.shape(img)[1] - 2))
  inds_y_1 = tf.clip_by_value(inds_y_1, 0.0, tf.to_float(tf.shape(img)[1] - 1))

  # compute interp weights
  w_x_0 = 1.0 - (inds_x - inds_x_0)
  w_x_1 = 1.0 - w_x_0
  w_y_0 = 1.0 - (inds_y - inds_y_0)
  w_y_1 = 1.0 - w_y_0

  w_0_0 = w_y_0 * w_x_0
  w_1_0 = w_y_1 * w_x_0
  w_0_1 = w_y_0 * w_x_1
  w_1_1 = w_y_1 * w_x_1

  # gather for interp
  inds_0_0 = tf.to_int32(tf.stack([inds_b, inds_y_0, inds_x_0], axis=-1))
  inds_1_0 = tf.to_int32(tf.stack([inds_b, inds_y_1, inds_x_0], axis=-1))
  inds_0_1 = tf.to_int32(tf.stack([inds_b, inds_y_0, inds_x_1], axis=-1))
  inds_1_1 = tf.to_int32(tf.stack([inds_b, inds_y_1, inds_x_1], axis=-1))

  img_0_0 = tf.gather_nd(img, inds_0_0) * w_0_0[Ellipsis, tf.newaxis]
  img_1_0 = tf.gather_nd(img, inds_1_0) * w_1_0[Ellipsis, tf.newaxis]
  img_0_1 = tf.gather_nd(img, inds_0_1) * w_0_1[Ellipsis, tf.newaxis]
  img_1_1 = tf.gather_nd(img, inds_1_1) * w_1_1[Ellipsis, tf.newaxis]

  out_img = img_0_0 + img_1_0 + img_0_1 + img_1_1

  # boundary conditions for invalid indices
  invalid_inds = tf.tile(invalid_inds[:, :, :, tf.newaxis],
                         [1, 1, 1, tf.shape(img)[3]])

  out_img = tf.where(invalid_inds, tf.zeros_like(out_img), out_img)

  return out_img
Example #16
0
def _generate_detections_per_image(boxes,
                                   scores,
                                   max_total_size=100,
                                   nms_iou_threshold=0.3,
                                   score_threshold=0.05,
                                   pre_nms_num_boxes=5000):
    """Generate the final detections per image given the model outputs.

  Args:
    boxes: a tensor with shape [N, num_classes, 4] or [N, 1, 4], which box
      predictions on all feature levels. The N is the number of total anchors on
      all levels.
    scores: a tensor with shape [N, num_classes], which stacks class probability
      on all feature levels. The N is the number of total anchors on all levels.
      The num_classes is the number of classes predicted by the model. Note that
      the class_outputs here is the raw score.
    max_total_size: a scalar representing maximum number of boxes retained over
      all classes.
    nms_iou_threshold: a float representing the threshold for deciding whether
      boxes overlap too much with respect to IOU.
    score_threshold: a float representing the threshold for deciding when to
      remove boxes based on score.
    pre_nms_num_boxes: an int number of top candidate detections per class
      before NMS.

  Returns:
    nmsed_boxes: `float` Tensor of shape [max_total_size, 4] representing top
      detected boxes in [y1, x1, y2, x2].
    nmsed_scores: `float` Tensor of shape [max_total_size] representing sorted
      confidence scores for detected boxes. The values are between [0, 1].
    nmsed_classes: `int` Tensor of shape [max_total_size] representing classes
      for detected boxes.
    valid_detections: `int` Tensor of shape [1] only the top `valid_detections`
      boxes are valid detections.
  """
    nmsed_boxes = []
    nmsed_scores = []
    nmsed_classes = []
    num_classes_for_box = boxes.get_shape().as_list()[1]
    num_classes = scores.get_shape().as_list()[1]
    for i in range(num_classes):
        boxes_i = boxes[:, min(num_classes_for_box - 1, i)]
        scores_i = scores[:, i]

        # Obtains pre_nms_num_boxes before running NMS.
        scores_i, indices = tf.nn.top_k(scores_i,
                                        k=tf.minimum(
                                            tf.shape(scores_i)[-1],
                                            pre_nms_num_boxes))
        boxes_i = tf.gather(boxes_i, indices)

        (nmsed_indices_i,
         nmsed_num_valid_i) = tf.image.non_max_suppression_padded(
             tf.cast(boxes_i, tf.float32),
             tf.cast(scores_i, tf.float32),
             max_total_size,
             iou_threshold=nms_iou_threshold,
             score_threshold=score_threshold,
             pad_to_max_output_size=True,
             name='nms_detections_' + str(i))
        nmsed_boxes_i = tf.gather(boxes_i, nmsed_indices_i)
        nmsed_scores_i = tf.gather(scores_i, nmsed_indices_i)
        # Sets scores of invalid boxes to -1.
        nmsed_scores_i = tf.where(
            tf.less(tf.range(max_total_size), [nmsed_num_valid_i]),
            nmsed_scores_i, -tf.ones_like(nmsed_scores_i))
        nmsed_classes_i = tf.fill([max_total_size], i)
        nmsed_boxes.append(nmsed_boxes_i)
        nmsed_scores.append(nmsed_scores_i)
        nmsed_classes.append(nmsed_classes_i)

    # Concats results from all classes and sort them.
    nmsed_boxes = tf.concat(nmsed_boxes, axis=0)
    nmsed_scores = tf.concat(nmsed_scores, axis=0)
    nmsed_classes = tf.concat(nmsed_classes, axis=0)
    nmsed_scores, indices = tf.nn.top_k(nmsed_scores,
                                        k=max_total_size,
                                        sorted=True)
    nmsed_boxes = tf.gather(nmsed_boxes, indices)
    nmsed_classes = tf.gather(nmsed_classes, indices)
    valid_detections = tf.reduce_sum(
        tf.cast(tf.greater(nmsed_scores, -1), tf.int32))
    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
def generate_detections_per_image_tpu(cls_outputs,
                                      box_outputs,
                                      anchor_boxes,
                                      image_info,
                                      pre_nms_num_detections=1000,
                                      post_nms_num_detections=100,
                                      nms_threshold=0.3,
                                      bbox_reg_weights=(10., 10., 5., 5.)):
    """Generate the final detections per image given the model outputs.

  Args:
    cls_outputs: a tensor with shape [N, num_classes], which stacks class
      logit outputs on all feature levels. The N is the number of total anchors
      on all levels. The num_classes is the number of classes predicted by the
      model. Note that the cls_outputs should be the output of softmax().
    box_outputs: a tensor with shape [N, num_classes*4], which stacks box
      regression outputs on all feature levels. The N is the number of total
      anchors on all levels.
    anchor_boxes: a tensor with shape [N, 4], which stacks anchors on all
      feature levels. The N is the number of total anchors on all levels.
    image_info: a tensor of shape [5] which encodes the input image's [height,
      width, scale, original_height, original_width]
    pre_nms_num_detections: an integer that specifies the number of candidates
      before NMS.
    post_nms_num_detections: an integer that specifies the number of candidates
      after NMS.
    nms_threshold: a float number to specify the IOU threshold of NMS.
    bbox_reg_weights: a list of 4 float scalars, which are default weights on
      (dx, dy, dw, dh) for normalizing bbox regression targets.

  Returns:
    detections: Tuple of tensors corresponding to number of valid boxes,
    box coordinates, object categories for each boxes, and box scores
    -- respectively.
  """
    num_boxes, num_classes = cls_outputs.get_shape().as_list()

    # Remove background class scores.
    cls_outputs = cls_outputs[:, 1:num_classes]
    top_k_scores, top_k_indices_with_classes = tf.nn.top_k(
        tf.reshape(cls_outputs, [-1]), k=pre_nms_num_detections, sorted=False)
    classes = tf.mod(top_k_indices_with_classes, num_classes - 1)
    top_k_indices = tf.floordiv(top_k_indices_with_classes, num_classes - 1)

    anchor_boxes = tf.gather(anchor_boxes, top_k_indices)
    box_outputs = tf.reshape(box_outputs,
                             [num_boxes, num_classes, 4])[:, 1:num_classes, :]
    class_indices = classes
    box_outputs = tf.gather_nd(
        box_outputs, tf.stack([top_k_indices, class_indices], axis=1))

    # apply bounding box regression to anchors
    boxes = box_utils.decode_boxes(box_outputs, anchor_boxes, bbox_reg_weights)
    boxes = box_utils.clip_boxes(boxes, image_info[0], image_info[1])

    list_of_all_boxes = []
    list_of_all_scores = []
    list_of_all_classes = []
    # Skip background class.
    for class_i in range(num_classes):
        # Compute bitmask for the given classes.
        class_i_bitmask = tf.cast(tf.equal(classes, class_i),
                                  top_k_scores.dtype)
        # This works because score is in [0, 1].
        class_i_scores = top_k_scores * class_i_bitmask
        # The TPU and CPU have different behaviors for
        # tf.image.non_max_suppression_padded (b/116754376).
        (class_i_post_nms_indices,
         class_i_nms_num_valid) = tf.image.non_max_suppression_padded(
             tf.to_float(boxes),
             tf.to_float(class_i_scores),
             post_nms_num_detections,
             iou_threshold=nms_threshold,
             score_threshold=0.05,
             pad_to_max_output_size=True,
             name='nms_detections_' + str(class_i))
        class_i_post_nms_boxes = tf.gather(boxes, class_i_post_nms_indices)
        class_i_post_nms_scores = tf.gather(class_i_scores,
                                            class_i_post_nms_indices)
        mask = tf.less(tf.range(post_nms_num_detections),
                       [class_i_nms_num_valid])
        class_i_post_nms_scores = tf.where(
            mask, class_i_post_nms_scores,
            tf.zeros_like(class_i_post_nms_scores))
        class_i_classes = tf.fill(tf.shape(class_i_post_nms_scores),
                                  class_i + 1)
        list_of_all_boxes.append(class_i_post_nms_boxes)
        list_of_all_scores.append(class_i_post_nms_scores)
        list_of_all_classes.append(class_i_classes)

    post_nms_boxes = tf.concat(list_of_all_boxes, axis=0)
    post_nms_scores = tf.concat(list_of_all_scores, axis=0)
    post_nms_classes = tf.concat(list_of_all_classes, axis=0)

    # sort all results.
    post_nms_scores, sorted_indices = tf.nn.top_k(tf.to_float(post_nms_scores),
                                                  k=post_nms_num_detections,
                                                  sorted=True)
    post_nms_boxes = tf.gather(post_nms_boxes, sorted_indices)
    post_nms_classes = tf.gather(post_nms_classes, sorted_indices)

    valid_mask = tf.where(tf.greater(post_nms_scores, 0),
                          tf.ones_like(post_nms_scores),
                          tf.zeros_like(post_nms_scores))
    num_valid_boxes = tf.reduce_sum(valid_mask, axis=-1)
    box_classes = tf.to_float(post_nms_classes)
    return num_valid_boxes, post_nms_boxes, box_classes, post_nms_scores
Example #18
0
def prepare_encoder_input(features,
                          hparams,
                          embed_scope=None,
                          embed_token_fn=common_embed.embed_tokens):
    """Prepares the input for the screen encoder.

  Args:
    features: the feature dict.
    hparams: the hyperparameter.
    embed_scope: the embedding variable scope.
    embed_token_fn: the function for embedding tokens.
  Returns:
    object_embedding: a Tensor of shape
        [batch_size, num_steps, max_object_count, embed_depth]
    object_mask: a binary tensor of shape
        [batch_size, num_steps, max_object_count]
    nonpadding_bias: a Tensor of shape
        [batch_size, num_steps, max_object_count]
  """
    with tf.control_dependencies(
        [tf.assert_equal(tf.rank(features["obj_text"]), 4)]):
        if hparams.get("synthetic_screen_noise", 0.) > 0.:
            num_objects = tf.shape(features["obj_text"])[2]
            # [batch, length, num_objects]
            target_obj_mask = tf.cast(
                tf.one_hot(features["objects"], depth=num_objects), tf.bool)
            num_tokens = tf.shape(features["obj_text"])[-1]
            target_obj_mask = tf.tile(tf.expand_dims(target_obj_mask, 3),
                                      [1, 1, 1, num_tokens])
            # Randomly keep tokens
            keep_mask = tf.greater_equal(
                tf.random_uniform(shape=tf.shape(features["obj_text"])),
                hparams.synthetic_screen_noise)
            # Keep paddings
            keep_mask = tf.logical_or(tf.equal(features["obj_text"], 0),
                                      keep_mask)
            # Keep targets
            target_obj_mask = tf.logical_or(target_obj_mask, keep_mask)
            features["obj_text"] = tf.where(
                target_obj_mask, features["obj_text"],
                tf.random_uniform(shape=tf.shape(features["obj_text"]),
                                  maxval=50000,
                                  dtype=tf.int32))
        text_embeddings, _ = embed_token_fn(features["obj_text"],
                                            hparams.task_vocab_size,
                                            hparams.hidden_size,
                                            hparams,
                                            embed_scope=embed_scope)
        with tf.variable_scope("obj_text_embed", reuse=tf.AUTO_REUSE):
            if hparams.obj_text_aggregation == "max":
                embed_bias = tf.cast(tf.less(features["obj_text"], 2),
                                     tf.float32) * -1e7
                with tf.control_dependencies(
                    [tf.assert_equal(tf.rank(embed_bias), 4)]):
                    text_embeddings = tf.reduce_max(
                        text_embeddings + tf.expand_dims(embed_bias, 4), -2)
                    no_txt_embed = tf.get_variable(name="no_txt_embed",
                                                   shape=[hparams.hidden_size])
                    shape = common_layers.shape_list(text_embeddings)
                    no_txt_embed = tf.tile(
                        tf.reshape(no_txt_embed,
                                   [1, 1, 1, hparams.hidden_size]),
                        [shape[0], shape[1], shape[2], 1])
                    text_embeddings = tf.maximum(text_embeddings, no_txt_embed)
            elif hparams.obj_text_aggregation == "sum":
                # [batch, step, #max_obj, #max_token]  0 for padded tokens
                real_objects = tf.cast(
                    tf.greater_equal(features["obj_text"], 2), tf.float32)
                # [batch, step, #max_obj, hidden]   0s for padded objects
                text_embeddings = tf.reduce_sum(
                    text_embeddings * tf.expand_dims(real_objects, 4), -2)
            elif hparams.obj_text_aggregation == "mean":
                shape_list = common_layers.shape_list(text_embeddings)
                embeddings = tf.reshape(text_embeddings, [-1] + shape_list[3:])
                emb_sum = tf.reduce_sum(tf.abs(embeddings), axis=-1)
                non_paddings = tf.not_equal(emb_sum, 0.0)
                embeddings = common_embed.average_bag_of_embeds(
                    embeddings,
                    non_paddings,
                    use_bigrams=True,
                    bigram_embed_scope=embed_scope,
                    append_start_end=True)
                text_embeddings = tf.reshape(
                    embeddings, shape_list[:3] + [hparams.hidden_size])
            else:
                raise ValueError("Unrecognized token aggregation %s" %
                                 (hparams.obj_text_aggregation))
    with tf.control_dependencies([
            tf.assert_equal(tf.rank(features["obj_type"]), 3),
            tf.assert_equal(tf.rank(features["obj_clickable"]), 3)
    ]):
        with tf.variable_scope("encode_object_attr", reuse=tf.AUTO_REUSE):
            type_embedding = tf.nn.embedding_lookup(params=tf.get_variable(
                name="embed_type_w",
                shape=[hparams.get("num_types", 100), hparams.hidden_size]),
                                                    ids=tf.maximum(
                                                        features["obj_type"],
                                                        0))
            clickable_embedding = tf.nn.embedding_lookup(
                params=tf.get_variable(name="embed_clickable_w",
                                       shape=[2, hparams.hidden_size]),
                ids=features["obj_clickable"])
    with tf.control_dependencies(
        [tf.assert_equal(tf.rank(features["obj_screen_pos"]), 4)]):

        def _create_embed(feature_name, vocab_size, depth):
            """Embed a position feature."""
            pos_embedding_list = []
            with tf.variable_scope("encode_object_" + feature_name,
                                   reuse=tf.AUTO_REUSE):
                num_featues = common_layers.shape_list(
                    features[feature_name])[-1]
                for i in range(num_featues):
                    pos_embedding_list.append(
                        tf.nn.embedding_lookup(
                            params=tf.get_variable(name=feature_name +
                                                   "_embed_w_%d" % i,
                                                   shape=[vocab_size, depth]),
                            ids=features[feature_name][:, :, :, i]))
                pos_embedding = tf.add_n(pos_embedding_list)
                return pos_embedding

        pos_embedding = _create_embed("obj_screen_pos", hparams.max_pixel_pos,
                                      hparams.hidden_size)
    if "all" == hparams.screen_embedding_feature or (
            "dom" in hparams.screen_embedding_feature):
        dom_embedding = _create_embed("obj_dom_pos", hparams.max_dom_pos,
                                      hparams.hidden_size)
    object_embed = tf.zeros_like(text_embeddings, dtype=tf.float32)
    if hparams.screen_embedding_feature == "all":
        object_embed = (text_embeddings + type_embedding + pos_embedding +
                        dom_embedding)
    elif "text" in hparams.screen_embedding_feature:
        object_embed += text_embeddings
    elif "type" in hparams.screen_embedding_feature:
        object_embed += type_embedding
    elif "pos" in hparams.screen_embedding_feature:
        object_embed += pos_embedding
    elif "dom" in hparams.screen_embedding_feature:
        object_embed += dom_embedding
    elif "click" in hparams.screen_embedding_feature:
        object_embed += clickable_embedding
    object_mask = tf.cast(tf.not_equal(features["obj_type"], -1), tf.float32)
    object_embed = object_embed * tf.expand_dims(object_mask, 3)
    att_bias = (1. - object_mask) * common_attention.large_compatible_negative(
        object_embed.dtype)
    return object_embed, object_mask, att_bias
Example #19
0
 def loop_cond(idx, _):
     return tf.less(idx, tf.constant(inner_steps, dtype=tf.int32))
Example #20
0
 def sample_predicate(i, unused_sample_ta):
     return tf.less(i, self.data_dim)
def resize_crop_pad(image,
                    desired_output_size,
                    stride,
                    aug_scale_min=1.0,
                    aug_scale_max=1.0,
                    boxes=None,
                    classes=None,
                    masks=None,
                    crop_mask_size=112):
  """Resize, crop and pad images, boxes and masks (RetinaNet style).

  Resize, crop and pad images, (optionally boxes and masks) given the desired
  output size of the image and the stride size.

  Here are the preprocessing steps.
  1. For a given image, keep its aspect ratio and rescale the image to make it
     the largest rectangle to be bounded by the rectangle specified by the
     `desired_output_size`.
  2. Pad the rescaled image such that the height and width of the image become
     the smallest multiple of the stride that is larger or equal to the desired
     output diemension.

  Args:
    image: an image tensor of shape [original_height, original_width, 3].
    desired_output_size: a tuple of two integers indicating the desired output
      image size. Note that the actual output size could be different from this.
    stride: the stride of the backbone network. Each of the output image sides
      must be the multiple of this.
    aug_scale_min: a `float` with range between [0, 1.0] representing minimum
      random scale applied to desired_size for training scale jittering.
    aug_scale_max: a `float` with range between [1.0, inf] representing maximum
      random scale applied to desired_size for training scale jittering.
    boxes: (Optional) a tensor of shape [num_boxes, 4] represneting the box
      corners in normalized coordinates.
    classes: (Optional) a tensor of shape [num_boxes] representing the box
      classes.
    masks: (Optional) a tensor of shape [num_boxes, image_height, image_width]
      representing the instance masks which have the same shape as the input
      image.
    crop_mask_size: an integer indicating the size of the cropped mask.

  Returns:
    image: the processed image tensor after being resized and padded.
    image_info: a tensor of shape [5] which encodes the height, width before
      and after resizing and the scaling factor.
    boxes: None or the processed box tensor after being resized and padded.
      After the processing, boxes will be in the absolute coordinates w.r.t.
      the scaled image.
    classes: None or the processed class tensor after boxes being resized and
      filtered.
    masks: None or the processed mask tensor after being resized.
  """
  if boxes is not None:
    assert classes is not None

  input_shape = tf.shape(image)
  input_height = tf.cast(input_shape[0], dtype=tf.float32)
  input_width = tf.cast(input_shape[1], dtype=tf.float32)
  desired_height, desired_width = desired_output_size

  # Find the scale factor such that the scaled image is surrounded by the
  # rectangle of shape of desired_output_size.
  scale_if_resize_height = desired_height / input_height
  scale_if_resize_width = desired_width / input_width
  scale = tf.minimum(scale_if_resize_height, scale_if_resize_width)
  desired_scaled_height = scale * input_height
  desired_scaled_width = scale * input_width
  desired_scaled_size = tf.stack(
      [desired_scaled_height, desired_scaled_width], axis=0)

  random_jittering = aug_scale_min != 1.0 or aug_scale_max != 1.0

  if random_jittering:
    random_scale = tf.random_uniform([], aug_scale_min, aug_scale_max)
    scale = random_scale * scale
    scaled_size = tf.round(random_scale * desired_scaled_size)
  else:
    scaled_size = desired_scaled_size
  scaled_size_int = tf.cast(scaled_size, dtype=tf.int32)
  desired_scaled_size_int = tf.cast(desired_scaled_size, dtype=tf.int32)

  image = tf.image.resize_images(
      image,
      scaled_size_int,
      method=tf.image.ResizeMethod.BILINEAR)

  if boxes is not None:
    normalized_boxes = boxes
    # Convert the normalized coordinates to the coordinates w.r.t.
    # the scaled image.
    boxes = boxes * tf.tile(tf.expand_dims(scaled_size, axis=0), [1, 2])

    if masks is not None and not random_jittering:
      num_instances = tf.shape(boxes)[0]
      masks = tf.image.crop_and_resize(
          image=tf.expand_dims(masks, axis=-1),
          boxes=normalized_boxes,
          box_indices=tf.range(num_instances, dtype=tf.int32),
          crop_size=[crop_mask_size, crop_mask_size],
          method='bilinear')
      masks = tf.squeeze(masks, axis=-1)

  if random_jittering:
    max_offset = scaled_size - desired_scaled_size
    max_offset = tf.where(
        tf.less(max_offset, 0), tf.zeros_like(max_offset), max_offset)
    offset = tf.cast(
        max_offset * tf.random_uniform((2,), 0, 1), dtype=tf.int32)

    image = image[
        offset[0]:offset[0] + desired_scaled_size_int[0],
        offset[1]:offset[1] + desired_scaled_size_int[1],
        :]

    if boxes is not None:
      box_offsets = tf.cast(
          tf.tile(tf.expand_dims(offset, axis=0), [1, 2]),
          dtype=tf.float32)
      boxes -= box_offsets
      boxes = box_utils.clip_boxes(
          boxes, desired_scaled_size_int[0], desired_scaled_size_int[1])
      indices = tf.where(tf.logical_and(
          tf.greater(boxes[:, 2] - boxes[:, 0], 0),
          tf.greater(boxes[:, 3] - boxes[:, 1], 0)))[:, 0]
      boxes = tf.gather(boxes, indices)
      classes = tf.gather(classes, indices)
      if masks is not None:
        masks = tf.gather(masks, indices)

        # Convert the processed boxes back to the normalized coordinates w.r.t.
        # the original image in order to crop and resize the instance masks.
        cropped_boxes = boxes + box_offsets
        cropped_boxes /= tf.tile(tf.expand_dims(scaled_size, axis=0), [1, 2])

        num_instances = tf.shape(boxes)[0]
        masks = tf.image.crop_and_resize(
            image=tf.expand_dims(masks, axis=-1),
            boxes=cropped_boxes,
            box_indices=tf.range(num_instances, dtype=tf.int32),
            crop_size=[crop_mask_size, crop_mask_size],
            method='bilinear')
        masks = tf.squeeze(masks, axis=-1)

  # Pad image such that its height and width are the closest multiple of stride.
  padded_height = int(math.ceil(desired_height * 1.0 / stride) * stride)
  padded_width = int(math.ceil(desired_width * 1.0 / stride) * stride)
  image = tf.image.pad_to_bounding_box(
      image, 0, 0, padded_height, padded_width)
  image.set_shape([padded_height, padded_width, 3])

  # desired_scaled_size is the actual image size. Pixels beyond this are from
  # padding.
  image_info = tf.stack([
      desired_scaled_size[0],
      desired_scaled_size[1],
      1.0 / scale,
      input_height,
      input_width])

  return image, image_info, boxes, classes, masks
Example #22
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """Auto-encoder using a Transformer decoder and a prior over latent sequences.

  Args:
    inputs: Tensor of shape [batch, length, 1, hparams.hidden_size] or None.
    targets: Tensor of shape [batch, ..., channels]. Ellipses may be 1 or 2
      dimensions denoting sequence length.
    target_space: int. Used for encoding inputs under a target space id.
    hparams: HParams.
    cache: Tensor of shape [batch, length] or None.
    predict_mask: Tensor masking whether to use gold targets or predictions.

  Returns:
    decoder_output: Tensor of shape [batch, ..., hparams.hidden_size] presenting
      pre-logit activations. After a transformation (`top` in `T2TModel`), it is
      used with targets to compute the "training" (reconstruction) loss.
    losses: dict of str to Tensors. There are three loss terms: "extra",
      "extra_loss", and "latent_pred". The first is hard-coded to 0. The latter
      two are Tensors of shape [batch].
    cache: Tensor of shape [batch, length], either the same as cache, or newly
      computed if the cache input is None.
  """
    original_targets_shape = common_layers.shape_list(targets)
    batch_size = original_targets_shape[0]
    if len(original_targets_shape) == 4:
        compress_fn = compress_encoder_2d
        decompress_fn = decompress_decoder_2d
    else:
        compress_fn = compress_encoder_1d
        decompress_fn = decompress_decoder_1d

    ed_attention_bias = None
    if inputs is not None:
        inputs, ed_attention_bias = transformer_text_encoder(
            inputs, target_space, hparams, name="input_encoder")

    losses = {"extra": 0., "extra_loss": 0., "latent_pred": 0.}
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        targets_compressed = compress_fn(targets, hparams, name="compress")

        if hparams.mode == tf.estimator.ModeKeys.TRAIN:
            scale = common_layers.inverse_exp_decay(hparams.startup_steps)
        else:
            scale = 1.0
        scale = tf.to_float(tf.less(tf.random_uniform([batch_size]), scale))

        latents_dense, latents_discrete, extra_loss, _ = bottleneck_layer(
            targets_compressed, hparams)
        extra_loss = scale * tf.reduce_mean(extra_loss)

        _, latents_pred_loss = latent_prediction_model(inputs,
                                                       ed_attention_bias,
                                                       latents_discrete,
                                                       latents_dense,
                                                       hparams,
                                                       name="latent_pred")
        latent_time = tf.less(hparams.mask_startup_steps,
                              tf.to_int32(tf.train.get_global_step()))
        latents_pred_loss = scale * tf.reduce_mean(latents_pred_loss)
        latents_pred_loss *= tf.to_float(latent_time)

        # Apply dropout noise for each data point and time step.
        latents_dense_shape = common_layers.shape_list(latents_dense)
        latents_dense = tf.nn.dropout(
            latents_dense,
            keep_prob=1 - hparams.latent_dropout,
            noise_shape=[latents_dense_shape[0], latents_dense_shape[1], 1])

        # TODO(trandustin): Can we combine extra and extra_loss?
        losses = {
            "extra": 0.,
            "extra_loss": extra_loss,
            "latent_pred": latents_pred_loss
        }
    else:
        # Set the latent length, which is num_latents times the number of latent
        # pixels. The number of latent pixels is determined by a compression factor
        # on the number of image pixels.
        latent_len = (
            (hparams.img_len * hparams.img_len * hparams.num_latents) /
            (2**hparams.num_compress_steps))
        _, _, _, embed_fn = bottleneck_layer(targets_compressed, hparams)
        latents_dense = tf.zeros(
            [batch_size, latent_len, 1, hparams.hidden_size])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs,
                                          ed_attention_bias, embed_fn, hparams)
        cache_one_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
        latents_dense = embed_fn(cache_one_hot, hparams.hidden_size)

    if len(original_targets_shape) == 4:
        compressed_img_len = (hparams.img_len //
                              2**(hparams.num_compress_steps // 2))
        latents_dense = tf.reshape(latents_dense, [
            batch_size, compressed_img_len, compressed_img_len,
            hparams.num_latents * hparams.hidden_size
        ])

    latents_dense = decompress_fn(latents_dense, hparams, name="decompress")
    latents_dense = tf.reshape(
        latents_dense,
        [-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

    if hparams.use_gold_targets:
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            masking = predict_mask
        else:
            masking = common_layers.inverse_exp_decay(
                hparams.mask_startup_steps)
        targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)
        mask = tf.less(
            masking, tf.random_uniform(common_layers.shape_list(targets)[:-1]))
        mask = tf.expand_dims(tf.to_float(mask), 2)
        latents_dense = mask * targets + (1.0 - mask) * latents_dense

    latents_dense = tf.reshape(latents_dense, original_targets_shape)
    if hparams.decode_autoregressive:
        decoder_output = transformer_image_decoder(latents_dense,
                                                   inputs,
                                                   ed_attention_bias,
                                                   hparams,
                                                   name="decoder")
    else:
        decoder_output = latents_dense
    return decoder_output, losses, cache
Example #23
0
 def continue_optimization(t, mean, var, best_val, best_sol):
     return tf.logical_and(tf.less(t, self.max_iters), tf.reduce_max(var) > self.epsilon)
Example #24
0
 def cond(i, old_adv_x, old_loss):
     del old_adv_x, old_loss
     return tf.less(i, self.num_batches)
Example #25
0
 def is_nonzero_chunk(example):
     """A chunk is zero if all targets are 0s."""
     return tf.less(0, tf.reduce_sum(tf.abs(example["targets"])))
Example #26
0
 def lt(self, x, y):
     return tf.less(x, y)
Example #27
0
 def _condition(step, *unused_args):
     return tf.less(step, num_steps)
Example #28
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {
        "extra": tf.constant(0.0),
        "latent_pred": tf.constant(0.0),
        "neg_q_entropy": tf.constant(0.0)
    }
    if hparams.do_ae:
        # flatten here
        original_targets = targets
        original_targets_shape = tf.shape(original_targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            if inputs is not None:
                max_targets_len_from_inputs = tf.concat([inputs, inputs],
                                                        axis=1)
            else:
                max_targets_len_from_inputs = targets
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        if hparams.word_shuffle:
            tf.logging.info("Using word shuffle with rate = {}".format(
                hparams.word_shuffle))
            targets_idx = tf.range(start=0,
                                   limit=common_layers.shape_list(targets)[1],
                                   delta=1)
            targets_idx = tf.to_float(targets_idx)
            noise = tf.random_uniform(
                shape=common_layers.shape_list(targets_idx),
                minval=0,
                maxval=1 + hparams.word_shuffle)
            targets_idx += noise
            permutation = contrib.framework().argsort(targets_idx)
            targets_permuted = tf.gather(targets, indices=permutation, axis=1)
            targets = targets_permuted
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        # Add positional information
        targets_shape = common_layers.shape_list(targets)
        targets = tf.reshape(
            targets, [targets_shape[0], targets_shape[1], targets_shape[3]])
        targets = common_attention.add_positional_embedding(
            targets, hparams.max_length, name="targets_position")
        targets = tf.reshape(targets, shape=targets_shape)
        if hparams.word_dropout:
            mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                                     minval=0.0,
                                     maxval=1.0)
            targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                                     tf.zeros_like(targets))
        else:
            targets_noisy = targets

        targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = (
                hparams.bottleneck(inputs=targets_c,
                                   filter_size=hparams.compress_filter_size,
                                   mode=hparams.mode,
                                   name="vc"))
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)

                # Scale by latent dimension for summary so we can compare across
                # batches.
                if _DO_SUMMARIES:
                    tf.summary.scalar("latent_pred_loss_mean",
                                      tf.reduce_mean(latent_pred_loss))
                if hparams.sum_over_latents:
                    latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2])

                losses["latent_pred"] = tf.reduce_mean(
                    latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale
                losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    tf.squared_difference(inputs_c, targets_c)) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _, _ = hparams.bottleneck(
                            inputs=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            mode=hparams.mode,
                            name="vc")
                    return bn

                inputs_c = bn_inputs()
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _, _ = hparams.bottleneck(
                    inputs=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    mode=hparams.mode,
                    name="vc")
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed, _ = hparams.bottleneck(
                    inputs=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        d_shape = common_layers.shape_list(d)
        d = tf.reshape(d, [d_shape[0], d_shape[1], d_shape[3]])
        d = common_attention.add_positional_embedding(d,
                                                      hparams.max_length,
                                                      name="latents_position")
        d = tf.reshape(d, shape=d_shape)

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if inputs is not None and hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        else:
            targets = d

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)

    # res was generated from padded targets, which means it has some extra
    # elements. These can cause shape problems when computing loss with respect to
    # the original (unpadded) targets. So we remove their extra elements here.
    res = res[:, :original_targets_shape[1], :, :]

    data_dim = common_layers.shape_list(res)[1]
    latent_dim = common_layers.shape_list(targets_c)[1]
    return res, losses, cache, data_dim, latent_dim
Example #29
0
    def _update_mask(self, weights, threshold, gradients):  # pylint: disable=unused-argument
        """Updates the mask for a given weight tensor.

    This functions first computes the cdf of the weight tensor, and estimates
    the threshold value such that 'desired_sparsity' fraction of weights
    have magnitude less than the threshold.

    Args:
      weights: The weight tensor that needs to be masked.
      threshold: The current threshold value. The function will compute a new
        threshold and return the exponential moving average using the current
        value of threshold
      gradients: The gradient tensor that is used for salience calculation.

    Returns:
      new_threshold: The new value of the threshold based on weights, and
        sparsity at the current global_step
      new_mask: A numpy array of the same size and shape as weights containing
        0 or 1 to indicate which of the values in weights falls below
        the threshold

    Raises:
      ValueError: if sparsity is not defined
    """
        if self._sparsity is None:
            raise ValueError('Sparsity variable undefined')

        sparsity = self._get_sparsity(weights.op.name)
        with tf.name_scope(weights.op.name + '_pruning_ops'):
            tf.logging.info('Applying option %s pruning',
                            self._spec.prune_option)
            if self._spec.prune_option == 'weight':
                abs_weights = tf.abs(weights)
            elif self._spec.prune_option in ('first_order_gradient',
                                             'second_order_gradient'):
                if gradients is None:
                    raise ValueError('gradient tensor cannot be None.')
                # gradient variable stores absolute value already
                abs_weights = tf.multiply(tf.abs(weights), gradients)
            else:
                raise ValueError('undefined option')

            k = tf.cast(
                tf.round(
                    tf.cast(tf.size(abs_weights), tf.float32) *
                    (1 - sparsity)), tf.int32)

            # Generate a random shuffling of the weights s.t. the tie-breaker on
            # weight magnitude is random uniform.
            shuffling = tf.random_shuffle(tf.range(tf.size(abs_weights)))
            shuffling = tf.reshape(shuffling, [-1, 1])

            # Flatten the weights and scatter the values randomly.
            abs_weights = tf.reshape(abs_weights, [-1])
            abs_weights = tf.scatter_nd(shuffling, abs_weights,
                                        tf.shape(abs_weights))

            # Sort the entire array
            _, indices = tf.nn.top_k(abs_weights, k=tf.size(abs_weights))

            # `k` is how many non-zero weights we're going to have. Create a new
            # mask where the first `k` elements are set to one and all others are
            # set to zero.
            mask_staging = tf.range(tf.size(abs_weights))
            mask_staging = tf.cast(tf.less(mask_staging, k), tf.float32)

            # Scatter the mask back into the proper positions for the weight matrix.
            indices = tf.reshape(indices, [-1, 1])
            new_mask = tf.scatter_nd(indices, mask_staging,
                                     tf.shape(mask_staging))

            # Un-shuffle the newly created mask.
            new_mask = tf.reshape(tf.gather_nd(new_mask, shuffling),
                                  tf.shape(weights))
        return tf.constant(0, tf.float32), new_mask
Example #30
0
def maybe_gen_fake_data_based_on_real_data(image, label, reso,
                                           min_fake_lesion_ratio,
                                           gen_fake_probability):
    """Remove real lesion and synthesize lesion."""
    # TODO(lehou): Replace magic numbers with flag variables.
    gen_prob_indicator = tf.random_uniform(shape=[],
                                           minval=0.0,
                                           maxval=1.0,
                                           dtype=tf.float32)

    background_mask = tf.less(label, 0.5)
    lesion_mask = tf.greater(label, 1.5)
    liver_mask = tf.logical_not(tf.logical_or(background_mask, lesion_mask))

    liver_intensity = tf.boolean_mask(image, liver_mask)
    lesion_intensity = tf.boolean_mask(image, lesion_mask)

    intensity_diff = tf.reduce_mean(liver_intensity) - (
        tf.reduce_mean(lesion_intensity))
    intensity_diff *= 1.15
    intensity_diff = tf.cond(tf.is_nan(intensity_diff), lambda: 0.0,
                             lambda: intensity_diff)

    lesion_liver_ratio = 0.0
    lesion_liver_ratio += tf.random.normal(shape=[], mean=0.01, stddev=0.01)
    lesion_liver_ratio += tf.random.normal(shape=[], mean=0.0, stddev=0.05)
    lesion_liver_ratio = tf.clip_by_value(lesion_liver_ratio,
                                          min_fake_lesion_ratio,
                                          min_fake_lesion_ratio + 0.20)

    fake_lesion_mask = tf.logical_and(
        _gen_rand_mask(ratio_mean=lesion_liver_ratio,
                       ratio_stddev=0.0,
                       scale=reso // 32,
                       shape=label.shape,
                       smoothness=reso // 32), tf.logical_not(background_mask))
    liver_mask = tf.logical_not(
        tf.logical_or(background_mask, fake_lesion_mask))

    # Blur the masks
    lesion_mask_blur = tf.squeeze(
        tf.nn.conv3d(tf.expand_dims(
            tf.expand_dims(tf.cast(lesion_mask, tf.float32), -1), 0),
                     filter=tf.ones([reso // 32] * 3 + [1, 1], tf.float32) /
                     (reso // 32)**3,
                     strides=[1, 1, 1, 1, 1],
                     padding='SAME'))
    fake_lesion_mask_blur = tf.squeeze(
        tf.nn.conv3d(tf.expand_dims(
            tf.expand_dims(tf.cast(fake_lesion_mask, tf.float32), -1), 0),
                     filter=tf.ones([reso // 32] * 3 + [1, 1], tf.float32) /
                     (reso // 32)**3,
                     strides=[1, 1, 1, 1, 1],
                     padding='SAME'))

    # Remove real lesion and add fake lesion.
    # If the intensitify is too small (maybe no liver or lesion region labeled),
    # do not generate fake data.
    gen_prob_indicator = tf.cond(tf.greater(intensity_diff, 0.0001),
                                 lambda: gen_prob_indicator, lambda: 0.0)
    # pylint: disable=g-long-lambda
    image = tf.cond(
        tf.greater(gen_prob_indicator, 1 - gen_fake_probability),
        lambda: image + intensity_diff * lesion_mask_blur \
                      - intensity_diff * fake_lesion_mask_blur,
        lambda: image)
    label = tf.cond(
        tf.greater(gen_prob_indicator, 1 - gen_fake_probability),
        lambda: tf.cast(background_mask, tf.float32) * 0 + \
            tf.cast(liver_mask, tf.float32) * 1 + \
            tf.cast(fake_lesion_mask, tf.float32) * 2,
        lambda: label)
    # pylint: enable=g-long-lambda

    return image, label