Example #1
0
 def update_masks():
     """check whether all pruning conditions are met before pruning."""
     with tf.name_scope(self._spec.name):
         is_step_within_pruning_range = tf.logical_and(
             tf.greater_equal(self._global_step,
                              self._spec.begin_pruning_step),
             # If end_pruning_step is negative, keep pruning forever!
             tf.logical_or(
                 tf.less_equal(self._global_step,
                               self._spec.end_pruning_step),
                 tf.less(self._spec.end_pruning_step, 0)))
         is_pruning_step = tf.less_equal(
             tf.add(self._last_update_step,
                    self._spec.pruning_frequency), self._global_step)
         return tf.logical_and(is_step_within_pruning_range,
                               is_pruning_step)
Example #2
0
 def _rpn_score_loss(self, score_outputs, score_targets, normalizer=1.0):
   """Computes score loss."""
   # score_targets has three values:
   # (1) score_targets[i]=1, the anchor is a positive sample.
   # (2) score_targets[i]=0, negative.
   # (3) score_targets[i]=-1, the anchor is don't care (ignore).
   with tf.name_scope('rpn_score_loss'):
     mask = tf.logical_or(tf.equal(score_targets, 1),
                          tf.equal(score_targets, 0))
     score_targets = tf.maximum(score_targets, tf.zeros_like(score_targets))
     # RPN score loss is sum over all except ignored samples.
     score_loss = tf.losses.sigmoid_cross_entropy(
         score_targets, score_outputs, weights=mask,
         reduction=tf.losses.Reduction.SUM)
     score_loss /= normalizer
     return score_loss
Example #3
0
 def while_body(t, z, accept):
   """Truncated rejection sampling."""
   new_z = self.proposal.sample(num_samples)
   accept_prob = tf.squeeze(tf.exp(self.accept_fn(new_z - self.data_mean)),
                            axis=-1)
   new_accept = tf.math.less_equal(
       tf.random_uniform(shape=[num_samples], minval=0., maxval=1.),
       accept_prob)
   force_accept = tf.math.greater_equal(
       tf.to_float(t),
       tf.to_float(self.T) - 1.)
   new_accept = tf.math.logical_or(new_accept, force_accept)
   accepted = tf.logical_or(accept, new_accept)
   swap = tf.math.logical_and(tf.math.logical_not(accept), new_accept)
   z = tf.where(swap, new_z, z)
   return t + 1, z, accepted
def maybe_split_sequence_lengths(sequence_length, num_splits, total_length):
    """Validates and splits `sequence_length`, if necessary.
  Returned value must be used in graph for all validations to be executed.
  Args:
    sequence_length: A batch of sequence lengths, either sized `[batch_size]`
      and equal to either 0 or `total_length`, or sized
      `[batch_size, num_splits]`.
    num_splits: The scalar number of splits of the full sequences.
    total_length: The scalar total sequence length (potentially padded).
  Returns:
    sequence_length: If input shape was `[batch_size, num_splits]`, returns the
      same Tensor. Otherwise, returns a Tensor of that shape with each input
      length in the batch divided by `num_splits`.
  Raises:
    ValueError: If `sequence_length` is not shaped `[batch_size]` or
      `[batch_size, num_splits]`.
    tf.errors.InvalidArgumentError: If `sequence_length` is shaped
      `[batch_size]` and all values are not either 0 or `total_length`.
  """
    if sequence_length.shape.ndims == 1:
        if total_length % num_splits != 0:
            raise ValueError(
                '`total_length` must be evenly divisible by `num_splits`.')
        with tf.control_dependencies([
                tf.Assert(tf.reduce_all(
                    tf.logical_or(tf.equal(sequence_length, 0),
                                  tf.equal(sequence_length, total_length))),
                          data=[sequence_length])
        ]):
            sequence_length = (tf.tile(tf.expand_dims(sequence_length, axis=1),
                                       [1, num_splits]) // num_splits)
    elif sequence_length.shape.ndims == 2:
        with tf.control_dependencies([
                tf.assert_less_equal(
                    sequence_length,
                    tf.constant(total_length // num_splits, tf.int32),
                    message='Segment length cannot be more than '
                    '`total_length / num_splits`.')
        ]):
            sequence_length = tf.identity(sequence_length)
        sequence_length.set_shape([sequence_length.shape[0], num_splits])
    else:
        raise ValueError(
            'Sequence lengths must be given as a vector or a 2D Tensor whose '
            'second dimension size matches its initial hierarchical split. Got '
            'shape: %s' % sequence_length.shape.as_list())
    return sequence_length
def _distorted_crop_window(image_shape,
                           min_object_covered=0.1,
                           aspect_ratio_range=(0.75, 1.33),
                           area_range=(0.08, 1.0),
                           max_attempts=100):
    """Computes a sampled distorted crop window from an input image shape.

  Calls into `tf.image.sample_distorted_bounding_box`, using the entire image as
  the bounding box. This can theoretically fail, in which case, we fall back to
  a deterministic center square crop.

  Args:
    image_shape: The shape of the image, expressed as a Tensor of shape [3], an
      iterable of length 3, or a tf.Shape with rank 3.
    min_object_covered: See `tf.image.sample_distorted_bounding_box`.
    aspect_ratio_range: See `tf.image.sample_distorted_bounding_box`.
    area_range: See `tf.image.sample_distorted_bounding_box`.
    max_attempts: See `tf.image.sample_distorted_bounding_box`.

  Returns:
    A Tensor of shape [6], representing the crop box in the format
    [offset_height, offset_width, offset_channel, crop_dim, crop_dim, channels].
    `offset_channel` is always 0.
  """
    with tf.name_scope('distorted_crop_window'):
        sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
            image_shape,
            bounding_boxes=tf.zeros(shape=[1, 0, 4]),
            min_object_covered=min_object_covered,
            aspect_ratio_range=aspect_ratio_range,
            area_range=area_range,
            max_attempts=max_attempts,
            use_image_if_no_bounding_boxes=True)
        bbox_begin, bbox_size, _ = sample_distorted_bounding_box
        offset_y, offset_x, _ = tf.unstack(bbox_begin)
        target_height, target_width, _ = tf.unstack(bbox_size)
        crop_window_params = [
            offset_y, offset_x, 0, target_height, target_width, image_shape[2]
        ]
        # sample_distorted_bounding_box can fail, in which case it returns the input
        # image dimensions. In case of failure, fall back to central crop.
        success = tf.logical_or(tf.not_equal(target_height, image_shape[0]),
                                tf.not_equal(target_width, image_shape[1]))
        crop_window = tf.cond(
            success, lambda: tf.stack(crop_window_params),
            lambda: _center_crop_window(image_shape, crop_frac=1.))
        return crop_window
    def build_infer_graph(self, FLAGS, batch_data, bbox=None, name='val'):
        """
        """
        if FLAGS.guided:
            batch_data, edge = batch_data
            edge = edge[:, :, :, 0:1] / 255.
            edge = tf.cast(edge > FLAGS.edge_threshold, tf.float32)
        regular_mask = bbox2mask(FLAGS, bbox, name='mask_c')
        irregular_mask = brush_stroke_mask(FLAGS, name='mask_c')
        mask = tf.cast(
            tf.logical_or(
                tf.cast(irregular_mask, tf.bool),
                tf.cast(regular_mask, tf.bool),
            ),
            tf.float32
        )

        batch_pos = batch_data / 127.5 - 1.
        batch_incomplete = batch_pos*(1.-mask)
        if FLAGS.guided:
            edge = edge * mask
            xin = tf.concat([batch_incomplete, edge], axis=3)
        else:
            xin = batch_incomplete
        # inpaint
        x1, x2, offset_flow = self.build_inpaint_net(
            xin, mask, reuse=True,
            training=False, padding=FLAGS.padding)
        batch_predicted = x2
        # apply mask and reconstruct
        batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask)
        # global image visualization
        if FLAGS.guided:
            viz_img = [
                batch_pos,
                batch_incomplete + edge,
                batch_complete]
        else:
            viz_img = [batch_pos, batch_incomplete, batch_complete]
        if offset_flow is not None:
            viz_img.append(
                resize(offset_flow, scale=4,
                       func=tf.compat.v1.image.resize_bilinear))
        images_summary(
            tf.concat(viz_img, axis=2),
            name+'_raw_incomplete_complete', FLAGS.viz_max_out)
        return batch_complete
