Ejemplo n.º 1
0
    def get_scheduled_sample_inputs(self, done_warm_start, groundtruth_items,
                                    generated_items, scheduled_sampling_func):
        """Scheduled sampling.

    Args:
      done_warm_start: whether we are done with warm start or not.
      groundtruth_items: list of ground truth items.
      generated_items: list of generated items.
      scheduled_sampling_func: scheduled sampling function to choose between
        groundtruth items and generated items.

    Returns:
      A mix list of ground truth and generated items.
    """
        def sample():
            """Calculate the scheduled sampling params based on iteration number."""
            with tf.variable_scope("scheduled_sampling", reuse=tf.AUTO_REUSE):
                return [
                    scheduled_sampling_func(item_gt, item_gen) for item_gt,
                    item_gen in zip(groundtruth_items, generated_items)
                ]

        cases = [
            (tf.logical_not(done_warm_start), lambda: groundtruth_items),
            (tf.logical_not(self.is_training), lambda: generated_items),
        ]
        output_items = tf.case(cases, default=sample, strict=True)

        return output_items
Ejemplo n.º 2
0
def beta_schedule(schedule, global_step, final_beta, decay_start, decay_end):
    """Get KL multiplier (beta) based on the schedule."""
    if decay_start > decay_end:
        raise ValueError("decay_end is smaller than decay_end.")

    # Since some of the TF schedules do not support incrementing a value,
    # in all of the schedules, we anneal the beta from final_beta to zero
    # and then reverse it at the bottom.
    if schedule == "constant":
        decayed_value = 0.0
    elif schedule == "linear":
        decayed_value = tf.train.polynomial_decay(
            learning_rate=final_beta,
            global_step=global_step - decay_start,
            decay_steps=decay_end - decay_start,
            end_learning_rate=0.0)
    elif schedule == "noisy_linear_cosine_decay":
        decayed_value = tf.train.noisy_linear_cosine_decay(
            learning_rate=final_beta,
            global_step=global_step - decay_start,
            decay_steps=decay_end - decay_start)
    # TODO(mechcoder): Add log_annealing schedule.
    else:
        raise ValueError("Unknown beta schedule.")

    increased_value = final_beta - decayed_value
    increased_value = tf.maximum(0.0, increased_value)

    beta = tf.case(pred_fn_pairs=[(tf.less(global_step,
                                           decay_start), lambda: 0.0),
                                  (tf.greater(global_step,
                                              decay_end), lambda: final_beta)],
                   default=lambda: increased_value)
    return beta
Ejemplo n.º 3
0
def blend(image1, image2, factor):
    """Blend image1 and image2 using 'factor'.

    Factor can be above 0.0.  A value of 0.0 means only image1 is used.
    A value of 1.0 means only image2 is used.  A value between 0.0 and
    1.0 means we linearly interpolate the pixel values between the two
    images.  A value greater than 1.0 "extrapolates" the difference
    between the two pixel values, and we clip the results to values
    between 0 and 255.

    Args:
      image1: An image Tensor of type uint8.
      image2: An image Tensor of type uint8.
      factor: A floating point value above 0.0.

    Returns:
      A blended image Tensor of type uint8.
    """

    def _blend():
        image_1 = tf.image.convert_image_dtype(image1, tf.float32)
        image_2 = tf.image.convert_image_dtype(image2, tf.float32)
        output = image_1 + factor * (image_2 - image_1)
        output = tf.where_v2(
            tf.logical_and(tf.less(0., factor), tf.less(factor, 1.)),
            x=output, y=tf.clip_by_value(output, 0., 255.))
        return tf.image.convert_image_dtype(output, tf.uint8)

    pred_fn_pairs = [
        (tf.equal(factor, 0.), lambda: image1),
        (tf.equal(factor, 1.), lambda: image2),
    ]
    return tf.case(
        pred_fn_pairs, default=_blend, exclusive=True, strict=True, name='blend')
