Пример #1
0
    def create_id3_embedding(self, videos):
        """Embeds the given videos using the Inflated 3D Convolution network.

      Downloads the graph of the I3D from tf.hub and adds it to the graph on the
      first call.

      Args:
        videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
          Expected range is [-1, 1].

      Returns:
        embedding: <float32>[batch_size, embedding_size]. embedding_size depends
                   on the model used.

      Raises:
        ValueError: when a provided embedding_layer is not supported.
      """

        batch_size = 16
        module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"

        # Making sure that we import the graph separately for
        # each different input video tensor.
        module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
            videos.name).replace(":", "_")

        assert_ops = [
            tf.Assert(
                tf.reduce_max(videos) <= 1.001,
                ["max value in frame is > 1", videos]),
            tf.Assert(
                tf.reduce_min(videos) >= -1.001,
                ["min value in frame is < -1", videos]),
            tf.assert_equal(tf.shape(videos)[0],
                            batch_size,
                            ["invalid frame batch size: ",
                             tf.shape(videos)],
                            summarize=6),
        ]
        with tf.control_dependencies(assert_ops):
            videos = tf.identity(videos)

        module_scope = "%s_apply_default/" % module_name

        # To check whether the module has already been loaded into the graph, we look
        # for a given tensor name. If this tensor name exists, we assume the function
        # has been called before and the graph was imported. Otherwise we import it.
        # Note: in theory, the tensor could exist, but have wrong shapes.
        # This will happen if create_id3_embedding is called with a frames_placehoder
        # of wrong size/batch size, because even though that will throw a tf.Assert
        # on graph-execution time, it will insert the tensor (with wrong shape) into
        # the graph. This is why we need the following assert.
        video_batch_size = int(videos.shape[0])
        assert video_batch_size in [batch_size, -1, None], "Invalid batch size"
        tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
        if not _is_in_graph(tensor_name):
            # i3d_model = hub.Module(module_spec, name=module_name)
            self.model(videos)

        # gets the kinetics-i3d-400-logits layer
        tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
        tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)

        return tensor
Пример #2
0
  def _parse_train_data(self, data):
    """Parses data for training.

    Args:
      data: the decoded tensor dictionary from TfExampleDecoder.

    Returns:
      image: image tensor that is preproessed to have normalized value and
        dimension [output_size[0], output_size[1], 3]
      labels: a dictionary of tensors used for training. The following describes
        {key: value} pairs in the dictionary.
        image_info: a 2D `Tensor` that encodes the information of the image and
          the applied preprocessing. It is in the format of
          [[original_height, original_width], [scaled_height, scaled_width],
        anchor_boxes: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, 4] representing anchor boxes at each level.
        rpn_score_targets: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, anchors_per_location]. The height_l and
          width_l represent the dimension of class logits at l-th level.
        rpn_box_targets: ordered dictionary with keys
          [min_level, min_level+1, ..., max_level]. The values are tensor with
          shape [height_l, width_l, anchors_per_location * 4]. The height_l and
          width_l represent the dimension of bounding box regression output at
          l-th level.
        gt_boxes: groundtruth bounding box annotations. The box is represented
           in [y1, x1, y2, x2] format. The coordinates are w.r.t the scaled
           image that is fed to the network. The tennsor is padded with -1 to
           the fixed dimension [self._max_num_instances, 4].
        gt_classes: groundtruth classes annotations. The tennsor is padded
          with -1 to the fixed dimension [self._max_num_instances].
        gt_attributes: groundtruth attributes annotations. The tennsor is padded
          with -1 to the fixed dimension [self._max_num_instances,
          self._num_attributes].
        gt_masks: groundtrugh masks cropped by the bounding box and
          resized to a fixed size determined by mask_crop_size.
    """
    classes = data['groundtruth_classes']
    attributes = tf.cast(data['groundtruth_attributes'], dtype=tf.float32)
    attributes = tf.reshape(attributes, [-1, self._num_attributes])
    boxes = data['groundtruth_boxes']
    if self._include_mask:
      masks = data['groundtruth_instance_masks']

    is_crowds = data['groundtruth_is_crowd']
    # Skips annotations with `is_crowd` = True.
    if self._skip_crowd_during_training and self._is_training:
      num_groundtrtuhs = tf.shape(classes)[0]
      with tf.control_dependencies([num_groundtrtuhs, is_crowds]):
        indices = tf.cond(
            tf.greater(tf.size(is_crowds), 0),
            lambda: tf.where(tf.logical_not(is_crowds))[:, 0],
            lambda: tf.cast(tf.range(num_groundtrtuhs), tf.int64))
      classes = tf.gather(classes, indices)
      attributes = tf.gather(attributes, indices)
      boxes = tf.gather(boxes, indices)
      if self._include_mask:
        masks = tf.gather(masks, indices)

    # Gets original image and its size.
    image = data['image']
    image_shape = tf.shape(image)[0:2]

    # Normalizes image with mean and std pixel values.
    image = input_utils.normalize_image(image)

    # Flips image randomly during training.
    if self._aug_rand_hflip:
      if self._include_mask:
        image, boxes, masks = input_utils.random_horizontal_flip(
            image, boxes, masks)
      else:
        image, boxes = input_utils.random_horizontal_flip(
            image, boxes)

    # Converts boxes from normalized coordinates to pixel coordinates.
    # Now the coordinates of boxes are w.r.t. the original image.
    boxes = box_utils.denormalize_boxes(boxes, image_shape)

    # Resizes and crops image.
    image, image_info = input_utils.resize_and_crop_image(
        image,
        self._output_size,
        padded_size=input_utils.compute_padded_size(
            self._output_size, 2 ** self._max_level),
        aug_scale_min=self._aug_scale_min,
        aug_scale_max=self._aug_scale_max)
    image_height, image_width, _ = image.get_shape().as_list()

    # Resizes and crops boxes.
    # Now the coordinates of boxes are w.r.t the scaled image.
    image_scale = image_info[2, :]
    offset = image_info[3, :]
    boxes = input_utils.resize_and_crop_boxes(
        boxes, image_scale, image_info[1, :], offset)

    # Filters out ground truth boxes that are all zeros.
    indices = box_utils.get_non_empty_box_indices(boxes)
    boxes = tf.gather(boxes, indices)
    classes = tf.gather(classes, indices)
    attributes = tf.gather(attributes, indices)
    if self._include_mask:
      masks = tf.gather(masks, indices)
      # Transfer boxes to the original image space and do normalization.
      cropped_boxes = boxes + tf.tile(tf.expand_dims(offset, axis=0), [1, 2])
      cropped_boxes /= tf.tile(tf.expand_dims(image_scale, axis=0), [1, 2])
      cropped_boxes = box_utils.normalize_boxes(cropped_boxes, image_shape)
      num_masks = tf.shape(masks)[0]
      masks = tf.image.crop_and_resize(
          tf.expand_dims(masks, axis=-1),
          cropped_boxes,
          box_indices=tf.range(num_masks, dtype=tf.int32),
          crop_size=[self._mask_crop_size, self._mask_crop_size],
          method='bilinear')
      masks = tf.squeeze(masks, axis=-1)

    # Assigns anchor targets.
    # Note that after the target assignment, box targets are absolute pixel
    # offsets w.r.t. the scaled image.
    input_anchor = anchor.Anchor(
        self._min_level,
        self._max_level,
        self._num_scales,
        self._aspect_ratios,
        self._anchor_size,
        (image_height, image_width))
    anchor_labeler = anchor.RpnAnchorLabeler(
        input_anchor,
        self._rpn_match_threshold,
        self._rpn_unmatched_threshold,
        self._rpn_batch_size_per_im,
        self._rpn_fg_fraction)
    rpn_score_targets, rpn_box_targets = anchor_labeler.label_anchors(
        boxes, tf.cast(tf.expand_dims(classes, axis=-1), dtype=tf.float32))

    # If bfloat16 is used, casts input image to tf.bfloat16.
    if self._use_bfloat16:
      image = tf.cast(image, dtype=tf.bfloat16)

    # Packs labels for model_fn outputs.
    labels = {
        'anchor_boxes': input_anchor.multilevel_boxes,
        'image_info': image_info,
        'rpn_score_targets': rpn_score_targets,
        'rpn_box_targets': rpn_box_targets,
    }
    labels['gt_boxes'] = input_utils.clip_or_pad_to_fixed_size(
        boxes, self._max_num_instances, -1)
    labels['gt_classes'] = input_utils.clip_or_pad_to_fixed_size(
        classes, self._max_num_instances, -1)
    labels['gt_attributes'] = input_utils.clip_or_pad_to_fixed_size(
        attributes, self._max_num_instances, -1)
    if self._include_mask:
      labels['gt_masks'] = input_utils.clip_or_pad_to_fixed_size(
          masks, self._max_num_instances, -1)

    return image, labels
Пример #3
0
def vae_model_fn(features, labels, mode, params):
    # Build mtf_features & seq length dict for getting number of microbatches
    # We need to pack inputs into a dict to pass into serialize_training_step

    H = W = params["dataset"]["image_size"]  # TODO: check equal
    mode_str = mode_to_str(mode)
    batch_size = params[f"{mode_str}_batch_size"]
    n_channels = params.get("input_channels", 3)
    model = DiscreteVAE(num_tokens=params["num_tokens"],
                        dim=params["n_embd"],
                        hidden_dim=params["hidden_dim"],
                        input_channels=n_channels,
                        convblocks=params.get("convblocks", [(3, 64), (3, 128),
                                                             (3, 256)]),
                        recompute_grad=params.get("recompute_grad", False),
                        use_bf16=params.get("use_bf16", False),
                        stack_factor=params.get("stack_factor", 1),
                        dimensions=H)

    if mode == tf.estimator.ModeKeys.PREDICT:
        raise NotImplementedError

    train_gumbel = params.get("train_gumbel_hard", True)
    eval_gumbel = params.get("eval_gumbel_hard", True)

    # We're not predicting, so we better be training or evaluating
    assert (mode == tf.estimator.ModeKeys.TRAIN
            or mode == tf.estimator.ModeKeys.EVAL)

    gumbel = train_gumbel if mode == tf.estimator.ModeKeys.TRAIN else eval_gumbel

    if params.get("temp_anneal_steps", None):
        warmup_frac = tf.cast(tf.train.get_global_step(),
                              tf.float32) / params["temp_anneal_steps"]
        warmup_frac = tf.minimum(warmup_frac, tf.constant(1.0))
        temp = params["temp_start"] - warmup_frac * (params["temp_start"] -
                                                     params["temp"])
    else:
        temp = params.get("temp", 1.0)

    # TODO: add back in microbatching
    if params.get("use_bf16", False):
        with tf.tpu.bfloat16_scope():
            with tf.variable_scope("vae"):
                loss, reconstruction = model.forward(features,
                                                     return_recon_loss=True,
                                                     temperature=temp,
                                                     hard_gumbel=gumbel)
                loss = tf.cast(loss, tf.float32)
                reconstruction = tf.cast(reconstruction, tf.float32)
    else:
        with tf.variable_scope("vae"):
            loss, reconstruction = model.forward(features,
                                                 return_recon_loss=True,
                                                 temperature=temp,
                                                 hard_gumbel=gumbel)

    optimizer = tf.train.AdamOptimizer(learning_rate=params["lr"])
    optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    global_step = tf.train.get_or_create_global_step()
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step)

    def host_call_fn(gs, loss, input, reconstruction):
        gs = gs[0]
        loss = tf.math.reduce_mean(loss)
        denormalize = lambda x: (x + 1) / 2

        with tf2.summary.create_file_writer(params['model_path']).as_default():
            tf2.summary.scalar('loss', loss, step=gs)
            tf2.summary.image('input_image', denormalize(input), step=gs)
            tf2.summary.image('reconstruction_image',
                              denormalize(reconstruction),
                              step=gs)

            return tf.summary.all_v2_summary_ops()

    def metric_fn(gs, loss, input, reconstruction):
        gs = gs[0]
        loss = tf.math.reduce_mean(loss)
        denormalize = lambda x: (x + 1) / 2

        with tf2.summary.create_file_writer(params['model_path']).as_default():
            loss_op = tf.metrics.mean(loss)

            with tf2.summary.record_if(loss_op[0] < tf.constant(1e-9)):
                tf2.summary.image('eval/input_image',
                                  denormalize(input),
                                  step=gs)
                tf2.summary.image('eval/reconstruction_image',
                                  denormalize(reconstruction),
                                  step=gs)

            with tf.control_dependencies(tf.summary.all_v2_summary_ops()):
                dummy_op = tf.no_op()

            return {"_loss": loss_op, "zzz_dummy": (tf.constant(0), dummy_op)}

    # To log the loss, current learning rate, and epoch for Tensorboard, the
    # summary op needs to be run on the host CPU via host_call. host_call
    # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
    # dimension. These Tensors are implicitly concatenated to
    # [params['batch_size']].
    gs_t = tf.reshape(global_step, [1])
    loss_t = tf.reshape(loss, [1])

    host_call = (host_call_fn, [gs_t, loss_t, features, reconstruction])
    metric = (metric_fn, [gs_t, loss_t, features, reconstruction])

    return tpu_estimator.TPUEstimatorSpec(
        mode,
        loss=loss,
        host_call=host_call if mode == tf.estimator.ModeKeys.TRAIN else None,
        train_op=train_op,
        eval_metrics=metric)
Пример #4
0
    def custom_getter(getter, name, *args, **kwargs):
        """The custom getter that will be returned."""
        if not kwargs.get("trainable", True):
            return getter(name, *args, **kwargs)
        if kwargs["dtype"] not in _OK_DTYPES_FOR_BBB:
            raise ValueError("Disallowed data type {}.".format(
                kwargs["dtype"]))

        var_scope = tf.get_variable_scope()
        if var_scope.reuse and not fresh_noise_per_connection:
            # Re-use the sampling noise by returning the very same posterior sample
            # if configured to do so.
            the_match = [
                x for x in get_variable_metadata()
                if x.raw_variable_name == name
            ]
            if not the_match:
                raise ValueError(
                    "Internal error. No metadata for variable {}".format(name))
            if len(the_match) > 1:
                raise ValueError(
                    "Multiple matches for variable {}. Matches: {}".format(
                        name, [x.raw_variable_name for x in the_match]))

            return the_match[0].posterior_estimate

        raw_variable_shape = kwargs["shape"]

        def construct_subgraph():
            """Constructs subgraph used to reparameterize the variable in question."""
            posterior = posterior_builder(getter, name=name, *args, **kwargs)
            prior = prior_builder(getter, name=name, *args, **kwargs)

            # If the user does not return an extra dictionary of prior variables,
            # then fill in an empty dictionary.
            if isinstance(posterior, collections.Sequence):
                posterior_dist, posterior_vars = posterior
            else:
                posterior_dist, posterior_vars = posterior, {}

            if isinstance(prior, collections.Sequence):
                prior_dist, prior_vars = prior
            else:
                prior_dist, prior_vars = prior, {}

            if posterior_dist.reparameterization_type != _OK_PZATION_TYPE:
                raise ValueError(
                    "Distribution {} incompatible with Bayes by Backprop.".
                    format(posterior_dist.__class__.__name__))

            posterior_estimator = _produce_posterior_estimate(
                posterior_dist, sampling_mode_tensor, name)
            kl_cost = kl_builder(posterior_dist, prior_dist,
                                 posterior_estimator)
            variable_metadata = _VariableMetadata(
                raw_variable_name=name,
                raw_variable_shape=raw_variable_shape,
                scope_name=var_scope.name,
                posterior=posterior_dist,
                posterior_estimate=posterior_estimator,
                prior=prior_dist,
                kl_cost=kl_cost,
                prior_vars=prior_vars,
                posterior_vars=posterior_vars)
            return posterior_estimator, variable_metadata

        # Entering the `tf.control_dependencies(None)` context is crucial to
        # provide compatibility with `tf.while_loop` and thus RNNs. The main thing
        # it does is making the `kl_cost` fetchable by causing these ops to be
        # created outside the context of any tf.while_loop. Note also that it causes
        # a RNN core's weights to be sampled just once when unrolled over a
        # sequence, rather than at every timestep.
        control_deps = [] if keep_control_dependencies else None
        with tf.control_dependencies(control_deps):
            posterior_estimator, var_metadata = construct_subgraph()

        # Only add these ops to a collection once per unique variable.
        # This is to ensure that KL costs are not tallied up more than once.
        var_with_name = _all_var_metadata_registry[tf.get_default_graph()].get(
            name)
        if var_with_name is None:
            _all_var_metadata_registry[
                tf.get_default_graph()][name] = var_metadata

        return posterior_estimator