Example #7
0
def image_corruption(image, label, reso, image_corrupt_ratio_mean,
                     image_corrupt_ratio_stddev):
    """Randomly drop non-lesion pixels."""

    if image_corrupt_ratio_mean < 0.000001 and (image_corrupt_ratio_stddev <
                                                0.000001):
        return image

    # Corrupt non-lesion region according to keep_mask.
    keep_mask = _gen_rand_mask(1 - image_corrupt_ratio_mean,
                               image_corrupt_ratio_stddev, reso // 3,
                               image.shape)

    keep_mask = tf.logical_or(tf.greater(label, 1.5), keep_mask)
    image *= tf.cast(keep_mask, tf.float32)

    return image
Example #8
0
    def loss(self, inputs):
        """L2 loss on velocity."""
        graph = self._build_graph(inputs, is_training=True)
        network_output = self._learned_model(graph)

        # build target velocity change
        cur_velocity = inputs['velocity']
        target_velocity = inputs['target|velocity']
        target_velocity_change = target_velocity - cur_velocity
        target_normalized = self._output_normalizer(target_velocity_change)

        # build loss
        node_type = inputs['node_type'][:, 0]
        loss_mask = tf.logical_or(tf.equal(node_type, common.NodeType.NORMAL),
                                  tf.equal(node_type, common.NodeType.OUTFLOW))
        error = tf.reduce_sum((target_normalized - network_output)**2, axis=1)
        loss = tf.reduce_mean(error[loss_mask])
        return loss
Example #9
0
def detectMinVal(input_mat, var, threshold=1e-6, name='', debug=False):
    eigen_min = tf.reduce_min(input_mat)
    eigen_max = tf.reduce_max(input_mat)
    eigen_ratio = eigen_max / eigen_min
    input_mat_clipped = clipoutNeg(input_mat, threshold)

    if debug:
        input_mat_clipped = tf.cond(
            tf.logical_or(tf.greater(eigen_ratio, 0.),
                          tf.less(eigen_ratio,
                                  -500)), lambda: input_mat_clipped,
            lambda: tf.Print(input_mat_clipped, [
                tf.convert_to_tensor('screwed ratio ' + name +
                                     ' eigen values!!!'),
                tf.convert_to_tensor(var.name), eigen_min, eigen_max,
                eigen_ratio
            ]))

    return input_mat_clipped
Example #10
0
def _rollout(model, initial_state, num_steps):
    """Rolls out a model trajectory."""
    node_type = initial_state['node_type'][:, 0]
    mask = tf.logical_or(tf.equal(node_type, NodeType.NORMAL),
                         tf.equal(node_type, NodeType.OUTFLOW))

    def step_fn(step, velocity, trajectory):
        prediction = model({**initial_state, 'velocity': velocity})
        # don't update boundary nodes
        next_velocity = tf.where(mask, prediction, velocity)
        trajectory = trajectory.write(step, velocity)
        return step + 1, next_velocity, trajectory

    _, _, output = tf.while_loop(
        cond=lambda step, cur, traj: tf.less(step, num_steps),
        body=step_fn,
        loop_vars=(0, initial_state['velocity'],
                   tf.TensorArray(tf.float32, num_steps)),
        parallel_iterations=1)
    return output.stack()
Example #11
0
def _sequence_correct(labels: decode_utils.LabelsDict,
                      predictions: decode_utils.PredictionsDict):
    """Computes a per-example sequence accuracy."""
    target_decode_steps = decode_utils.decode_steps_from_labels(
        labels, trim_start_symbol=True)

    predicted_decode_steps = decode_utils.decode_steps_from_predictions(
        predictions)

    decode_utils.assert_shapes_match(target_decode_steps,
                                     predicted_decode_steps)

    equal_tokens = decode_utils.compare_decode_steps(target_decode_steps,
                                                     predicted_decode_steps)

    target_len = labels["target_len"] - 1
    loss_mask = tf.sequence_mask(lengths=tf.to_int32(target_len),
                                 maxlen=tf.to_int32(tf.shape(equal_tokens)[1]))
    equal_tokens = tf.logical_or(equal_tokens, tf.logical_not(loss_mask))
    all_equal = tf.cast(tf.reduce_all(equal_tokens, 1), tf.float32)
    return all_equal
Example #12
0
def zero_out_clipped_grads(grad, x, clip_min, clip_max):
    """
  Helper function to erase entries in the gradient where the update would be
  clipped.
  :param grad: The gradient
  :param x: The current input
  :param clip_min: Minimum input component value
  :param clip_max: Maximum input component value
  """
    signed_grad = tf.sign(grad)

    # Find input components that lie at the boundary of the input range, and
    # where the gradient points in the wrong direction.
    clip_low = tf.logical_and(tf.less_equal(x, tf.cast(clip_min, x.dtype)),
                              tf.less(signed_grad, 0))
    clip_high = tf.logical_and(tf.greater_equal(x, tf.cast(clip_max, x.dtype)),
                               tf.greater(signed_grad, 0))
    clip = tf.logical_or(clip_low, clip_high)
    grad = tf.where(clip, mul(grad, 0), grad)

    return grad
Example #13
0
def compare_generating_steps(target_decode_steps, predicted_decode_steps):
    """Compare generating steps only but ignoring target copying steps.

  Args:
    target_decode_steps: Target DecodeSteps, Each tensor is expected to be shape
      [batch_size, output_length].
    predicted_decode_steps: Predicted DecodeSteps, Each tensor is expected to be
      shape [batch_size, output_length].

  Returns:
    A tensor of bools indicating whether generating steps are equal.
    Copy Steps will have value True.
  """
    # Set all copying steps to True, Since we only care about generating steps.
    return tf.logical_or(
        tf.not_equal(target_decode_steps.action_types,
                     constants.GENERATE_ACTION),
        tf.logical_and(
            tf.equal(target_decode_steps.action_types,
                     predicted_decode_steps.action_types),
            tf.equal(target_decode_steps.action_ids,
                     predicted_decode_steps.action_ids)))
Example #14
0
File: krylov.py Project: cthl/sqgn
    def _update(self, rs, ps):
        ops = []

        # Compute the coefficient alpha.
        pTHp = tf.zeros(shape=[], dtype=ps[0].dtype)
        for p, Hz in zip(ps, self._hessians):
            # Recall that p has already been assigned to z, and hence Hz = Hp.
            pTHp += tf.reduce_sum(p * Hz)

        # Compute the coefficient for the update.
        alpha = self._rTr / pTHp

        # Create a tensor that computes the norm of the iterate after the update
        # without actually modifying it.
        norm_dw_new = tf.zeros(shape=[], dtype=self._norm_dw.dtype)
        for dw, p in zip(self._dws, ps):
            dw_new = dw + alpha * p
            norm_dw_new += tf.reduce_sum(dw_new * dw_new)
        norm_dw_new = tf.sqrt(norm_dw_new)

        # Determine if we should follow the direction p until it intersects with the
        # boundary of the trust region.
        # This is the case if either p is a direction of indefiniteness or if dw + p
        # would be outside the trust region.
        follow_to_boundary = tf.logical_or(pTHp <= 0.0,
                                           norm_dw_new > self._radius_placeh)
        self._follow_to_boundary = tf.Variable(False)
        ops.append(tf.assign(self._follow_to_boundary, follow_to_boundary))

        # If we follow p up to the boundary, we do not update dw here.
        # Instead, we determine the final update dw in the 'solve' method.
        alpha_or_zero = tf.cond(follow_to_boundary, lambda: 0.0, lambda: alpha)

        # Update the solution and residual.
        for dw, r, p, Hz in zip(self._dws, rs, ps, self._hessians):
            ops.append(tf.assign_add(dw, alpha_or_zero * p))
            ops.append(tf.assign_sub(r, alpha_or_zero * Hz))

        return tf.group(ops)