Ejemplo n.º 4
0
        def _drop_channels(data):
            image = data["image"]

            def _drop(keep_i):
                shape = image.get_shape().as_list()
                size, num_channels = shape[:-1], shape[-1]
                return tf.concat([
                    image[:, :, i:i + 1] if i == keep_i else tf.random_uniform(
                        size + [1], noise_min, noise_max)
                    for i in range(num_channels)
                ],
                                 axis=2)

            def _drop_random_channel(coin_channel):
                return tf.case({
                    tf.equal(coin_channel, 0): lambda: _drop(0),
                    tf.equal(coin_channel, 1): lambda: _drop(1),
                    tf.equal(coin_channel, 2): lambda: _drop(2),
                })

            coin_keep_original = tf.random.uniform([],
                                                   0.0,
                                                   1.0,
                                                   dtype=tf.float32)
            coin_channel = tf.random.uniform([], 0, 3, dtype=tf.int32)
            image = tf.case({
                tf.less(coin_keep_original, keep_original):
                lambda: image,
                tf.greater_equal(coin_keep_original, keep_original):
                lambda: _drop_random_channel(coin_channel)
            })
            data["image"] = image
            return data
Ejemplo n.º 5
0
def _produce_posterior_estimate(posterior_dist, posterior_estimate_mode,
                                raw_var_name):
    """Create tensor representing estimate of posterior.

  Args:
    posterior_dist: An instance of `tfp.distributions.Distribution`.
        The variational posterior from which to produce an estimate of the
        variable in question.
    posterior_estimate_mode: A `Tensor` of dtype `tf.string`, which
        determines the inference mode.
    raw_var_name: The name of the variable over which inference is done.

  Returns:
    `z_sample`, a `Tensor` representing an estimate derived from the
        posterior distribution.
  """
    conds = [
        tf.equal(posterior_estimate_mode,
                 tf.constant(EstimatorModes.sample),
                 name="equal_sample_mode"),
        tf.equal(posterior_estimate_mode,
                 tf.constant(EstimatorModes.mean),
                 name="equal_mean_mode"),
        tf.equal(posterior_estimate_mode,
                 tf.constant(EstimatorModes.last_sample),
                 name="equal_last_sample_mode"),
    ]
    # pylint: disable=unnecessary-lambda
    results = [
        lambda: posterior_dist.sample(), lambda: posterior_dist.mean(),
        lambda: posterior_dist.last_sample()
    ]

    def default_case_branch_raising_error():
        err_msg = "Invalid posterior estimate mode."
        raise_err = tf.Assert(tf.constant(False), data=[tf.constant(err_msg)])
        with tf.control_dependencies([raise_err]):
            return posterior_dist.mean()

    if hasattr(posterior_dist, "last_sample"):
        cases = [(conds[0], results[0]), (conds[1], results[1]),
                 (conds[2], results[2])]
    else:
        cases = [(conds[0], results[0]), (conds[1], results[1])]
    z_sample = tf.case(cases,
                       exclusive=True,
                       default=default_case_branch_raising_error,
                       name="{}_posterior_estimate".format(raw_var_name))
    # pylint: enable=unnecessary-lambda
    return z_sample