Пример #5
0
    def __call__(self, x):
        # Constrained sequence
        cs_scores = np.array([[10.0, 12.0, 6.0, 4.0], [13.0, 12.0, 11.0,
                                                       10.0]])
        cs_input = np.array([cs_scores, cs_scores, cs_scores],
                            dtype=np.float32)
        cs_transition_weights = np.array(
            [[-1.0, 1.0, -2.0, 2.0, 0.0], [3.0, -3.0, 4.0, -4.0, 0.0],
             [5.0, 1.0, 10.0, 1.0, 1.0], [-7.0, 7.0, -8.0, 8.0, 0.0],
             [0.0, 1.0, 2.0, 3.0, 0.0]],
            dtype=np.float32)
        cs_allowed_transitions = np.array([[True, True, True, True, True],
                                           [True, True, True, True, True],
                                           [True, False, True, False, False],
                                           [True, True, True, True, True],
                                           [True, False, True, True, True]])
        constrained_sequence = text.viterbi_constrained_sequence(
            cs_input, [2, 2, 2],
            allowed_transitions=cs_allowed_transitions,
            transition_weights=cs_transition_weights,
            use_log_space=True,
            use_start_and_end_states=True)
        # Max Spanning Tree
        mst_num_nodes = tf.constant([4, 3], tf.int32)
        mst_scores = tf.constant(
            [[[0, 0, 0, 0], [1, 0, 0, 0], [1, 2, 0, 0], [1, 2, 3, 4]],
             [[4, 3, 2, 9], [0, 0, 2, 9], [0, 0, 0, 9], [9, 9, 9, 9]]],
            tf.int32)  # pyformat: disable
        (max_spanning_tree,
         _) = text.max_spanning_tree(mst_num_nodes, mst_scores)
        # Normalize
        normalized = text.case_fold_utf8(['A String'])
        normalized = text.normalize_utf8(normalized)
        # Regex split
        regex_split = text.regex_split(input=['Yo dawg!'],
                                       delim_regex_pattern=r'\s')
        # Rouge-L
        rl_hypotheses = tf.ragged.constant(
            [['captain', 'of', 'the', 'delta', 'flight'],
             ['the', '1990', 'transcript']])
        rl_references = tf.ragged.constant(
            [['delta', 'air', 'lines', 'flight'],
             ['this', 'concludes', 'the', 'transcript']])
        (rouge_l, _, _) = text.metrics.rouge_l(rl_hypotheses, rl_references)
        # Sentence breaking
        sb_token_word = [['Welcome', 'to', 'the', 'U.S.', '!', 'Harry'],
                         ['Wu', 'Tang', 'Clan', ';', 'ain\'t', 'nothing']]
        sb_token_properties = [[0, 0, 0, 256, 0, 0], [0, 0, 0, 0, 0, 0]]
        sb_token_starts = []
        sb_token_ends = []
        for sentence in sb_token_word:
            sentence_string = ''
            sentence_start = []
            sentence_end = []
            for word in sentence:
                sentence_start.append(len(sentence_string))
                sentence_string = sentence_string.join([word, ' '])
                sentence_end.append(len(sentence_string))
            sb_token_starts.append(sentence_start)
            sb_token_ends.append(sentence_end)
        sb_token_starts = tf.constant(sb_token_starts, dtype=tf.int64)
        sb_token_ends = tf.constant(sb_token_ends, dtype=tf.int64)
        sb_token_properties = tf.ragged.constant(sb_token_properties,
                                                 dtype=tf.int64)
        (sentence_breaking, _, _,
         _) = text.sentence_fragments(sb_token_word, sb_token_starts,
                                      sb_token_ends, sb_token_properties)
        # Sentencepiece tokenizer
        sp_model_file = (
            'third_party/tensorflow_text/python/ops/test_data/test_oss_model.model'
        )
        sp_model = open(sp_model_file, 'rb').read()
        sp_tokenizer = text.SentencepieceTokenizer(sp_model)
        sentencepiece = sp_tokenizer.tokenize(['A sentence of things.'])
        sentencepiece = sp_tokenizer.detokenize(sentencepiece)
        (sentencepiece, _,
         _) = sp_tokenizer.tokenize_with_offsets(sentencepiece)
        sentencepiece_size = sp_tokenizer.vocab_size()
        sentencepiece_id = sp_tokenizer.id_to_string(1)
        # Split merge tokenizer - not in this version
        sm_tokenizer = text.SplitMergeTokenizer()
        split_merge = sm_tokenizer.tokenize(b'IloveFlume!',
                                            [0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0])
        # Unicode script tokenizer
        us_tokenizer = text.UnicodeScriptTokenizer()
        unicode_script = us_tokenizer.tokenize(['a string'])
        # Whitespace tokenizer
        ws_tokenizer = text.WhitespaceTokenizer()
        whitespace = ws_tokenizer.tokenize(['a string'])
        # Wordpiece tokenizer
        wp_initializer = tf.lookup.KeyValueTensorInitializer(
            ['i'], [1], key_dtype=tf.string, value_dtype=tf.int64)
        self.wp_vocab_table = tf.lookup.StaticHashTable(wp_initializer,
                                                        default_value=-1)
        wp_tokenizer = text.WordpieceTokenizer(self.wp_vocab_table)
        wordpiece = wp_tokenizer.tokenize(['i am'])
        # Wordshape
        wordshapes = text.wordshape([u'a-b', u'a\u2010b'.encode('utf-8')],
                                    text.WordShape.HAS_PUNCTUATION_DASH)

        with tf.control_dependencies([
                constrained_sequence, max_spanning_tree, normalized,
                regex_split, rouge_l, sentence_breaking, sentencepiece,
                sentencepiece_id, sentencepiece_size, split_merge,
                unicode_script, whitespace, wordpiece, wordshapes
        ]):
            y = tf.add(x, [1])
        return {'y': y}
def GetEmbeddingLookupList(signals_list,
                           embedding_vars,
                           sparse_ids,
                           sparse_weights=None,
                           combiners='sqrtn',
                           partition_strategies='mod'):
    """Get a list of embedding lookup tensors.

  Args:
    signals_list: A list of strings, representing names of features.
    embedding_vars: Dict mapping feature names to full embedding variables.
    sparse_ids: Dict mapping feature names to SparseTensors of their ids.
    sparse_weights: Either None, or a dict mapping feature names to
      SparseTensors of their weights (which can also be None).
    combiners: Either a common combiner type for all features ('mean', sqrtn' or
      'sum') or a dict mapping each feature name to a combiner type.
    partition_strategies: Either a common partition_strategy for all features
      ('mod' or 'div') or a dict mapping feature_names to partition_stratgies.

  Returns:
    embedding_lookup_list: A list of embedding lookup tensors used for bag of
      words attribution, aligned with signals_list.
  """
    assert isinstance(embedding_vars, dict) and isinstance(sparse_ids, dict)
    assert sparse_weights is None or isinstance(sparse_weights, dict)
    assert combiners in ('mean', 'sqrtn', 'sum') or isinstance(combiners, dict)
    assert (partition_strategies in ('mod', 'div')
            or isinstance(partition_strategies, dict))
    embedding_lookup_list = []
    for signal in signals_list:
        combiner = combiners[signal] if isinstance(combiners,
                                                   dict) else combiners
        partition_strategy = (partition_strategies[signal] if isinstance(
            partition_strategies, dict) else partition_strategies)

        # Batch dimension should be 1 for attribution.
        with tf.control_dependencies(
            [tf.assert_equal(tf.shape(sparse_ids[signal])[0], 1)]):
            embedding_lookup = tf.nn.embedding_lookup(
                params=embedding_vars[signal],
                ids=tf.sparse_tensor_to_dense(sparse_ids[signal]),
                partition_strategy=partition_strategy)
        if sparse_weights is None or sparse_weights[signal] is None:
            num_vals = tf.size(sparse_ids[signal].values)
            if combiner == 'mean':
                embedding_weights = tf.fill([1, num_vals],
                                            1.0 / tf.to_float(num_vals))
            elif combiner == 'sqrtn':
                embedding_weights = tf.fill([1, num_vals], 1.0 /
                                            tf.sqrt(tf.to_float(num_vals)))
            else:
                embedding_weights = tf.ones([1, num_vals], dtype=tf.float32)
        else:
            # Batch dimension should be 1 for attribution.
            with tf.control_dependencies(
                [tf.assert_equal(tf.shape(sparse_weights[signal])[0], 1)]):
                dense_weights = tf.sparse_tensor_to_dense(
                    sparse_weights[signal])
            if combiner == 'mean':
                embedding_weights = dense_weights / tf.reduce_sum(
                    dense_weights)
            elif combiner == 'sqrtn':
                embedding_weights = (
                    dense_weights /
                    tf.sqrt(tf.reduce_sum(tf.pow(dense_weights, 2))))
            else:
                embedding_weights = dense_weights
        embedding_lookup *= tf.expand_dims(embedding_weights, -1)
        embedding_lookup_list.append(embedding_lookup)
    return embedding_lookup_list
Пример #7
0
def model_fn_ALEXNET(features,
                     activation='relu',
                     kernel_initializer=tf.keras.initializers.TruncatedNormal(
                         mean=0, stddev=0.1),
                     bias_initializer='zeros'):

    # input: [None, 227, 227, 3]
    # conv1: f 96, k (11,11), s (4,4), VALID, relu --> [None, 54, 54, 96]
    with tf.control_dependencies(
            tf.debugging.assert_equal(features.get_shape()[1:],
                                      [227, 227, 3])):
        conv1 = Conv2D(filters=96,
                       kernel_size=(11, 11),
                       strides=(4, 4),
                       padding='valid',
                       activation=activation,
                       use_bias=True,
                       kernel_initializer=kernel_initializer,
                       bias_initializer=bias_initializer)(features)

    # pool1: k (3,3), s (2,2), VALID               --> [None, 26, 26, 96]
    with tf.control_dependencies(
            tf.debugging.assert_equal(conv1.get_shape()[1:], [54, 54, 96])):
        pool1 = MaxPool2D(pool_size=(3, 3), strides=(2, 2),
                          padding='valid')(conv1)

    # conv2: f 256, k (5,5), s (1,1), SAME, relu   --> [None, 26, 26, 256]
    with tf.control_dependencies(
            tf.debugging.assert_equal(features.get_shape()[1:], [26, 26, 96])):
        conv2 = Conv2D(filters=256,
                       kernel_size=(5, 5),
                       strides=(1, 1),
                       padding='same',
                       activation=activation,
                       use_bias=True,
                       kernel_initializer=kernel_initializer,
                       bias_initializer=bias_initializer)(pool1)

    # pool2: k (3,3), s (2,2), VALID               --> [None, 12, 12, 256]
    with tf.control_dependencies(
            tf.debugging.assert_equal(conv1.get_shape()[1:], [26, 26, 256])):
        pool2 = MaxPool2D(pool_size=(3, 3), strides=(2, 2),
                          padding='valid')(conv2)

    # conv3: f 384, k (3,3), s(1,1), SAME, relu    --> [None, 12, 12, 384]
    with tf.control_dependencies(
            tf.debugging.assert_equal(features.get_shape()[1:],
                                      [12, 12, 256])):
        conv3 = Conv2D(filters=384,
                       kernel_size=(3, 3),
                       strides=(1, 1),
                       padding='same',
                       activation=activation,
                       use_bias=True,
                       kernel_initializer=kernel_initializer,
                       bias_initializer=bias_initializer)(pool2)

    # conv4: f 384, k (3,3), s(1,1), SAME, relu    --> [None, 12, 12, 384]
    with tf.control_dependencies(
            tf.debugging.assert_equal(features.get_shape()[1:],
                                      [12, 12, 384])):
        conv4 = Conv2D(filters=384,
                       kernel_size=(3, 3),
                       strides=(1, 1),
                       padding='same',
                       activation=activation,
                       use_bias=True,
                       kernel_initializer=kernel_initializer,
                       bias_initializer=bias_initializer)(conv3)

    # conv5: f 256, k (3,3), s(1,1), SAME, relu    --> [None, 12, 12, 256]
    with tf.control_dependencies(
            tf.debugging.assert_equal(features.get_shape()[1:],
                                      [12, 12, 384])):
        conv5 = Conv2D(filters=256,
                       kernel_size=(3, 3),
                       strides=(1, 1),
                       padding='same',
                       activation=activation,
                       use_bias=True,
                       kernel_initializer=kernel_initializer,
                       bias_initializer=bias_initializer)(conv4)

    # pool5: k (3,3), s (2,2)                      --> [None,  5,  5, 256]
    with tf.control_dependencies(
            tf.debugging.assert_equal(conv1.get_shape()[1:], [12, 12, 256])):
        pool5 = MaxPool2D(pool_size=(3, 3), strides=(2, 2),
                          padding='valid')(conv5)

    # flatten --> [None, 6400]
    flatten = Flatten()(pool5)

    # fc6: f 4096, relu --> [None, 4096]
    with tf.control_dependencies(
            tf.debugging.assert_equal(flatten.get_shape()[1:], [6400])):
        fc6 = Dense(units=496,
                    activation=activation,
                    use_bias=True,
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)(flatten)

    # drop7: p 0.5      --> [None, 4096]
    drop7 = Dropout(rate=0.5)(fc6)

    # fc7: f 4096, relu --> [None, 4096]
    with tf.control_dependencies(
            tf.debugging.assert_equal(fc6.get_shape()[1:], [6400])):
        fc7 = Dense(units=496,
                    activation=activation,
                    use_bias=True,
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)(drop7)

    # drop8: p 0.5      --> [None, 4096]
    drop8 = Dropout(rate=0.5)(fc7)

    return drop8