Example #15
0
        def maybe_update_alpha():
            """Operator to update alpha.

      Checks if global_step is between begin_compression_step and
      end_compression_step.
      """
            with tf.compat.v1.name_scope(self._spec.name):
                # prune if current step is more than begin_compression_step and
                # less than end_compression_step (unless it's negative)
                is_step_within_compression_range = tf.logical_and(
                    tf.greater_equal(tf.cast(self._global_step, tf.int32),
                                     self._spec.begin_compression_step),
                    tf.logical_or(
                        tf.less_equal(tf.cast(self._global_step, tf.int32),
                                      self._spec.end_compression_step),
                        tf.less(self._spec.end_compression_step, 0)))
                is_compression_step = tf.less_equal(
                    tf.add(self._last_alpha_update_step,
                           self._spec.compression_frequency),
                    tf.cast(self._global_step, tf.int32))
                return tf.logical_and(is_step_within_compression_range,
                                      is_compression_step)
def _online_sample_masks(inputs,
                         tgt_len,
                         num_predict,
                         boundary=None,
                         stride=1):
    """Sample target positions to predict."""
    tf.logging.info("Online sample with strategy: `%s`.",
                    FLAGS.sample_strategy)
    if FLAGS.sample_strategy == "single_token":
        return _single_token_mask(inputs, tgt_len, num_predict)
    else:
        if FLAGS.sample_strategy == "whole_word":
            assert boundary is not None, "whole word sampling requires `boundary`"
            is_target, target_mask = _whole_word_mask(inputs, tgt_len,
                                                      num_predict, boundary)
        elif FLAGS.sample_strategy == "token_span":
            is_target, target_mask = _token_span_mask(inputs,
                                                      tgt_len,
                                                      num_predict,
                                                      stride=stride)
        elif FLAGS.sample_strategy == "word_span":
            assert boundary is not None, "word span sampling requires `boundary`"
            is_target, target_mask = _word_span_mask(inputs,
                                                     tgt_len,
                                                     num_predict,
                                                     boundary,
                                                     stride=stride)
        else:
            raise NotImplementedError

        # Fill in single tokens if not full
        cur_num_masked = tf.reduce_sum(tf.cast(is_target, tf.int64))
        extra_mask, extra_tgt_mask = _single_token_mask(
            inputs, tgt_len, num_predict - cur_num_masked, is_target)
        return tf.logical_or(is_target,
                             extra_mask), target_mask + extra_tgt_mask
Example #17
0
  def _build(self, x, state):
    prev_keep_mask = state
    shape = tf.shape(x)
    noise = tf.random_uniform(shape, dtype=x.dtype)
    other_mask = tf.floor(self._keep_prob + noise)
    choice_noise = tf.random_uniform(shape, dtype=x.dtype)
    choice = tf.less(choice_noise, self._flip_prob)
    # KLUDGE(melisgl): The client has to pass the last keep_mask from
    # a batch to the next so the mask may end up next to some
    # recurrent cell state. This state is often zero at the beginning
    # and may be periodically zeroed (per example) during training.
    # While zeroing LSTM state is okay, zeroing the dropout mask is
    # not. So instead of forcing every client to deal with this common
    # (?) case, if an all zero mask is detected, then regenerate a
    # fresh mask. This is of course a major hack and won't help with
    # learnt initial states, for example.
    sum_ = tf.reduce_sum(prev_keep_mask, 1, keepdims=True)
    is_initializing = tf.equal(sum_, 0.0)

    self._keep_mask = tf.where(tf.logical_or(choice, is_initializing),
                               other_mask,
                               prev_keep_mask)
    self._time_step += 1
    return x * self._keep_mask / self._keep_prob * self._scaler
Example #18
0
def _define_collect(batch_env,
                    ppo_hparams,
                    scope,
                    frame_stack_size,
                    eval_phase,
                    sampling_temp,
                    force_beginning_resets,
                    distributional_size=1):
    """Collect trajectories.

  Args:
    batch_env: Batch environment.
    ppo_hparams: PPO hparams, defined in tensor2tensor.models.research.rl.
    scope: var scope.
    frame_stack_size: Number of last observations to feed into the policy.
    eval_phase: TODO(koz4k): Write docstring.
    sampling_temp: Sampling temperature for the policy.
    force_beginning_resets: Whether to reset at the beginning of each episode.
    distributional_size: optional, number of buckets in distributional RL.

  Returns:
    Returns memory (observations, rewards, dones, actions,
    pdfs, values_functions)
    containing a rollout of environment from nested wrapped structure.
  """
    epoch_length = ppo_hparams.epoch_length

    to_initialize = []
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        num_agents = batch_env.batch_size

        to_initialize.append(batch_env)
        wrappers = [(StackWrapper, {
            "history": frame_stack_size
        }), (_MemoryWrapper, {})]
        rollout_metadata = None
        speculum = None
        for w in wrappers:
            tf.logging.info("Applying wrapper %s(%s) to env %s." %
                            (str(w[0]), str(w[1]), str(batch_env)))
            batch_env = w[0](batch_env, **w[1])
            to_initialize.append(batch_env)

        rollout_metadata = _rollout_metadata(batch_env, distributional_size)
        speculum = batch_env.speculum

        def initialization_lambda(sess):
            for batch_env in to_initialize:
                batch_env.initialize(sess)

        memory = [
            tf.get_variable(  # pylint: disable=g-complex-comprehension
                "collect_memory_%d_%s" % (epoch_length, name),
                shape=[epoch_length] + shape,
                dtype=dtype,
                initializer=tf.zeros_initializer(),
                trainable=False) for (shape, dtype, name) in rollout_metadata
        ]

        cumulative_rewards = tf.get_variable("cumulative_rewards",
                                             len(batch_env),
                                             trainable=False)

        eval_phase_t = tf.convert_to_tensor(eval_phase)
        should_reset_var = tf.Variable(True, trainable=False)
        zeros_tensor = tf.zeros(len(batch_env))

    force_beginning_resets = tf.convert_to_tensor(force_beginning_resets)

    def reset_ops_group():
        return tf.group(batch_env.reset(tf.range(len(batch_env))),
                        tf.assign(cumulative_rewards, zeros_tensor))

    reset_op = tf.cond(
        tf.logical_or(should_reset_var.read_value(), force_beginning_resets),
        reset_ops_group, tf.no_op)

    with tf.control_dependencies([reset_op]):
        reset_once_op = tf.assign(should_reset_var, False)

    with tf.control_dependencies([reset_once_op]):

        def step(index, scores_sum, scores_num):
            """Single step."""
            index %= epoch_length  # Only needed in eval runs.
            # Note - the only way to ensure making a copy of tensor is to run simple
            # operation. We are waiting for tf.copy:
            # https://github.com/tensorflow/tensorflow/issues/11186
            obs_copy = batch_env.observ + 0
            value_fun_shape = (num_agents, )
            if distributional_size > 1:
                value_fun_shape = (num_agents, distributional_size)

            def env_step(arg1, arg2, arg3):  # pylint: disable=unused-argument
                """Step of the environment."""

                (logits, value_function) = get_policy(obs_copy, ppo_hparams,
                                                      batch_env.action_space,
                                                      distributional_size)
                action = common_layers.sample_with_temperature(
                    logits, sampling_temp)
                action = tf.cast(action, tf.int32)
                action = tf.reshape(action, shape=(num_agents, ))

                reward, done = batch_env.simulate(action)

                pdf = tfp.distributions.Categorical(logits=logits).prob(action)
                pdf = tf.reshape(pdf, shape=(num_agents, ))
                value_function = tf.reshape(value_function,
                                            shape=value_fun_shape)
                done = tf.reshape(done, shape=(num_agents, ))

                with tf.control_dependencies([reward, done]):
                    return tf.identity(pdf), tf.identity(value_function), \
                           tf.identity(done)

            # TODO(piotrmilos): while_body is executed at most once,
            # thus should be replaced with tf.cond
            pdf, value_function, top_level_done = tf.while_loop(
                lambda _1, _2, _3: tf.equal(speculum.size(), 0),
                env_step,
                [
                    tf.constant(0.0, shape=(num_agents, )),
                    tf.constant(0.0, shape=value_fun_shape),
                    tf.constant(False, shape=(num_agents, ))
                ],
                parallel_iterations=1,
                back_prop=False,
            )

            with tf.control_dependencies([pdf, value_function]):
                obs, reward, done, action = speculum.dequeue()
                to_save = [obs, reward, done, action, pdf, value_function]
                save_ops = [
                    tf.scatter_update(memory_slot, index, value)
                    for memory_slot, value in zip(memory, to_save)
                ]
                cumulate_rewards_op = cumulative_rewards.assign_add(reward)

                agent_indices_to_reset = tf.where(top_level_done)[:, 0]
            with tf.control_dependencies([cumulate_rewards_op]):
                # TODO(piotrmilos): possibly we need cumulative_rewards.read_value()
                scores_sum_delta = tf.reduce_sum(
                    tf.gather(cumulative_rewards.read_value(),
                              agent_indices_to_reset))
                scores_num_delta = tf.count_nonzero(done, dtype=tf.int32)
            with tf.control_dependencies(save_ops +
                                         [scores_sum_delta, scores_num_delta]):
                reset_env_op = batch_env.reset(agent_indices_to_reset)
                reset_cumulative_rewards_op = tf.scatter_update(
                    cumulative_rewards, agent_indices_to_reset,
                    tf.gather(zeros_tensor, agent_indices_to_reset))
            with tf.control_dependencies(
                [reset_env_op, reset_cumulative_rewards_op]):
                return [
                    index + 1, scores_sum + scores_sum_delta,
                    scores_num + scores_num_delta
                ]

        def stop_condition(i, _, resets):
            return tf.cond(eval_phase_t, lambda: resets < num_agents,
                           lambda: i < epoch_length)

        init = [tf.constant(0), tf.constant(0.0), tf.constant(0)]
        index, scores_sum, scores_num = tf.while_loop(stop_condition,
                                                      step,
                                                      init,
                                                      parallel_iterations=1,
                                                      back_prop=False)

    # We handle force_beginning_resets differently. We assume that all envs are
    # reseted at the end of episod (though it happens at the beginning of the
    # next one
    scores_num = tf.cond(force_beginning_resets,
                         lambda: scores_num + len(batch_env),
                         lambda: scores_num)

    with tf.control_dependencies([scores_sum]):
        scores_sum = tf.cond(
            force_beginning_resets, lambda: scores_sum + tf.reduce_sum(
                cumulative_rewards.read_value()), lambda: scores_sum)

    mean_score = tf.cond(tf.greater(scores_num, 0),
                         lambda: scores_sum / tf.cast(scores_num, tf.float32),
                         lambda: 0.)
    printing = tf.Print(0, [mean_score, scores_sum, scores_num],
                        "mean_score: ")
    with tf.control_dependencies([index, printing]):
        memory = [mem.read_value() for mem in memory]
        # When generating real data together with PPO training we must use single
        # agent. For PPO to work we reshape the history, as if it was generated
        # by real_ppo_effective_num_agents.
        if ppo_hparams.effective_num_agents is not None and not eval_phase:
            new_memory = []
            effective_num_agents = ppo_hparams.effective_num_agents
            assert epoch_length % ppo_hparams.effective_num_agents == 0, (
                "The rollout of ppo_hparams.epoch_length will be distributed amongst"
                "effective_num_agents of agents")
            new_epoch_length = int(epoch_length / effective_num_agents)
            for mem, info in zip(memory, rollout_metadata):
                shape, _, name = info
                new_shape = [effective_num_agents, new_epoch_length
                             ] + shape[1:]
                perm = list(range(len(shape) + 1))
                perm[0] = 1
                perm[1] = 0
                mem = tf.transpose(mem, perm=perm)
                mem = tf.reshape(mem, shape=new_shape)
                mem = tf.transpose(mem,
                                   perm=perm,
                                   name="collect_memory_%d_%s" %
                                   (new_epoch_length, name))
                new_memory.append(mem)
            memory = new_memory

        with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
            mean_score_summary = tf.cond(
                tf.greater(scores_num, 0),
                lambda: tf.summary.scalar("mean_score_this_iter", mean_score),
                str)
            summaries = tf.summary.merge([
                mean_score_summary,
                tf.summary.scalar("episodes_finished_this_iter", scores_num)
            ])
            return memory, summaries, initialization_lambda