Ejemplo n.º 6
0
  def test_cov_update_thunks(self):
    """Ensures covariance update ops run once per global_step."""
    with self._graph.as_default(), self.test_session() as sess:
      fisher_estimator = estimator.FisherEstimatorRoundRobin(
          variables=[self.weights],
          layer_collection=self.layer_collection,
          damping=0.2,
          cov_ema_decay=0.0)

      # Construct an op that executes one covariance update per step.
      global_step = tf.train.get_or_create_global_step()
      (cov_variable_thunks, cov_update_op_thunks, _,
       _) = fisher_estimator.create_ops_and_vars_thunks()
      for thunk in cov_variable_thunks:
        thunk()
      cov_matrices = [
          fisher_factor.cov
          for fisher_factor in self.layer_collection.get_factors()
      ]
      cov_update_op = tf.case([(tf.equal(global_step, i), thunk)
                               for i, thunk in enumerate(cov_update_op_thunks)])
      increment_global_step = global_step.assign_add(1)

      sess.run(tf.global_variables_initializer())
      initial_cov_values = sess.run(cov_matrices)

      # Ensure there's one update per covariance matrix.
      self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))

      # Test is no-op if only 1 covariance matrix.
      assert len(cov_matrices) > 1

      for i in range(len(cov_matrices)):
        # Compare new and old covariance values
        new_cov_values = sess.run(cov_matrices)
        is_cov_equal = [
            np.allclose(initial_cov_value, new_cov_value)
            for (initial_cov_value,
                 new_cov_value) in zip(initial_cov_values, new_cov_values)
        ]
        num_cov_equal = sum(is_cov_equal)

        # Ensure exactly one covariance matrix changes per step.
        self.assertEqual(num_cov_equal, len(cov_matrices) - i)

        # Run all covariance update ops.
        sess.run(cov_update_op)
        sess.run(increment_global_step)
def categorical_case(pmf, fns, rand=None):
  """Returns the outputs of fns[i] with probability pmf[i].

  Args:
    pmf: A 1-D tensor of probabilities, the probability mass function.
    fns: A list of callables that return tensors, same length as pmf.
    rand: An optional scalar between 0.0 and 1.0, the output of an RNG.

  Returns:
    A tensor, the output of fns[i] with probability pmf[i].
  """
  rand = tf.random_uniform([]) if rand is None else rand
  cmf = tf.pad(tf.cumsum(pmf), [(1, 0)])
  cmf = [cmf[i] for i in range(len(fns) + 1)]
  preds = [(rand >= a) & (rand < b) for a, b in zip(cmf[:-1], cmf[1:])]
  return tf.case(list(zip(preds, fns)), exclusive=True)
def apply_with_random_selector(x, func, num_cases):
    """Computes func(x, sel), with sel sampled from [0...num_cases-1].

  Args:
    x: input Tensor.
    func: Python function to apply.
    num_cases: Python int32, number of cases to sample sel from.

  Returns:
    The result of func(x, sel), where func receives the value of the
    selector as a python integer, but sel is sampled dynamically.
  """
    sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
    # Pass the real x only to one of the func calls.
    pairs = []
    for i in range(num_cases):

        def _apply(i_value=i):
            return func(x, i_value)

        pairs.append((tf.equal(sel, i), _apply))
    return tf.case(pairs)
Ejemplo n.º 9
0
  def _resize_pp(data):
    im = data["image"]

    if randomize_resize_method:
      # pick random resizing method
      r = tf.random_uniform([], 0, 3, dtype=tf.int32)
      im = tf.case({
          tf.equal(r, tf.cast(0, r.dtype)):
              _resize(im, tf.image.ResizeMethod.BILINEAR, True),
          tf.equal(r, tf.cast(1, r.dtype)):
              _resize(im, tf.image.ResizeMethod.NEAREST_NEIGHBOR, True),
          tf.equal(r, tf.cast(2, r.dtype)):
              _resize(im, tf.image.ResizeMethod.BICUBIC, True),
          # NOTE: use align_corners=False for AREA resize, but True for the
          # others. See https://github.com/tensorflow/tensorflow/issues/6720
          tf.equal(r, tf.cast(3, r.dtype)):
              _resize(im, tf.image.ResizeMethod.AREA, False),
      })
    else:
      im = tf.image.resize_images(im, im_size)
    data["image"] = im
    return data