Пример #8
0
  def _compute_inner_update_scinol(self, var, grad, state):
    update_ops = []

    betting_domain = tf.cast(
        state.get_hyper(BETTING_DOMAIN), var.dtype.base_dtype)

    reward = state.get_slot(var, INNER_REWARD)
    betting_fraction = state.get_slot(var, OUTER_BETTING_FRACTION)
    sum_grad_squared = state.get_slot(var, INNER_SUM_GRAD_SQUARED)
    sum_grad = state.get_slot(var, INNER_SUM_GRAD)
    inner_maximum_gradient = state.get_slot(var, INNER_MAXIMUM_GRADIENT)

    # clip inner gradient to respect previous inner_maximum_gradient value
    # This introduces at most an additive constant overhead in the regret
    # since the inner betting fraction lies in a bounded domain.
    clipped_grad = tf.clip_by_value(grad, -inner_maximum_gradient,
                                    inner_maximum_gradient)

    with tf.control_dependencies([clipped_grad]):
      inner_maximum_gradient_updated = self._assign(
          inner_maximum_gradient,
          tf.maximum(inner_maximum_gradient, tf.abs(grad)))
      update_ops.append(inner_maximum_gradient_updated)

    clipped_old_betting_fraction = tf.clip_by_value(betting_fraction,
                                                    -betting_domain,
                                                    betting_domain)

    # Process grad to respect truncation to [-betting_domain, betting_domain]
    truncated_grad = tf.where(
        tf.greater_equal(
            clipped_grad * (betting_fraction - clipped_old_betting_fraction),
            0.0), clipped_grad, tf.zeros(tf.shape(clipped_grad)))

    reward_delta = -betting_fraction * truncated_grad
    reward_updated = self._assign_add(reward, reward_delta)
    update_ops.append(reward_updated)

    sum_grad_squared_updated = self._assign_add(sum_grad_squared,
                                                tf.square(truncated_grad))
    update_ops.append(sum_grad_squared_updated)

    sum_grad_updated = self._assign_add(sum_grad, truncated_grad)
    update_ops.append(sum_grad_updated)

    # The second term in this maximum, inner_maximum_gradient_updated / self.eta
    # is a hack to force the betting fraction to not be too big at first.
    scaling = tf.minimum(tf.rsqrt(sum_grad_squared_updated +
                tf.square(inner_maximum_gradient_updated)),
                         self.eta/inner_maximum_gradient_updated)
    theta = -sum_grad_updated * scaling

    # rescale inner flag is a hack that rescales the epsilon_v by the
    # maximum inner gradient.
    if self.rescale_inner:
      epsilon_scaling = inner_maximum_gradient_updated
    else:
      epsilon_scaling = 1.0

    inner_betting_fraction = tf.sign(theta) * tf.minimum(tf.abs(theta),
                                                         1.0) * scaling / 2.0
    new_betting_fraction = inner_betting_fraction * (
        reward_updated + epsilon_scaling * self.epsilon_v)

    betting_fraction_updated = self._assign(betting_fraction,
                                            new_betting_fraction)
    update_ops.append(betting_fraction_updated)

    clipped_betting_fraction = tf.clip_by_value(betting_fraction_updated,
                                                -betting_domain, betting_domain)

    if self.output_summaries:
      mean_unclipped_betting_fraction_summary = tf.reduce_mean(
          tf.abs(betting_fraction_updated))
      max_unclipped_betting_fraction_summary = tf.reduce_max(
          tf.abs(betting_fraction_updated))

      mean_clipped_betting_fraction_summary = tf.reduce_mean(
          tf.abs(clipped_betting_fraction))
      max_clipped_betting_fraction_summary = tf.reduce_max(
          tf.abs(clipped_betting_fraction))

      max_abs_gradient = tf.reduce_max(tf.abs(grad))
      max_truncated_grad = tf.reduce_max(tf.abs(truncated_grad))

      tf.summary.scalar(self._name + "/mean_unclipped_bet/" + var.name,
                        mean_unclipped_betting_fraction_summary)
      tf.summary.scalar(self._name + "/max_unclipped_bet/" + var.name,
                        max_unclipped_betting_fraction_summary)
      tf.summary.scalar(self._name + "/mean_clipped_bet/" + var.name,
                        mean_clipped_betting_fraction_summary)
      tf.summary.scalar(self._name + "/max_clipped_bet/" + var.name,
                        max_clipped_betting_fraction_summary)

      tf.summary.scalar(self._name + "/max_abs_inner_grad/" + var.name,
                        max_abs_gradient)
      tf.summary.scalar(
          self._name + "/max_abs_truncated_inner_grad/" + var.name,
          max_truncated_grad)
    return clipped_betting_fraction, tf.group(*update_ops)