Example #19
0
def maybe_gen_fake_data_based_on_real_data(image, label, reso,
                                           min_fake_lesion_ratio,
                                           gen_fake_probability):
    """Remove real lesion and synthesize lesion."""
    # TODO(lehou): Replace magic numbers with flag variables.
    gen_prob_indicator = tf.random_uniform(shape=[],
                                           minval=0.0,
                                           maxval=1.0,
                                           dtype=tf.float32)

    background_mask = tf.less(label, 0.5)
    lesion_mask = tf.greater(label, 1.5)
    liver_mask = tf.logical_not(tf.logical_or(background_mask, lesion_mask))

    liver_intensity = tf.boolean_mask(image, liver_mask)
    lesion_intensity = tf.boolean_mask(image, lesion_mask)

    intensity_diff = tf.reduce_mean(liver_intensity) - (
        tf.reduce_mean(lesion_intensity))
    intensity_diff *= 1.15
    intensity_diff = tf.cond(tf.is_nan(intensity_diff), lambda: 0.0,
                             lambda: intensity_diff)

    lesion_liver_ratio = 0.0
    lesion_liver_ratio += tf.random.normal(shape=[], mean=0.01, stddev=0.01)
    lesion_liver_ratio += tf.random.normal(shape=[], mean=0.0, stddev=0.05)
    lesion_liver_ratio = tf.clip_by_value(lesion_liver_ratio,
                                          min_fake_lesion_ratio,
                                          min_fake_lesion_ratio + 0.20)

    fake_lesion_mask = tf.logical_and(
        _gen_rand_mask(ratio_mean=lesion_liver_ratio,
                       ratio_stddev=0.0,
                       scale=reso // 32,
                       shape=label.shape,
                       smoothness=reso // 32), tf.logical_not(background_mask))
    liver_mask = tf.logical_not(
        tf.logical_or(background_mask, fake_lesion_mask))

    # Blur the masks
    lesion_mask_blur = tf.squeeze(
        tf.nn.conv3d(tf.expand_dims(
            tf.expand_dims(tf.cast(lesion_mask, tf.float32), -1), 0),
                     filter=tf.ones([reso // 32] * 3 + [1, 1], tf.float32) /
                     (reso // 32)**3,
                     strides=[1, 1, 1, 1, 1],
                     padding='SAME'))
    fake_lesion_mask_blur = tf.squeeze(
        tf.nn.conv3d(tf.expand_dims(
            tf.expand_dims(tf.cast(fake_lesion_mask, tf.float32), -1), 0),
                     filter=tf.ones([reso // 32] * 3 + [1, 1], tf.float32) /
                     (reso // 32)**3,
                     strides=[1, 1, 1, 1, 1],
                     padding='SAME'))

    # Remove real lesion and add fake lesion.
    # If the intensitify is too small (maybe no liver or lesion region labeled),
    # do not generate fake data.
    gen_prob_indicator = tf.cond(tf.greater(intensity_diff, 0.0001),
                                 lambda: gen_prob_indicator, lambda: 0.0)
    # pylint: disable=g-long-lambda
    image = tf.cond(
        tf.greater(gen_prob_indicator, 1 - gen_fake_probability),
        lambda: image + intensity_diff * lesion_mask_blur \
                      - intensity_diff * fake_lesion_mask_blur,
        lambda: image)
    label = tf.cond(
        tf.greater(gen_prob_indicator, 1 - gen_fake_probability),
        lambda: tf.cast(background_mask, tf.float32) * 0 + \
            tf.cast(liver_mask, tf.float32) * 1 + \
            tf.cast(fake_lesion_mask, tf.float32) * 2,
        lambda: label)
    # pylint: enable=g-long-lambda

    return image, label
Example #20
0
 def has_nan(self):
     return tf.logical_or(tf.math.is_nan(self.x), tf.math.is_nan(self.y))
Example #21
0
 def logical_or(self, x, y):
     return tf.logical_or(x, y)
Example #22
0
 def parse1_func(filename):
     # read data
     dtype = tf.float32
     image = tf.read_file(filename)
     image = tf.image.decode_image(image, channels=channels)
     shape = tf.shape(image)
     height = shape[-3]
     width = shape[-2]
     # pre down-scale for high resolution image
     dscale = 1
     if is_training and config.pre_down:
         '''
         if (width >= 3072 and height >= 1536) or (width >= 1536 and height >= 3072):
             dscale = 3
         elif (width >= 1024 and height >= 512) or (width >= 512 and height >= 1024):
             dscale = 2
         '''
         def c_t(const1, const2, true_fn, false_fn):
             return tf.cond(tf.logical_or(
                 tf.logical_and(
                     tf.greater_equal(width, const1), tf.greater_equal(height, const2)
                 ),
                 tf.logical_and(
                     tf.greater_equal(width, const2), tf.greater_equal(height, const1)
                 )
             ), true_fn, false_fn)
         dscale = c_t(3072, 1536, lambda: 3,
             lambda: c_t(1024, 512, lambda: 2, lambda: 1)
         )
     elif is_testing and config.pre_down:
         '''
         if (width >= 3072 and height >= 3072):
             dscale = 4
         elif (width >= 2048 and height >= 2048):
             dscale = 3
         elif (width >= 1024 and height >= 1024):
             dscale = 2
         '''
         def c_t(const1, true_fn, false_fn):
             return tf.cond(tf.logical_and(
                 tf.greater_equal(width, const1), tf.greater_equal(height, const1)
             ), true_fn, false_fn)
         dscale = c_t(3072, lambda: 4,
             lambda: c_t(2048, lambda: 3,
                 lambda: c_t(1024, lambda: 2, lambda: 1)
             )
         )
     # padding
     cropped_height = patch_height * dscale
     cropped_width = patch_width * dscale
     '''
     if cropped_height > height or cropped_width > width:
         pad_height = cropped_height - height
         pad_width = cropped_width - width
         if pad_height > 0:
             pad_height = [pad_height // 2, pad_height - pad_height // 2]
             height = cropped_height
         else:
             pad_height = [0, 0]
         if pad_width > 0:
             pad_width = [pad_width // 2, pad_width - pad_width // 2]
             width = cropped_width
         else:
             pad_width = [0, 0]
         block = tf.pad(image, [pad_height, pad_width, [0, 0]], mode='REFLECT')
     else:
         block = image
     '''
     cond_height = tf.greater(cropped_height, height)
     cond_width = tf.greater(cropped_width, width)
     def c_f1():
         def _1():
             ph = cropped_height - height
             return [ph // 2, ph - ph // 2]
         pad_height = tf.cond(cond_height, _1, lambda: [0, 0])
         def _2():
             pw = cropped_width - width
             return [pw // 2, pw - pw // 2]
         pad_width = tf.cond(cond_width, _2, lambda: [0, 0])
         return tf.pad(image, [pad_height, pad_width, [0, 0]], mode='REFLECT')
     block = tf.cond(tf.logical_or(cond_height, cond_width), c_f1, lambda: image)
     height = tf.maximum(cropped_height, height)
     width = tf.maximum(cropped_width, width)
     # cropping
     if is_training:
         block = tf.random_crop(block, [cropped_height, cropped_width, channels])
         block = tf.image.random_flip_up_down(block)
         block = tf.image.random_flip_left_right(block)
     elif is_testing:
         offset_height = (height - cropped_height) // 2
         offset_width = (width - cropped_width) // 2
         block = tf.image.crop_to_bounding_box(block, offset_height, offset_width,
                                               cropped_height, cropped_width)
     # convert dtype
     block = tf.image.convert_image_dtype(block, dtype, saturate=False)
     # random color augmentation
     if is_training and config.color_augmentation > 0:
         block = tf.image.random_saturation(block, 1 - config.color_augmentation, 1 + config.color_augmentation)
         block = tf.image.random_brightness(block, config.color_augmentation)
         block = tf.image.random_contrast(block, 1 - config.color_augmentation, 1 + config.color_augmentation)
     # data format conversion
     block.set_shape([None, None, channels])
     if data_format == 'NCHW':
         block = tf.transpose(block, (2, 0, 1))
     # return
     return block
Example #23
0
def prepare_encoder_input(features,
                          hparams,
                          embed_scope=None,
                          embed_token_fn=common_embed.embed_tokens):
    """Prepares the input for the screen encoder.

  Args:
    features: the feature dict.
    hparams: the hyperparameter.
    embed_scope: the embedding variable scope.
    embed_token_fn: the function for embedding tokens.
  Returns:
    object_embedding: a Tensor of shape
        [batch_size, num_steps, max_object_count, embed_depth]
    object_mask: a binary tensor of shape
        [batch_size, num_steps, max_object_count]
    nonpadding_bias: a Tensor of shape
        [batch_size, num_steps, max_object_count]
  """
    with tf.control_dependencies(
        [tf.assert_equal(tf.rank(features["obj_text"]), 4)]):
        if hparams.get("synthetic_screen_noise", 0.) > 0.:
            num_objects = tf.shape(features["obj_text"])[2]
            # [batch, length, num_objects]
            target_obj_mask = tf.cast(
                tf.one_hot(features["objects"], depth=num_objects), tf.bool)
            num_tokens = tf.shape(features["obj_text"])[-1]
            target_obj_mask = tf.tile(tf.expand_dims(target_obj_mask, 3),
                                      [1, 1, 1, num_tokens])
            # Randomly keep tokens
            keep_mask = tf.greater_equal(
                tf.random_uniform(shape=tf.shape(features["obj_text"])),
                hparams.synthetic_screen_noise)
            # Keep paddings
            keep_mask = tf.logical_or(tf.equal(features["obj_text"], 0),
                                      keep_mask)
            # Keep targets
            target_obj_mask = tf.logical_or(target_obj_mask, keep_mask)
            features["obj_text"] = tf.where(
                target_obj_mask, features["obj_text"],
                tf.random_uniform(shape=tf.shape(features["obj_text"]),
                                  maxval=50000,
                                  dtype=tf.int32))
        text_embeddings, _ = embed_token_fn(features["obj_text"],
                                            hparams.task_vocab_size,
                                            hparams.hidden_size,
                                            hparams,
                                            embed_scope=embed_scope)
        with tf.variable_scope("obj_text_embed", reuse=tf.AUTO_REUSE):
            if hparams.obj_text_aggregation == "max":
                embed_bias = tf.cast(tf.less(features["obj_text"], 2),
                                     tf.float32) * -1e7
                with tf.control_dependencies(
                    [tf.assert_equal(tf.rank(embed_bias), 4)]):
                    text_embeddings = tf.reduce_max(
                        text_embeddings + tf.expand_dims(embed_bias, 4), -2)
                    no_txt_embed = tf.get_variable(name="no_txt_embed",
                                                   shape=[hparams.hidden_size])
                    shape = common_layers.shape_list(text_embeddings)
                    no_txt_embed = tf.tile(
                        tf.reshape(no_txt_embed,
                                   [1, 1, 1, hparams.hidden_size]),
                        [shape[0], shape[1], shape[2], 1])
                    text_embeddings = tf.maximum(text_embeddings, no_txt_embed)
            elif hparams.obj_text_aggregation == "sum":
                # [batch, step, #max_obj, #max_token]  0 for padded tokens
                real_objects = tf.cast(
                    tf.greater_equal(features["obj_text"], 2), tf.float32)
                # [batch, step, #max_obj, hidden]   0s for padded objects
                text_embeddings = tf.reduce_sum(
                    text_embeddings * tf.expand_dims(real_objects, 4), -2)
            elif hparams.obj_text_aggregation == "mean":
                shape_list = common_layers.shape_list(text_embeddings)
                embeddings = tf.reshape(text_embeddings, [-1] + shape_list[3:])
                emb_sum = tf.reduce_sum(tf.abs(embeddings), axis=-1)
                non_paddings = tf.not_equal(emb_sum, 0.0)
                embeddings = common_embed.average_bag_of_embeds(
                    embeddings,
                    non_paddings,
                    use_bigrams=True,
                    bigram_embed_scope=embed_scope,
                    append_start_end=True)
                text_embeddings = tf.reshape(
                    embeddings, shape_list[:3] + [hparams.hidden_size])
            else:
                raise ValueError("Unrecognized token aggregation %s" %
                                 (hparams.obj_text_aggregation))
    with tf.control_dependencies([
            tf.assert_equal(tf.rank(features["obj_type"]), 3),
            tf.assert_equal(tf.rank(features["obj_clickable"]), 3)
    ]):
        with tf.variable_scope("encode_object_attr", reuse=tf.AUTO_REUSE):
            type_embedding = tf.nn.embedding_lookup(params=tf.get_variable(
                name="embed_type_w",
                shape=[hparams.get("num_types", 100), hparams.hidden_size]),
                                                    ids=tf.maximum(
                                                        features["obj_type"],
                                                        0))
            clickable_embedding = tf.nn.embedding_lookup(
                params=tf.get_variable(name="embed_clickable_w",
                                       shape=[2, hparams.hidden_size]),
                ids=features["obj_clickable"])
    with tf.control_dependencies(
        [tf.assert_equal(tf.rank(features["obj_screen_pos"]), 4)]):

        def _create_embed(feature_name, vocab_size, depth):
            """Embed a position feature."""
            pos_embedding_list = []
            with tf.variable_scope("encode_object_" + feature_name,
                                   reuse=tf.AUTO_REUSE):
                num_featues = common_layers.shape_list(
                    features[feature_name])[-1]
                for i in range(num_featues):
                    pos_embedding_list.append(
                        tf.nn.embedding_lookup(
                            params=tf.get_variable(name=feature_name +
                                                   "_embed_w_%d" % i,
                                                   shape=[vocab_size, depth]),
                            ids=features[feature_name][:, :, :, i]))
                pos_embedding = tf.add_n(pos_embedding_list)
                return pos_embedding

        pos_embedding = _create_embed("obj_screen_pos", hparams.max_pixel_pos,
                                      hparams.hidden_size)
    if "all" == hparams.screen_embedding_feature or (
            "dom" in hparams.screen_embedding_feature):
        dom_embedding = _create_embed("obj_dom_pos", hparams.max_dom_pos,
                                      hparams.hidden_size)
    object_embed = tf.zeros_like(text_embeddings, dtype=tf.float32)
    if hparams.screen_embedding_feature == "all":
        object_embed = (text_embeddings + type_embedding + pos_embedding +
                        dom_embedding)
    elif "text" in hparams.screen_embedding_feature:
        object_embed += text_embeddings
    elif "type" in hparams.screen_embedding_feature:
        object_embed += type_embedding
    elif "pos" in hparams.screen_embedding_feature:
        object_embed += pos_embedding
    elif "dom" in hparams.screen_embedding_feature:
        object_embed += dom_embedding
    elif "click" in hparams.screen_embedding_feature:
        object_embed += clickable_embedding
    object_mask = tf.cast(tf.not_equal(features["obj_type"], -1), tf.float32)
    object_embed = object_embed * tf.expand_dims(object_mask, 3)
    att_bias = (1. - object_mask) * common_attention.large_compatible_negative(
        object_embed.dtype)
    return object_embed, object_mask, att_bias
Example #24
0
def trilerp_gather(vol, inds, bad_inds=None):
  """Trilinear interpolation dense gather from volume at query inds."""

  inds_b = inds[Ellipsis, 0]
  inds_x = inds[Ellipsis, 1]
  inds_y = inds[Ellipsis, 2]
  inds_z = inds[Ellipsis, 3]

  inds_x_0 = tf.floor(inds_x)
  inds_x_1 = inds_x_0 + 1
  inds_y_0 = tf.floor(inds_y)
  inds_y_1 = inds_y_0 + 1
  inds_z_0 = tf.floor(inds_z)
  inds_z_1 = inds_z_0 + 1

  # store invalid indices to implement correct out-of-bounds conditions
  invalid_x = tf.logical_or(
      tf.less(inds_x_0, 0.0),
      tf.greater(inds_x_1, tf.to_float(tf.shape(vol)[2] - 1)))
  invalid_y = tf.logical_or(
      tf.less(inds_y_0, 0.0),
      tf.greater(inds_y_1, tf.to_float(tf.shape(vol)[1] - 1)))
  invalid_z = tf.logical_or(
      tf.less(inds_z_0, 0.0),
      tf.greater(inds_z_1, tf.to_float(tf.shape(vol)[3] - 1)))
  if bad_inds is not None:
    invalid_inds = tf.logical_or(
        tf.logical_or(tf.logical_or(invalid_x, invalid_y), invalid_z), bad_inds)
  else:
    invalid_inds = tf.logical_or(tf.logical_or(invalid_x, invalid_y), invalid_z)

  inds_x_0 = tf.clip_by_value(inds_x_0, 0.0, tf.to_float(tf.shape(vol)[2] - 2))
  inds_x_1 = tf.clip_by_value(inds_x_1, 0.0, tf.to_float(tf.shape(vol)[2] - 1))
  inds_y_0 = tf.clip_by_value(inds_y_0, 0.0, tf.to_float(tf.shape(vol)[1] - 2))
  inds_y_1 = tf.clip_by_value(inds_y_1, 0.0, tf.to_float(tf.shape(vol)[1] - 1))
  inds_z_0 = tf.clip_by_value(inds_z_0, 0.0, tf.to_float(tf.shape(vol)[3] - 2))
  inds_z_1 = tf.clip_by_value(inds_z_1, 0.0, tf.to_float(tf.shape(vol)[3] - 1))

  # compute interp weights
  w_x_0 = 1.0 - (inds_x - inds_x_0)
  w_x_1 = 1.0 - w_x_0
  w_y_0 = 1.0 - (inds_y - inds_y_0)
  w_y_1 = 1.0 - w_y_0
  w_z_0 = 1.0 - (inds_z - inds_z_0)
  w_z_1 = 1.0 - w_z_0

  w_0_0_0 = w_y_0 * w_x_0 * w_z_0
  w_1_0_0 = w_y_1 * w_x_0 * w_z_0
  w_0_1_0 = w_y_0 * w_x_1 * w_z_0
  w_0_0_1 = w_y_0 * w_x_0 * w_z_1
  w_1_1_0 = w_y_1 * w_x_1 * w_z_0
  w_0_1_1 = w_y_0 * w_x_1 * w_z_1
  w_1_0_1 = w_y_1 * w_x_0 * w_z_1
  w_1_1_1 = w_y_1 * w_x_1 * w_z_1

  # gather for interp
  inds_0_0_0 = tf.to_int32(
      tf.stack([inds_b, inds_y_0, inds_x_0, inds_z_0], axis=-1))
  inds_1_0_0 = tf.to_int32(
      tf.stack([inds_b, inds_y_1, inds_x_0, inds_z_0], axis=-1))
  inds_0_1_0 = tf.to_int32(
      tf.stack([inds_b, inds_y_0, inds_x_1, inds_z_0], axis=-1))
  inds_0_0_1 = tf.to_int32(
      tf.stack([inds_b, inds_y_0, inds_x_0, inds_z_1], axis=-1))
  inds_1_1_0 = tf.to_int32(
      tf.stack([inds_b, inds_y_1, inds_x_1, inds_z_0], axis=-1))
  inds_0_1_1 = tf.to_int32(
      tf.stack([inds_b, inds_y_0, inds_x_1, inds_z_1], axis=-1))
  inds_1_0_1 = tf.to_int32(
      tf.stack([inds_b, inds_y_1, inds_x_0, inds_z_1], axis=-1))
  inds_1_1_1 = tf.to_int32(
      tf.stack([inds_b, inds_y_1, inds_x_1, inds_z_1], axis=-1))

  vol_0_0_0 = tf.gather_nd(vol, inds_0_0_0) * w_0_0_0[Ellipsis, tf.newaxis]
  vol_1_0_0 = tf.gather_nd(vol, inds_1_0_0) * w_1_0_0[Ellipsis, tf.newaxis]
  vol_0_1_0 = tf.gather_nd(vol, inds_0_1_0) * w_0_1_0[Ellipsis, tf.newaxis]
  vol_0_0_1 = tf.gather_nd(vol, inds_0_0_1) * w_0_0_1[Ellipsis, tf.newaxis]
  vol_1_1_0 = tf.gather_nd(vol, inds_1_1_0) * w_1_1_0[Ellipsis, tf.newaxis]
  vol_0_1_1 = tf.gather_nd(vol, inds_0_1_1) * w_0_1_1[Ellipsis, tf.newaxis]
  vol_1_0_1 = tf.gather_nd(vol, inds_1_0_1) * w_1_0_1[Ellipsis, tf.newaxis]
  vol_1_1_1 = tf.gather_nd(vol, inds_1_1_1) * w_1_1_1[Ellipsis, tf.newaxis]

  out_vol = vol_0_0_0 + vol_1_0_0 + vol_0_1_0 + vol_0_0_1 + \
            vol_1_1_0 + vol_0_1_1 + vol_1_0_1 + vol_1_1_1

  # boundary conditions for invalid indices
  invalid_inds = tf.tile(invalid_inds[:, :, :, :, tf.newaxis],
                         [1, 1, 1, 1, tf.shape(vol)[4]])
  out_vol = tf.where(invalid_inds, tf.zeros_like(out_vol), out_vol)

  return out_vol
Example #25
0
def bilerp_gather(img, inds):
  """Bilinear interpolation dense gather from image at query inds."""

  inds_b, _, _, = tf.meshgrid(
      tf.range(tf.shape(img)[0]),
      tf.range(tf.shape(img)[1]),
      tf.range(tf.shape(img)[2]),
      indexing='ij')

  inds_b = tf.to_float(inds_b)
  inds_x = inds[Ellipsis, 0]
  inds_y = inds[Ellipsis, 1]

  inds_x_0 = tf.floor(inds_x)
  inds_x_1 = inds_x_0 + 1
  inds_y_0 = tf.floor(inds_y)
  inds_y_1 = inds_y_0 + 1

  # store invalid indices to implement correct out-of-bounds conditions
  invalid_x = tf.logical_or(
      tf.less(inds_x_0, 0.0),
      tf.greater(inds_x_1, tf.to_float(tf.shape(img)[2] - 1)))
  invalid_y = tf.logical_or(
      tf.less(inds_y_0, 0.0),
      tf.greater(inds_y_1, tf.to_float(tf.shape(img)[1] - 1)))
  invalid_inds = tf.logical_or(invalid_x, invalid_y)

  inds_x_0 = tf.clip_by_value(inds_x_0, 0.0, tf.to_float(tf.shape(img)[2] - 2))
  inds_x_1 = tf.clip_by_value(inds_x_1, 0.0, tf.to_float(tf.shape(img)[2] - 1))
  inds_y_0 = tf.clip_by_value(inds_y_0, 0.0, tf.to_float(tf.shape(img)[1] - 2))
  inds_y_1 = tf.clip_by_value(inds_y_1, 0.0, tf.to_float(tf.shape(img)[1] - 1))

  # compute interp weights
  w_x_0 = 1.0 - (inds_x - inds_x_0)
  w_x_1 = 1.0 - w_x_0
  w_y_0 = 1.0 - (inds_y - inds_y_0)
  w_y_1 = 1.0 - w_y_0

  w_0_0 = w_y_0 * w_x_0
  w_1_0 = w_y_1 * w_x_0
  w_0_1 = w_y_0 * w_x_1
  w_1_1 = w_y_1 * w_x_1

  # gather for interp
  inds_0_0 = tf.to_int32(tf.stack([inds_b, inds_y_0, inds_x_0], axis=-1))
  inds_1_0 = tf.to_int32(tf.stack([inds_b, inds_y_1, inds_x_0], axis=-1))
  inds_0_1 = tf.to_int32(tf.stack([inds_b, inds_y_0, inds_x_1], axis=-1))
  inds_1_1 = tf.to_int32(tf.stack([inds_b, inds_y_1, inds_x_1], axis=-1))

  img_0_0 = tf.gather_nd(img, inds_0_0) * w_0_0[Ellipsis, tf.newaxis]
  img_1_0 = tf.gather_nd(img, inds_1_0) * w_1_0[Ellipsis, tf.newaxis]
  img_0_1 = tf.gather_nd(img, inds_0_1) * w_0_1[Ellipsis, tf.newaxis]
  img_1_1 = tf.gather_nd(img, inds_1_1) * w_1_1[Ellipsis, tf.newaxis]

  out_img = img_0_0 + img_1_0 + img_0_1 + img_1_1

  # boundary conditions for invalid indices
  invalid_inds = tf.tile(invalid_inds[:, :, :, tf.newaxis],
                         [1, 1, 1, tf.shape(img)[3]])

  out_img = tf.where(invalid_inds, tf.zeros_like(out_img), out_img)

  return out_img
    def parser(value):
        """Parse an Imagenet record from value."""
        keys_to_features = {
            'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
            'image/class/label':
            tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
            'image/class/text':
            tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/object/bbox/xmin':
            tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin':
            tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax':
            tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax':
            tf.VarLenFeature(dtype=tf.float32),
            'image/object/class/label':
            tf.VarLenFeature(dtype=tf.int64),
        }

        parsed = tf.parse_single_example(value, keys_to_features)
        encoded_image = tf.reshape(parsed['image/encoded'],
                                   shape=[],
                                   name='encoded_image')
        image_format = parsed['image/format']
        xmin = tf.expand_dims(parsed['image/object/bbox/xmin'].values, 0)
        ymin = tf.expand_dims(parsed['image/object/bbox/ymin'].values, 0)
        xmax = tf.expand_dims(parsed['image/object/bbox/xmax'].values, 0)
        ymax = tf.expand_dims(parsed['image/object/bbox/ymax'].values, 0)

        # Note that we impose an ordering of (y, x) just to make life difficult.
        bbox = tf.concat([ymin, xmin, ymax, xmax], 0)

        # Force the variable number of bounding boxes into the shape
        # [1, num_boxes, coords].
        bbox = tf.expand_dims(bbox, 0)
        bbox = tf.transpose(bbox, [0, 2, 1])

        def decode_png():
            return tf.image.decode_png(encoded_image, 3)

        def decode_jpg():
            return tf.image.decode_jpeg(encoded_image, 3)

        # If image format is PNG, use decode_png, default to jpg.
        pred_fn_pairs = {
            tf.logical_or(tf.equal(image_format, 'png'),
                          tf.equal(image_format, 'PNG')):
            decode_png
        }

        image = tf.case(pred_fn_pairs, default=decode_jpg, exclusive=True)
        image.set_shape([None, None, 3])

        image = preprocess(image, bbox)

        label = tf.cast(tf.reshape(parsed['image/class/label'], shape=[]),
                        dtype=tf.int32,
                        name='cast_label')
        label = tf.reshape(label, [1])
        return tf.cast(image, tf.float32), label
def should_log(params):
    """Returns a Boolean `tf.Tensor` dictating whether we should log values."""
    global_step = tf.train.get_or_create_global_step()
    first_run = tf.equal(global_step, 1)
    log_every = tf.equal(tf.floormod(global_step, params.log_every), 0)
    return tf.logical_or(first_run, log_every)
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
    """
    Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits -
                              tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)

        if isinstance(p, float) and p > 0.999999:
            # Don't do top-p sampling in this case
            print("Top-p sampling DISABLED", flush=True)
            return {
                'probs':
                probs,
                'sample':
                tf.random.categorical(
                    logits=logits if ignore_ids is None else logits -
                    tf.cast(ignore_ids[None], tf.float32) * 1e10,
                    num_samples=num_samples,
                    dtype=tf.int32),
            }

        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')
        cumulative_probabilities = tf.math.cumsum(tf.batch_gather(
            probs, indices),
                                                  axis=-1,
                                                  exclusive=False)

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        p_expanded = p if isinstance(p, float) else p[:, None]
        exclude_mask = tf.logical_not(
            tf.logical_or(cumulative_probabilities < p_expanded,
                          tf.range(vocab_size)[None] < 1))

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(
            logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use,
                                            num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

        # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
        # unperm_indices = tf.argsort(indices, direction='ASCENDING')
        # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
        # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
        # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)

    return {
        'probs': probs,
        'sample': sample,
    }
Example #29
0
def assign_and_sample_proposals(proposed_boxes,
                                gt_boxes,
                                gt_classes,
                                gt_attributes,
                                num_samples_per_image=512,
                                mix_gt_boxes=True,
                                fg_fraction=0.25,
                                fg_iou_thresh=0.5,
                                bg_iou_thresh_hi=0.5,
                                bg_iou_thresh_lo=0.0):
    """Assigns the proposals with groundtruth classes and performs subsmpling.

  Given `proposed_boxes`, `gt_boxes`, `gt_classes` and `gt_attributes`, the
  function uses the following algorithm to generate the final
  `num_samples_per_image` RoIs.
    1. Calculates the IoU between each proposal box and each gt_boxes.
    2. Assigns each proposed box with a groundtruth class and box by choosing
       the largest IoU overlap.
    3. Samples `num_samples_per_image` boxes from all proposed boxes, and
       returns box_targets, class_targets, and RoIs.

  Args:
    proposed_boxes: a tensor of shape of [batch_size, N, 4]. N is the number
      of proposals before groundtruth assignment. The last dimension is the
      box coordinates w.r.t. the scaled images in [ymin, xmin, ymax, xmax]
      format.
    gt_boxes: a tensor of shape of [batch_size, MAX_NUM_INSTANCES, 4].
      The coordinates of gt_boxes are in the pixel coordinates of the scaled
      image. This tensor might have padding of values -1 indicating the invalid
      box coordinates.
    gt_classes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
      tensor might have paddings with values of -1 indicating the invalid
      classes.
    gt_attributes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES,
      num_attributes]. This tensor might have paddings with values of -1
      indicating the invalid attributes.
    num_samples_per_image: an integer represents RoI minibatch size per image.
    mix_gt_boxes: a bool indicating whether to mix the groundtruth boxes before
      sampling proposals.
    fg_fraction: a float represents the target fraction of RoI minibatch that
      is labeled foreground (i.e., class > 0).
    fg_iou_thresh: a float represents the IoU overlap threshold for an RoI to be
      considered foreground (if >= fg_iou_thresh).
    bg_iou_thresh_hi: a float represents the IoU overlap threshold for an RoI to
      be considered background (class = 0 if overlap in [LO, HI)).
    bg_iou_thresh_lo: a float represents the IoU overlap threshold for an RoI to
      be considered background (class = 0 if overlap in [LO, HI)).

  Returns:
    sampled_rois: a tensor of shape of [batch_size, K, 4], representing the
      coordinates of the sampled RoIs, where K is the number of the sampled
      RoIs, i.e. K = num_samples_per_image.
    sampled_gt_boxes: a tensor of shape of [batch_size, K, 4], storing the
      box coordinates of the matched groundtruth boxes of the samples RoIs.
    sampled_gt_classes: a tensor of shape of [batch_size, K], storing the
      classes of the matched groundtruth boxes of the sampled RoIs.
    sampled_gt_attributes: a tensor of shape of [batch_size, K,
      num_attributes], storing the attributes of the matched groundtruth
      attributes of the sampled RoIs.
    sampled_gt_indices: a tensor of shape of [batch_size, K], storing the
      indices of the sampled groudntruth boxes in the original `gt_boxes`
      tensor, i.e. gt_boxes[sampled_gt_indices[:, i]] = sampled_gt_boxes[:, i].
  """

    with tf.name_scope('sample_proposals'):
        if mix_gt_boxes:
            boxes = tf.concat([proposed_boxes, gt_boxes], axis=1)
        else:
            boxes = proposed_boxes

        (matched_gt_boxes, matched_gt_classes, matched_gt_attributes,
         matched_gt_indices, matched_iou,
         _) = box_matching(boxes, gt_boxes, gt_classes, gt_attributes)

        positive_match = tf.greater(matched_iou, fg_iou_thresh)
        negative_match = tf.logical_and(
            tf.greater_equal(matched_iou, bg_iou_thresh_lo),
            tf.less(matched_iou, bg_iou_thresh_hi))
        ignored_match = tf.less(matched_iou, 0.0)

        # re-assign negatively matched boxes to the background class.
        matched_gt_classes = tf.where(negative_match,
                                      tf.zeros_like(matched_gt_classes),
                                      matched_gt_classes)
        matched_gt_indices = tf.where(negative_match,
                                      tf.zeros_like(matched_gt_indices),
                                      matched_gt_indices)

        sample_candidates = tf.logical_and(
            tf.logical_or(positive_match, negative_match),
            tf.logical_not(ignored_match))

        sampler = (
            balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
                positive_fraction=fg_fraction, is_static=True))

        batch_size, _ = sample_candidates.get_shape().as_list()
        sampled_indicators = []
        for i in range(batch_size):
            sampled_indicator = sampler.subsample(sample_candidates[i],
                                                  num_samples_per_image,
                                                  positive_match[i])
            sampled_indicators.append(sampled_indicator)
        sampled_indicators = tf.stack(sampled_indicators)
        _, sampled_indices = tf.nn.top_k(tf.cast(sampled_indicators,
                                                 dtype=tf.int32),
                                         k=num_samples_per_image,
                                         sorted=True)

        sampled_indices_shape = tf.shape(sampled_indices)
        batch_indices = (
            tf.expand_dims(tf.range(sampled_indices_shape[0]), axis=-1) *
            tf.ones([1, sampled_indices_shape[-1]], dtype=tf.int32))
        gather_nd_indices = tf.stack([batch_indices, sampled_indices], axis=-1)

        sampled_rois = tf.gather_nd(boxes, gather_nd_indices)
        sampled_gt_boxes = tf.gather_nd(matched_gt_boxes, gather_nd_indices)
        sampled_gt_classes = tf.gather_nd(matched_gt_classes,
                                          gather_nd_indices)
        sampled_gt_attributes = tf.gather_nd(matched_gt_attributes,
                                             gather_nd_indices)
        sampled_gt_indices = tf.gather_nd(matched_gt_indices,
                                          gather_nd_indices)

        return (sampled_rois, sampled_gt_boxes, sampled_gt_classes,
                sampled_gt_attributes, sampled_gt_indices)
Example #30
0
def get_retrieval_examples(serialized_example, mask_rate, bert_hub_module_path,
                           query_seq_len, block_seq_len):
    """Make retrieval examples."""
    feature_spec = dict(title_ids=tf.FixedLenSequenceFeature([], tf.int64,
                                                             True),
                        token_ids=tf.FixedLenSequenceFeature([], tf.int64,
                                                             True),
                        sentence_starts=tf.FixedLenSequenceFeature([],
                                                                   tf.int64,
                                                                   True))
    features = tf.parse_single_example(serialized_example, feature_spec)
    features = {k: tf.cast(v, tf.int32) for k, v in features.items()}

    title_ids = features["title_ids"]
    token_ids = features["token_ids"]
    sentence_starts = features["sentence_starts"]
    sentence_ends = tf.concat([sentence_starts[1:], [tf.size(token_ids)]], 0)

    tokenizer = bert_utils.get_tokenizer(bert_hub_module_path)
    cls_id, sep_id = tokenizer.convert_tokens_to_ids(["[CLS]", "[SEP]"])

    # Randomly choose a sentence and pretend that it is a query.
    query_index = tf.random.uniform(shape=[],
                                    minval=0,
                                    maxval=tf.size(sentence_starts),
                                    dtype=tf.int32)
    query_start = sentence_starts[query_index]
    query_end = sentence_ends[query_index]

    query_ids = token_ids[query_start:query_end]

    mask_query = tf.less(tf.random.uniform([]), mask_rate)

    def _apply_mask():
        return tf.concat([token_ids[:query_start], token_ids[query_end:]], 0)

    block_ids = tf.cond(pred=mask_query,
                        true_fn=_apply_mask,
                        false_fn=lambda: token_ids)

    query_ids, query_mask = bert_utils.pad_or_truncate(
        token_ids=query_ids,
        sequence_length=query_seq_len,
        cls_id=cls_id,
        sep_id=sep_id)
    block_ids, block_mask, block_segment_ids = bert_utils.pad_or_truncate_pair(
        token_ids_a=title_ids,
        token_ids_b=block_ids,
        sequence_length=block_seq_len,
        cls_id=cls_id,
        sep_id=sep_id)

    # Masked examples for single-sentence blocks don't make any sense.
    keep_example = tf.logical_or(tf.logical_not(mask_query),
                                 tf.greater(tf.size(sentence_starts), 1))

    return dict(keep_example=keep_example,
                mask_query=mask_query,
                query_ids=query_ids,
                query_mask=query_mask,
                block_ids=block_ids,
                block_mask=block_mask,
                block_segment_ids=block_segment_ids)