Ejemplo n.º 10
0
def input_tensors_to_model_input(input_tensors,
                                 hparams,
                                 is_training,
                                 num_classes=constants.MIDI_PITCHES):
    """Processes an InputTensor into FeatureTensors and LabelTensors."""
    length = tf.cast(input_tensors.length, tf.int32)
    labels = tf.reshape(input_tensors.labels, (-1, num_classes))
    label_weights = tf.reshape(input_tensors.label_weights, (-1, num_classes))
    onsets = tf.reshape(input_tensors.onsets, (-1, num_classes))
    offsets = tf.reshape(input_tensors.offsets, (-1, num_classes))
    velocities = tf.reshape(input_tensors.velocities, (-1, num_classes))
    spec = tf.reshape(input_tensors.spec, (-1, hparams_frame_size(hparams)))

    # Slice specs and labels tensors so they are no longer than truncated_length.
    hparams_truncated_length = tf.cast(
        hparams.truncated_length_secs * hparams_frames_per_second(hparams),
        tf.int32)
    if hparams.truncated_length_secs:
        truncated_length = tf.reduce_min([hparams_truncated_length, length])
    else:
        truncated_length = length

    if is_training:
        truncated_note_sequence = tf.constant(0)
    else:
        truncated_note_sequence = truncate_note_sequence_op(
            input_tensors.note_sequence, truncated_length, hparams)

    # If max_expected_train_example_len is set, ensure that all examples are
    # padded to this length. This results in a fixed shape that can work on TPUs.
    if hparams.max_expected_train_example_len and is_training:
        # In this case, final_length is a constant.
        if hparams.truncated_length_secs:
            assert_op = tf.assert_equal(hparams.max_expected_train_example_len,
                                        hparams_truncated_length)
            with tf.control_dependencies([assert_op]):
                final_length = hparams.max_expected_train_example_len
        else:
            final_length = hparams.max_expected_train_example_len
    else:
        # In this case, it is min(hparams.truncated_length, length)
        final_length = truncated_length

    spec_delta = tf.shape(spec)[0] - final_length
    spec = tf.case([(spec_delta < 0,
                     lambda: tf.pad(spec, tf.stack([(0, -spec_delta),
                                                    (0, 0)]))),
                    (spec_delta > 0, lambda: spec[0:-spec_delta])],
                   default=lambda: spec)
    labels_delta = tf.shape(labels)[0] - final_length
    labels = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(labels, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: labels[0:-labels_delta])],
        default=lambda: labels)
    label_weights = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(label_weights, tf.stack([(0, -labels_delta),
                                                  (0, 0)]))),
         (labels_delta > 0, lambda: label_weights[0:-labels_delta])],
        default=lambda: label_weights)
    onsets = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(onsets, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: onsets[0:-labels_delta])],
        default=lambda: onsets)
    offsets = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(offsets, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: offsets[0:-labels_delta])],
        default=lambda: offsets)
    velocities = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(velocities, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: velocities[0:-labels_delta])],
        default=lambda: velocities)

    features = FeatureTensors(spec=tf.reshape(
        spec, (final_length, hparams_frame_size(hparams), 1)),
                              length=truncated_length,
                              sequence_id=tf.constant(0)
                              if is_training else input_tensors.sequence_id)
    labels = LabelTensors(
        labels=tf.reshape(labels, (final_length, num_classes)),
        label_weights=tf.reshape(label_weights, (final_length, num_classes)),
        onsets=tf.reshape(onsets, (final_length, num_classes)),
        offsets=tf.reshape(offsets, (final_length, num_classes)),
        velocities=tf.reshape(velocities, (final_length, num_classes)),
        note_sequence=truncated_note_sequence)

    return features, labels