Пример #9
0
def model_fn(features, labels, mode, params):
    """Constructs a spectrogram_lstm model with summaries.

  Args:
    features: Dictionary {name: Tensor} of model inputs.
    labels: Any training-only inputs.
    mode: Build mode, one of tf.estimator.ModeKeys.
    params: Dictionary of Model hyperparameters.

  Returns:
    EstimatorSpec describing the model.
  """
    del labels

    hparams = params['hparams']

    mixture_waveforms = features['receiver_audio']

    batch_size = signal_util.static_or_dynamic_dim_size(mixture_waveforms, 0)

    # Create mixtures of mixtures (MoMs) on-the-fly by splitting batch in half.
    if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
        mixture_waveforms_1mix = mixture_waveforms
        # Build MoMs by splitting batch in half.
        with tf.control_dependencies(
            [tf.compat.v1.assert_equal(tf.mod(batch_size, 2), 0)]):
            mixture_waveforms = tf.reshape(mixture_waveforms,
                                           (batch_size // 2, 2, -1))

        # Create the MoMs by summing up single mixtures.
        mix_of_mix_waveforms = tf.reduce_sum(mixture_waveforms,
                                             axis=1,
                                             keepdims=True)

    else:
        # Inference mode, mixture_waveforms is just an input placeholder.
        mix_of_mix_waveforms = mixture_waveforms

    # In eval mode, separate both MoMs and single mixtures.
    if mode == tf.estimator.ModeKeys.EVAL:
        input_waveforms = tf.concat(
            [mix_of_mix_waveforms, mixture_waveforms_1mix], axis=0)
    else:
        input_waveforms = mix_of_mix_waveforms

    # Separate the input waveforms.
    separated_waveforms = separate_waveforms(input_waveforms, hparams)

    # In eval mode, split into separated from MoMs and from single mixtures.
    if mode == tf.estimator.ModeKeys.EVAL:
        # Separated sources from single mixtures.
        separated_waveforms_1mix = separated_waveforms[batch_size // 2:, :, :]
        # Separated sources from MoMs.
        separated_waveforms = separated_waveforms[:batch_size // 2, :, :]

    predictions = {'separated_waveforms': separated_waveforms}
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Get reference sources.
    source_waveforms = features['source_images'][:, :, 0]
    max_sources = signal_util.static_or_dynamic_dim_size(source_waveforms, 1)
    source_waveforms_1mix = tf.concat(
        [source_waveforms, tf.zeros_like(source_waveforms)], axis=1)
    if batch_size > 1:
        source_waveforms = tf.reshape(source_waveforms,
                                      (batch_size // 2, 2 * max_sources, -1))
    else:
        source_waveforms = tf.concat(
            [source_waveforms,
             tf.zeros_like(source_waveforms)], axis=1)

    # MixIT loss.
    loss, _ = mixit.apply(log_mse_loss, mixture_waveforms, separated_waveforms)
    loss = tf.identity(tf.reduce_mean(loss), name='loss_mixit')
    tf.losses.add_loss(loss)

    # Build the optimizer.
    loss = tf.losses.get_total_loss()
    learning_rate = tf.train.exponential_decay(
        hparams.lr,
        tf.train.get_or_create_global_step(),
        decay_steps=hparams.lr_decay_steps,
        decay_rate=hparams.lr_decay_rate)
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
    if params.get('use_tpu', False):
        optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    # Build the train_op.
    train_op = optimizer.minimize(
        loss, global_step=tf.compat.v1.train.get_or_create_global_step())

    # Permute separated to match references for summaries.
    unique_signal_types = list(set(hparams.signal_types))
    loss_fns = {
        signal_type: log_mse_loss
        for signal_type in unique_signal_types
    }
    _, separated_waveforms = groupwise.apply(loss_fns, hparams.signal_types,
                                             source_waveforms,
                                             separated_waveforms,
                                             unique_signal_types)
    if mode == tf.estimator.ModeKeys.EVAL:
        # Also align sources separated from single mixtures.
        _, separated_waveforms_1mix = groupwise.apply(
            loss_fns, hparams.signal_types, source_waveforms_1mix,
            separated_waveforms_1mix, unique_signal_types)

    # In eval mode, evaluate separated from single mixtures, instead of from MoMs.
    if mode == tf.estimator.ModeKeys.EVAL:
        separated_waveforms = separated_waveforms_1mix
        source_waveforms = source_waveforms_1mix
        mix_of_mix_waveforms = mixture_waveforms_1mix

    # Compute spectrograms to be used in summaries.
    transformer = signal_transformer.SignalTransformer(
        sample_rate=hparams.sr,
        window_time_seconds=hparams.ws,
        hop_time_seconds=hparams.hs)
    source_spectrograms = transformer.forward(source_waveforms)
    mixture_spectrograms = transformer.forward(mix_of_mix_waveforms)
    separated_spectrograms = transformer.forward(separated_waveforms)

    summary_dict = {}

    # Audio summaries.
    summary_dict['audio'] = summaries.compute_audio_summaries(
        signal_names=hparams.signal_names,
        separated_waveforms=separated_waveforms,
        source_waveforms=source_waveforms,
        mixture_waveforms=mix_of_mix_waveforms)

    # Spectrogram image summaries.
    summary_dict['images'] = summaries.compute_spectrogram_summaries(
        signal_names=hparams.signal_names,
        separated_spectrograms=separated_spectrograms,
        source_spectrograms=source_spectrograms,
        mixture_spectrograms=mixture_spectrograms)

    scalars = {}
    weights = {}
    # Only compute scalar summaries for nonzero reference sources.
    source_is_nonzero = _weights_for_nonzero_refs(source_waveforms)

    # Metrics for single-source examples.
    weights_1src = tf.logical_and(
        source_is_nonzero, _weights_for_num_sources(source_waveforms, 1))
    scalars_1src, weights_1src = summaries.scalar_snr_metrics_weighted(
        hparams.signal_names, separated_waveforms, source_waveforms,
        mix_of_mix_waveforms, weights_1src)
    scalars.update({
        name + '_1src_ref_nonzero': value
        for name, value in scalars_1src.items()
    })
    weights.update({
        name + '_1src_ref_nonzero': value
        for name, value in weights_1src.items()
    })

    # Metrics for multi-source examples.
    max_sources = len(hparams.signal_names)
    if max_sources > 1:
        weights_multisource = _weights_for_num_sources(source_waveforms, 2)
        for num_sources in range(3, max_sources + 1):
            weights_multisource = tf.logical_or(
                weights_multisource,
                _weights_for_num_sources(source_waveforms, num_sources))
        weights_multisource = tf.logical_and(source_is_nonzero,
                                             weights_multisource)
        scalars_msrc, weights_msrc = summaries.scalar_snr_metrics_weighted(
            hparams.signal_names, separated_waveforms, source_waveforms,
            mix_of_mix_waveforms, weights_multisource)
        scalars.update({
            name + '_min2src_ref_nonzero': value
            for name, value in scalars_msrc.items()
        })
        weights.update({
            name + '_min2src_ref_nonzero': value
            for name, value in weights_msrc.items()
        })

    summary_dict['scalars'] = scalars
    summary_util.create_summaries(sample_rate=hparams.sr, **summary_dict)
    metrics = {
        name: tf.metrics.mean(s, weights=weights.get(name, None))
        for name, s in scalars.items()
    }

    logging_hook = tf.train.LoggingTensorHook({'loss': loss}, every_n_secs=10)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      predictions=predictions,
                                      loss=loss,
                                      eval_metric_ops=metrics,
                                      train_op=train_op,
                                      training_hooks=[logging_hook])
Пример #10
0
def embedding_postprocessor(input_tensor,
                            use_token_type=False,
                            token_type_ids=None,
                            token_type_vocab_size=16,
                            token_type_embedding_name="token_type_embeddings",
                            use_position_embeddings=True,
                            position_embedding_name="position_embeddings",
                            initializer_range=0.02,
                            max_position_embeddings=512,
                            dropout_prob=0.1):
    """Performs various post-processing on a word embedding tensor.

    Args:
      input_tensor: float Tensor of shape [batch_size, seq_length,
        embedding_size].
      use_token_type: bool. Whether to add embeddings for `token_type_ids`.
      token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length].
        Must be specified if `use_token_type` is True.
      token_type_vocab_size: int. The vocabulary size of `token_type_ids`.
      token_type_embedding_name: string. The name of the embedding table variable
        for token type ids.
      use_position_embeddings: bool. Whether to add position embeddings for the
        position of each token in the sequence.
      position_embedding_name: string. The name of the embedding table variable
        for positional embeddings.
      initializer_range: float. Range of the weight initialization.
      max_position_embeddings: int. Maximum sequence length that might ever be
        used with this model. This can be longer than the sequence length of
        input_tensor, but cannot be shorter.
      dropout_prob: float. Dropout probability applied to the final output tensor.

    Returns:
      float tensor with same shape as `input_tensor`.

    Raises:
      ValueError: One of the tensor shapes or input values is invalid.
    """
    input_shape = get_shape_list(input_tensor, expected_rank=3)
    batch_size = input_shape[0]
    seq_length = input_shape[1]
    width = input_shape[2]

    output = input_tensor

    if use_token_type:
        if token_type_ids is None:
            raise ValueError("`token_type_ids` must be specified if"
                             "`use_token_type` is True.")
        token_type_table = tf.get_variable(
            name=token_type_embedding_name,
            shape=[token_type_vocab_size, width],
            initializer=create_initializer(initializer_range))
        # This vocab will be small so we always do one-hot here, since it is always
        # faster for a small vocabulary.
        flat_token_type_ids = tf.reshape(token_type_ids, [-1])
        one_hot_ids = tf.one_hot(flat_token_type_ids,
                                 depth=token_type_vocab_size)
        token_type_embeddings = tf.matmul(one_hot_ids, token_type_table)
        token_type_embeddings = tf.reshape(token_type_embeddings,
                                           [batch_size, seq_length, width])
        output += token_type_embeddings

    if use_position_embeddings:
        assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)
        with tf.control_dependencies([assert_op]):
            full_position_embeddings = tf.get_variable(
                name=position_embedding_name,
                shape=[max_position_embeddings, width],
                initializer=create_initializer(initializer_range))
            # Since the position embedding table is a learned variable, we create it
            # using a (long) sequence length `max_position_embeddings`. The actual
            # sequence length might be shorter than this, for faster training of
            # tasks that do not have long sequences.
            #
            # So `full_position_embeddings` is effectively an embedding table
            # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
            # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
            # perform a slice.
            position_embeddings = tf.slice(full_position_embeddings, [0, 0],
                                           [seq_length, -1])
            num_dims = len(output.shape.as_list())

            # Only the last two dimensions are relevant (`seq_length` and `width`), so
            # we broadcast among the first dimensions, which is typically just
            # the batch size.
            position_broadcast_shape = [1] * (num_dims - 2) + [
                seq_length, width
            ]
            position_embeddings = tf.reshape(position_embeddings,
                                             position_broadcast_shape)
            output += position_embeddings

    output = layer_norm_and_dropout(output, dropout_prob)
    return output
Пример #11
0
    def step_fn(self, params, model):
        """A single step for supervised learning."""
        (train_images, train_labels, valid_images,
         valid_labels) = tf.raw_ops.InfeedDequeueTuple(
             dtypes=params.train_dtypes, shapes=params.train_shapes)

        if train_labels.dtype == tf.int32:
            train_labels = tf.one_hot(train_labels,
                                      depth=params.num_classes,
                                      dtype=tf.float32)
        if valid_labels.dtype == tf.int32:
            valid_labels = tf.one_hot(valid_labels,
                                      depth=params.num_classes,
                                      dtype=tf.float32)
        global_step = tf.train.get_or_create_global_step()

        num_replicas = tf.cast(params.num_replicas, tf.float32)

        with tf.variable_scope(MODEL_SCOPE):
            train_logits = model(train_images, training=True)

        with tf.variable_scope(SCORE_SCOPE):
            score_logits = model(train_images,
                                 training=False,
                                 return_scores=True)
            score_m = tf.tpu.cross_replica_sum(tf.reduce_sum(score_logits))
            score_m = tf.stop_gradient(score_m) / float(params.num_replicas)
            score_e = tf.exp(score_logits - score_m)
            score_z = tf.tpu.cross_replica_sum(tf.reduce_sum(score_e))
            score_probs = score_e / score_z

        # train the main model
        cross_entropy = tf.losses.softmax_cross_entropy(
            onehot_labels=train_labels,
            logits=train_logits,
            label_smoothing=params.label_smoothing,
            reduction=tf.losses.Reduction.NONE)
        cross_entropy = tf.reduce_sum(cross_entropy *
                                      tf.stop_gradient(score_probs))

        l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas,
                              tf.float32)
        weight_dec = common_utils.get_l2_loss(excluded_keywords=[SCORE_SCOPE])
        total_loss = cross_entropy + weight_dec * l2_reg_rate

        model_variables = [
            v for v in tf.trainable_variables() if MODEL_SCOPE in v.name
        ]
        train_gradients = tf.gradients(total_loss, model_variables)
        train_gradients = [
            tf.tpu.cross_replica_sum(g) for g in train_gradients
        ]
        train_gradients, grad_norm = tf.clip_by_global_norm(
            train_gradients, params.grad_bound)

        learning_rate, optimizer = common_utils.get_optimizer(params)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.cond(
            tf.math.is_finite(grad_norm), lambda: optimizer.
            apply_gradients(zip(train_gradients, model_variables),
                            global_step=global_step), tf.no_op)
        with tf.control_dependencies(update_ops + [train_op]):
            ema_train_op = common_utils.setup_ema(
                params, f'{MODEL_SCOPE}/{model.name}')

        with tf.control_dependencies([ema_train_op]):
            with tf.variable_scope(MODEL_SCOPE, reuse=True):
                valid_logits = model(valid_images, training=False)
                valid_cross_entropy = tf.losses.softmax_cross_entropy(
                    onehot_labels=valid_labels,
                    logits=valid_logits,
                    reduction=tf.losses.Reduction.MEAN) / float(
                        params.num_replicas)
                valid_gradients = tf.gradients(valid_cross_entropy,
                                               model_variables)
                valid_gradients = [
                    tf.tpu.cross_replica_sum(g) for g in valid_gradients
                ]

            dot_product = tf.add_n([
                tf.reduce_sum(g_t * g_v)
                for g_t, g_v in zip(train_gradients, valid_gradients)
            ])
            dot_product = tf.stop_gradient(dot_product)
            dot_product_avg = tf.get_variable(name='dot_product_avg',
                                              shape=[],
                                              trainable=False)
            dot_product_update = tf.assign_sub(
                dot_product_avg, 0.01 * (dot_product_avg - dot_product))
            with tf.control_dependencies([dot_product_update]):
                dot_product = tf.identity(dot_product - dot_product_avg)

        # trains the scorer.
        score_entropy = tf.reduce_sum(-score_probs * tf.math.log(score_probs))
        score_entropy = tf.tpu.cross_replica_sum(score_entropy) / float(
            valid_images.shape[0].value)
        score_variables = [
            v for v in tf.trainable_variables() if SCORE_SCOPE in v.name
        ]
        score_gradients = tf.gradients(dot_product * score_entropy,
                                       score_variables)
        score_gradients = [
            tf.tpu.cross_replica_sum(g) for g in score_gradients
        ]
        score_optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=params.scorer_lr, use_locking=True)
        score_train_op = tf.cond(
            global_step < params.scorer_wait_steps, tf.no_op,
            lambda: score_optimizer.apply_gradients(
                zip(score_gradients, score_variables)))

        with tf.control_dependencies([score_train_op]):
            logs = collections.OrderedDict()
            logs['global_step'] = tf.cast(global_step, tf.float32)

            logs['model/total'] = total_loss
            logs['model/weight_decay'] = weight_dec / num_replicas
            logs['model/cross_entropy'] = cross_entropy
            logs['model/lr'] = tf.identity(learning_rate) / num_replicas
            logs['model/grad_norm'] = grad_norm / num_replicas

            logs['score/dot_product'] = dot_product / num_replicas
            logs['score/dot_product_avg'] = dot_product_avg / num_replicas
            logs['score/entropy'] = score_entropy
            logs['score/p_min'] = tf.reduce_min(score_probs) / num_replicas
            logs['score/p_max'] = tf.reduce_max(score_probs) / num_replicas

            tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
            self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
            outfeed_enqueue_op = tf.cond(
                common_utils.should_log(params),
                lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors),
                tf.no_op)
        return outfeed_enqueue_op
Пример #12
0
def model_fn(features, labels, mode, params):
    """The model_fn to be used with TPUEstimator.

    Args:
      features: `Tensor` of batched images.
      labels: `Tensor` of one hot labels for the data samples
      mode: one of `tf.estimator.ModeKeys.{TRAIN,EVAL,PREDICT}`
      params: `dict` of parameters passed to the model from the TPUEstimator,
          `params['batch_size']` is always provided and should be used as the
          effective batch size.

    Returns:
      A `TPUEstimatorSpec` for the model
    """
    if isinstance(features, dict):
        features = features['feature']

    # In most cases, the default data format NCHW instead of NHWC should be
    # used for a significant performance boost on GPU. NHWC should be used
    # only if the network needs to be run on CPU since the pooling operations
    # are only supported on NHWC. TPU uses XLA compiler to figure out best layout.
    if FLAGS.data_format == 'channels_first':
        assert not FLAGS.transpose_input  # channels_first only for GPU
        features = tf.transpose(features, [0, 3, 1, 2])
        stats_shape = [3, 1, 1]
    else:
        stats_shape = [1, 1, 3]

    if FLAGS.transpose_input and mode != tf.estimator.ModeKeys.PREDICT:
        features = tf.transpose(features, [3, 0, 1, 2])  # HWCN to NHWC

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    has_moving_average_decay = (FLAGS.moving_average_decay > 0)
    # This is essential, if using a keras-derived model.
    tf.keras.backend.set_learning_phase(is_training)
    logging.info('Using open-source implementation.')
    override_params = {}
    if FLAGS.batch_norm_momentum is not None:
        override_params['batch_norm_momentum'] = FLAGS.batch_norm_momentum
    if FLAGS.batch_norm_epsilon is not None:
        override_params['batch_norm_epsilon'] = FLAGS.batch_norm_epsilon
    if FLAGS.dropout_rate is not None:
        override_params['dropout_rate'] = FLAGS.dropout_rate
    if FLAGS.survival_prob is not None:
        override_params['survival_prob'] = FLAGS.survival_prob
    if FLAGS.data_format:
        override_params['data_format'] = FLAGS.data_format
    if FLAGS.num_label_classes:
        override_params['num_classes'] = FLAGS.num_label_classes
    if FLAGS.depth_coefficient:
        override_params['depth_coefficient'] = FLAGS.depth_coefficient
    if FLAGS.width_coefficient:
        override_params['width_coefficient'] = FLAGS.width_coefficient

    def normalize_features(features, mean_rgb, stddev_rgb):
        """Normalize the image given the means and stddevs."""
        features -= tf.constant(mean_rgb,
                                shape=stats_shape,
                                dtype=features.dtype)
        features /= tf.constant(stddev_rgb,
                                shape=stats_shape,
                                dtype=features.dtype)
        return features

    def build_model():
        """Build model using the model_name given through the command line."""
        model_builder = model_builder_factory.get_model_builder(
            FLAGS.model_name)
        normalized_features = normalize_features(features,
                                                 model_builder.MEAN_RGB,
                                                 model_builder.STDDEV_RGB)
        logits, _ = model_builder.build_model(normalized_features,
                                              model_name=FLAGS.model_name,
                                              training=is_training,
                                              override_params=override_params,
                                              model_dir=FLAGS.model_dir)
        return logits

    if params['use_bfloat16']:
        with tf.tpu.bfloat16_scope():
            logits = tf.cast(build_model(), tf.float32)
    else:
        logits = build_model()

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'classes': tf.argmax(logits, axis=1),
            'probabilities': tf.nn.softmax(logits, name='softmax_tensor')
        }
        return tf.estimator.EstimatorSpec(
            mode=mode,
            predictions=predictions,
            export_outputs={
                'classify': tf.estimator.export.PredictOutput(predictions)
            })

    # If necessary, in the model_fn, use params['batch_size'] instead the batch
    # size flags (--train_batch_size or --eval_batch_size).
    batch_size = params['batch_size']  # pylint: disable=unused-variable

    # Calculate loss, which includes softmax cross entropy and L2 regularization.
    cross_entropy = tf.losses.softmax_cross_entropy(  # not losses.softmax_cross_entropy nn.softmax_cross_entropy_with_logits
        onehot_labels=labels,
        logits=logits,
        label_smoothing=FLAGS.label_smoothing)

    # Add weight decay to the loss for non-batch-normalization variables.
    loss = cross_entropy + FLAGS.weight_decay * tf.add_n([
        tf.nn.l2_loss(v) for v in tf.trainable_variables()
        if 'batch_normalization' not in v.name
    ])

    global_step = tf.train.get_global_step()
    if has_moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(
            decay=FLAGS.moving_average_decay, num_updates=global_step)
        ema_vars = utils.get_ema_vars()

    host_call = None
    restore_vars_dict = None
    if is_training:
        # Compute the current epoch and associated learning rate from global_step.
        current_epoch = (tf.cast(global_step, tf.float32) /
                         params['steps_per_epoch'])

        scaled_lr = FLAGS.base_learning_rate * (FLAGS.train_batch_size / 256.0)
        logging.info('base_learning_rate = %f', FLAGS.base_learning_rate)
        learning_rate = utils.build_learning_rate(scaled_lr, global_step,
                                                  params['steps_per_epoch'])
        optimizer = utils.build_optimizer(learning_rate, optimizer_name="sgd")
        if FLAGS.use_tpu:
            # When using TPU, wrap the optimizer with CrossShardOptimizer which
            # handles synchronization details between different TPU cores. To the
            # user, this should look like regular synchronous training.
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)

        # Batch normalization requires UPDATE_OPS to be added as a dependency to
        # the train operation.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, global_step)

        if has_moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

        if not FLAGS.skip_host_call:

            def host_call_fn(gs, lr, ce):
                """Training host call. Creates scalar summaries for training metrics.

                This function is executed on the CPU and should not directly reference
                any Tensors in the rest of the `model_fn`. To pass Tensors from the
                model to the `metric_fn`, provide as part of the `host_call`. See
                https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
                for more information.

                Arguments should match the list of `Tensor` objects passed as the second
                element in the tuple passed to `host_call`.

                Args:
                  gs: `Tensor with shape `[batch]` for the global_step
                  lr: `Tensor` with shape `[batch]` for the learning_rate.
                  ce: `Tensor` with shape `[batch]` for the current_epoch.

                Returns:
                  List of summary ops to run on the CPU host.
                """
                gs = gs[0]
                # Host call fns are executed FLAGS.iterations_per_loop times after one
                # TPU loop is finished, setting max_queue value to the same as number of
                # iterations will make the summary writer only flush the data to storage
                # once per loop.
                with tf2.summary.create_file_writer(
                        FLAGS.model_dir,
                        max_queue=FLAGS.iterations_per_loop).as_default():
                    with tf2.summary.record_if(True):
                        tf2.summary.scalar('learning_rate', lr[0], step=gs)
                        tf2.summary.scalar('current_epoch', ce[0], step=gs)

                        return tf.summary.all_v2_summary_ops()

            # To log the loss, current learning rate, and epoch for Tensorboard, the
            # summary op needs to be run on the host CPU via host_call. host_call
            # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
            # dimension. These Tensors are implicitly concatenated to
            # [params['batch_size']].
            gs_t = tf.reshape(global_step, [1])
            lr_t = tf.reshape(learning_rate, [1])
            ce_t = tf.reshape(current_epoch, [1])

            host_call = (host_call_fn, [gs_t, lr_t, ce_t])

    else:
        train_op = None
        if has_moving_average_decay:
            # Load moving average variables for eval.
            restore_vars_dict = ema.variables_to_restore(ema_vars)

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(labels, logits):
            """Evaluation metric function. Evaluates accuracy.

            This function is executed on the CPU and should not directly reference
            any Tensors in the rest of the `model_fn`. To pass Tensors from the model
            to the `metric_fn`, provide as part of the `eval_metrics`. See
            https://www.tensorflow.org/api_docs/python/tf/estimator/tpu/TPUEstimatorSpec
            for more information.

            Arguments should match the list of `Tensor` objects passed as the second
            element in the tuple passed to `eval_metrics`.

            Args:
              labels: `Tensor` with shape `[batch, num_classes]`.
              logits: `Tensor` with shape `[batch, num_classes]`.

            Returns:
              A dict of the metrics to return from evaluation.
            """

            # TODO 这里改 metric
            labels = tf.argmax(labels, axis=1)
            predictions = tf.argmax(logits, axis=1)
            accuracy = tf.metrics.accuracy(labels, predictions)
            auc = tf2.keras.metrics.AUC(name='auc')  # AUC
            auc.update_state(labels, predictions)
            return {'accuracy': accuracy, 'auc': auc}

        eval_metrics = (metric_fn, [labels, logits])

    num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
    logging.info('number of trainable parameters: %d', num_params)

    def _scaffold_fn():
        saver = tf.train.Saver(restore_vars_dict)
        return tf.train.Scaffold(saver=saver)

    if has_moving_average_decay and not is_training:
        # Only apply scaffold for eval jobs.
        scaffold_fn = _scaffold_fn
    else:
        scaffold_fn = None

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=loss,
                                             train_op=train_op,
                                             host_call=host_call,
                                             eval_metrics=eval_metrics,
                                             scaffold_fn=scaffold_fn)