Ejemplo n.º 11
0
    def parser(value):
        """Parse an Imagenet record from value."""
        keys_to_features = {
            'image/encoded':
            tf.FixedLenFeature((), tf.string, default_value=''),
            'image/format':
            tf.FixedLenFeature((), tf.string, default_value='jpeg'),
            'image/class/label':
            tf.FixedLenFeature([], dtype=tf.int64, default_value=-1),
            'image/class/text':
            tf.FixedLenFeature([], dtype=tf.string, default_value=''),
            'image/object/bbox/xmin':
            tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymin':
            tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/xmax':
            tf.VarLenFeature(dtype=tf.float32),
            'image/object/bbox/ymax':
            tf.VarLenFeature(dtype=tf.float32),
            'image/object/class/label':
            tf.VarLenFeature(dtype=tf.int64),
        }

        parsed = tf.parse_single_example(value, keys_to_features)
        encoded_image = tf.reshape(parsed['image/encoded'],
                                   shape=[],
                                   name='encoded_image')
        image_format = parsed['image/format']
        xmin = tf.expand_dims(parsed['image/object/bbox/xmin'].values, 0)
        ymin = tf.expand_dims(parsed['image/object/bbox/ymin'].values, 0)
        xmax = tf.expand_dims(parsed['image/object/bbox/xmax'].values, 0)
        ymax = tf.expand_dims(parsed['image/object/bbox/ymax'].values, 0)

        # Note that we impose an ordering of (y, x) just to make life difficult.
        bbox = tf.concat([ymin, xmin, ymax, xmax], 0)

        # Force the variable number of bounding boxes into the shape
        # [1, num_boxes, coords].
        bbox = tf.expand_dims(bbox, 0)
        bbox = tf.transpose(bbox, [0, 2, 1])

        def decode_png():
            return tf.image.decode_png(encoded_image, 3)

        def decode_jpg():
            return tf.image.decode_jpeg(encoded_image, 3)

        # If image format is PNG, use decode_png, default to jpg.
        pred_fn_pairs = {
            tf.logical_or(tf.equal(image_format, 'png'),
                          tf.equal(image_format, 'PNG')):
            decode_png
        }

        image = tf.case(pred_fn_pairs, default=decode_jpg, exclusive=True)
        image.set_shape([None, None, 3])

        image = preprocess(image, bbox)

        label = tf.cast(tf.reshape(parsed['image/class/label'], shape=[]),
                        dtype=tf.int32,
                        name='cast_label')
        label = tf.reshape(label, [1])
        return tf.cast(image, tf.float32), label
def input_tensors_to_model_input(input_tensors,
                                 hparams,
                                 is_training,
                                 num_classes=constants.MIDI_PITCHES):
    """Processes an InputTensor into FeatureTensors and LabelTensors."""
    length = tf.cast(input_tensors.length, tf.int32)
    labels = tf.reshape(input_tensors.labels, (-1, num_classes))
    label_weights = tf.reshape(input_tensors.label_weights, (-1, num_classes))
    onsets = tf.reshape(input_tensors.onsets, (-1, num_classes))
    offsets = tf.reshape(input_tensors.offsets, (-1, num_classes))
    velocities = tf.reshape(input_tensors.velocities, (-1, num_classes))
    spec = tf.reshape(input_tensors.spec, (-1, hparams_frame_size(hparams)))

    # Slice specs and labels tensors so they are no longer than truncated_length.
    hparams_truncated_length = tf.cast(
        hparams.truncated_length_secs * hparams_frames_per_second(hparams),
        tf.int32)
    if hparams.truncated_length_secs:
        truncated_length = tf.reduce_min([hparams_truncated_length, length])
    else:
        truncated_length = length

    if is_training:
        truncated_note_sequence = tf.constant(0)
    else:
        truncated_note_sequence = truncate_note_sequence_op(
            input_tensors.note_sequence, truncated_length, hparams)

    # If max_expected_train_example_len is set, ensure that all examples are
    # padded to this length. This results in a fixed shape that can work on TPUs.
    if hparams.max_expected_train_example_len and is_training:
        # In this case, final_length is a constant.
        if hparams.truncated_length_secs:
            assert_op = tf.assert_equal(hparams.max_expected_train_example_len,
                                        hparams_truncated_length)
            with tf.control_dependencies([assert_op]):
                final_length = hparams.max_expected_train_example_len
        else:
            final_length = hparams.max_expected_train_example_len
    else:
        # In this case, it is min(hparams.truncated_length, length)
        final_length = truncated_length

    spec_delta = tf.shape(spec)[0] - final_length
    spec = tf.case([(spec_delta < 0,
                     lambda: tf.pad(spec, tf.stack([(0, -spec_delta),
                                                    (0, 0)]))),
                    (spec_delta > 0, lambda: spec[0:-spec_delta])],
                   default=lambda: spec)
    labels_delta = tf.shape(labels)[0] - final_length
    labels = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(labels, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: labels[0:-labels_delta])],
        default=lambda: labels)
    label_weights = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(label_weights, tf.stack([(0, -labels_delta),
                                                  (0, 0)]))),
         (labels_delta > 0, lambda: label_weights[0:-labels_delta])],
        default=lambda: label_weights)
    onsets = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(onsets, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: onsets[0:-labels_delta])],
        default=lambda: onsets)
    offsets = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(offsets, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: offsets[0:-labels_delta])],
        default=lambda: offsets)
    velocities = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(velocities, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: velocities[0:-labels_delta])],
        default=lambda: velocities)

    features = FeatureTensors(spec=tf.reshape(
        spec, (final_length, hparams_frame_size(hparams), 1)),
                              length=truncated_length,
                              sequence_id=tf.constant(0)
                              if is_training else input_tensors.sequence_id)
    labels = LabelTensors(
        labels=tf.reshape(labels, (final_length, num_classes)),
        label_weights=tf.reshape(label_weights, (final_length, num_classes)),
        onsets=tf.reshape(onsets, (final_length, num_classes)),
        offsets=tf.reshape(offsets, (final_length, num_classes)),
        velocities=tf.reshape(velocities, (final_length, num_classes)),
        note_sequence=truncated_note_sequence)

    if hparams.drum_data_map:
        labels_dict = labels._asdict()
        for k in ('labels', 'onsets', 'offsets'):
            labels_dict[k] = drum_mappings.map_pianoroll(
                labels_dict[k],
                mapping_name=hparams.drum_data_map,
                reduce_mode='any',
                min_pitch=constants.MIN_MIDI_PITCH)
        for k in ('label_weights', 'velocities'):
            labels_dict[k] = drum_mappings.map_pianoroll(
                labels_dict[k],
                mapping_name=hparams.drum_data_map,
                reduce_mode='max',
                min_pitch=constants.MIN_MIDI_PITCH)
        if labels_dict['note_sequence'].dtype == tf.string:
            labels_dict['note_sequence'] = tf.py_func(
                functools.partial(drum_mappings.map_sequences,
                                  mapping_name=hparams.drum_data_map),
                [labels_dict['note_sequence']],
                tf.string,
                name='get_drum_sequences',
                stateful=False)
            labels_dict['note_sequence'].set_shape(())
        labels = LabelTensors(**labels_dict)

    return features, labels