Пример #13
0
 def mean_variance_with_update():
     with tf.control_dependencies([ema.apply([batch_mean, batch_variance])]):
         return (tf.identity(batch_mean),
                 tf.identity(batch_variance))
Пример #14
0
def CustomCropImages(images, input_shape, target_shape, target_locations):
    """Crop a list of images at with a custom crop location and size.

  Args:
    images: List of tensors of shape [batch_size, h, w, c].
    input_shape: Shape [h, w, c] of the input images.
    target_shape: Shape [h, w] of the cropped output.
    target_locations: List of crop center coordinates tensors of shape [b, 2].
  Returns:
    crops: List of cropped tensors of shape [batch_size] + target_shape + [3].
  """
    if len(input_shape) != 3:
        raise ValueError(
            'The input shape has to be of the form (height, width, channels) '
            'but has len {}'.format(len(input_shape)))
    if len(target_shape) != 2:
        raise ValueError(
            'The target shape has to be of the form (height, width) '
            'but has len {}'.format(len(target_shape)))
    if len(images) != len(target_locations):
        raise ValueError(
            'There should be one target location per image. Found {} '
            'images for {} locations'.format(len(images),
                                             len(target_locations)))
    if input_shape[0] == target_shape[0] and input_shape[1] == target_shape[1]:
        return [image for image in images]
    if input_shape[0] < target_shape[0] or input_shape[1] < target_shape[1]:
        raise ValueError(
            'The target shape {} is larger than the input image size '
            '{}'.format(target_shape, input_shape[:2]))
    assert_ops = []
    for image, target_location in zip(images, target_locations):
        # Assert all images have the same shape.
        assert_ops.append(
            tf.assert_equal(
                input_shape[:2],
                tf.shape(image)[1:3],
                message=('All images must have same width and height'
                         'for CenterCropImages.')))

    with tf.control_dependencies(assert_ops):
        crops = []
        for image, target_location in zip(images, target_locations):
            # If bounding box is outside of image boundaries, move it
            x_coordinates = tf.slice(target_location, [0, 1],
                                     [tf.shape(target_location)[0], 1])
            y_coordinates = tf.slice(target_location, [0, 0],
                                     [tf.shape(target_location)[0], 1])

            x_coordinates = tf.math.maximum(
                tf.cast(x_coordinates, tf.float32),
                tf.cast(target_shape[1] // 2, tf.float32))
            y_coordinates = tf.math.maximum(
                tf.cast(y_coordinates, tf.float32),
                tf.cast(target_shape[0] // 2, tf.float32))
            x_coordinates = tf.math.minimum(
                tf.cast(x_coordinates, tf.float32),
                tf.cast(tf.shape(image)[2] - target_shape[1] // 2, tf.float32))
            y_coordinates = tf.math.minimum(
                tf.cast(y_coordinates, tf.float32),
                tf.cast(tf.shape(image)[1] - target_shape[0] // 2, tf.float32))

            target_location = tf.concat([x_coordinates, y_coordinates], 1)
            crops.append(
                tf.image.extract_glimpse(image,
                                         target_shape,
                                         tf.cast(target_location, tf.float32),
                                         centered=False,
                                         normalized=False))
    return crops
Пример #15
0
def forward_pass(opts, transformer, iterations_per_step, is_training, outfeed, dense_queue, infeed):
    def make_counter():
        with tf.variable_scope("counter", reuse=tf.AUTO_REUSE, use_resource=True):
            itr_counter = tf.get_variable("iterations", [], tf.int32, trainable=False)
            increment_counter = tf.assign_add(itr_counter, 1)
            mod_itrs = tf.math.floormod(increment_counter, iterations_per_step)
            last_itr = tf.equal(mod_itrs, 0, name="last_update_itr")

            # Add accumulation counter if pipelined
            if opts.pipeline:
                grad_counter = internal_ops.get_current_iteration_counter()
                last_grad_itr = tf.equal(grad_counter, opts.gradient_accumulation_count-1, name="last_grad_itr")

                last_itr = tf.logical_and(last_itr, last_grad_itr, name="last_itr")

        return last_itr

    def make_src_mask(last_itr, source):
        with tf.variable_scope("transformer", reuse=tf.AUTO_REUSE, use_resource=True):
            transformer.compute_dense_grad = last_itr
            autoregressive_mask = tf.constant(np.triu(np.ones([S, S], dtype=np.bool), k=1))
            source_mask = autoregressive_mask
            source_mask = tf.cast(source_mask, opts.dtype) * -10000
        return source_mask

    def loss_and_metrics(logits, source):
        with tf.variable_scope("metrics", reuse=tf.AUTO_REUSE, use_resource=True):
            # Implement autoregressice loss through teacher forcing
            # The first few tokens have no hope of being correct
            # so we exclude the first "offset" tokens from the loss
            offset = opts.autoregression_offset
            logits = tf.cast(logits[:, offset:-1], tf.float32)  # logits always full precision
            target = source[:, offset + 1:]
            predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)

            # Accuracy
            acc, acc_op = tf.metrics.accuracy(target, predictions, name="token_accuracy")

            # Unweighted cross-entropy for tracking progress
            nll_loss = tf.losses.sparse_softmax_cross_entropy(labels=target, logits=logits)
            nll_loss = tf.reduce_mean(nll_loss)
            perplexity = tf.exp(nll_loss)

            # Training loss (weighted cross-entropy)
            # the weight of the loss on each token is normalized by the number of
            # that token appears in the sequence
            # For instance if there are 10 padding tokens, the loss from each will have a weight of 1/10
            nll_weights = tf.expand_dims(target, -1)
            nll_weights = tf.equal(nll_weights, tf.transpose(nll_weights, perm=[0, 2, 1]))
            nll_weights = tf.cast(nll_weights, tf.float32)
            nll_weights = 1.0 / tf.reduce_sum(nll_weights, -1)
            training_loss = tf.losses.sparse_softmax_cross_entropy(
                labels=target, logits=logits, weights=nll_weights)
            training_loss = tf.reduce_mean(training_loss)
        return {
            "training_loss": training_loss,
            "token_accuracy": acc,
            "acc_op": acc_op,
            "nll_loss": nll_loss,
            "perplexity": perplexity,
            "predictions": predictions,
            "target": target
        }

    def make_lr_schedule(global_step):
        with tf.variable_scope("training", reuse=tf.AUTO_REUSE, use_resource=True):
            # The learning rate schedule needs to be part of the graph so the lr can
            # change between different batchs within the same io step
            schedule = tf_utils.BertSchedule(opts, opts.dtype)
            lr = schedule(global_step)
        return lr

    def make_optimizer(lr, last_itr):
        with tf.variable_scope("training", reuse=tf.AUTO_REUSE, use_resource=True):
            optimizer_class, optimizer_kwargs = build_optimizer(opts.optimizer, opts.optimizer_arg)
            optimizer_class = optimizers.SparseOptimizer(optimizer_class)
            optimizer_class = global_step_update_opt.GlobalStepUpdateOptimizer(optimizer_class)
            if opts.loss_scale != 1:
                optimizer_class = scaling_opt.LossScalingOptimizer(optimizer_class)
                optimizer_kwargs['loss_scale'] = opts.loss_scale
                optimizer_kwargs['unscale_grad_pre_acc'] = opts.unscale_grad_pre_acc
            if opts.grad_acculation_mode == 'Avg':
                optimizer_class = scaling_opt.GradScalingOptimizer(optimizer_class)
                optimizer_kwargs['grad_scale'] = 1 / opts.gradient_accumulation_count
                optimizer_kwargs['scale_grad_pre_acc'] = opts.scale_grad_pre_acc
            if opts.grad_norm_clip:
                optimizer_class = grad_clip_opt.GradientClippingOptimizer(optimizer_class)
                optimizer_kwargs['norm_clip_threshold'] = opts.grad_norm_clip
            if opts.slots_fp_type is not None and tf.as_dtype(opts.slots_fp_type) != opts.dtype:
                optimizer_class = fp_slot_opt.SelectableSlotFPFormatOptimizer(optimizer_class)
                optimizer_kwargs['slots_dtype'] = opts.slots_fp_type
                optimizer_kwargs['force_fp32_weight_update'] = opts.force_fp32_weight_update
            optimizer = optimizer_class(learning_rate=lr, **optimizer_kwargs,
                                        sparse_layers=transformer.sparse_layers.values(),
                                        dense_gradient_condition=enable_dense_grad and last_itr,
                                        prune_and_grow_outfeed=dense_queue)
        return optimizer

    def make_pipeline_opt(outputs):
        optimizer = make_optimizer(outputs["learning_rate"], outputs["last_itr"])
        return pipelining_ops.OptimizerFunctionOutput(optimizer, outputs["training_loss"])

    def make_outfeed(lr, global_step, metrics, itr_counter):
        acc_op = metrics['acc_op']

        if is_training:
            with tf.control_dependencies([acc_op]):
                output_dict = {
                    **metrics,
                    "learning_rate": lr,
                    "global_step": tf.cast(global_step, tf.int32),
                    "iteration_counter": itr_counter}
                output = outfeed.enqueue(output_dict)
        else:
            # At inference time stream back the loss and accuracy
            with tf.control_dependencies([acc_op]):
                output = outfeed.enqueue(metrics)
        return output

    # Batch size and sequence length
    S = transformer.source_sequence_length
    enable_dense_grad = opts.prune_ratio is not None and opts.prune_ratio > 0

    if not opts.pipeline:
        # This autoregressive model is self-labeling needs only 1 input
        source = infeed
        last_itr = make_counter()
        source_mask = make_src_mask(last_itr, source)
        # Build the encoder
        logits = transformer.language_model(source=source, source_mask=source_mask,
                                            add_projection_layer=True, last_itr=last_itr,
                                            enable_dense_grad=enable_dense_grad,
                                            sparse_embeddings=opts.sparse_embeddings)
        metrics = loss_and_metrics(logits, source)
        if is_training:
            global_step = tf.cast(tf.train.get_or_create_global_step(), tf.int32)
            lr = make_lr_schedule(global_step)
            optimizer = make_optimizer(lr, last_itr)
            train_op = optimizer.minimize(metrics['training_loss'], global_step=global_step)
        else:
            lr, global_step = None, None
            train_op = tf.no_op()

        with tf.control_dependencies([train_op]):
            with tf.variable_scope("counter", reuse=tf.AUTO_REUSE, use_resource=True):
                itr_counter = tf.get_variable("iterations", [], tf.int32, trainable=False)
            output = make_outfeed(lr, global_step, metrics, itr_counter)
        return output
    else:
        def first_stage(global_step, source, input_stage_func):
            last_itr = make_counter()
            source_mask = make_src_mask(last_itr, source)
            return input_stage_func(source, source_mask, last_itr, global_step)

        def last_stage(encoder_out, source_mask, *args, **kwargs):
            last_itr = args[0]
            global_step = args[1]
            source = args[2]
            output_stage_func = kwargs['output_stage_func']
            logits, *_ = output_stage_func(encoder_out, source_mask, *args)
            metrics = loss_and_metrics(logits, source)
            if is_training:
                metrics.update({
                        "learning_rate": make_lr_schedule(global_step),
                        "last_itr": last_itr,
                        "global_step": tf.convert_to_tensor(global_step)
                })
                return metrics
            else:
                metrics['last_itr'] = last_itr
                return metrics

        stages, device_mapping, stage_options = transformer.language_model_stages(enable_dense_grad=enable_dense_grad,
                                                                                  sparse_embeddings=opts.sparse_embeddings)
        stages[0] = partial(first_stage, input_stage_func=stages[0])
        stages[-1] = partial(last_stage, output_stage_func=stages[-1])

        pipeline_op = pipelining_ops.pipeline(
            computational_stages=stages,
            gradient_accumulation_count=opts.gradient_accumulation_count,
            gradient_accumulation_dtype=opts.gradient_accumulation_dtype,
            repeat_count=iterations_per_step,
            inputs=[tf.cast(tf.train.get_or_create_global_step(), tf.int32)],
            infeed_queue=infeed,
            outfeed_queue=outfeed,
            optimizer_function=make_pipeline_opt if is_training else None,
            device_mapping=device_mapping,
            offload_activations=opts.offload_activations,
            offload_gradient_accumulation_buffers=opts.offload_gradient_accumulation_buffers,
            offload_weight_update_variables=opts.offload_weight_update_variables,
            forward_propagation_stages_poplar_options=stage_options,
            backward_propagation_stages_poplar_options=stage_options,
            name="Pipeline")

        return pipeline_op
Пример #16
0
def embed(input_ids,
          vocab_size,
          embedding_size,
          position_offset=0,
          initializer_range=0.02,
          max_position_embeddings=512,
          use_one_hot_embeddings=True):
    """reur and position embeddings
    :param input_ids: int Tensor of shape [batch_size, seq_length].
    :param vocab_size: number of words in vocab
    :param embedding_size: dimensionality of the embedding
    :param position_offset: aka number of cached tokens.
    :param initializer_range: float. Range of the weight initialization.
    :param max_position_embeddings: int. Maximum sequence length.
    :param use_one_hot_embeddings: probably want this to be true
    :return: [batch_size, seq_length, embedding_size] embedded tensor
    """
    (batch_size, seq_length) = get_shape_list(input_ids, expected_rank=2)

    embedding_table = tf.get_variable(
        name='word_embed',
        shape=[vocab_size, embedding_size],
        initializer=create_initializer(initializer_range),
    )

    assert_op = tf.assert_less_equal(tf.reduce_max(input_ids), vocab_size - 1)
    with tf.control_dependencies([assert_op]):
        if use_one_hot_embeddings:
            flat_input_ids = tf.reshape(input_ids, [-1])
            one_hot_input_ids = tf.one_hot(flat_input_ids, depth=vocab_size)
            output_flat = tf.matmul(one_hot_input_ids, embedding_table)
        else:
            output_flat = tf.nn.embedding_lookup(embedding_table, input_ids)

        embedded_input = tf.reshape(output_flat,
                                    [batch_size, seq_length, embedding_size])

    assert_op = tf.assert_less_equal(seq_length, max_position_embeddings)

    with tf.control_dependencies([assert_op]):
        full_position_embeddings = tf.get_variable(
            name='pos_embed',
            shape=[max_position_embeddings, embedding_size],
            initializer=create_initializer(initializer_range),
        )
        # Since the position embedding table is a learned variable, we create it
        # using a (long) sequence length `max_position_embeddings`. The actual
        # sequence length might be shorter than this, for faster training of
        # tasks that do not have long sequences.
        #
        # So `full_position_embeddings` is effectively an embedding table
        # for position [0, 1, 2, ..., max_position_embeddings-1], and the current
        # sequence has positions [0, 1, 2, ... seq_length-1], so we can just
        # perform a slice.
        if position_offset == 0:
            embedded_input += tf.slice(full_position_embeddings, [0, 0],
                                       [seq_length, -1])[None]
        else:
            # Tensorflow is too stupid to allow slicing
            flat_pos_ids = (tf.range(seq_length, dtype=tf.int32) +
                            position_offset)
            one_hot_pos_ids = tf.one_hot(flat_pos_ids,
                                         depth=max_position_embeddings)

            # [seq_length, full_position_embeddings], [full_position_embeddings, dim]
            seq_embeds = tf.matmul(one_hot_pos_ids, full_position_embeddings)
            embedded_input += seq_embeds[None]

            # embedded_input += tf.slice(full_position_embeddings[position_offset:], [0, 0], [seq_length, -1])[None]

    return layer_norm(embedded_input, name='embed_norm'), embedding_table
def AddIntegratedGradientsOps(graph,
                              attribution_tensors,
                              output_tensor,
                              num_evals,
                              attribution_dims_map,
                              zero_baseline_tensors=None,
                              new_output_scope='attribution',
                              baseline_scope='baseline',
                              tensors_to_keep=None):
    """Modify graph to create ops for computing integrated gradients.

  Function to modify a tensorflow graph by adding ops for attributing the change
  in value of a given output tensor, to different input 'attribution_tensors'
  (see arxiv.org/abs/1703.01365).

  The first dimension of each attribution_tensor and output_tensor is assumed
  to be the batch dimension. That is, if we create multiple input values for the
  attribution tensors, we should be able to concatenate them along the first
  dimension, and the resulting output tensor should have corresponding values
  for different values of its first dimension.

  The attribution works by interpolating between a given input, and a given
  baseline, to create multiple (num_evals) interpolated inputs. At each
  interpolated input, we compute the gradient of the output tensor with respect
  to each attribution tensor. The gradients for each attribution tensor are
  averaged over all interpolated inputs, to get an attribution score for it.

  Example Usage: attribution_feed_dict = AddIntegratedGradientsOps(...)
  Then to get attribution for a given input (specificed by input_feed_dict,
  relative to a baseline given be baseline_feed_dict):
  combined_feed_dict = attribution_feed_dict['create_combined_feed_dict'](
      input_feed_dict, baseline_feed_dict)
  with graph.as_default(), sess.as_default():
    attributions = sess.run(
        attribution_feed_dict['mean_grads'], combined_feed_dict)
  for tensor, attribution in zip(attribution_tensors, attributions):
    print('Attribution for %s: %s' % (tensor.op.name, attribution))

  Warning: This function is not compatible with tf.cond. If there is a tf.cond
  in the graph path between the attribution tensors and the output tensor, the
  attribution ops may not work.
  # TODO(manasrj): Make attribution ops compatible with tf.cond.

  Args:
    graph: The tf.Graph to add attribution ops to.
    attribution_tensors: Tensors for which to compute attribution scores. The
      tensors must satisfy two properties: (1) The output tensor must
      be computable given values for attribution tensors. (2) Each
      attribution tensor must be computationally independent of the
      others, i.e., it should not be the case that one of the
      attribution tensor's value is completely determined by the
      values of the other attribution tensors. Properties (1) and (2) ensure
      the attribution tensors form an input-output cut in the computation
      graph.
    output_tensor: Tensor for whose value we are performing attribution.
    num_evals: Integer scalar. Number of interpolated points at which to
      evaluate gradients. Higher values of this parameter increase computation
      time, but also increase accuracy of attributions.
    attribution_dims_map: Dict mapping attribution tensors to lists of integers.
      For each attribution_tensor, we compute a separate gradient value for each
      slice along the dims in the list. For example, if we have a rank 3
      attribution tensor T that consists of embeddings lookups, with the first
      dimension being the batch dimension, and the second dimension being the
      sparse ids, then setting attribution_dims_map[T] = [1] will give us a
      separate gradient for each sparse id. If an attribution_tensor has no
      entry in attribution_dims_map, then the list defaults to [].
    zero_baseline_tensors: Set of attribution tensors. For each tensor T in this
      set, we compute gradients with respect to T for all interpolated values of
      T between the value computed from the input feed, and zero. For each
      tensor U not in zero_baseline_tensors, we compute gradients for
      interpolated values between the one derived from the input feed, and the
      one derived from the baseline feed.
    new_output_scope: String. New ops needed for computing the output tensor at
      different interpolated values are created under this scope name.
    baseline_scope: String. New ops needed for computing attribution tensor
      interpolated values are created under this scope name.
    tensors_to_keep: Set of tensors. By default, tensors in the graph between
      the output_tensor and attribution tensors are copied to a different part
      of the graph, and evaluated separately for each interpolation. If we want
      a value to be fixed (only computed for the main input instead of each
      interpolation), it should be put in tensors_to_keep.

  Returns:
    attribution_hooks: Dict with the following keys (among others):
      mean_grads: List of attribution scores (aligned with attribution_tensors).
      create_combined_feed_dict: A Function that takes an input feed dict, and
        optionally, a baseline feed dict, and creates a combined feed dict to
        pass to sess.run to get attributions.
  """
    ops_to_tensors = lambda ops: [op.outputs[0] for op in ops]
    attribution_hooks = {}
    if tensors_to_keep is None:
        tensors_to_keep = []
    else:
        tensors_to_keep = list(tensors_to_keep)
    if zero_baseline_tensors is None:
        zero_baseline_tensors = []
    with graph.as_default():
        # Compute parts of graph and check correctness.
        all_ops = graph.get_operations()
        constant_ops = contrib_graph_editor.select.select_ops(
            all_ops, positive_filter=lambda x: x.type == 'Const')
        placeholder_ops = contrib_graph_editor.select.select_ops(
            all_ops, positive_filter=lambda x: x.type == 'Placeholder')
        var_read_ops = contrib_graph_editor.select.select_ops('/read$',
                                                              graph=graph)
        attr_ops = [t.op for t in attribution_tensors]
        required_ops = set(
            contrib_graph_editor.select.get_backward_walk_ops(
                output_tensor.op,
                stop_at_ts=(tensors_to_keep + list(attribution_tensors) +
                            ops_to_tensors(var_read_ops) +
                            ops_to_tensors(placeholder_ops))))

        # Check that attribution tensors are sufficient to compute output_tensor.
        forward_ops = set(
            contrib_graph_editor.select.get_forward_walk_ops(attr_ops +
                                                             var_read_ops +
                                                             constant_ops))
        assert required_ops.issubset(forward_ops)

        required_sgv = contrib_graph_editor.subgraph.make_view(required_ops)
        attribution_subgraph, attribution_transform_info = (
            contrib_graph_editor.transform.copy_with_input_replacements(
                required_sgv, {}, graph, new_output_scope))
        attribution_hooks['attribution_subgraph'] = attribution_subgraph
        attribution_hooks[
            'attribution_transform_info'] = attribution_transform_info

        # Copy feed to attribution part of graph so we can have one part for
        # baseline and one for input.
        backward_ops = contrib_graph_editor.select.get_backward_walk_ops(
            attr_ops, stop_at_ts=ops_to_tensors(var_read_ops))
        backward_sgv = contrib_graph_editor.subgraph.make_view(backward_ops)
        _, baseline_transform_info = (
            contrib_graph_editor.transform.copy_with_input_replacements(
                backward_sgv, {}, graph, baseline_scope))
        attribution_hooks['baseline_transform_info'] = baseline_transform_info

        # Function to compute combined feed dict. The default setting of
        # baseline_transform_info is to get around python's late binding.
        def CreateCombinedFeedDict(
                input_feed_dict,
                baseline_feed_dict=None,
                baseline_transform_info=baseline_transform_info):
            """Combine baseline and input feed dicts into a common feed dict."""
            combined_feed_dict = input_feed_dict.copy()
            if baseline_feed_dict is None:
                baseline_feed_dict = input_feed_dict
            for key, feed_value in baseline_feed_dict.items():
                if isinstance(key, tf.Tensor):
                    combined_feed_dict[baseline_transform_info.transformed(
                        key)] = (feed_value)
                elif isinstance(key, six.text_type):
                    if six.PY2:
                        tensor = graph.get_tensor_by_name(key.decode())
                    else:
                        tensor = graph.get_tensor_by_name(key)
                    combined_feed_dict[baseline_transform_info.transformed(
                        tensor)] = (feed_value)
                elif isinstance(key, tf.SparseTensor):
                    sparse_transformed_tensor = tf.SparseTensor(
                        baseline_transform_info.transformed(key.indices),
                        baseline_transform_info.transformed(key.values),
                        baseline_transform_info.transformed(key.dense_shape))
                    combined_feed_dict[sparse_transformed_tensor] = feed_value
                else:
                    raise ValueError('Invalid key type %s in Feed Dict.' %
                                     type(key))
            return combined_feed_dict

        attribution_hooks['create_combined_feed_dict'] = CreateCombinedFeedDict

        # Create new tensors with the multipliers to insert after previous ones.
        attribution_hooks['multipliers'] = []
        attribution_hooks['weighted_attribution_tensors'] = []
        for attribution_tensor in attribution_tensors:
            with tf.control_dependencies(
                [tf.assert_equal(tf.shape(attribution_tensor)[0], 1)]):
                attribution_dims = (attribution_dims_map[attribution_tensor]
                                    if attribution_tensor
                                    in attribution_dims_map else [])
                vocab_size = len(attribution_tensor.get_shape())
                attribution_dim_cond = tf.sparse_to_indicator(
                    tf.SparseTensor(
                        tf.expand_dims(
                            tf.range(len(attribution_dims), dtype=tf.int64),
                            1), attribution_dims, [vocab_size]), vocab_size)
                base_multiplier_shape = tf.concat([
                    tf.expand_dims(num_evals, 0),
                    tf.ones_like(tf.shape(attribution_tensor))[1:]
                ], 0)
                tile_dims = tf.where(
                    attribution_dim_cond, tf.shape(attribution_tensor),
                    tf.ones_like(tf.shape(attribution_tensor)))
                pre_tile_multiplier = tf.reshape(
                    tf.range(tf.to_float(num_evals)) /
                    tf.to_float(num_evals - 1), base_multiplier_shape)
                multiplier = tf.tile(pre_tile_multiplier, tile_dims)
                if attribution_tensor in zero_baseline_tensors:
                    weighted_attribution_tensor = multiplier * attribution_tensor
                else:
                    base_attribution_tensor = baseline_transform_info.transformed(
                        attribution_tensor)
                    weighted_attribution_tensor = (
                        multiplier * attribution_tensor +
                        (1 - multiplier) * base_attribution_tensor)
                attribution_hooks['weighted_attribution_tensors'].append(
                    weighted_attribution_tensor)
                attribution_hooks['multipliers'].append(multiplier)

        contrib_graph_editor.reroute_ts(
            attribution_hooks['weighted_attribution_tensors'],
            attribution_tensors,
            can_modify=attribution_subgraph.ops)
        g = tf.gradients(attribution_transform_info.transformed(output_tensor),
                         attribution_hooks['multipliers'])
        attribution_hooks['mean_grads'] = [
            tf.reduce_mean(grad, 0) for grad in g
        ]
    return attribution_hooks
Пример #18
0
 def _body(step):
     run_op = train_class.step_fn(params, model)
     with tf.control_dependencies([run_op]):
         return step + 1
Пример #19
0
def model_fn(features, labels, mode, params):

    kernel_initializer = variance_scaling_initializer(
        distribution='truncated_normal')
    #kernel_initializer = tf.keras.initializers.TruncatedNormal(mean=0, stddev=0.1)
    bias_initializer = 'zeros'

    feat = tf.reshape(input_layer(features, params['feature_columns']),
                      [None, 227, 227, 3])

    net = model_fn_ALEXNET(feat,
                           activation='relu',
                           kernel_initializer=kernel_initializer,
                           bias_initializer=bias_initializer)

    # logits: output is [None, CLASSES]
    logits = Dense(units=params['n_classes'],
                   activation=None,
                   use_bias=True,
                   kernel_initializer=kernel_initializer,
                   bias_initializer=bias_initializer)(net)

    # predictions
    predicted_classes = tf.argmax(logits, 1)

    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'class_ids': predicted_classes[:, tf.newaxis],
            'probabilities': tf.keras.layers.Softmax(axis=1)(logits),
            'logits': logits
        }
        return tf.estimator.EstimatorSpec(mode, predictions=predictions)

    xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,
                                                              logits=logits)
    loss = tf.reduce_mean(xentropy)

    accuracy = tf.metrics.accuracy(labels, predicted_classes, name='acc_op')

    with tf.name_scope('metrics'):
        tf.summary.scalar('accuracy', accuracy[1])

    metrics = {
        'metrics/accuracy': accuracy,
    }

    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode,
                                          loss=loss,
                                          eval_metric_ops=metrics)

    assert mode == tf.estimator.ModeKeys.TRAIN
    optimizer = params['optimizer']

    # get operations related to batch normalization
    # see: https://stackoverflow.com/questions/45299522/batch-normalization-in-a-custom-estimator-in-tensorflow
    # see: https://github.com/tensorflow/tensorflow/issues/16455
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss,
                                      global_step=tf.train.get_global_step())

    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
Пример #20
0
 def mean_var_with_update():
     ema_apply_op = ema.apply([batch_mean, batch_var])
     with tf.control_dependencies([ema_apply_op]):
         return tf.identity(batch_mean), tf.identity(batch_var)
Пример #21
0
  def model_fn(features, labels, mode, params=None):
    """Build model and optimizer."""
    is_training = mode == tf.estimator.ModeKeys.TRAIN

    # Check training mode.
    if FLAGS.train_mode == 'pretrain':
      num_transforms = 2
      if FLAGS.fine_tune_after_block > -1:
        raise ValueError('Does not support layer freezing during pretraining,'
                         'should set fine_tune_after_block<=-1 for safety.')
    elif FLAGS.train_mode == 'finetune':
      num_transforms = 1
    else:
      raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))

    # Split channels, and optionally apply extra batched augmentation.
    features_list = tf.split(
        features, num_or_size_splits=num_transforms, axis=-1)
    if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
      features_list = data_util.batch_random_blur(
          features_list, FLAGS.image_size, FLAGS.image_size)
    features = tf.concat(features_list, 0)  # (num_transforms * bsz, h, w, c)

    # Base network forward pass.
    with tf.variable_scope('base_model'):
      if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
        # Finetune just supervised (linear) head will not update BN stats.
        model_train_mode = False
      else:
        # Pretrain or finetuen anything else will update BN stats.
        model_train_mode = is_training
      hiddens = model(features, is_training=model_train_mode)

    # Add head and loss.
    if FLAGS.train_mode == 'pretrain':
      tpu_context = params['context'] if 'context' in params else None
      hiddens_proj = model_util.projection_head(hiddens, is_training)
      contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
          hiddens_proj,
          hidden_norm=FLAGS.hidden_norm,
          temperature=FLAGS.temperature,
          tpu_context=tpu_context if is_training else None,
          loss_type=FLAGS.loss_type,
          flags=FLAGS)
      logits_sup = tf.zeros([params['batch_size'], num_classes])
      gradients_penalty = FLAGS.gradient_penalty_weight * obj_lib.add_gradients_penalty(features, model, model_train_mode)
    else:
      contrast_loss = tf.zeros([])
      logits_con = tf.zeros([params['batch_size'], 10])
      labels_con = tf.zeros([params['batch_size'], 10])
      hiddens = model_util.projection_head(hiddens, is_training)
      logits_sup = model_util.supervised_head(
          hiddens, num_classes, is_training)
      obj_lib.add_supervised_loss(
          labels=labels['labels'],
          logits=logits_sup,
          weights=labels['mask'])

    # Add weight decay to loss, for non-LARS optimizers.
    model_util.add_weight_decay(adjust_per_optimizer=True)
    loss = tf.losses.get_total_loss()

    if FLAGS.train_mode == 'pretrain':
      variables_to_train = tf.trainable_variables()
    else:
      collection_prefix = 'trainable_variables_inblock_'
      variables_to_train = []
      for j in range(FLAGS.fine_tune_after_block + 1, 6):
        variables_to_train += tf.get_collection(collection_prefix + str(j))
      assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

    tf.logging.info('===============Variables to train (begin)===============')
    tf.logging.info(variables_to_train)
    tf.logging.info('================Variables to train (end)================')

    learning_rate = model_util.learning_rate_schedule(
        FLAGS.learning_rate, num_train_examples)

    if is_training:
      if FLAGS.train_summary_steps > 0:
        # Compute stats for the summary.
        prob_con = tf.nn.softmax(logits_con)
        entropy_con = - tf.reduce_mean(
            tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))

        summary_writer = tf2.summary.create_file_writer(FLAGS.model_dir)
        with tf.control_dependencies([summary_writer.init()]):
          with summary_writer.as_default():
            should_record = tf.math.equal(
                tf.math.floormod(tf.train.get_global_step(),
                                 FLAGS.train_summary_steps), 0)
            with tf2.summary.record_if(should_record):
              contrast_acc = tf.equal(
                  tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1))
              contrast_acc = tf.reduce_mean(tf.cast(contrast_acc, tf.float32))
              label_acc = tf.equal(
                  tf.argmax(labels['labels'], 1), tf.argmax(logits_sup, axis=1))
              label_acc = tf.reduce_mean(tf.cast(label_acc, tf.float32))
              tf2.summary.scalar(
                  'train_contrast_loss',
                  contrast_loss,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'train_contrast_acc',
                  contrast_acc,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'train_label_accuracy',
                  label_acc,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'contrast_entropy',
                  entropy_con,
                  step=tf.train.get_global_step())
              tf2.summary.scalar(
                  'learning_rate', learning_rate,
                  step=tf.train.get_global_step())

      optimizer = model_util.get_optimizer(learning_rate)
      control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      if FLAGS.train_summary_steps > 0:
        control_deps.extend(tf.summary.all_v2_summary_ops())
      with tf.control_dependencies(control_deps):
        train_op = optimizer.minimize(
            loss, global_step=tf.train.get_or_create_global_step(),
            var_list=variables_to_train)

      if FLAGS.checkpoint:
        def scaffold_fn():
          """Scaffold function to restore non-logits vars from checkpoint."""
          tf.train.init_from_checkpoint(
              FLAGS.checkpoint,
              {v.op.name: v.op.name
               for v in tf.global_variables(FLAGS.variable_schema)})

          if FLAGS.zero_init_logits_layer:
            # Init op that initializes output layer parameters to zeros.
            output_layer_parameters = [
                var for var in tf.trainable_variables() if var.name.startswith(
                    'head_supervised')]
            tf.logging.info('Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
            with tf.control_dependencies([tf.global_variables_initializer()]):
              init_op = tf.group([
                  tf.assign(x, tf.zeros_like(x))
                  for x in output_layer_parameters])
            return tf.train.Scaffold(init_op=init_op)
          else:
            return tf.train.Scaffold()
      else:
        scaffold_fn = None

      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)
    else:

      def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
                    **kws):
        """Inner metric function."""
        metrics = {k: tf.metrics.mean(v, weights=mask)
                   for k, v in kws.items()}
        metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_sup, 1), tf.argmax(logits_sup, axis=1),
            weights=mask)
        metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
            tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
        metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
            tf.argmax(labels_con, 1), tf.argmax(logits_con, axis=1),
            weights=mask)
        metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
            tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)
        return metrics

      metrics = {
          'logits_sup': logits_sup,
          'labels_sup': labels['labels'],
          'logits_con': logits_con,
          'labels_con': labels_con,
          'mask': labels['mask'],
          'contrast_loss': tf.fill((params['batch_size'],), contrast_loss),
          'regularization_loss': tf.fill((params['batch_size'],),
                                         tf.losses.get_regularization_loss()),
      }

      return tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(metric_fn, metrics),
          scaffold_fn=None)
 def init_learner_state(self):
   learner_init_op = tf.initialize_variables(
       self.learner.learner.get_variables(tf.GraphKeys.GLOBAL_VARIABLES))
   local_inits = tf.get_collection(tf.GraphKeys.LOCAL_INIT_OP)
   with tf.control_dependencies(local_inits + [learner_init_op]):
     return self.learner.assign_state(self.learner.initial_state())