Ejemplo n.º 13
0
    def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters,
                    is_training):
        """
    Args:
      layer_id: current layer
      prev_layers: cache of previous layers. for skip connections
      start_idx: where to start looking at. technically, we can infer this
        from layer_id, but why bother...
      is_training: for batch_norm
    """

        inputs = prev_layers[-1]
        if self.whole_channels:
            if self.data_format == "NHWC":
                inp_h = inputs.get_shape()[1].value
                inp_w = inputs.get_shape()[2].value
                inp_c = inputs.get_shape()[3].value
            elif self.data_format == "NCHW":
                inp_c = inputs.get_shape()[1].value
                inp_h = inputs.get_shape()[2].value
                inp_w = inputs.get_shape()[3].value

            count = self.sample_arc[start_idx]
            branches = {}
            with tf.variable_scope("branch_0"):
                y = self._conv_branch(inputs,
                                      3,
                                      is_training,
                                      out_filters,
                                      out_filters,
                                      start_idx=0)
                branches[tf.equal(count, 0)] = lambda: y
            with tf.variable_scope("branch_1"):
                y = self._conv_branch(inputs,
                                      3,
                                      is_training,
                                      out_filters,
                                      out_filters,
                                      start_idx=0,
                                      separable=True)
                branches[tf.equal(count, 1)] = lambda: y
            with tf.variable_scope("branch_2"):
                y = self._conv_branch(inputs,
                                      5,
                                      is_training,
                                      out_filters,
                                      out_filters,
                                      start_idx=0)
                branches[tf.equal(count, 2)] = lambda: y
            with tf.variable_scope("branch_3"):
                y = self._conv_branch(inputs,
                                      5,
                                      is_training,
                                      out_filters,
                                      out_filters,
                                      start_idx=0,
                                      separable=True)
                branches[tf.equal(count, 3)] = lambda: y
            if self.num_branches >= 5:
                with tf.variable_scope("branch_4"):
                    y = self._pool_branch(inputs,
                                          is_training,
                                          out_filters,
                                          "avg",
                                          start_idx=0)
                branches[tf.equal(count, 4)] = lambda: y
            if self.num_branches >= 6:
                with tf.variable_scope("branch_5"):
                    y = self._pool_branch(inputs,
                                          is_training,
                                          out_filters,
                                          "max",
                                          start_idx=0)
                branches[tf.equal(count, 5)] = lambda: y
            #out = tf.case(branches, default=lambda: tf.constant(0, tf.float32),
            #              exclusive=True)
            out = tf.case(
                branches,
                default=lambda: tf.constant(
                    0,
                    tf.float32,
                    shape=[self.batch_size, out_filters, inp_h, inp_w]),
                exclusive=True)

            if self.data_format == "NHWC":
                out.set_shape([None, inp_h, inp_w, out_filters])
            elif self.data_format == "NCHW":
                out.set_shape([None, out_filters, inp_h, inp_w])
        else:
            count = self.sample_arc[start_idx:start_idx +
                                    2 * self.num_branches]
            branches = []
            with tf.variable_scope("branch_0"):
                branches.append(
                    self._conv_branch(inputs,
                                      3,
                                      is_training,
                                      count[1],
                                      out_filters,
                                      start_idx=count[0]))
            with tf.variable_scope("branch_1"):
                branches.append(
                    self._conv_branch(inputs,
                                      3,
                                      is_training,
                                      count[3],
                                      out_filters,
                                      start_idx=count[2],
                                      separable=True))
            with tf.variable_scope("branch_2"):
                branches.append(
                    self._conv_branch(inputs,
                                      5,
                                      is_training,
                                      count[5],
                                      out_filters,
                                      start_idx=count[4]))
            with tf.variable_scope("branch_3"):
                branches.append(
                    self._conv_branch(inputs,
                                      5,
                                      is_training,
                                      count[7],
                                      out_filters,
                                      start_idx=count[6],
                                      separable=True))
            if self.num_branches >= 5:
                with tf.variable_scope("branch_4"):
                    branches.append(
                        self._pool_branch(inputs,
                                          is_training,
                                          count[9],
                                          "avg",
                                          start_idx=count[8]))
            if self.num_branches >= 6:
                with tf.variable_scope("branch_5"):
                    branches.append(
                        self._pool_branch(inputs,
                                          is_training,
                                          count[11],
                                          "max",
                                          start_idx=count[10]))

            with tf.variable_scope("final_conv"):
                w = create_weight(
                    "w", [self.num_branches * out_filters, out_filters])
                w_mask = tf.constant(
                    [False] * (self.num_branches * out_filters), tf.bool)
                new_range = tf.range(0,
                                     self.num_branches * out_filters,
                                     dtype=tf.int32)
                for i in range(self.num_branches):
                    start = out_filters * i + count[2 * i]
                    new_mask = tf.logical_and(
                        start <= new_range,
                        new_range < start + count[2 * i + 1])
                    w_mask = tf.logical_or(w_mask, new_mask)
                w = tf.boolean_mask(w, w_mask)
                w = tf.reshape(w, [1, 1, -1, out_filters])

                inp = prev_layers[-1]
                if self.data_format == "NHWC":
                    branches = tf.concat(branches, axis=3)
                elif self.data_format == "NCHW":
                    branches = tf.concat(branches, axis=1)
                    N = tf.shape(inp)[0]
                    H = inp.get_shape()[2].value
                    W = inp.get_shape()[3].value
                    branches = tf.reshape(branches, [N, -1, H, W])
                out = tf.nn.conv2d(branches,
                                   w, [1, 1, 1, 1],
                                   "SAME",
                                   data_format=self.data_format)
                out = batch_norm(out,
                                 is_training,
                                 data_format=self.data_format)
                out = tf.nn.relu(out)

        if layer_id > 0:
            if self.whole_channels:
                skip_start = start_idx + 1
            else:
                skip_start = start_idx + 2 * self.num_branches
            skip = self.sample_arc[skip_start:skip_start + layer_id]
            with tf.variable_scope("skip"):
                res_layers = []
                for i in range(layer_id):
                    res_layers.append(
                        tf.cond(tf.equal(skip[i], 1), lambda: prev_layers[i],
                                lambda: tf.zeros_like(prev_layers[i])))
                res_layers.append(out)
                out = tf.add_n(res_layers)
                out = batch_norm(out,
                                 is_training,
                                 data_format=self.data_format)

        return out