Пример #23
0
 def default_case_branch_raising_error():
     err_msg = "Invalid posterior estimate mode."
     raise_err = tf.Assert(tf.constant(False), data=[tf.constant(err_msg)])
     with tf.control_dependencies([raise_err]):
         return posterior_dist.mean()
  def train_op(self, ds_state):
    """Train with ES + Grads."""

    perturbs = ds_state.perturbation
    rp_grads = ds_state.grads
    meta_loss = ds_state.meta_loss
    antith_meta_loss = ds_state.antith_meta_loss

    # convert the [bs] shaped tensors to something like [bs, 1, 1, ...].
    broadcast_loss = [
        tf.reshape(meta_loss, [-1] + [1] * (len(p.shape.as_list()) - 1))
        for p in perturbs
    ]
    broadcast_antith_loss = [
        tf.reshape(antith_meta_loss, [-1] + [1] * (len(p.shape.as_list()) - 1))
        for p in perturbs
    ]

    # ES gradient:
    # f(x+s) * d/ds(log(p(s))) = f(x+s) * s / (std**2)
    # for antith:
    # (f(x+s) - f(x-s))*s/(2 * std**2)
    es_grads = []
    for pos_loss, neg_loss, perturb in py_utils.eqzip(broadcast_loss,
                                                      broadcast_antith_loss,
                                                      perturbs):
      # this is the same as having 2 samples.
      es_grads.append(
          (pos_loss - neg_loss) * perturb / (self.custom_getter.std**2))

    def mean_and_var(g):
      mean = tf.reduce_mean(g, axis=0, keep_dims=True)
      square_sum = tf.reduce_sum(tf.square((g - mean)), axis=0)
      var = square_sum / (g.shape.as_list()[0] - 1)
      return tf.squeeze(mean, 0), var + 1e-8

    def combine(es, rp):
      """Do inverse variance rescaling."""
      mean_es, var_es = mean_and_var(es)
      mean_rp, var_rp = mean_and_var(rp)

      es_var_inv = 1. / var_es
      rp_var_inv = 1. / var_rp

      den = es_var_inv + rp_var_inv
      combine_g = (mean_es * es_var_inv + mean_rp * rp_var_inv) / den

      weight_es = es_var_inv / den

      return combine_g, weight_es

    combine_grads, _ = zip(
        *[combine(es, rp) for es, rp in py_utils.eqzip(es_grads, rp_grads)])

    grads_vars = py_utils.eqzip(combine_grads,
                                self.learner.theta_mod.get_variables())

    grads_vars = common.clip_grads_vars(grads_vars, self.gradient_clip_by_value)
    grads_vars = common.assert_grads_vars_not_nan(grads_vars)

    self._did_use_getter_on_all_variables()

    with tf.device(self.remote_device):
      train_op = self.meta_opt.apply_gradients(grads_vars)

    with tf.control_dependencies([train_op]):
      op = common.assert_post_update_not_nan(grads_vars)
      return tf.group(train_op, op, name="train_op")
Пример #25
0
  def build_model(self, input_image, input_images_d_steps=None):
    """Build model and losses and train_ops.

    Args:
      input_image: A single (B, H, W, C) image, in [0, 255]
      input_images_d_steps: If training a discriminator, this is expected to
        be a (B*N, H, W, C) stack of images, where N=number of sub batches.
        See build_input.

    Returns:
      output_image and bpp if self.evaluation else None.
    """
    if input_images_d_steps is None:
      input_images_d_steps = []
    else:
      input_images_d_steps.set_shape(
          self.input_spec["input_images_d_steps"].shape)
      input_images_d_steps = tf.split(input_images_d_steps, self.num_steps_disc)

    if self.evaluation and input_images_d_steps:
      raise ValueError("Only need input_image for eval! {}".format(
          input_images_d_steps))

    input_image.set_shape(self.input_spec["input_image"].shape)

    self.build_transforms()

    if self.training:
      self._lpips_loss = LPIPSLoss(self._lpips_weight_path)
      self._lpips_loss_weight = self._config.loss_config.lpips_weight

    if self._setup_discriminator:
      self.build_discriminator()

    # Global step needs to be created for train, val and eval.
    global_step = tf.train.get_or_create_global_step()

    # Compute output graph.
    nodes_gen, bpp_pair = self._compute_compression_graph(input_image)

    if self.evaluation:
      tf.logging.info("Evaluation mode: build_model done.")
      reconstruction = tf.clip_by_value(nodes_gen.reconstruction, 0, 255.)
      return reconstruction, bpp_pair.total_qbpp

    nodes_disc = []  # list of Nodes, one for every sub-batch of disc
    for i, sub_batch in enumerate(input_images_d_steps):
      with tf.name_scope("sub_batch_disc_{}".format(i)):
        nodes, _ = self._compute_compression_graph(sub_batch,
                                                   create_summaries=False)
        nodes_disc.append(nodes)

    if self._auto_encoder_ckpt_path:
      self._prepare_auto_encoder_restore()

    # The following is inspired by compare_gan/gans/modular_gan.py:
    # Let's say we want to train the discriminator for D steps for every 1 step
    # of generator training. We do the unroll_graph=True options:
    # The features given to the model_fn are split into
    # D + 1 sub-batches. The code then creates D train_ops for the
    # discriminator, each feeding a different sub-batch of features
    # into the discriminator.
    # The train_op for the generator then depends on all these D train_ops
    # and uses the last (D+1 th) sub-batch.
    # Note that the graph is only created once.

    d_train_ops = []
    if self._setup_discriminator:
      tf.logging.info("Unrolling graph for discriminator")
      self._global_step_disc = tf.get_variable(
          "global_step_disc", [], dtype=global_step.dtype, trainable=False)
      with tf.name_scope("steps"):
        tf.summary.scalar("global_step", global_step)
        tf.summary.scalar("global_step_disc", self._global_step_disc)

      # Create optimizer once, and then call minimize on it multiple times
      # within self._train_discriminator.
      disc_optimizer = self._make_discriminator_optimizer(
          self._global_step_disc)
      for i, nodes in enumerate(nodes_disc):
        with tf.name_scope("train_disc_{}".format(i + 1)):
          with tf.control_dependencies(d_train_ops):
            d_train_ops.append(
                self._train_discriminator(
                    nodes, disc_optimizer, create_summaries=(i == 0)))

    # Depend on `d_train_ops`, which ensures all `self._num_steps_disc` steps of
    # the discriminator will run before the generator training op.
    with tf.control_dependencies(d_train_ops):
      train_op = self._train_generator(nodes_gen, bpp_pair, global_step)

    if self.training:
      self._train_op = train_op
Пример #26
0
 def mean_var():
     with tf.control_dependencies([ema_apply_op]):
         return tf.identity(batch_mean), tf.identity(batch_var)
Пример #27
0
    def call(self, inputs, prev_state):
        """Evaluates one timestep of the current neural stack cell.

    See section 3.4 of Grefenstette et al., 2015.

    Args:
      inputs: The inputs to the neural stack cell should be a tf.float32 tensor
        with shape [batch_size, embedding_size]
      prev_state: The NeuralStackState from the previous timestep.

    Returns:
      A tuple of the output of the stack as well as the new NeuralStackState.
    """
        batch_size = tf.shape(inputs)[0]

        # Call the controller and get controller interface values.
        with tf.control_dependencies([prev_state.read_strengths]):
            controller_output = self.call_controller(
                inputs, prev_state.read_values, prev_state.controller_state,
                batch_size)

        # Always write input values to memory regardless of push strength.
        # See Equation-1 in Grefenstette et al., 2015.
        new_memory_values = prev_state.memory_values + tf.reduce_sum(
            tf.expand_dims(controller_output.write_values, axis=2) *
            prev_state.write_strengths,
            axis=1)

        # Attenuate the read strengths of existing memory values depending on the
        # current pop strength.
        # See Equation-2 in Grefenstette et al., 2015.
        new_read_strengths = prev_state.read_strengths
        for h in range(self._num_read_heads - 1, -1, -1):
            new_read_strengths = tf.nn.relu(new_read_strengths - tf.nn.relu(
                tf.slice(controller_output.pop_strengths, [0, h, 0, 0],
                         [-1, 1, -1, -1]) -
                tf.expand_dims(tf.reduce_sum(
                    new_read_strengths * self.get_read_mask(h), axis=2),
                               axis=3)))

        # Combine all write heads and their associated push values into a single set
        # of read weights.
        new_read_strengths += tf.reduce_sum(controller_output.push_strengths *
                                            prev_state.write_strengths,
                                            axis=1,
                                            keep_dims=True)

        # Calculate the "top" value of the stack by looking at read strengths.
        # See Equation-3 in Grefenstette et al., 2015.
        new_read_values = tf.reduce_sum(
            tf.minimum(
                new_read_strengths,
                tf.nn.relu(1 - tf.expand_dims(tf.reduce_sum(
                    new_read_strengths * tf.concat([
                        self.get_read_mask(h)
                        for h in range(self._num_read_heads)
                    ],
                                                   axis=1),
                    axis=2),
                                              axis=3))) *
            tf.expand_dims(new_memory_values, axis=1),
            axis=2)

        # Temporarily split write strengths apart so they can be shifted in
        # different directions.
        write_strengths_by_head = tf.split(prev_state.write_strengths,
                                           self._num_write_heads,
                                           axis=1)
        # Shift the write strengths for each write head in the direction indicated
        # by get_write_head_offset().
        new_write_strengths = tf.concat([
            tf.roll(
                write_strength, shift=self.get_write_head_offset(h), axis=2)
            for h, write_strength in enumerate(write_strengths_by_head)
        ],
                                        axis=1)

        return (controller_output.outputs,
                NeuralStackState(controller_state=controller_output.state,
                                 read_values=new_read_values,
                                 memory_values=new_memory_values,
                                 read_strengths=new_read_strengths,
                                 write_strengths=new_write_strengths))