Ejemplo n.º 14
0
 def _drop_random_channel(coin_channel):
     return tf.case({
         tf.equal(coin_channel, 0): lambda: _drop(0),
         tf.equal(coin_channel, 1): lambda: _drop(1),
         tf.equal(coin_channel, 2): lambda: _drop(2),
     })
Ejemplo n.º 15
0
  def test_inv_update_thunks(self):
    """Ensures inverse update ops run once per global_step."""
    with self._graph.as_default(), self.test_session() as sess:
      fisher_estimator = estimator.FisherEstimatorRoundRobin(
          variables=[self.weights],
          layer_collection=self.layer_collection,
          damping=0.2,
          cov_ema_decay=0.0)

      # Construct op that updates one inverse per global step.
      global_step = tf.train.get_or_create_global_step()
      (cov_variable_thunks, _, inv_variable_thunks,
       inv_update_op_thunks) = fisher_estimator.create_ops_and_vars_thunks()
      for thunk in cov_variable_thunks:
        thunk()
      for thunk in inv_variable_thunks:
        thunk()
      inv_matrices = [
          matrix
          for fisher_factor in self.layer_collection.get_factors()
          for matrix in fisher_factor._matpower_by_exp_and_damping.values()
      ]
      inv_update_op = tf.case([(tf.equal(global_step, i), thunk)
                               for i, thunk in enumerate(inv_update_op_thunks)])
      increment_global_step = global_step.assign_add(1)

      sess.run(tf.global_variables_initializer())
      initial_inv_values = sess.run(inv_matrices)

      # Ensure there's one update per inverse matrix. This is true as long as
      # there's no fan-in/fan-out or parameter re-use.
      self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))

      # Test is no-op if only 1 invariance matrix.
      assert len(inv_matrices) > 1

      # Assign each covariance matrix a value other than the identity. This
      # ensures that the inverse matrices are updated to something different as
      # well.
      sess.run([
          fisher_factor._cov.add_to_average(
              2 * tf.eye(int(fisher_factor._cov_shape[0])))
          for fisher_factor in self.layer_collection.get_factors()
      ])

      for i in range(len(inv_matrices)):
        # Compare new and old inverse values
        new_inv_values = sess.run(inv_matrices)
        is_inv_equal = [
            np.allclose(initial_inv_value, new_inv_value)
            for (initial_inv_value,
                 new_inv_value) in zip(initial_inv_values, new_inv_values)
        ]
        num_inv_equal = sum(is_inv_equal)

        # Ensure exactly one inverse matrix changes per step.
        self.assertEqual(num_inv_equal, len(inv_matrices) - i)

        # Run all inverse update ops.
        sess.run(inv_update_op)
        sess.run(increment_global_step)