Пример #28
0
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
  """Model definition entry.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN and EVAL.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.

  Raises:
    RuntimeError: if both ckpt and backbone_ckpt are set.
  """
  is_tpu = params['strategy'] == 'tpu'
  if params['img_summary_steps']:
    utils.image('input_image', features, is_tpu)
  training_hooks = []
  params['is_training_bn'] = (mode == tf.estimator.ModeKeys.TRAIN)

  if params['use_keras_model']:

    def model_fn(inputs):
      model = efficientdet_keras.EfficientDetNet(
          config=hparams_config.Config(params))
      cls_out_list, box_out_list = model(inputs, params['is_training_bn'])
      cls_outputs, box_outputs = {}, {}
      for i in range(params['min_level'], params['max_level'] + 1):
        cls_outputs[i] = cls_out_list[i - params['min_level']]
        box_outputs[i] = box_out_list[i - params['min_level']]
      return cls_outputs, box_outputs
  else:
    model_fn = functools.partial(model, config=hparams_config.Config(params))

  precision = utils.get_precision(params['strategy'], params['mixed_precision'])
  cls_outputs, box_outputs = utils.build_model_with_precision(
      precision, model_fn, features, params['is_training_bn'])

  levels = cls_outputs.keys()
  for level in levels:
    cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
    box_outputs[level] = tf.cast(box_outputs[level], tf.float32)

  # Set up training loss and learning rate.
  update_learning_rate_schedule_parameters(params)
  global_step = tf.train.get_or_create_global_step()
  learning_rate = learning_rate_schedule(params, global_step)

  # cls_loss and box_loss are for logging. only total_loss is optimized.
  det_loss, cls_loss, box_loss = detection_loss(
      cls_outputs, box_outputs, labels, params)
  reg_l2loss = reg_l2_loss(params['weight_decay'])
  total_loss = det_loss + reg_l2loss

  if mode == tf.estimator.ModeKeys.TRAIN:
    utils.scalar('lrn_rate', learning_rate, is_tpu)
    utils.scalar('trainloss/cls_loss', cls_loss, is_tpu)
    utils.scalar('trainloss/box_loss', box_loss, is_tpu)
    utils.scalar('trainloss/det_loss', det_loss, is_tpu)
    utils.scalar('trainloss/reg_l2_loss', reg_l2loss, is_tpu)
    utils.scalar('trainloss/loss', total_loss, is_tpu)
    train_epochs = tf.cast(global_step, tf.float32) / params['steps_per_epoch']
    utils.scalar('train_epochs', train_epochs, is_tpu)

  moving_average_decay = params['moving_average_decay']
  if moving_average_decay:
    ema = tf.train.ExponentialMovingAverage(
        decay=moving_average_decay, num_updates=global_step)
    ema_vars = utils.get_ema_vars()

  if mode == tf.estimator.ModeKeys.TRAIN:
    if params['optimizer'].lower() == 'sgd':
      optimizer = tf.train.MomentumOptimizer(
          learning_rate, momentum=params['momentum'])
    elif params['optimizer'].lower() == 'adam':
      optimizer = tf.train.AdamOptimizer(learning_rate)
    else:
      raise ValueError('optimizers should be adam or sgd')

    if is_tpu:
      optimizer = tf.tpu.CrossShardOptimizer(optimizer)
    elif params['mixed_precision']:
      optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer)

    # Batch norm requires update_ops to be added as a train_op dependency.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    var_list = tf.trainable_variables()
    if variable_filter_fn:
      var_list = variable_filter_fn(var_list)

    if params.get('clip_gradients_norm', None):
      logging.info('clip gradients norm by %f', params['clip_gradients_norm'])
      grads_and_vars = optimizer.compute_gradients(total_loss, var_list)
      with tf.name_scope('clip'):
        grads = [gv[0] for gv in grads_and_vars]
        tvars = [gv[1] for gv in grads_and_vars]
        # First clip each variable's norm, then clip global norm.
        clip_norm = abs(params['clip_gradients_norm'])
        clipped_grads = [
            tf.clip_by_norm(g, clip_norm) if g is not None else None
            for g in grads
        ]
        clipped_grads, _ = tf.clip_by_global_norm(clipped_grads, clip_norm)
        utils.scalar('gradient_norm', tf.linalg.global_norm(clipped_grads),
                     is_tpu)
        grads_and_vars = list(zip(clipped_grads, tvars))

      with tf.control_dependencies(update_ops):
        train_op = optimizer.apply_gradients(grads_and_vars, global_step)
    else:
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(
            total_loss, global_step, var_list=var_list)

    if moving_average_decay:
      with tf.control_dependencies([train_op]):
        train_op = ema.apply(ema_vars)

  else:
    train_op = None

  eval_metrics = None
  if mode == tf.estimator.ModeKeys.EVAL:

    def metric_fn(**kwargs):
      """Returns a dictionary that has the evaluation metrics."""
      if params['nms_configs'].get('pyfunc', True):
        detections_bs = []
        for index in range(kwargs['boxes'].shape[0]):
          nms_configs = params['nms_configs']
          detections = tf.numpy_function(
              functools.partial(nms_np.per_class_nms, nms_configs=nms_configs),
              [
                  kwargs['boxes'][index],
                  kwargs['scores'][index],
                  kwargs['classes'][index],
                  tf.slice(kwargs['image_ids'], [index], [1]),
                  tf.slice(kwargs['image_scales'], [index], [1]),
                  params['num_classes'],
                  nms_configs['max_output_size'],
              ], tf.float32)
          detections_bs.append(detections)
        detections_bs = postprocess.transform_detections(
            tf.stack(detections_bs))
      else:
        # These two branches should be equivalent, but currently they are not.
        # TODO(tanmingxing): enable the non_pyfun path after bug fix.
        nms_boxes, nms_scores, nms_classes, _ = postprocess.per_class_nms(
            params, kwargs['boxes'], kwargs['scores'], kwargs['classes'],
            kwargs['image_scales'])
        img_ids = tf.cast(
            tf.expand_dims(kwargs['image_ids'], -1), nms_scores.dtype)
        detections_bs = [
            img_ids * tf.ones_like(nms_scores),
            nms_boxes[:, :, 1],
            nms_boxes[:, :, 0],
            nms_boxes[:, :, 3] - nms_boxes[:, :, 1],
            nms_boxes[:, :, 2] - nms_boxes[:, :, 0],
            nms_scores,
            nms_classes,
        ]
        detections_bs = tf.stack(detections_bs, axis=-1, name='detnections')

      if params.get('testdev_dir', None):
        logging.info('Eval testdev_dir %s', params['testdev_dir'])
        eval_metric = coco_metric.EvaluationMetric(
            testdev_dir=params['testdev_dir'])
        coco_metrics = eval_metric.estimator_metric_fn(detections_bs,
                                                       tf.zeros([1]))
      else:
        logging.info('Eval val with groudtruths %s.', params['val_json_file'])
        eval_metric = coco_metric.EvaluationMetric(
            filename=params['val_json_file'], label_map=params['label_map'])
        coco_metrics = eval_metric.estimator_metric_fn(
            detections_bs, kwargs['groundtruth_data'])

      # Add metrics to output.
      cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
      box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])
      output_metrics = {
          'cls_loss': cls_loss,
          'box_loss': box_loss,
      }
      output_metrics.update(coco_metrics)
      return output_metrics

    cls_loss_repeat = tf.reshape(
        tf.tile(tf.expand_dims(cls_loss, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])
    box_loss_repeat = tf.reshape(
        tf.tile(tf.expand_dims(box_loss, 0), [
            params['batch_size'],
        ]), [params['batch_size'], 1])

    cls_outputs = postprocess.to_list(cls_outputs)
    box_outputs = postprocess.to_list(box_outputs)
    params['nms_configs']['max_nms_inputs'] = anchors.MAX_DETECTION_POINTS
    boxes, scores, classes = postprocess.pre_nms(params, cls_outputs,
                                                 box_outputs)
    metric_fn_inputs = {
        'cls_loss_repeat': cls_loss_repeat,
        'box_loss_repeat': box_loss_repeat,
        'image_ids': labels['source_ids'],
        'groundtruth_data': labels['groundtruth_data'],
        'image_scales': labels['image_scales'],
        'boxes': boxes,
        'scores': scores,
        'classes': classes,
    }
    eval_metrics = (metric_fn, metric_fn_inputs)

  checkpoint = params.get('ckpt') or params.get('backbone_ckpt')

  if checkpoint and mode == tf.estimator.ModeKeys.TRAIN:
    # Initialize the model from an EfficientDet or backbone checkpoint.
    if params.get('ckpt') and params.get('backbone_ckpt'):
      raise RuntimeError(
          '--backbone_ckpt and --checkpoint are mutually exclusive')

    if params.get('backbone_ckpt'):
      var_scope = params['backbone_name'] + '/'
      if params['ckpt_var_scope'] is None:
        # Use backbone name as default checkpoint scope.
        ckpt_scope = params['backbone_name'] + '/'
      else:
        ckpt_scope = params['ckpt_var_scope'] + '/'
    else:
      # Load every var in the given checkpoint
      var_scope = ckpt_scope = '/'

    def scaffold_fn():
      """Loads pretrained model through scaffold function."""
      logging.info('restore variables from %s', checkpoint)

      var_map = utils.get_ckpt_var_map(
          ckpt_path=checkpoint,
          ckpt_scope=ckpt_scope,
          var_scope=var_scope,
          skip_mismatch=params['skip_mismatch'])

      tf.train.init_from_checkpoint(checkpoint, var_map)
      return tf.train.Scaffold()
  elif mode == tf.estimator.ModeKeys.EVAL and moving_average_decay:

    def scaffold_fn():
      """Load moving average variables for eval."""
      logging.info('Load EMA vars with ema_decay=%f', moving_average_decay)
      restore_vars_dict = ema.variables_to_restore(ema_vars)
      saver = tf.train.Saver(restore_vars_dict)
      return tf.train.Scaffold(saver=saver)
  else:
    scaffold_fn = None

  if is_tpu:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        eval_metrics=eval_metrics,
        host_call=utils.get_tpu_host_call(global_step, params),
        scaffold_fn=scaffold_fn,
        training_hooks=training_hooks)
  else:
    # Profile every 1K steps.
    if params.get('profile', False):
      profile_hook = tf.estimator.ProfilerHook(
          save_steps=1000, output_dir=params['model_dir'], show_memory=True)
      training_hooks.append(profile_hook)

      # Report memory allocation if OOM; it will slow down the running.
      class OomReportingHook(tf.estimator.SessionRunHook):

        def before_run(self, run_context):
          return tf.estimator.SessionRunArgs(
              fetches=[],
              options=tf.RunOptions(report_tensor_allocations_upon_oom=True))

      training_hooks.append(OomReportingHook())

    logging_hook = tf.estimator.LoggingTensorHook(
        {
            'step': global_step,
            'det_loss': det_loss,
            'cls_loss': cls_loss,
            'box_loss': box_loss,
        },
        every_n_iter=params.get('iterations_per_loop', 100),
    )
    training_hooks.append(logging_hook)

    eval_metric_ops = (
        eval_metrics[0](**eval_metrics[1]) if eval_metrics else None)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=total_loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        scaffold=scaffold_fn() if scaffold_fn else None,
        training_hooks=training_hooks)
Пример #29
0
def load_examples():
    if a.input_dir is None or not os.path.exists(a.input_dir):
        raise Exception("input_dir does not exist")

    input_paths = glob.glob(os.path.join(a.input_dir, "*", "*.jpg"))
    decode = tf.image.decode_jpeg
    if len(input_paths) == 0:
        input_paths = glob.glob(os.path.join(a.input_dir, "*", "*.png"))
        decode = tf.image.decode_png

    print("len = ", len(input_paths))

    if len(input_paths) == 0:
        raise Exception("input_dir contains no image files")

    def get_name(path):
        name, _ = os.path.splitext(os.path.basename(path))
        return name

    if all(get_name(path).isdigit() for path in input_paths):
        input_paths = sorted(input_paths, key=lambda path: int(get_name(path)))
    else:
        # if the image names are numbers, sort by the value rather than asciibetically
        # having sorted inputs means that the outputs are sorted in test mode
        input_paths = sorted(input_paths)

    with tf.name_scope("load_images"):
        path_queue = tf.train.string_input_producer(input_paths,
                                                    shuffle=a.mode == "train")
        reader = tf.WholeFileReader()
        paths, contents = reader.read(path_queue)
        raw_input = decode(contents)
        raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32)

        assertion = tf.assert_equal(tf.shape(raw_input)[2],
                                    1,
                                    message="image does not have 1 channels")
        with tf.control_dependencies([assertion]):
            raw_input = tf.identity(raw_input)

        raw_input.set_shape([IMAGE_HEIGHT, IMAGE_WIDTH * 2, 1])

        # break apart image pair and move to range [-1, 1]
        width = tf.shape(raw_input)[1]  # [height, width, channels]
        a_images = preprocess(raw_input[:, :width // 2, :])
        b_images = preprocess(raw_input[:, width // 2:, :])

    if a.which_direction == "AtoB":
        inputs, targets = [a_images, b_images]
    elif a.which_direction == "BtoA":
        inputs, targets = [b_images, a_images]
    else:
        raise Exception("invalid direction")

    # synchronize seed for image operations so that we do the same operations to both
    # input and output images
    seed = random.randint(0, 2**31 - 1)

    def transform(image):
        r = image
        if a.flip:
            r = tf.image.random_flip_left_right(r, seed=seed)

        if a.mode == "train":
            # crop image
            h = r.get_shape().as_list()[0]
            w = r.get_shape().as_list()[1]
            h_offset = tf.cast(tf.floor(
                tf.random_uniform([1], 0, h - CROP_SIZE + 1, seed=seed)),
                               dtype=tf.int32)
            w_offset = tf.cast(tf.floor(
                tf.random_uniform([1], 0, w - CROP_SIZE + 1, seed=seed)),
                               dtype=tf.int32)
            r = tf.image.crop_to_bounding_box(r, h_offset[0], w_offset[0],
                                              CROP_SIZE, CROP_SIZE)
        return r

    with tf.name_scope("input_images"):
        input_images = transform(inputs)

    with tf.name_scope("target_images"):
        target_images = transform(targets)

    paths_batch, inputs_batch, targets_batch = tf.train.batch(
        [paths, input_images, target_images], batch_size=a.batch_size)
    steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size))

    print(inputs_batch.get_shape().as_list())

    return Examples(
        paths=paths_batch,
        inputs=inputs_batch,
        targets=targets_batch,
        count=len(input_paths),
        steps_per_epoch=steps_per_epoch,
    )
Пример #30
0
 def mean_var_update():
     with _tf.control_dependencies([_tf.assign(shadow_mean, _tf.multiply(shadow_mean, decay)
                                                                      + _tf.multiply(batch_mean, 1. - decay)),
                                    _tf.assign(shadow_var, _tf.multiply(shadow_var, decay)
                                                                     + _tf.multiply(batch_var, 1. - decay))]):
         return _tf.identity(batch_mean), _tf.identity(batch_var)