def _prepare_decoder_input(area_encoding,
                           decoder_nonpadding,
                           features,
                           hparams,
                           embed_scope=None):
    """Prepare the input for the action decoding.

  Args:
    area_encoding: the encoder output in shape of [batch_size, area_len, depth].
    decoder_nonpadding: the nonpadding mask for the decoding seq.
    features: a dictionary of tensors in the shape of [batch_size, seq_length].
    hparams: the hyperparameters.
    embed_scope: the embedding scope.
  Returns:
    decoder_input: decoder input in shape of
        [batch_size, num_steps, latent_depth]
    decoder_self_attention_bias: decoder attention bias.
  """
    with tf.variable_scope("prepare_decoder_input", reuse=tf.AUTO_REUSE):
        shape = common_layers.shape_list(features["task"])
        batch_size = shape[0]
        encoder_input_length = shape[1]
        depth = common_layers.shape_list(area_encoding)[-1]
        if hparams.span_aggregation == "sum":
            verb_embeds = span_embedding(encoder_input_length, area_encoding,
                                         features["verb_refs"], hparams)
            object_embeds = span_embedding(encoder_input_length, area_encoding,
                                           features["obj_refs"], hparams)
            input_embeds = span_embedding(encoder_input_length, area_encoding,
                                          features["input_refs"], hparams)
            non_input_embeds = tf.tile(
                tf.expand_dims(
                    tf.expand_dims(
                        tf.get_variable(name="non_input_embeds",
                                        shape=[depth]), 0), 0),
                [batch_size,
                 tf.shape(features["input_refs"])[1], 1])
            input_embeds = tf.where(
                tf.tile(
                    tf.expand_dims(
                        tf.equal(features["input_refs"][:, :, 1],
                                 features["input_refs"][:, :, 0]), 2),
                    [1, 1, tf.shape(input_embeds)[-1]]), non_input_embeds,
                input_embeds)
        elif hparams.span_aggregation == "mean":
            area_encoding = area_encoding[:, :encoder_input_length, :]
            verb_embeds = span_average_embed(area_encoding,
                                             features["verb_refs"],
                                             embed_scope)
            object_embeds = span_average_embed(area_encoding,
                                               features["obj_refs"],
                                               embed_scope)
            input_embeds = span_average_embed(area_encoding,
                                              features["input_refs"],
                                              embed_scope)
        else:
            raise ValueError("Unrecognized span aggregation method %s" %
                             (hparams.span_aggregation))
        embeds = verb_embeds + object_embeds + input_embeds
        embeds = tf.multiply(tf.expand_dims(decoder_nonpadding, 2), embeds)
        start_embed = tf.tile(
            tf.expand_dims(
                tf.expand_dims(
                    tf.get_variable(name="start_step_embed", shape=[depth]),
                    0), 0), [batch_size, 1, 1])
        embeds = tf.concat([start_embed, embeds], axis=1)
        embeds = embeds[:, :-1, :]
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                common_layers.shape_list(features["verb_refs"])[1]))
        if hparams.pos == "timing":
            decoder_input = common_attention.add_timing_signal_1d(embeds)
        elif hparams.pos == "emb":
            decoder_input = common_attention.add_positional_embedding(
                embeds, hparams.max_length, "targets_positional_embedding",
                None)
        else:
            decoder_input = embeds
        return decoder_input, decoder_self_attention_bias
예제 #2
0
파일: data_util.py 프로젝트: markWJJ/FROST
def to_grayscale(image, keep_channels=True):
    image = tf.image.rgb_to_grayscale(image)
    if keep_channels:
        image = tf.tile(image, [1, 1, 3])
    return image
예제 #3
0
def expand_tile(value, size):
    """Add a new axis of given size."""
    value = tf.convert_to_tensor(value, name='value')
    ndims = value.shape.ndims
    return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims)
예제 #4
0
    def log_prob(self, x, b_enc=None, b_dec=None):
        """Gets the log probability and conditionals for observations.

    Args:
      x: A batch of observations to compute the log probability of, sized
          `[batch_size, num_dims]`.
      b_enc: External encoder bias terms (`b` in [1]), sized
          `[batch_size, num_hidden]`, or None if the internal bias term should
          be used.
      b_dec: External decoder bias terms (`c` in [1]), sized
         `[batch_size, num_dims]`, or None if the internal bias term should be
         used.

    Returns:
       log_prob: The log probabilities of each observation in the batch, sized
           `[batch_size]`.
       cond_probs: The conditional probabilities at each index for every batch,
           sized `[batch_size, num_dims]`.
    """
        batch_size = tf.shape(x)[0]

        b_enc = b_enc if b_enc is not None else self.b_enc
        b_dec = b_dec if b_dec is not None else self.b_dec

        # Broadcast if needed.
        if b_enc.shape[0] == 1 != batch_size:
            b_enc = tf.tile(b_enc, [batch_size, 1])
        if b_dec.shape[0] == 1 != batch_size:
            b_dec = tf.tile(b_dec, [batch_size, 1])

        # Initial condition before the loop.
        a_0 = b_enc
        log_p_0 = tf.zeros([batch_size, 1])
        cond_p_0 = []

        x_arr = tf.unstack(
            tf.reshape(tf.transpose(x), [self.num_dims, batch_size, 1]))
        w_enc_arr = tf.unstack(self.w_enc)
        w_dec_arr = tf.unstack(self.w_dec_t)
        b_dec_arr = tf.unstack(
            tf.reshape(tf.transpose(b_dec), [self.num_dims, batch_size, 1]))

        def loop_body(i, a, log_p, cond_p):
            """Accumulate hidden state, log_p, and cond_p for index i."""
            # Get variables for time step.
            w_enc_i = w_enc_arr[i]
            w_dec_i = w_dec_arr[i]
            b_dec_i = b_dec_arr[i]
            v_i = x_arr[i]

            cond_p_i, _ = self._cond_prob(a, w_dec_i, b_dec_i)

            # Get log probability for this value. Log space avoids numerical issues.
            log_p_i = v_i * _safe_log(cond_p_i) + (
                1 - v_i) * _safe_log(1 - cond_p_i)

            # Accumulate log probability.
            log_p_new = log_p + log_p_i

            # Save conditional probabilities.
            cond_p_new = cond_p + [cond_p_i]

            # Encode value and add to hidden units.
            a_new = a + tf.matmul(v_i, w_enc_i)

            return a_new, log_p_new, cond_p_new

        # Build the actual loop
        a, log_p, cond_p = a_0, log_p_0, cond_p_0
        for i in range(self.num_dims):
            a, log_p, cond_p = loop_body(i, a, log_p, cond_p)

        return (tf.squeeze(log_p, squeeze_dims=[1]),
                tf.transpose(tf.squeeze(tf.stack(cond_p), [2])))
예제 #5
0
파일: RBF.py 프로젝트: changqj/GAMES102
 def kernel(self, _x, a, b):  # 训练时使用
     x1 = tf.tile(_x, [1, self.hidden_size])  # 将x水平复制 hidden次
     x2 = tf.reshape(x1, [-1, self.hidden_size, self.feature])
     dist = tf.reduce_sum((a * x2 + b)**2, 2)
     return tf.exp(-dist / 2)
예제 #6
0
    def _create_initial_state(self, initial_ids, initial_cache):
        """Return initial state dictionary and its shape invariants.

    Args:
      initial_ids: initial ids to pass into the symbols_to_logits_fn.
        int tensor with shape [batch_size, 1]
      initial_cache: dictionary storing values to be passed into the
        symbols_to_logits_fn.

    Returns:
        state and shape invariant dictionaries with keys from _StateKeys
    """
        for key, value in initial_cache.items():
            for inner_value in tf.nest.flatten(value):
                if inner_value.dtype != self.dtype:
                    raise TypeError(
                        "initial_cache element for key '%s' has dtype %s that does not "
                        "match SequenceBeamSearch's dtype of %s. Value: %s" %
                        (key, value.dtype.name, self.dtype.name, inner_value))

        # Current loop index (starts at 0)
        cur_index = tf.constant(0)

        # Create alive sequence with shape [batch_size, beam_size, 1]
        alive_seq = _expand_to_beam_size(initial_ids, self.beam_size)
        alive_seq = tf.expand_dims(alive_seq, axis=2)
        if self.padded_decode:
            alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])

        # Create tensor for storing initial log probabilities.
        # Assume initial_ids are prob 1.0
        initial_log_probs = tf.constant([[0.] + [-float("inf")] *
                                         (self.beam_size - 1)],
                                        dtype=self.dtype)
        alive_log_probs = tf.tile(initial_log_probs, [self.batch_size, 1])

        # Expand all values stored in the dictionary to the beam size, so that each
        # beam has a separate cache.
        alive_cache = tf.nest.map_structure(
            lambda t: _expand_to_beam_size(t, self.beam_size), initial_cache)

        # Initialize tensor storing finished sequences with filler values.
        finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)

        # Set scores of the initial finished seqs to negative infinity.
        finished_scores = tf.ones([self.batch_size, self.beam_size],
                                  dtype=self.dtype) * -inf(self.dtype)

        # Initialize finished flags with all False values.
        finished_flags = tf.zeros([self.batch_size, self.beam_size], tf.bool)

        # Create state dictionary
        state = {
            _StateKeys.CUR_INDEX: cur_index,
            _StateKeys.ALIVE_SEQ: alive_seq,
            _StateKeys.ALIVE_LOG_PROBS: alive_log_probs,
            _StateKeys.ALIVE_CACHE: alive_cache,
            _StateKeys.FINISHED_SEQ: finished_seq,
            _StateKeys.FINISHED_SCORES: finished_scores,
            _StateKeys.FINISHED_FLAGS: finished_flags
        }

        # Create state invariants for each value in the state dictionary. Each
        # dimension must be a constant or None. A None dimension means either:
        #   1) the dimension's value is a tensor that remains the same but may
        #      depend on the input sequence to the model (e.g. batch size).
        #   2) the dimension may have different values on different iterations.
        if self.padded_decode:
            state_shape_invariants = {
                _StateKeys.CUR_INDEX:
                tf.TensorShape([]),
                _StateKeys.ALIVE_SEQ:
                tf.TensorShape([
                    self.batch_size, self.beam_size, self.max_decode_length + 1
                ]),
                _StateKeys.ALIVE_LOG_PROBS:
                tf.TensorShape([self.batch_size, self.beam_size]),
                _StateKeys.ALIVE_CACHE:
                tf.nest.map_structure(_get_shape, alive_cache),
                _StateKeys.FINISHED_SEQ:
                tf.TensorShape([
                    self.batch_size, self.beam_size, self.max_decode_length + 1
                ]),
                _StateKeys.FINISHED_SCORES:
                tf.TensorShape([self.batch_size, self.beam_size]),
                _StateKeys.FINISHED_FLAGS:
                tf.TensorShape([self.batch_size, self.beam_size])
            }
        else:
            state_shape_invariants = {
                _StateKeys.CUR_INDEX:
                tf.TensorShape([]),
                _StateKeys.ALIVE_SEQ:
                tf.TensorShape([None, self.beam_size, None]),
                _StateKeys.ALIVE_LOG_PROBS:
                tf.TensorShape([None, self.beam_size]),
                _StateKeys.ALIVE_CACHE:
                tf.nest.map_structure(_get_shape_keep_last_dim, alive_cache),
                _StateKeys.FINISHED_SEQ:
                tf.TensorShape([None, self.beam_size, None]),
                _StateKeys.FINISHED_SCORES:
                tf.TensorShape([None, self.beam_size]),
                _StateKeys.FINISHED_FLAGS:
                tf.TensorShape([None, self.beam_size])
            }

        return state, state_shape_invariants
예제 #7
0
                                dtype=tf.float32,
                                trainable=True,
                                initializer=tf.glorot_uniform_initializer())
    
    # SOS represents starting marker 
    # It tells the decoder that it is about to decode the first word of the output
    # I have set SOS as a trainable parameter
    
    Wc = tf.get_variable("Wc", shape=[2*hidden_size,embd_dim],
                            dtype=tf.float32,
                            trainable=True,
                            initializer=tf.glorot_uniform_initializer())
    


SOS = tf.tile(SOS,[N,1]) #now SOS shape: [N,embd_dim]
inp = SOS
hidden=final_encoded_state
cell=tf.zeros([N, 2*hidden_size], dtype=tf.float32)
decoder_outputs=tf.TensorArray(size=max_summary_len, dtype=tf.float32)
outputs=tf.TensorArray(size=max_summary_len, dtype=tf.int32)

embd_summary_t = tf.transpose(embd_summary,[1,0,2])

#encoder_context_vector shape is [32,600]
#The above encoder_context_vector is simple multiplication not matmul where the scores like 0.3 is multiplied for one word hidden state and so on
for i in range(max_summary_len):
    
    attention_scores = align(encoder_states,hidden)
    encoder_context_vector = tf.reduce_sum(encoder_states*attention_scores,axis=1) #[32,178,600]* #[32 178 1] #shape [ 32 600]
    inp = dropout(inp,rate=0.3,training=tf_train)
예제 #8
0
def box_matching(boxes, gt_boxes, gt_classes):
    """Match boxes to groundtruth boxes.

  Given the proposal boxes and the groundtruth boxes and classes, perform the
  groundtruth matching by taking the argmax of the IoU between boxes and
  groundtruth boxes.

  Args:
    boxes: a tensor of shape of [batch_size, N, 4] representing the box
      coordiantes to be matched to groundtruth boxes.
    gt_boxes: a tensor of shape of [batch_size, MAX_INSTANCES, 4] representing
      the groundtruth box coordinates. It is padded with -1s to indicate the
      invalid boxes.
    gt_classes: [batch_size, MAX_INSTANCES] representing the groundtruth box
      classes. It is padded with -1s to indicate the invalid classes.

  Returns:
    matched_gt_boxes: a tensor of shape of [batch_size, N, 4], representing
      the matched groundtruth box coordinates for each input box. If the box
      does not overlap with any groundtruth boxes, the matched boxes of it
      will be set to all 0s.
    matched_gt_classes: a tensor of shape of [batch_size, N], representing
      the matched groundtruth classes for each input box. If the box does not
      overlap with any groundtruth boxes, the matched box classes of it will
      be set to 0, which corresponds to the background class.
    matched_gt_indices: a tensor of shape of [batch_size, N], representing
      the indices of the matched groundtruth boxes in the original gt_boxes
      tensor. If the box does not overlap with any groundtruth boxes, the
      index of the matched groundtruth will be set to -1.
    matched_iou: a tensor of shape of [batch_size, N], representing the IoU
      between the box and its matched groundtruth box. The matched IoU is the
      maximum IoU of the box and all the groundtruth boxes.
    iou: a tensor of shape of [batch_size, N, K], representing the IoU matrix
      between boxes and the groundtruth boxes. The IoU between a box and the
      invalid groundtruth boxes whose coordinates are [-1, -1, -1, -1] is -1.
  """
    # Compute IoU between boxes and gt_boxes.
    # iou <- [batch_size, N, K]
    iou = box_utils.bbox_overlap(boxes, gt_boxes)

    # max_iou <- [batch_size, N]
    # 0.0 -> no match to gt, or -1.0 match to no gt
    matched_iou = tf.reduce_max(iou, axis=-1)

    # background_box_mask <- bool, [batch_size, N]
    background_box_mask = tf.less_equal(matched_iou, 0.0)

    argmax_iou_indices = tf.argmax(iou, axis=-1, output_type=tf.int32)

    argmax_iou_indices_shape = tf.shape(argmax_iou_indices)
    batch_indices = (
        tf.expand_dims(tf.range(argmax_iou_indices_shape[0]), axis=-1) *
        tf.ones([1, argmax_iou_indices_shape[-1]], dtype=tf.int32))
    gather_nd_indices = tf.stack([batch_indices, argmax_iou_indices], axis=-1)

    matched_gt_boxes = tf.gather_nd(gt_boxes, gather_nd_indices)
    matched_gt_boxes = tf.where(
        tf.tile(tf.expand_dims(background_box_mask, axis=-1), [1, 1, 4]),
        tf.zeros_like(matched_gt_boxes, dtype=tf.float32), matched_gt_boxes)

    matched_gt_classes = tf.gather_nd(gt_classes, gather_nd_indices)
    matched_gt_classes = tf.where(background_box_mask,
                                  tf.zeros_like(matched_gt_classes),
                                  matched_gt_classes)

    matched_gt_indices = tf.where(background_box_mask,
                                  -tf.ones_like(argmax_iou_indices),
                                  argmax_iou_indices)

    return (matched_gt_boxes, matched_gt_classes, matched_gt_indices,
            matched_iou, iou)
예제 #9
0
def compute_floor_offsets_with_indices(y_source,
                                       x_source,
                                       y_target=None,
                                       x_target=None):
    """Computes offsets from floored source(floored) to target coordinates.

  This function computes the offsets from source coordinates ("floored" as if
  they were put on the grids) to target coordinates. Note that the input
  coordinates should be the "absolute" coordinates in terms of the output image
  dimensions as opposed to the normalized coordinates (i.e. values in [0, 1]).
  If the input y and x source have the second dimension (representing the
  neighboring pixels), then the offsets are computed from each of the
  neighboring pixels to their corresponding target (first dimension).

  Args:
    y_source: A tensor with shape [num_points] (or [num_points, num_neighbors])
      representing the absolute y-coordinates (in the output image space) of the
      source points.
    x_source: A tensor with shape [num_points] (or [num_points, num_neighbors])
      representing the absolute x-coordinates (in the output image space) of the
      source points.
    y_target: A tensor with shape [num_points] representing the absolute
      y-coordinates (in the output image space) of the target points. If not
      provided, then y_source is used as the targets.
    x_target: A tensor with shape [num_points] representing the absolute
      x-coordinates (in the output image space) of the target points. If not
      provided, then x_source is used as the targets.

  Returns:
    A tuple of two tensors:
      offsets: A tensor with shape [num_points, 2] (or
        [num_points, num_neighbors, 2]) representing the offsets of each input
        point.
      indices: A tensor with shape [num_points, 2] (or
        [num_points, num_neighbors, 2]) representing the indices of where the
        offsets should be retrieved in the output image dimension space.

  Raise:
    ValueError: source and target shapes have unexpected values.
  """
    y_source_floored = tf.floor(y_source)
    x_source_floored = tf.floor(x_source)

    source_shape = shape_utils.combined_static_and_dynamic_shape(y_source)
    if y_target is None and x_target is None:
        y_target = y_source
        x_target = x_source
    else:
        target_shape = shape_utils.combined_static_and_dynamic_shape(y_target)
        if len(source_shape) == 2 and len(target_shape) == 1:
            _, num_neighbors = source_shape
            y_target = tf.tile(tf.expand_dims(y_target, -1),
                               multiples=[1, num_neighbors])
            x_target = tf.tile(tf.expand_dims(x_target, -1),
                               multiples=[1, num_neighbors])
        elif source_shape != target_shape:
            raise ValueError('Inconsistent source and target shape.')

    y_offset = y_target - y_source_floored
    x_offset = x_target - x_source_floored

    y_source_indices = tf.cast(y_source_floored, tf.int32)
    x_source_indices = tf.cast(x_source_floored, tf.int32)

    indices = tf.stack([y_source_indices, x_source_indices], axis=-1)
    offsets = tf.stack([y_offset, x_offset], axis=-1)
    return offsets, indices
예제 #10
0
def generate_detections_per_image_op(cls_outputs,
                                     box_outputs,
                                     anchor_boxes,
                                     image_id,
                                     image_info,
                                     num_detections=100,
                                     pre_nms_num_detections=1000,
                                     nms_threshold=0.3,
                                     bbox_reg_weights=(10., 10., 5., 5.)):
    """Generates detections with model outputs and anchors.

  Args:
    cls_outputs: a Tensor with shape [N, num_classes], which stacks class
      logit outputs on all feature levels. The N is the number of total anchors
      on all levels. The num_classes is the number of classes predicted by the
      model. Note that the cls_outputs should be the output of softmax().
    box_outputs: a Tensor with shape [N, num_classes*4], which stacks
      box regression outputs on all feature levels. The N is the number of total
      anchors on all levels.
    anchor_boxes: a Tensor with shape [N, 4], which stacks anchors on all
      feature levels. The N is the number of total anchors on all levels.
    image_id: an integer number to specify the image id.
    image_info: a tensor of shape [5] which encodes the input image's [height,
      width, scale, original_height, original_width]
    num_detections: Number of detections after NMS.
    pre_nms_num_detections: Number of candidates before NMS.
    nms_threshold: a float number to specify the threshold of NMS.
    bbox_reg_weights: a list of 4 float scalars, which are default weights on
      (dx, dy, dw, dh) for normalizing bbox regression targets.
  Returns:
    detections: detection results in a tensor with each row representing
      [image_id, ymin, xmin, ymax, xmax, score, class]
  """
    num_boxes, num_classes = cls_outputs.get_shape().as_list()

    # Removes background class scores.
    cls_outputs = cls_outputs[:, 1:num_classes]
    top_k_scores, top_k_indices_with_classes = tf.nn.top_k(
        tf.reshape(cls_outputs, [-1]), k=pre_nms_num_detections, sorted=True)
    classes = tf.mod(top_k_indices_with_classes, num_classes - 1)
    top_k_indices = tf.floordiv(top_k_indices_with_classes, num_classes - 1)

    anchor_boxes = tf.gather(anchor_boxes, top_k_indices)
    box_outputs = tf.reshape(box_outputs,
                             [num_boxes, num_classes, 4])[:, 1:num_classes, :]
    box_outputs = tf.gather_nd(box_outputs,
                               tf.stack([top_k_indices, classes], axis=1))

    # Applies bounding box regression to anchors.
    boxes = box_utils.batch_decode_box_outputs_op(
        tf.expand_dims(anchor_boxes, axis=0),
        tf.expand_dims(box_outputs, axis=0), bbox_reg_weights)[0]
    boxes = box_utils.clip_boxes(tf.expand_dims(boxes, axis=0),
                                 tf.expand_dims(image_info[:2], axis=0))[0]

    classes = tf.tile(tf.reshape(classes, [1, pre_nms_num_detections]),
                      [num_classes - 1, 1])
    scores = tf.tile(tf.reshape(top_k_scores, [1, pre_nms_num_detections]),
                     [num_classes - 1, 1])
    boxes = tf.tile(tf.reshape(boxes, [1, pre_nms_num_detections, 4]),
                    [num_classes - 1, 1, 1])

    class_bitmask = tf.tile(
        tf.reshape(tf.range(num_classes - 1), [num_classes - 1, 1]),
        [1, pre_nms_num_detections])
    scores = tf.where(tf.equal(classes, class_bitmask), scores,
                      tf.zeros_like(scores))
    scores = tf.where(tf.greater(scores, 0.05), scores, tf.zeros_like(scores))
    # Reshape classes to be compartible with the top_k function.
    classes = tf.reshape(classes, [num_classes - 1, pre_nms_num_detections, 1])
    scores, sorted_tensors = box_utils.top_k(scores,
                                             k=pre_nms_num_detections,
                                             tensors=[boxes, classes])
    boxes = sorted_tensors[0]
    classes = tf.reshape(sorted_tensors[1],
                         [num_classes - 1, pre_nms_num_detections])

    idx, num_valid = non_max_suppression.non_max_suppression_padded(
        scores,
        boxes,
        max_output_size=num_detections,
        iou_threshold=nms_threshold,
        level=0)

    post_nms_boxes = non_max_suppression.gather_boxes_by_indices(
        boxes, num_detections, idx, num_valid)
    post_nms_scores = non_max_suppression.gather_scores_by_indices(
        scores, num_detections, idx, num_valid)

    # Sorts all results.
    sorted_scores, sorted_indices = tf.nn.top_k(tf.to_float(
        tf.reshape(post_nms_scores, [-1])),
                                                k=num_detections,
                                                sorted=True)
    post_nms_boxes = tf.gather(tf.reshape(post_nms_boxes, [-1, 4]),
                               sorted_indices)
    classes = tf.batch_gather(classes, idx)
    post_nms_classes = tf.gather(tf.reshape(classes, [-1]), sorted_indices) + 1

    if isinstance(image_id, int):
        image_id = tf.constant(image_id)
    image_id = tf.reshape(image_id, [])
    detections_result = tf.stack([
        tf.to_float(tf.fill(tf.shape(sorted_scores), image_id)),
        post_nms_boxes[:, 0],
        post_nms_boxes[:, 1],
        post_nms_boxes[:, 2],
        post_nms_boxes[:, 3],
        sorted_scores,
        tf.to_float(post_nms_classes),
    ],
                                 axis=1)
    return detections_result
예제 #11
0
def beam_search(
    symbols_to_logits_fn,
    initial_ids,
    beam_size,
    decode_length,
    vocab_size,
    alpha,
    eos_id,
    batch_size,
    minimum_score=None,
):
    """Beam search with length penalties.

  Requires a function that can take the currently decoded sybmols and return
  the logits for the next symbol. The implementation is inspired by
  https://arxiv.org/abs/1609.08144.

  Args:
    symbols_to_logits_fn: Interface to the model, to provide logits.
        Shoud take [batch_size, decoded_ids] and return [batch_size, vocab_size]
    initial_ids: Ids to start off the decoding, this will be the first thing
        handed to symbols_to_logits_fn (after expanding to beam size)
        [batch_size]
    beam_size: Size of the beam.
    decode_length: Number of steps to decode for.
    vocab_size: Size of the vocab, must equal the size of the logits returned by
        symbols_to_logits_fn
    alpha: alpha for length penalty.
    eos_id: ID for end of sentence.
    batch_size: Integer specifying batch size. Should match first dimension of
        initial_ids.
    minimum_score: Minimum score to continue exploring beam. Used to optimize
        performance. Ignored if None or 0. Should be in range [0-1].
  Returns:
    Tuple of
    (decoded beams [batch_size, beam_size, decode_length]
     decoding probablities [batch_size, beam_size])
  """
    def _inner_loop(i, alive_seq, alive_log_probs, finished_seq,
                    finished_scores, finished_flags):
        """Inner beam seach loop.

    There are three groups of tensors, alive, finished, and topk.
    The alive group contains information about the current alive sequences
    The topk group contains information about alive + topk current decoded words
    the finished group contains information about finished sentences, that is,
    the ones that have decoded to <EOS>. These are what we return.
    The general beam search algorithm is as follows:
    While we haven't terminated (pls look at termination condition)
      1. Grow the current alive to get beam*2 topk sequences
      2. Among the topk, keep the top beam_size ones that haven't reached EOS
      into alive
      3. Among the topk, keep the top beam_size ones have reached EOS into
      finished
    Repeat
    To make things simple with using fixed size tensors, we will end
    up inserting unfinished sequences into finished in the beginning. To stop
    that we add -ve INF to the score of the unfinished sequence so that when a
    true finished sequence does appear, it will have a higher score than all the
    unfinished ones.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far.
          Shape is [batch_size, beam_size, decode_length + 1]
      alive_log_probs: probabilities of the beams. [batch_size, beam_size]
      finished_seq: Current finished sequences.
        [batch_size, beam_size, decode_length + 1]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]
      finished_flags: finished bools for each of these sequences.
        [batch_size, beam_size]

    Returns:
      Tuple of
        (Incremented loop index
         New alive sequences,
         Log probs of the alive sequences,
         New finished sequences,
         Scores of the new finished sequences,
         Flags inidicating which sequence in finished as reached EOS)
    """

        # Each inner loop, we carry out three steps:
        # 1. Get the current topk items.
        # 2. Extract the ones that have finished and haven't finished
        # 3. Recompute the contents of finished based on scores.
        topk_seq, topk_log_probs, topk_scores, topk_finished = _grow_topk(
            i,
            alive_seq,
            alive_log_probs,
            batch_size,
            beam_size,
            symbols_to_logits_fn,
            alpha,
            vocab_size,
            eos_id,
            decode_length,
        )
        alive_seq, alive_log_probs, _ = _grow_alive(topk_seq, topk_scores,
                                                    topk_log_probs,
                                                    topk_finished, batch_size,
                                                    beam_size)
        finished_seq, finished_scores, finished_flags = _grow_finished(
            finished_seq,
            finished_scores,
            finished_flags,
            topk_seq,
            topk_scores,
            topk_finished,
            batch_size,
            beam_size,
        )

        return (
            i + 1,
            alive_seq,
            alive_log_probs,
            finished_seq,
            finished_scores,
            finished_flags,
        )

    def _loop_cond(
        i,
        unused_alive_seq,
        alive_log_probs,
        unused_finished_seq,
        finished_scores,
        finished_in_finished,
    ):
        """Checking termination condition.

    We terminate when we decoded up to decode_length or the lowest scoring item
    in finished has a greater score that the higest prob item in alive divided
    by the max length penalty. Optionally also terminate if all alive scores
    are below lower bound.

    Args:
      i: loop index
      alive_log_probs: probabilities of the beams. [batch_size, beam_size]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]
      finished_in_finished: finished bools for each of these sequences.
        [batch_size, beam_size]

    Returns:
      True to continue the loop, False to stop.
    """
        max_length_penalty = tf.pow(((5.0 + tf.to_float(decode_length)) / 6.0),
                                    alpha)
        # The best possible score of the most likley alive sequence
        lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty

        # Now to compute the lowest score of a finished sequence in finished
        # If the sequence isn't finished, we multiply it's score by 0. since
        # scores are all -ve, taking the min will give us the score of the lowest
        # finished item.
        lowest_score_of_finished_in_finished = tf.reduce_min(
            finished_scores * tf.to_float(finished_in_finished), axis=1)
        # If none of the sequences have finished, then the min will be 0 and
        # we have to replace it by -ve INF if it is. The score of any seq in alive
        # will be much higher than -ve INF and the termination condition will not
        # be met.
        lowest_score_of_finished_in_finished = _apply_negative_infinity_mask(
            lowest_score_of_finished_in_finished,
            tf.logical_not(tf.reduce_any(finished_in_finished, 1)),
        )

        # Will terminate beam search early if bound_is_met is True.
        bound_is_met = tf.reduce_all(
            tf.greater(lowest_score_of_finished_in_finished,
                       lower_bound_alive_scores))

        # Check if all alive scores are below minimum.
        if minimum_score:
            minimum_score_log = tf.log(minimum_score)
            bound_is_met = tf.logical_or(
                bound_is_met,
                tf.reduce_all(
                    tf.less(lower_bound_alive_scores, minimum_score_log)),
            )

        return tf.logical_and(tf.less(i, decode_length),
                              tf.logical_not(bound_is_met))

    # Assume initial_ids are prob 1.0
    initial_log_probs = tf.constant([[0.0] + [-float("inf")] * (beam_size - 1)
                                     ])
    # Expand size to [batch_size, beam_size].
    alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1])

    # Expand size to [batch_size, beam_size, decode_length + 1]
    alive_seq = tf.expand_dims(initial_ids, 1)  # [batch_size, 1]
    alive_seq = tf.tile(alive_seq, [1, beam_size])  # [batch_size, beam_size]
    alive_seq = _one_hot_tensor_3d(alive_seq, 0, decode_length + 1)

    # Finished will keep track of all the sequences that have finished so far
    # Finished log probs will be negative infinity in the beginning
    # finished_flags will keep track of booleans
    finished_seq = tf.zeros(tf.shape(alive_seq), tf.int32)
    # Setting the scores of the initial to negative infinity.
    finished_scores = tf.ones([batch_size, beam_size], dtype=tf.float32) * -INF
    finished_flags = tf.zeros([batch_size, beam_size], tf.bool)

    initial_variables = [
        tf.constant(0),  # []
        alive_seq,  # [batch_size, beam_size, decode_length + 1]
        alive_log_probs,  # [batch_size, beam_size]
        finished_seq,  # [batch_size, beam_size, decode_length + 1]
        finished_scores,  # [batch_size, beam_size]
        finished_flags,  # [batch_size, beam_size]
    ]

    # Execute while loop.
    (
        _,
        alive_seq,
        alive_log_probs,
        finished_seq,
        finished_scores,
        finished_flags,
    ) = tf.while_loop(
        _loop_cond,
        _inner_loop,
        initial_variables,
        parallel_iterations=1,
        back_prop=False,
    )

    # Accounting for corner case: It's possible that no sequence in alive for a
    # particular batch item ever reached EOS. In that case, we should just copy
    # the contents of alive for that batch item. tf.reduce_any(finished_flags, 1)
    # if 0, means that no sequence for that batch index had reached EOS. We need
    # to do the same for the scores as well.
    finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq,
                            alive_seq)
    finished_scores = tf.where(tf.reduce_any(finished_flags, 1),
                               finished_scores, alive_log_probs)
    finished_seq = tf.cast(finished_seq, tf.int64)
    return finished_seq, finished_scores
예제 #12
0
def tf_rotation_resampling(voxel_array,
                           transformation_matrix,
                           params,
                           Scale_matrix=None,
                           size=64,
                           new_size=128):
    """
    Batch transformation and resampling function
    :param voxel_array: batch of voxels. Shape = [batch_size, height, width, depth, features]
    :param transformation_matrix: Rotation matrix. Shape = [batch_size, height, width, depth, features]
    :param size: original size of the voxel array
    :param new_size: size of the resampled array
    :return: transformed voxel array
    """
    batch_size = tf.shape(voxel_array)[0]
    n_channels = voxel_array.get_shape()[4].value
    target = tf.zeros([batch_size, new_size, new_size, new_size])
    #Aligning the centroid of the object (voxel grid) to origin for rotation,
    #then move the centroid back to the original position of the grid centroid
    T = tf.constant([[1, 0, 0, -size * 0.5], [0, 1, 0, -size * 0.5],
                     [0, 0, 1, -size * 0.5], [0, 0, 0, 1]])
    # add one more dimension to T and then tile
    T = tf.tile(tf.reshape(T, (1, 4, 4)), [batch_size, 1, 1])

    # However, since the rotated grid might be out of bound for the original grid size,
    # move the rotated grid to a new bigger grid
    T_new_inv = tf.constant([[1, 0, 0, new_size * 0.5],
                             [0, 1, 0, new_size * 0.5],
                             [0, 0, 1, new_size * 0.5], [0, 0, 0, 1]])
    T_new_inv = tf.tile(tf.reshape(T_new_inv, (1, 4, 4)), [batch_size, 1, 1])

    # Add the actual shifting in x and y dimension accoding to input param
    x_shift = tf.reshape(params[:, 3], (batch_size, 1, 1))
    y_shift = tf.reshape(params[:, 4], (batch_size, 1, 1))
    z_shift = tf.reshape(params[:, 5], (batch_size, 1, 1))
    # ========================================================
    # Because tensorflow does not allow tensor item replacement
    # A new matrix needs to be created from scratch by concatenating different vectors into rows and stacking them up
    ones = tf.ones_like(x_shift)
    zeros = tf.zeros_like(x_shift)

    T_translate = tf.concat([
        tf.concat([ones, zeros, zeros, x_shift], axis=2),
        tf.concat([zeros, ones, zeros, y_shift], axis=2),
        tf.concat([zeros, zeros, ones, z_shift], axis=2),
        tf.concat([zeros, zeros, zeros, ones], axis=2)
    ],
                            axis=1)
    total_M = tf.matmul(
        tf.matmul(tf.matmul(tf.matmul(T_new_inv, T_translate), Scale_matrix),
                  transformation_matrix), T)

    try:
        total_M = tf.matrix_inverse(total_M)

        total_M = total_M[:, 0:
                          3, :]  #Ignore the homogenous coordinate so the results are 3D vectors. shape: (batch * 3 * 4)
        grid = tf_voxel_meshgrid(new_size,
                                 new_size,
                                 new_size,
                                 homogeneous=True)
        # here you created new_size^3 grid, but the T matrix just translate this by size * 0.5, this will not align the grid to origin point
        # shape: (4 * new_size^3), here 4 is 3 + homogeneous=True
        grid = tf.tile(
            tf.reshape(grid, (1, tf.to_int32(
                grid.get_shape()[0]), tf.to_int32(grid.get_shape()[1]))),
            [batch_size, 1, 1])
        grid_transform = tf.matmul(
            total_M, grid
        )  # (batch * 3 * 4) matmul (batch * 4 * new_size^3) is (3 * 4) matmul (4 * new_size^3) along batch
        x_s_flat = tf.reshape(grid_transform[:, 0, :], [-1])
        y_s_flat = tf.reshape(grid_transform[:, 1, :], [-1])
        z_s_flat = tf.reshape(grid_transform[:, 2, :], [-1])
        input_transformed = tf_interpolate(
            voxel_array, x_s_flat, y_s_flat, z_s_flat,
            [batch_size, new_size, new_size, new_size, n_channels])
        target = tf.reshape(
            input_transformed,
            [batch_size, new_size, new_size, new_size, n_channels])

        return target, grid_transform
    except tf.InvalidArgumentError:
        return None
예제 #13
0
def _model_fn(features, labels, mode, params, model, variable_filter_fn=None):
    """Model definition entry.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the model outputs class logits and box regression outputs.
    variable_filter_fn: the filter function that takes trainable_variables and
      returns the variable list after applying the filter rule.

  Returns:
    tpu_spec: the TPUEstimatorSpec to run training, evaluation, or prediction.

  Raises:
    RuntimeError: if both ckpt and backbone_ckpt are set.
  """
    # Convert params (dict) to Config for easier access.
    training_hooks = None
    if params['data_format'] == 'channels_first':
        features = tf.transpose(features, [0, 3, 1, 2])

    def _model_outputs(inputs):
        return model(inputs, config=hparams_config.Config(params))

    cls_outputs, box_outputs = utils.build_model_with_precision(
        params['precision'], _model_outputs, features)

    levels = cls_outputs.keys()
    for level in levels:
        cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
        box_outputs[level] = tf.cast(box_outputs[level], tf.float32)

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            'image': features,
        }
        for level in levels:
            predictions['cls_outputs_%d' % level] = cls_outputs[level]
            predictions['box_outputs_%d' % level] = box_outputs[level]
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_or_create_global_step()
    learning_rate = learning_rate_schedule(params, global_step)

    # cls_loss and box_loss are for logging. only total_loss is optimized.
    det_loss, cls_loss, box_loss, box_iou_loss = detection_loss(
        cls_outputs, box_outputs, labels, params)
    reg_l2loss = reg_l2_loss(params['weight_decay'])
    total_loss = det_loss + reg_l2loss

    if mode == tf.estimator.ModeKeys.TRAIN:
        utils.scalar('lrn_rate', learning_rate)
        utils.scalar('trainloss/cls_loss', cls_loss)
        utils.scalar('trainloss/box_loss', box_loss)
        utils.scalar('trainloss/box_iou_loss', box_iou_loss)
        utils.scalar('trainloss/det_loss', det_loss)
        utils.scalar('trainloss/reg_l2_loss', reg_l2loss)
        utils.scalar('trainloss/loss', total_loss)

    moving_average_decay = params['moving_average_decay']
    if moving_average_decay:
        ema = tf.train.ExponentialMovingAverage(decay=moving_average_decay,
                                                num_updates=global_step)
        ema_vars = utils.get_ema_vars()
    if params['strategy'] == 'horovod':
        import horovod.tensorflow as hvd  # pylint: disable=g-import-not-at-top
        learning_rate = learning_rate * hvd.size()
    if mode == tf.estimator.ModeKeys.TRAIN:
        if params['optimizer'].lower() == 'sgd':
            optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                   momentum=params['momentum'])
        elif params['optimizer'].lower() == 'adam':
            optimizer = tf.train.AdamOptimizer(learning_rate)
        else:
            raise ValueError('optimizers should be adam or sgd')
        if params['strategy'] == 'tpu':
            optimizer = tf.tpu.CrossShardOptimizer(optimizer)
        elif params['strategy'] == 'horovod':
            optimizer = hvd.DistributedOptimizer(optimizer)
            training_hooks = [hvd.BroadcastGlobalVariablesHook(0)]

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        var_list = tf.trainable_variables()
        if variable_filter_fn:
            var_list = variable_filter_fn(var_list)

        if params.get('clip_gradients_norm', 0) > 0:
            logging.info('clip gradients norm by %f',
                         params['clip_gradients_norm'])
            grads_and_vars = optimizer.compute_gradients(
                total_loss,
                var_list,
                aggregation_method=tf.AggregationMethod.
                EXPERIMENTAL_ACCUMULATE_N)
            with tf.name_scope('clip'):
                grads = [gv[0] for gv in grads_and_vars]
                tvars = [gv[1] for gv in grads_and_vars]
                clipped_grads, gnorm = tf.clip_by_global_norm(
                    grads, params['clip_gradients_norm'])
                utils.scalar('gnorm', gnorm)
                grads_and_vars = list(zip(clipped_grads, tvars))

            with tf.control_dependencies(update_ops):
                train_op = optimizer.apply_gradients(grads_and_vars,
                                                     global_step)
        else:
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(
                    total_loss,
                    global_step,
                    var_list=var_list,
                    aggregation_method=tf.AggregationMethod.
                    EXPERIMENTAL_ACCUMULATE_N)

        if moving_average_decay:
            with tf.control_dependencies([train_op]):
                train_op = ema.apply(ema_vars)

    else:
        train_op = None

    eval_metrics = None
    if mode == tf.estimator.ModeKeys.EVAL:

        def metric_fn(**kwargs):
            """Returns a dictionary that has the evaluation metrics."""
            batch_size = params['batch_size']
            if params['strategy'] == 'tpu':
                batch_size = params['batch_size'] * params['num_shards']
            eval_anchors = anchors.Anchors(params['min_level'],
                                           params['max_level'],
                                           params['num_scales'],
                                           params['aspect_ratios'],
                                           params['anchor_scale'],
                                           params['image_size'])
            anchor_labeler = anchors.AnchorLabeler(eval_anchors,
                                                   params['num_classes'])
            cls_loss = tf.metrics.mean(kwargs['cls_loss_repeat'])
            box_loss = tf.metrics.mean(kwargs['box_loss_repeat'])

            if params.get('testdev_dir', None):
                logging.info('Eval testdev_dir %s', params['testdev_dir'])
                coco_metrics = coco_metric_fn(
                    batch_size,
                    anchor_labeler,
                    params['val_json_file'],
                    testdev_dir=params['testdev_dir'],
                    disable_pyfun=params.get('disable_pyfun', None),
                    **kwargs)
            else:
                logging.info('Eval val with groudtruths %s.',
                             params['val_json_file'])
                coco_metrics = coco_metric_fn(batch_size, anchor_labeler,
                                              params['val_json_file'],
                                              **kwargs)

            # Add metrics to output.
            output_metrics = {
                'cls_loss': cls_loss,
                'box_loss': box_loss,
            }
            output_metrics.update(coco_metrics)
            return output_metrics

        cls_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(cls_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        box_loss_repeat = tf.reshape(
            tf.tile(tf.expand_dims(box_loss, 0), [
                params['batch_size'],
            ]), [params['batch_size'], 1])
        metric_fn_inputs = {
            'cls_loss_repeat': cls_loss_repeat,
            'box_loss_repeat': box_loss_repeat,
            'source_ids': labels['source_ids'],
            'groundtruth_data': labels['groundtruth_data'],
            'image_scales': labels['image_scales'],
        }
        add_metric_fn_inputs(params, cls_outputs, box_outputs,
                             metric_fn_inputs)
        eval_metrics = (metric_fn, metric_fn_inputs)

    checkpoint = params.get('ckpt') or params.get('backbone_ckpt')

    if checkpoint and mode == tf.estimator.ModeKeys.TRAIN:
        # Initialize the model from an EfficientDet or backbone checkpoint.
        if params.get('ckpt') and params.get('backbone_ckpt'):
            raise RuntimeError(
                '--backbone_ckpt and --checkpoint are mutually exclusive')

        if params.get('backbone_ckpt'):
            var_scope = params['backbone_name'] + '/'
            if params['ckpt_var_scope'] is None:
                # Use backbone name as default checkpoint scope.
                ckpt_scope = params['backbone_name'] + '/'
            else:
                ckpt_scope = params['ckpt_var_scope'] + '/'
        else:
            # Load every var in the given checkpoint
            var_scope = ckpt_scope = '/'

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            logging.info('restore variables from %s', checkpoint)

            var_map = utils.get_ckpt_var_map(ckpt_path=checkpoint,
                                             ckpt_scope=ckpt_scope,
                                             var_scope=var_scope,
                                             var_exclude_expr=params.get(
                                                 'var_exclude_expr', None))

            tf.train.init_from_checkpoint(checkpoint, var_map)

            return tf.train.Scaffold()
    elif mode == tf.estimator.ModeKeys.EVAL and moving_average_decay:

        def scaffold_fn():
            """Load moving average variables for eval."""
            logging.info('Load EMA vars with ema_decay=%f',
                         moving_average_decay)
            restore_vars_dict = ema.variables_to_restore(ema_vars)
            saver = tf.train.Saver(restore_vars_dict)
            return tf.train.Scaffold(saver=saver)
    else:
        scaffold_fn = None

    return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                             loss=total_loss,
                                             train_op=train_op,
                                             eval_metrics=eval_metrics,
                                             host_call=utils.get_tpu_host_call(
                                                 global_step, params),
                                             scaffold_fn=scaffold_fn,
                                             training_hooks=training_hooks)
예제 #14
0
def add_metric_fn_inputs(params,
                         cls_outputs,
                         box_outputs,
                         metric_fn_inputs,
                         max_detection_points=anchors.MAX_DETECTION_POINTS):
    """Selects top-k predictions and adds the selected to metric_fn_inputs.

  Args:
    params: a parameter dictionary that includes `min_level`, `max_level`,
      `batch_size`, and `num_classes`.
    cls_outputs: an OrderDict with keys representing levels and values
      representing logits in [batch_size, height, width, num_anchors].
    box_outputs: an OrderDict with keys representing levels and values
      representing box regression targets in
      [batch_size, height, width, num_anchors * 4].
    metric_fn_inputs: a dictionary that will hold the top-k selections.
    max_detection_points: an integer specifing the maximum detection points to
      keep before NMS. Keep all anchors if max_detection_points <= 0.
  """
    batch_size = params['batch_size']
    num_classes = params['num_classes']
    cls_outputs_all = []
    box_outputs_all = []
    # Concatenates class and box of all levels into one tensor.
    for level in range(params['min_level'], params['max_level'] + 1):
        if params['data_format'] == 'channels_first':
            cls_outputs[level] = tf.transpose(cls_outputs[level], [0, 2, 3, 1])
            box_outputs[level] = tf.transpose(box_outputs[level], [0, 2, 3, 1])

        cls_outputs_all.append(
            tf.reshape(cls_outputs[level], [batch_size, -1, num_classes]))
        box_outputs_all.append(
            tf.reshape(box_outputs[level], [batch_size, -1, 4]))
    cls_outputs_all = tf.concat(cls_outputs_all, 1)
    box_outputs_all = tf.concat(box_outputs_all, 1)

    if max_detection_points > 0:
        # Prune anchors and detections to only keep max_detection_points.
        # Due to some issues, top_k is currently slow in graph model.
        cls_outputs_all_reshape = tf.reshape(cls_outputs_all, [batch_size, -1])
        _, cls_topk_indices = tf.math.top_k(cls_outputs_all_reshape,
                                            k=max_detection_points,
                                            sorted=False)
        indices = cls_topk_indices // num_classes
        classes = cls_topk_indices % num_classes
        cls_indices = tf.stack([indices, classes], axis=2)
        cls_outputs_all_after_topk = tf.gather_nd(cls_outputs_all,
                                                  cls_indices,
                                                  batch_dims=1)
        box_outputs_all_after_topk = tf.gather_nd(box_outputs_all,
                                                  tf.expand_dims(indices, 2),
                                                  batch_dims=1)
    else:
        # Keep all anchors, but for each anchor, just keep the max probablity for
        # each class.
        cls_outputs_idx = tf.math.argmax(cls_outputs_all,
                                         axis=-1,
                                         output_type=tf.int32)
        num_anchors = cls_outputs_all.shape[1]

        classes = cls_outputs_idx
        indices = tf.tile(tf.expand_dims(tf.range(num_anchors), axis=0),
                          [batch_size, 1])
        cls_outputs_all_after_topk = tf.reduce_max(cls_outputs_all, -1)
        box_outputs_all_after_topk = box_outputs_all

    metric_fn_inputs['cls_outputs_all'] = cls_outputs_all_after_topk
    metric_fn_inputs['box_outputs_all'] = box_outputs_all_after_topk
    metric_fn_inputs['indices_all'] = indices
    metric_fn_inputs['classes_all'] = classes
예제 #15
0
파일: model.py 프로젝트: vanton/magenta
def build_genie_model(feat_dict,
                      cfg,
                      batch_size,
                      seq_len,
                      is_training=True,
                      seq_varlens=None,
                      dtype=tf.float32):
    """Builds a Piano Genie model.

  Args:
    feat_dict: Dictionary containing input tensors.
    cfg: Configuration object.
    batch_size: Number of items in batch.
    seq_len: Length of each batch item.
    is_training: Set to False for evaluation.
    seq_varlens: If not None, a tensor with the batch sequence lengths.
    dtype: Model weight type.

  Returns:
    A dict containing tensors for relevant model config.
  """
    out_dict = {}

    # Parse features
    pitches = util.demidify(feat_dict["midi_pitches"])
    velocities = feat_dict["velocities"]
    pitches_scalar = ((tf.cast(pitches, tf.float32) / 87.) * 2.) - 1.

    # Create sequence lens
    if is_training and cfg.train_randomize_seq_len:
        seq_lens = tf.random_uniform([batch_size],
                                     minval=cfg.train_seq_len_min,
                                     maxval=seq_len + 1,
                                     dtype=tf.int32)
        stp_varlen_mask = tf.sequence_mask(seq_lens,
                                           maxlen=seq_len,
                                           dtype=tf.float32)
    elif seq_varlens is not None:
        seq_lens = seq_varlens
        stp_varlen_mask = tf.sequence_mask(seq_varlens,
                                           maxlen=seq_len,
                                           dtype=tf.float32)
    else:
        seq_lens = tf.ones([batch_size], dtype=tf.int32) * seq_len
        stp_varlen_mask = None

    # Encode
    if (cfg.stp_emb_unconstrained or cfg.stp_emb_vq or cfg.stp_emb_iq
            or cfg.seq_emb_unconstrained or cfg.seq_emb_vae
            or cfg.lor_emb_unconstrained):
        # Build encoder features
        enc_feats = []
        if cfg.enc_pitch_scalar:
            enc_feats.append(tf.expand_dims(pitches_scalar, axis=-1))
        else:
            enc_feats.append(tf.one_hot(pitches, 88))
        if "delta_times_int" in cfg.enc_aux_feats:
            enc_feats.append(
                tf.one_hot(feat_dict["delta_times_int"],
                           cfg.data_max_discrete_times + 1))
        if "velocities" in cfg.enc_aux_feats:
            enc_feats.append(
                tf.one_hot(velocities, cfg.data_max_discrete_velocities + 1))
        enc_feats = tf.concat(enc_feats, axis=2)

        with tf.variable_scope("encoder"):
            enc_stp, enc_seq = simple_lstm_encoder(
                enc_feats,
                seq_lens,
                rnn_celltype=cfg.rnn_celltype,
                rnn_nlayers=cfg.rnn_nlayers,
                rnn_nunits=cfg.rnn_nunits,
                rnn_bidirectional=cfg.enc_rnn_bidirectional,
                dtype=dtype)

    latents = []

    # Step embeddings (single vector per timestep)
    if cfg.stp_emb_unconstrained:
        with tf.variable_scope("stp_emb_unconstrained"):
            stp_emb_unconstrained = tf.layers.dense(
                enc_stp, cfg.stp_emb_unconstrained_embedding_dim)

        out_dict["stp_emb_unconstrained"] = stp_emb_unconstrained
        latents.append(stp_emb_unconstrained)

    # Quantized step embeddings with VQ-VAE
    if cfg.stp_emb_vq:
        import sonnet as snt  # pylint:disable=g-import-not-at-top,import-outside-toplevel
        with tf.variable_scope("stp_emb_vq"):
            with tf.variable_scope("pre_vq"):
                # pre_vq_encoding is tf.float32 of [batch_size, seq_len, embedding_dim]
                pre_vq_encoding = tf.layers.dense(enc_stp,
                                                  cfg.stp_emb_vq_embedding_dim)

            with tf.variable_scope("quantizer"):
                assert stp_varlen_mask is None
                vq_vae = snt.nets.VectorQuantizer(
                    embedding_dim=cfg.stp_emb_vq_embedding_dim,
                    num_embeddings=cfg.stp_emb_vq_codebook_size,
                    commitment_cost=cfg.stp_emb_vq_commitment_cost)
                vq_vae_output = vq_vae(pre_vq_encoding,
                                       is_training=is_training)

                stp_emb_vq_quantized = vq_vae_output["quantize"]
                stp_emb_vq_discrete = tf.reshape(
                    tf.argmax(vq_vae_output["encodings"],
                              axis=1,
                              output_type=tf.int32), [batch_size, seq_len])
                stp_emb_vq_codebook = tf.transpose(vq_vae.embeddings)

        out_dict["stp_emb_vq_quantized"] = stp_emb_vq_quantized
        out_dict["stp_emb_vq_discrete"] = stp_emb_vq_discrete
        out_dict["stp_emb_vq_loss"] = vq_vae_output["loss"]
        out_dict["stp_emb_vq_codebook"] = stp_emb_vq_codebook
        out_dict["stp_emb_vq_codebook_ppl"] = vq_vae_output["perplexity"]
        latents.append(stp_emb_vq_quantized)

        # This tensor retrieves continuous embeddings from codebook. It should
        # *never* be used during training.
        out_dict["stp_emb_vq_quantized_lookup"] = tf.nn.embedding_lookup(
            stp_emb_vq_codebook, stp_emb_vq_discrete)

    # Integer-quantized step embeddings with straight-through
    if cfg.stp_emb_iq:
        with tf.variable_scope("stp_emb_iq"):
            with tf.variable_scope("pre_iq"):
                # pre_iq_encoding is tf.float32 of [batch_size, seq_len]
                pre_iq_encoding = tf.layers.dense(enc_stp, 1)[:, :, 0]

            def iqst(x, n):
                """Integer quantization with straight-through estimator."""
                eps = 1e-7
                s = float(n - 1)
                xp = tf.clip_by_value((x + 1) / 2.0, -eps, 1 + eps)
                xpp = tf.round(s * xp)
                xppp = 2 * (xpp / s) - 1
                return xpp, x + tf.stop_gradient(xppp - x)

            with tf.variable_scope("quantizer"):
                # Pass rounded vals to decoder w/ straight-through estimator
                stp_emb_iq_discrete_f, stp_emb_iq_discrete_rescaled = iqst(
                    pre_iq_encoding, cfg.stp_emb_iq_nbins)
                stp_emb_iq_discrete = tf.cast(stp_emb_iq_discrete_f + 1e-4,
                                              tf.int32)
                stp_emb_iq_discrete_f = tf.cast(stp_emb_iq_discrete,
                                                tf.float32)
                stp_emb_iq_quantized = tf.expand_dims(
                    stp_emb_iq_discrete_rescaled, axis=2)

                # Determine which elements round to valid indices
                stp_emb_iq_inrange = tf.logical_and(
                    tf.greater_equal(pre_iq_encoding, -1),
                    tf.less_equal(pre_iq_encoding, 1))
                stp_emb_iq_inrange_mask = tf.cast(stp_emb_iq_inrange,
                                                  tf.float32)
                stp_emb_iq_valid_p = weighted_avg(stp_emb_iq_inrange_mask,
                                                  stp_varlen_mask)

                # Regularize to encourage encoder to output in range
                stp_emb_iq_range_penalty = weighted_avg(
                    tf.square(tf.maximum(tf.abs(pre_iq_encoding) - 1, 0)),
                    stp_varlen_mask)

                # Regularize to correlate latent finite differences to input
                stp_emb_iq_dlatents = pre_iq_encoding[:,
                                                      1:] - pre_iq_encoding[:, :
                                                                            -1]
                if cfg.stp_emb_iq_contour_dy_scalar:
                    stp_emb_iq_dnotes = pitches_scalar[:,
                                                       1:] - pitches_scalar[:, :
                                                                            -1]
                else:
                    stp_emb_iq_dnotes = tf.cast(
                        pitches[:, 1:] - pitches[:, :-1], tf.float32)
                if cfg.stp_emb_iq_contour_exp == 1:
                    power_func = tf.identity
                elif cfg.stp_emb_iq_contour_exp == 2:
                    power_func = tf.square
                else:
                    raise NotImplementedError()
                if cfg.stp_emb_iq_contour_comp == "product":
                    comp_func = tf.multiply
                elif cfg.stp_emb_iq_contour_comp == "quotient":
                    comp_func = lambda x, y: tf.divide(x, y + 1e-6)
                else:
                    raise NotImplementedError()

                stp_emb_iq_contour_penalty = weighted_avg(
                    power_func(
                        tf.maximum(
                            cfg.stp_emb_iq_contour_margin -
                            comp_func(stp_emb_iq_dnotes, stp_emb_iq_dlatents),
                            0)),
                    None if stp_varlen_mask is None else stp_varlen_mask[:,
                                                                         1:])

                # Regularize to maintain note consistency
                stp_emb_iq_note_held = tf.cast(
                    tf.equal(pitches[:, 1:] - pitches[:, :-1], 0), tf.float32)
                if cfg.stp_emb_iq_deviate_exp == 1:
                    power_func = tf.abs
                elif cfg.stp_emb_iq_deviate_exp == 2:
                    power_func = tf.square

                if stp_varlen_mask is None:
                    mask = stp_emb_iq_note_held
                else:
                    mask = stp_varlen_mask[:, 1:] * stp_emb_iq_note_held
                stp_emb_iq_deviate_penalty = weighted_avg(
                    power_func(stp_emb_iq_dlatents), mask)

                # Calculate perplexity of discrete encoder posterior
                if stp_varlen_mask is None:
                    mask = stp_emb_iq_inrange_mask
                else:
                    mask = stp_varlen_mask * stp_emb_iq_inrange_mask
                stp_emb_iq_discrete_oh = tf.one_hot(stp_emb_iq_discrete,
                                                    cfg.stp_emb_iq_nbins)
                stp_emb_iq_avg_probs = weighted_avg(stp_emb_iq_discrete_oh,
                                                    mask,
                                                    axis=[0, 1],
                                                    expand_mask=True)
                stp_emb_iq_discrete_ppl = tf.exp(
                    -tf.reduce_sum(stp_emb_iq_avg_probs *
                                   tf.log(stp_emb_iq_avg_probs + 1e-10)))

        out_dict["stp_emb_iq_quantized"] = stp_emb_iq_quantized
        out_dict["stp_emb_iq_discrete"] = stp_emb_iq_discrete
        out_dict["stp_emb_iq_valid_p"] = stp_emb_iq_valid_p
        out_dict["stp_emb_iq_range_penalty"] = stp_emb_iq_range_penalty
        out_dict["stp_emb_iq_contour_penalty"] = stp_emb_iq_contour_penalty
        out_dict["stp_emb_iq_deviate_penalty"] = stp_emb_iq_deviate_penalty
        out_dict["stp_emb_iq_discrete_ppl"] = stp_emb_iq_discrete_ppl
        latents.append(stp_emb_iq_quantized)

        # This tensor converts discrete values to continuous.
        # It should *never* be used during training.
        out_dict["stp_emb_iq_quantized_lookup"] = tf.expand_dims(
            2. * (stp_emb_iq_discrete_f / (cfg.stp_emb_iq_nbins - 1.)) - 1.,
            axis=2)

    # Sequence embedding (single vector per sequence)
    if cfg.seq_emb_unconstrained:
        with tf.variable_scope("seq_emb_unconstrained"):
            seq_emb_unconstrained = tf.layers.dense(
                enc_seq, cfg.seq_emb_unconstrained_embedding_dim)

        out_dict["seq_emb_unconstrained"] = seq_emb_unconstrained

        seq_emb_unconstrained = tf.stack([seq_emb_unconstrained] * seq_len,
                                         axis=1)
        latents.append(seq_emb_unconstrained)

    # Sequence embeddings (variational w/ reparameterization trick)
    if cfg.seq_emb_vae:
        with tf.variable_scope("seq_emb_vae"):
            seq_emb_vae = tf.layers.dense(enc_seq,
                                          cfg.seq_emb_vae_embedding_dim * 2)

            mean = seq_emb_vae[:, :cfg.seq_emb_vae_embedding_dim]
            stddev = 1e-6 + tf.nn.softplus(
                seq_emb_vae[:, cfg.seq_emb_vae_embedding_dim:])
            seq_emb_vae = mean + stddev * tf.random_normal(
                tf.shape(mean), 0, 1, dtype=dtype)

            kl = tf.reduce_mean(
                0.5 * tf.reduce_sum(tf.square(mean) + tf.square(stddev) -
                                    tf.log(1e-8 + tf.square(stddev)) - 1,
                                    axis=1))

        out_dict["seq_emb_vae"] = seq_emb_vae
        out_dict["seq_emb_vae_kl"] = kl

        seq_emb_vae = tf.stack([seq_emb_vae] * seq_len, axis=1)
        latents.append(seq_emb_vae)

    # Low-rate embeddings
    if cfg.lor_emb_unconstrained:
        assert seq_len % cfg.lor_emb_n == 0

        with tf.variable_scope("lor_emb_unconstrained"):
            # Downsample step embeddings
            rnn_embedding_dim = int(enc_stp.get_shape()[-1])
            enc_lor = tf.reshape(enc_stp, [
                batch_size, seq_len // cfg.lor_emb_n,
                cfg.lor_emb_n * rnn_embedding_dim
            ])
            lor_emb_unconstrained = tf.layers.dense(
                enc_lor, cfg.lor_emb_unconstrained_embedding_dim)

            out_dict["lor_emb_unconstrained"] = lor_emb_unconstrained

            # Upsample lo-rate embeddings for decoding
            lor_emb_unconstrained = tf.expand_dims(lor_emb_unconstrained,
                                                   axis=2)
            lor_emb_unconstrained = tf.tile(lor_emb_unconstrained,
                                            [1, 1, cfg.lor_emb_n, 1])
            lor_emb_unconstrained = tf.reshape(
                lor_emb_unconstrained,
                [batch_size, seq_len, cfg.lor_emb_unconstrained_embedding_dim])

            latents.append(lor_emb_unconstrained)

    # Build decoder features
    dec_feats = latents

    if cfg.dec_autoregressive:
        # Retrieve pitch numbers
        curr_pitches = pitches
        last_pitches = curr_pitches[:, :-1]
        last_pitches = tf.pad(last_pitches, [[0, 0], [1, 0]],
                              constant_values=-1)  # Prepend <SOS> token
        out_dict["dec_last_pitches"] = last_pitches
        dec_feats.append(tf.one_hot(last_pitches + 1, 89))

        if cfg.dec_pred_velocity:
            curr_velocities = velocities
            last_velocities = curr_velocities[:, :-1]
            last_velocities = tf.pad(last_velocities, [[0, 0], [1, 0]])
            dec_feats.append(
                tf.one_hot(last_velocities,
                           cfg.data_max_discrete_velocities + 1))

    if "delta_times_int" in cfg.dec_aux_feats:
        dec_feats.append(
            tf.one_hot(feat_dict["delta_times_int"],
                       cfg.data_max_discrete_times + 1))
    if "velocities" in cfg.dec_aux_feats:
        assert not cfg.dec_pred_velocity
        dec_feats.append(
            tf.one_hot(feat_dict["velocities"],
                       cfg.data_max_discrete_velocities + 1))

    assert dec_feats
    dec_feats = tf.concat(dec_feats, axis=2)

    # Decode
    with tf.variable_scope("decoder"):
        dec_stp, dec_initial_state, dec_final_state = simple_lstm_decoder(
            dec_feats,
            seq_lens,
            batch_size,
            rnn_celltype=cfg.rnn_celltype,
            rnn_nlayers=cfg.rnn_nlayers,
            rnn_nunits=cfg.rnn_nunits)

        with tf.variable_scope("pitches"):
            dec_recons_logits = tf.layers.dense(dec_stp, 88)

        dec_recons_loss = weighted_avg(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=dec_recons_logits, labels=pitches), stp_varlen_mask)

        out_dict["dec_initial_state"] = dec_initial_state
        out_dict["dec_final_state"] = dec_final_state
        out_dict["dec_recons_logits"] = dec_recons_logits
        out_dict["dec_recons_scores"] = tf.nn.softmax(dec_recons_logits,
                                                      axis=-1)
        out_dict["dec_recons_preds"] = tf.argmax(dec_recons_logits,
                                                 output_type=tf.int32,
                                                 axis=-1)
        out_dict["dec_recons_midi_preds"] = util.remidify(
            out_dict["dec_recons_preds"])
        out_dict["dec_recons_loss"] = dec_recons_loss

        if cfg.dec_pred_velocity:
            with tf.variable_scope("velocities"):
                dec_recons_velocity_logits = tf.layers.dense(
                    dec_stp, cfg.data_max_discrete_velocities + 1)

            dec_recons_velocity_loss = weighted_avg(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    logits=dec_recons_velocity_logits, labels=velocities),
                stp_varlen_mask)

            out_dict["dec_recons_velocity_logits"] = dec_recons_velocity_logits
            out_dict["dec_recons_velocity_loss"] = dec_recons_velocity_loss

    # Stats
    if cfg.stp_emb_vq or cfg.stp_emb_iq:
        discrete = out_dict["stp_emb_vq_discrete" if cfg.
                            stp_emb_vq else "stp_emb_iq_discrete"]
        dx = pitches[:, 1:] - pitches[:, :-1]
        dy = discrete[:, 1:] - discrete[:, :-1]
        contour_violation = tf.reduce_mean(
            tf.cast(tf.less(dx * dy, 0), tf.float32))

        dx_hold = tf.equal(dx, 0)
        deviate_violation = weighted_avg(
            tf.cast(tf.not_equal(dy, 0), tf.float32),
            tf.cast(dx_hold, tf.float32))

        out_dict["contour_violation"] = contour_violation
        out_dict["deviate_violation"] = deviate_violation

    return out_dict
예제 #16
0
def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
                                  max_number_of_boxes):
    """Extracts groundtruth data from detection_model and prepares it for eval.

  Args:
    detection_model: A `DetectionModel` object.
    class_agnostic: Whether the detections are class_agnostic.
    max_number_of_boxes: Max number of groundtruth boxes.

  Returns:
    A tuple of:
    groundtruth: Dictionary with the following fields:
      'groundtruth_boxes': [batch_size, num_boxes, 4] float32 tensor of boxes,
        in normalized coordinates.
      'groundtruth_classes': [batch_size, num_boxes] int64 tensor of 1-indexed
        classes.
      'groundtruth_masks': 4D float32 tensor of instance masks (if provided in
        groundtruth)
      'groundtruth_is_crowd': [batch_size, num_boxes] bool tensor indicating
        is_crowd annotations (if provided in groundtruth).
      'groundtruth_area': [batch_size, num_boxes] float32 tensor indicating
        the area (in the original absolute coordinates) of annotations (if
        provided in groundtruth).
      'num_groundtruth_boxes': [batch_size] tensor containing the maximum number
        of groundtruth boxes per image..
      'groundtruth_keypoints': [batch_size, num_boxes, num_keypoints, 2] float32
        tensor of keypoints (if provided in groundtruth).
      'groundtruth_dp_num_points_list': [batch_size, num_boxes] int32 tensor
        with the number of DensePose points for each instance (if provided in
        groundtruth).
      'groundtruth_dp_part_ids_list': [batch_size, num_boxes,
        max_sampled_points] int32 tensor with the part ids for each DensePose
        sampled point (if provided in groundtruth).
      'groundtruth_dp_surface_coords_list': [batch_size, num_boxes,
        max_sampled_points, 4] containing the DensePose surface coordinates for
        each sampled point (if provided in groundtruth).
      'groundtruth_track_ids_list': [batch_size, num_boxes] int32 tensor
        with track ID for each instance (if provided in groundtruth).
      'groundtruth_group_of': [batch_size, num_boxes] bool tensor indicating
        group_of annotations (if provided in groundtruth).
      'groundtruth_labeled_classes': [batch_size, num_classes] int64
        tensor of 1-indexed classes.
    class_agnostic: Boolean indicating whether detections are class agnostic.
  """
    input_data_fields = fields.InputDataFields()
    groundtruth_boxes = tf.stack(
        detection_model.groundtruth_lists(fields.BoxListFields.boxes))
    groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
    # For class-agnostic models, groundtruth one-hot encodings collapse to all
    # ones.
    if class_agnostic:
        groundtruth_classes_one_hot = tf.ones(
            [groundtruth_boxes_shape[0], groundtruth_boxes_shape[1], 1])
    else:
        groundtruth_classes_one_hot = tf.stack(
            detection_model.groundtruth_lists(fields.BoxListFields.classes))
    label_id_offset = 1  # Applying label id offset (b/63711816)
    groundtruth_classes = (tf.argmax(groundtruth_classes_one_hot, axis=2) +
                           label_id_offset)
    groundtruth = {
        input_data_fields.groundtruth_boxes: groundtruth_boxes,
        input_data_fields.groundtruth_classes: groundtruth_classes
    }
    if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
        groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack(
            detection_model.groundtruth_lists(fields.BoxListFields.masks))

    if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
        groundtruth[input_data_fields.groundtruth_is_crowd] = tf.stack(
            detection_model.groundtruth_lists(fields.BoxListFields.is_crowd))

    if detection_model.groundtruth_has_field(
            input_data_fields.groundtruth_area):
        groundtruth[input_data_fields.groundtruth_area] = tf.stack(
            detection_model.groundtruth_lists(
                input_data_fields.groundtruth_area))

    if detection_model.groundtruth_has_field(fields.BoxListFields.keypoints):
        groundtruth[input_data_fields.groundtruth_keypoints] = tf.stack(
            detection_model.groundtruth_lists(fields.BoxListFields.keypoints))

    if detection_model.groundtruth_has_field(
            fields.BoxListFields.keypoint_visibilities):
        groundtruth[
            input_data_fields.groundtruth_keypoint_visibilities] = tf.stack(
                detection_model.groundtruth_lists(
                    fields.BoxListFields.keypoint_visibilities))

    if detection_model.groundtruth_has_field(fields.BoxListFields.group_of):
        groundtruth[input_data_fields.groundtruth_group_of] = tf.stack(
            detection_model.groundtruth_lists(fields.BoxListFields.group_of))

    if detection_model.groundtruth_has_field(
            fields.InputDataFields.groundtruth_labeled_classes):
        labeled_classes_list = detection_model.groundtruth_lists(
            fields.InputDataFields.groundtruth_labeled_classes)
        labeled_classes = [
            tf.where(x)[:, 0] + label_id_offset for x in labeled_classes_list
        ]
        if len(labeled_classes) > 1:
            num_classes = labeled_classes_list[0].shape[0]
            padded_labeled_classes = []
            for x in labeled_classes:
                padding = num_classes - tf.shape(x)[0]
                padded_labeled_classes.append(tf.pad(x, [[0, padding]]))
            groundtruth[
                input_data_fields.groundtruth_labeled_classes] = tf.stack(
                    padded_labeled_classes)
        else:
            groundtruth[
                input_data_fields.groundtruth_labeled_classes] = tf.stack(
                    labeled_classes)

    if detection_model.groundtruth_has_field(
            fields.BoxListFields.densepose_num_points):
        groundtruth[input_data_fields.groundtruth_dp_num_points] = tf.stack(
            detection_model.groundtruth_lists(
                fields.BoxListFields.densepose_num_points))
    if detection_model.groundtruth_has_field(
            fields.BoxListFields.densepose_part_ids):
        groundtruth[input_data_fields.groundtruth_dp_part_ids] = tf.stack(
            detection_model.groundtruth_lists(
                fields.BoxListFields.densepose_part_ids))
    if detection_model.groundtruth_has_field(
            fields.BoxListFields.densepose_surface_coords):
        groundtruth[
            input_data_fields.groundtruth_dp_surface_coords] = tf.stack(
                detection_model.groundtruth_lists(
                    fields.BoxListFields.densepose_surface_coords))

    if detection_model.groundtruth_has_field(fields.BoxListFields.track_ids):
        groundtruth[input_data_fields.groundtruth_track_ids] = tf.stack(
            detection_model.groundtruth_lists(fields.BoxListFields.track_ids))

    groundtruth[input_data_fields.num_groundtruth_boxes] = (tf.tile(
        [max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
    return groundtruth
예제 #17
0
  def call(self,
           logits,
           annotation_begins,
           annotation_ends,
           annotation_labels,
           block_ids,
           num_replicas=None,
           eps=0):
    """Calls the layer.

    Args:
      logits: <float32>[batch_size, main_seq_len, 2] Logits per position.
      annotation_begins: <int32>[batch_size, main_seq_len] Positions of
        beginnings of answer spans.
      annotation_ends: <int32>[batch_size, main_seq_len] Positions of endings of
        answer spans.
      annotation_labels: <int32>[batch_size, main_seq_len] Positions of labels
        of answer spans. Label is 0 when the span is a placeholder one (included
        only for padding purposes) and should be ignored.
      block_ids: <int32>[batch_size] Block IDs of every sample in the batch.
      num_replicas: Number of replicas to gather summaries from. If None
        (default) then cross-replicas summaries are not used.
      eps: <float> Small constant for numerical stability.

    Returns:
        total_loss: <float>
    """
    seq_length = tf.shape(logits)[1]

    # (1) Aggregate block_ids across global batch. Compute cross block mask.
    all_block_ids = block_ids
    if num_replicas:
      all_block_ids = tpu_utils.cross_replica_concat(
          tensor=all_block_ids,
          num_replicas=num_replicas,
          name='block_ids_concat')

    # [batch_size, global_batch_size]
    cross_blocks_eq_mask = tf.cast(
        tf.equal(
            tf.expand_dims(block_ids, 1), tf.expand_dims(all_block_ids, 0)),
        tf.float32)

    # (2) Apply softmax over all positions in the (global) batch
    # across the blocks with the same `block_id`.

    # [batch_size, seq_len, 2]
    probs = cross_batch_softmax(logits, cross_blocks_eq_mask, num_replicas)

    # (3) Prepare one-hot labels based on annotation begins and ends

    # [batch_size, seq_len, 1]
    annotation_begins_one_hot = _one_hot_multi(
        annotation_begins,
        annotation_labels > 0,
        seq_length,
    )
    # [batch_size, seq_len, 1]
    annotation_ends_one_hot = _one_hot_multi(
        annotation_ends,
        annotation_labels > 0,
        seq_length,
    )
    # [batch_size, seq_len, 2]
    one_hot_labels = tf.concat(
        [annotation_begins_one_hot, annotation_ends_one_hot], 2)

    # (4) Compute the probability of the current begin / end positions across
    # the blocks with the same `block_id`.

    # [batch_size, 2]
    correct_probs = tf.reduce_sum(probs * one_hot_labels, axis=1)
    if num_replicas:
      # [global_batch_size, 2]
      correct_probs = tpu_utils.cross_replica_concat(
          tensor=correct_probs,
          num_replicas=num_replicas,
          name='correct_probs_concat')

    # [batch_size, 2]
    correct_probs = tf.matmul(cross_blocks_eq_mask, correct_probs)

    # (5) Compute log probability. We allow cases when there are no correct
    # labels not only for the current sample, but for the whole document
    # across the whole batch. In that case the probability of the correct label
    # would be 0 and the loss would be infinite. Therefore, we just do not
    # compute loss on these documents.

    # [batch_size, 1]
    num_annotations_per_sample = tf.reduce_sum(
        annotation_labels, 1, keepdims=True)
    if num_replicas:
      # [global_batch_size, 1]
      num_annotations_per_sample = tpu_utils.cross_replica_concat(
          tensor=num_annotations_per_sample,
          num_replicas=num_replicas,
          name='num_annotations_per_sample_concat')

    # [batch_size, 1]
    num_annotations_per_doc = tf.matmul(
        cross_blocks_eq_mask, tf.cast(num_annotations_per_sample, tf.float32))
    # [batch_size, 2]
    doc_with_annotations_mask = tf.stop_gradient(
        tf.cast(tf.tile(num_annotations_per_doc > 0, [1, 2]), tf.float32))
    doc_without_annotations_mask = tf.stop_gradient(1 -
                                                    doc_with_annotations_mask)
    log_correct_probs = tf.log(
        correct_probs + eps +
        doc_without_annotations_mask) * doc_with_annotations_mask

    # (6) Divide by the number of blocks per block_id
    # If there are K blocks with the same block_id, then on step (4) we'll
    # compute loss for this document K times. So we need to divide it back by K.

    # [batch_size, 2]
    log_correct_probs /= tf.reduce_sum(cross_blocks_eq_mask, 1, keepdims=True)

    # (7) Sum over blocks and begin/end predictions

    loss = -tf.reduce_sum(log_correct_probs)
    return loss
예제 #18
0
def compute_embedding_contrastive_loss(
        inf_embedding,
        con_embedding,
        positives=None,
        contrastive_loss_mode='both_directions'):
    """Compute triplet loss between inference and condition_embeddings.

  Expects embeddings to be L2-normalized.

  Args:
    inf_embedding: A rank 3 tensor: [num_tasks, num_inf_episodes, K].
    con_embedding: A rank 3 tensor: [num_tasks, num_con_episodes, K].
    positives: (Optional). A rank 1 bool tensor: [num_tasks]. If provided,
      instead of assigning positives to just the 1st task in the batch, it uses
      the positives given. Positives should be defined as if the 1st task was
      the anchor. When not provided, the 1st con_embedding is positive and every
      other con_embedding is negatives.
    contrastive_loss_mode: Which contrastive loss function to use.

  Returns:
    The contrastive loss computed using the task zero inf_embedding and
    each of the `num_tasks` con_embeddings.
  """
    if len(inf_embedding.shape) != 3:
        raise ValueError('Unexpected inf_embedding shape: {}.'.format(
            inf_embedding.shape))
    if len(con_embedding.shape) != 3:
        raise ValueError('Unexpected con_embedding shape: {}.'.format(
            con_embedding.shape))
    avg_inf_embedding = tf.reduce_mean(inf_embedding, axis=1)
    avg_con_embedding = tf.reduce_mean(con_embedding, axis=1)
    anchor = avg_inf_embedding[0:1]
    if positives is not None:
        labels = positives
    else:
        labels = tf.math.equal(tf.range(tf.shape(avg_con_embedding)[0]), 0)
    # Unlike TEC paper, use standard contrastive loss.
    # This does L2 distance in space
    if contrastive_loss_mode == 'default':
        # anchor_inf --> con embeddings
        embed_loss = slim_losses.metric_learning.contrastive_loss(
            labels, anchor, avg_con_embedding)
    elif contrastive_loss_mode == 'both_directions':
        # anchor_inf --> con embeddings and anchor_con --> inf embeddings.
        # Since data is paired, we know we can reuse the labels.
        # Seems to perform best.
        embed_loss1 = slim_losses.metric_learning.contrastive_loss(
            labels, anchor, avg_con_embedding)
        anchor_cond = avg_con_embedding[0:1]
        embed_loss2 = slim_losses.metric_learning.contrastive_loss(
            labels, anchor_cond, avg_inf_embedding)
        embed_loss = embed_loss1 + embed_loss2
    elif contrastive_loss_mode == 'reverse_direction':
        # anchor_con --> inf embeddings.
        anchor_cond = avg_con_embedding[0:1]
        embed_loss = slim_losses.metric_learning.contrastive_loss(
            labels, anchor_cond, avg_inf_embedding)
    elif contrastive_loss_mode == 'cross_entropy':
        # softmax(temperature * z^T c), does both directions by default.
        #
        # This should be similar to the InfoNCE contrastive loss, but is slightly
        # different because entries in the same batch may be for the same task.
        #
        # Performance untested.
        temperature = 2
        anchor_cond = avg_con_embedding[0:1]
        cosine_sim = tf.reduce_sum(anchor * avg_con_embedding, axis=1)
        loss1 = tf.keras.losses.binary_crossentropy(labels,
                                                    temperature * cosine_sim,
                                                    from_logits=True)
        cosine_sim_2 = tf.reduce_sum(anchor_cond * avg_inf_embedding, axis=1)
        loss2 = tf.keras.losses.binary_crossentropy(labels,
                                                    temperature * cosine_sim_2,
                                                    from_logits=True)
        embed_loss = loss1 + loss2
    elif contrastive_loss_mode == 'triplet':
        if positives is None:
            # Triplet loss requires a different labeling scheme than the other losses.
            # Assume unique-task pairing scheme [0, 1, 2, ..., N, 0, 1, 2, ..., N].
            positives = tf.range(avg_inf_embedding.shape[0], dtype=tf.int32)
        labels = tf.tile(positives, [2])
        embeds = tf.concat([avg_inf_embedding, avg_con_embedding], axis=0)
        embed_loss = slim_losses.metric_learning.triplet_semihard_loss(
            labels, embeds, margin=3.0)
    else:
        raise ValueError('Did not understand contrastive_loss_mode')
    return embed_loss
예제 #19
0
    def __init__(self, dataset, parameters):

        self.verbose = False
        self.feature_vector_length=parameters['Feature_vector_length']

        # Placeholders for input, output and dropout
        self.input_token_indices = tf.placeholder(tf.int32, [None], name="input_token_indices")
        self.input_label_indices_vector = tf.placeholder(tf.float32, [None, dataset.number_of_classes], name="input_label_indices_vector")
        self.input_label_indices_flat = tf.placeholder(tf.int32, [None], name="input_label_indices_flat")
        self.input_token_character_indices = tf.placeholder(tf.int32, [None, None], name="input_token_indices")
        self.input_token_lengths = tf.placeholder(tf.int32, [None], name="input_token_lengths")
        self.dropout_keep_prob = tf.placeholder(tf.float32, name="dropout_keep_prob")
        
        self.input_features=tf.placeholder(tf.float32, [None,self.feature_vector_length], name="features")
        
        
        self.vocabulary_size=dataset.vocabulary_size

        # Internal parameters
        #initializer = tf.contrib.layers.xavier_initializer()
        initializer = tf.glorot_uniform_initializer()

        if parameters['use_character_lstm']:
            with tf.variable_scope("character_embedding"):
                self.character_embedding_weights = tf.get_variable(
                    "character_embedding_weights",
                    shape=[dataset.alphabet_size, parameters['character_embedding_dimension']],
                    initializer=initializer)
                embedded_characters = tf.nn.embedding_lookup(self.character_embedding_weights, self.input_token_character_indices, name='embedded_characters')
                if self.verbose: print("embedded_characters: {0}".format(embedded_characters))
               # utils_tf.variable_summaries(self.character_embedding_weights)

            # Character LSTM layer
            with tf.variable_scope('character_lstm') as vs:
              if parameters['Use_LSTM']==True:
                character_lstm_output = bidirectional_LSTM(embedded_characters, parameters['character_lstm_hidden_state_dimension'], initializer,
                                                           sequence_length=self.input_token_lengths, output_sequence=False)
                self.character_lstm_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)
              else:
                 character_lstm_output = bidirectional_GRU(embedded_characters, parameters['character_lstm_hidden_state_dimension'], initializer,
                                                           sequence_length=self.input_token_lengths, output_sequence=False) 
            # Attention, not implemented      

            #  with tf.variable_scope('attention') as scope:
             #    word_level_output = task_specific_attention(character_lstm_output,dataset.token_lengths,scope=scope)
                # print (w)

             # sentence_inputs = tf.reshape(word_level_output, [self.document_size, self.sentence_size, self.word_output_size])

        # Token embedding layer
        with tf.variable_scope("token_embedding"):
            self.token_embedding_weights = tf.get_variable(
                "token_embedding_weights",
                shape=[dataset.vocabulary_size, parameters['token_embedding_dimension']],
                initializer=initializer,
                trainable=not parameters['freeze_token_embeddings'])
            embedded_tokens = tf.nn.embedding_lookup(self.token_embedding_weights, self.input_token_indices)
          #  utils_tf.variable_summaries(self.token_embedding_weights)

        # Concatenate character LSTM outputs and token embeddings
        if parameters['use_character_lstm']:
            with tf.variable_scope("concatenate_token_and_character_vectors"):
                if self.verbose: print('embedded_tokens: {0}'.format(embedded_tokens))
                token_lstm_input = tf.concat([character_lstm_output, embedded_tokens], axis=1, name='token_lstm_input')
                if self.verbose: print("token_lstm_input: {0}".format(token_lstm_input))
        else:
            token_lstm_input = embedded_tokens
            
        if parameters['use_features_before_final_lstm']:
            with tf.variable_scope("features_argumentation_pre_LSTM"):
                token_lstm_input=tf.concat([token_lstm_input, self.input_features], 1)
                print (token_lstm_input)
            

        # Add dropout
        with tf.variable_scope("dropout"):
            token_lstm_input_drop = tf.nn.dropout(token_lstm_input, self.dropout_keep_prob, name='token_lstm_input_drop')
            if self.verbose: print("token_lstm_input_drop: {0}".format(token_lstm_input_drop))
            # https://www.tensorflow.org/api_guides/python/contrib.rnn
            # Prepare data shape to match `rnn` function requirements
            # Current data input shape: (batch_size, n_steps, n_input)
            # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
            token_lstm_input_drop_expanded = tf.expand_dims(token_lstm_input_drop, axis=0, name='token_lstm_input_drop_expanded')
            if self.verbose: print("token_lstm_input_drop_expanded: {0}".format(token_lstm_input_drop_expanded))
            
        #if parameters['use_features_before_final_lstm']:
        #   with tf.variable_scope("features_argumentation_pre_LSTM"):
        #       token_lstm_input_drop_expanded=tf.concat([token_lstm_input_drop_expanded, self.input_features], 1)
        #       print (token_lstm_input_drop_expanded)

        # Token LSTM layer
        with tf.variable_scope('token_lstm') as vs:
            if parameters['Use_LSTM']==True: 
                token_lstm_output = bidirectional_LSTM(token_lstm_input_drop_expanded, parameters['token_lstm_hidden_state_dimension'], initializer, output_sequence=True)
            else: 
                token_lstm_output = bidirectional_GRU(token_lstm_input_drop_expanded, parameters['token_lstm_hidden_state_dimension'], initializer, output_sequence=True)
            token_lstm_output_squeezed = tf.squeeze(token_lstm_output, axis=0)
            self.token_lstm_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)

        # Needed only if Bidirectional LSTM is used for token level
        with tf.variable_scope("feedforward_after_lstm") as vs:
            W = tf.get_variable(
                "W",
                shape=[2 * parameters['token_lstm_hidden_state_dimension'], parameters['token_lstm_hidden_state_dimension']],
                initializer=initializer)
            b = tf.Variable(tf.constant(0.0, shape=[parameters['token_lstm_hidden_state_dimension']]), name="bias")
            outputs = tf.nn.xw_plus_b(token_lstm_output_squeezed, W, b, name="output_before_tanh")
            outputs = tf.nn.tanh(outputs, name="output_after_tanh")
            self.token_lstm_variables += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)

        with tf.variable_scope("feedforward_before_crf") as vs:
            W = tf.get_variable(
                "W",
                shape=[parameters['token_lstm_hidden_state_dimension'], dataset.number_of_classes],
                initializer=initializer)
            b = tf.Variable(tf.constant(0.0, shape=[dataset.number_of_classes]), name="bias")
            scores = tf.nn.xw_plus_b(outputs, W, b, name="scores")
            self.unary_scores = scores
            self.predictions = tf.argmax(self.unary_scores, 1, name="predictions")
            #utils_tf.variable_summaries(W)
           # utils_tf.variable_summaries(b)
            self.feedforward_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)

        # CRF layer
        parameters['use_crf'] = False # for now
        if parameters['use_crf']:
            print ("CRF IS IN USE")
            with tf.variable_scope("crf") as vs:
                # Add start and end tokens
                small_score = -1000.0
                large_score = 0.0
                sequence_length = tf.shape(self.unary_scores)[0]
                unary_scores_with_start_and_end = tf.concat([self.unary_scores, tf.tile( tf.constant(small_score, shape=[1, 2]) , [sequence_length, 1])], 1)
                start_unary_scores = [[small_score] * dataset.number_of_classes + [large_score, small_score]]
                end_unary_scores = [[small_score] * dataset.number_of_classes + [small_score, large_score]]
                self.unary_scores = tf.concat([start_unary_scores, unary_scores_with_start_and_end, end_unary_scores], 0)
                start_index = dataset.number_of_classes
                end_index = dataset.number_of_classes + 1
                input_label_indices_flat_with_start_and_end = tf.concat([ tf.constant(start_index, shape=[1]), self.input_label_indices_flat, tf.constant(end_index, shape=[1]) ], 0)

                # Apply CRF layer
                sequence_length = tf.shape(self.unary_scores)[0]
                sequence_lengths = tf.expand_dims(sequence_length, axis=0, name='sequence_lengths')
                unary_scores_expanded = tf.expand_dims(self.unary_scores, axis=0, name='unary_scores_expanded')
                input_label_indices_flat_batch = tf.expand_dims(input_label_indices_flat_with_start_and_end, axis=0, name='input_label_indices_flat_batch')
                if self.verbose: print('unary_scores_expanded: {0}'.format(unary_scores_expanded))
                if self.verbose: print('input_label_indices_flat_batch: {0}'.format(input_label_indices_flat_batch))
                if self.verbose: print("sequence_lengths: {0}".format(sequence_lengths))
                # https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/crf
                # Compute the log-likelihood of the gold sequences and keep the transition params for inference at test time.
                self.transition_parameters=tf.get_variable(
                    "transitions",
                    shape=[dataset.number_of_classes+2, dataset.number_of_classes+2],
                    initializer=initializer)
                #utils_tf.variable_summaries(self.transition_parameters)
                log_likelihood, _ = tf.contrib.crf.crf_log_likelihood(
                    unary_scores_expanded, input_label_indices_flat_batch, sequence_lengths, transition_params=self.transition_parameters)
                self.loss =  tf.reduce_mean(-log_likelihood, name='cross_entropy_mean_loss')
                self.accuracy = tf.constant(1)

                self.crf_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name) # LATER FOR RESTORE

        # Do not use CRF layer
        else:
            with tf.variable_scope("crf") as vs:
                self.transition_parameters = tf.get_variable(
                    "transitions",
                    shape=[dataset.number_of_classes+2, dataset.number_of_classes+2],
                    initializer=initializer)
               # utils_tf.variable_summaries(self.transition_parameters)
                self.crf_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=vs.name)

            # Calculate mean cross-entropy loss
            with tf.variable_scope("loss"):
                losses = tf.nn.softmax_cross_entropy_with_logits(logits=self.unary_scores, labels=self.input_label_indices_vector, name='softmax')
                self.loss =  tf.reduce_mean(losses, name='cross_entropy_mean_loss')
            with tf.variable_scope("accuracy"):
                correct_predictions = tf.equal(self.predictions, tf.argmax(self.input_label_indices_vector, 1))
                self.accuracy = tf.reduce_mean(tf.cast(correct_predictions, 'float'), name='accuracy')

        self.define_training_procedure(parameters)
        self.summary_op = tf.summary.merge_all()
        self.saver = tf.train.Saver(max_to_keep=100)
def mel_perf_transformer_encode(encoder_function,
                                perf_inputs,
                                mel_inputs,
                                target_space,
                                hparams,
                                attention_weights=None,
                                features=None,
                                losses=None,
                                prepare_encoder_fn=None,
                                **kwargs):
    """Encode transformer inputs. Used for melody & performance autoencoder.

  Performance is mean-aggregated across time and combined with melody in a
  variety of different ways.

  Args:
    encoder_function: the encoder function
    perf_inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim]
    which will be flattened along the two spatial dimensions.
    mel_inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim]
    which will be flattened along the two spatial dimensions.
    target_space: scalar, target space ID.
    hparams: hyperparameters for model.
    attention_weights: weight to store attention to.
    features: optionally pass the entire features dictionary as well. This is
      needed now for "packed" datasets.
    losses: optional list onto which to append extra training losses
    prepare_encoder_fn: optional, alternative to transformer_prepare_encoder.
    **kwargs: additional arguments to pass to encoder_function

  Returns:
    Tuple of:
        encoder_output: Encoder representation.
            [batch_size, input_length, hidden_dim]
        encoder_decoder_attention_bias: Bias and mask weights for
            encoder-decoder attention. [batch_size, input_length]
  """
    perf_inputs = common_layers.flatten4d3d(perf_inputs)
    mel_inputs = common_layers.flatten4d3d(mel_inputs)

    if not prepare_encoder_fn:
        prepare_encoder_fn = transformer_prepare_encoder
    perf_encoder_input, perf_self_attention_bias, perf_encdec_attention_bias = (
        prepare_encoder_fn(perf_inputs,
                           target_space,
                           hparams,
                           features=features,
                           reuse_target_embedding=tf.AUTO_REUSE))

    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
        value=hparams.layer_prepostprocess_dropout,
        hparams=hparams)

    perf_encoder_input = tf.nn.dropout(
        perf_encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

    perf_attn_bias_for_padding = None
    # Otherwise the encoder will just use encoder_self_attention_bias.
    if hparams.unidirectional_encoder:
        perf_attn_bias_for_padding = perf_encdec_attention_bias

    # do the same thing for melody
    mel_encoder_input, mel_self_attention_bias, mel_encdec_attention_bias = (
        prepare_encoder_fn(mel_inputs,
                           target_space,
                           hparams,
                           features=features,
                           reuse_target_embedding=tf.AUTO_REUSE))

    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
        value=hparams.layer_prepostprocess_dropout,
        hparams=hparams)

    mel_encoder_input = tf.nn.dropout(
        mel_encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

    mel_attn_bias_for_padding = None
    # Otherwise the encoder will just use encoder_self_attention_bias.
    if hparams.unidirectional_encoder:
        mel_attn_bias_for_padding = mel_encdec_attention_bias

    # use the proper encoder function for perf/melody
    perf_encoder_output = encoder_function(
        perf_encoder_input,
        perf_self_attention_bias,
        hparams,
        name="perf_encoder",
        nonpadding=features_to_nonpadding(features, "inputs"),
        save_weights_to=attention_weights,
        make_image_summary=not common_layers.is_xla_compiled(),
        losses=losses,
        attn_bias_for_padding=perf_attn_bias_for_padding,
        **kwargs)
    # same thing for melody
    mel_encoder_output = encoder_function(
        mel_encoder_input,
        mel_self_attention_bias,
        hparams,
        name="mel_encoder",
        nonpadding=features_to_nonpadding(features, "inputs"),
        save_weights_to=attention_weights,
        make_image_summary=not common_layers.is_xla_compiled(),
        losses=losses,
        attn_bias_for_padding=mel_attn_bias_for_padding,
        **kwargs)

    # concatenate the global mean vector/bias term with the full melody encoding
    perf_mean_vector = tf.math.reduce_mean(perf_encoder_output,
                                           axis=1,
                                           keep_dims=True)

    # different methods of aggregating over the performance + melody vectors!
    if hparams.aggregation == "sum":
        # add both mean performance and melody vectors together
        perf_mean_bias = tf.math.reduce_mean(perf_encdec_attention_bias,
                                             axis=-1,
                                             keep_dims=True)
        encoder_output = mel_encoder_output + perf_mean_vector
        encoder_decoder_attention_bias = mel_encdec_attention_bias + perf_mean_bias
    elif hparams.aggregation == "concat":
        # concatenate melody with mean-aggregated performance embedding
        stop_token = tf.zeros((1, 1, 384))
        encoder_output = tf.concat(
            [mel_encoder_output, stop_token, perf_mean_vector], axis=1)
        perf_mean_bias = tf.math.reduce_mean(perf_encdec_attention_bias,
                                             axis=-1,
                                             keep_dims=True)
        stop_bias = tf.zeros((1, 1, 1, 1))
        encoder_decoder_attention_bias = tf.concat(
            [mel_encdec_attention_bias, stop_bias, perf_mean_bias], axis=-1)
    elif hparams.aggregation == "tile":
        # tile performance embedding across each dimension of melody embedding!
        dynamic_val = tf.shape(mel_encoder_output)[1]
        shp = tf.convert_to_tensor([1, dynamic_val, 1], dtype=tf.int32)
        tiled_mean = tf.tile(perf_mean_vector, shp)

        encoder_output = tf.concat([mel_encoder_output, tiled_mean], axis=-1)
        encoder_decoder_attention_bias = mel_encdec_attention_bias
    else:
        NotImplementedError(
            "aggregation method must be in [sum, concat, tile].")

    return encoder_output, encoder_decoder_attention_bias
예제 #21
0
    def sample(self, b_enc=None, b_dec=None, n=None, temperature=None):
        """Generate samples for the batch from the NADE.

    Args:
      b_enc: External encoder bias terms (`b` in [1]), sized
          `[batch_size, num_hidden]`, or None if the internal bias term should
          be used.
      b_dec: External decoder bias terms (`c` in [1]), sized
          `[batch_size, num_dims]`, or None if the internal bias term should
          be used.
      n: The number of samples to generate, or None, if the batch size of
          `b_enc` should be used.
      temperature: The amount to divide the logits by before sampling
          each Bernoulli, or None if a threshold of 0.5 should be used instead
          of sampling.

    Returns:
      sample: The generated samples, sized `[batch_size, num_dims]`.
      log_prob: The log probabilities of each observation in the batch, sized
          `[batch_size]`.
    """
        b_enc = b_enc if b_enc is not None else self.b_enc
        b_dec = b_dec if b_dec is not None else self.b_dec

        batch_size = n or tf.shape(b_enc)[0]

        # Broadcast if needed.
        if b_enc.shape[0] == 1 != batch_size:
            b_enc = tf.tile(b_enc, [batch_size, 1])
        if b_dec.shape[0] == 1 != batch_size:
            b_dec = tf.tile(b_dec, [batch_size, 1])

        a_0 = b_enc
        sample_0 = []
        log_p_0 = tf.zeros([batch_size, 1])

        w_enc_arr = tf.unstack(self.w_enc)
        w_dec_arr = tf.unstack(self.w_dec_t)
        b_dec_arr = tf.unstack(
            tf.reshape(tf.transpose(b_dec), [self.num_dims, batch_size, 1]))

        def loop_body(i, a, sample, log_p):
            """Accumulate hidden state, sample, and log probability for index i."""
            # Get weights and bias for time step.
            w_enc_i = w_enc_arr[i]
            w_dec_i = w_dec_arr[i]
            b_dec_i = b_dec_arr[i]

            cond_p_i, cond_l_i = self._cond_prob(a, w_dec_i, b_dec_i)

            if temperature is None:
                v_i = tf.to_float(tf.greater_equal(cond_p_i, 0.5))
            else:
                bernoulli = tfp.distributions.Bernoulli(logits=cond_l_i /
                                                        temperature,
                                                        dtype=tf.float32)
                v_i = bernoulli.sample()

            # Accumulate sampled values.
            sample_new = sample + [v_i]

            # Get log probability for this value. Log space avoids numerical issues.
            log_p_i = v_i * _safe_log(cond_p_i) + (
                1 - v_i) * _safe_log(1 - cond_p_i)

            # Accumulate log probability.
            log_p_new = log_p + log_p_i

            # Encode value and add to hidden units.
            a_new = a + tf.matmul(v_i, w_enc_i)

            return a_new, sample_new, log_p_new

        a, sample, log_p = a_0, sample_0, log_p_0
        for i in range(self.num_dims):
            a, sample, log_p = loop_body(i, a, sample, log_p)

        return (tf.transpose(tf.squeeze(tf.stack(sample), [2])),
                tf.squeeze(log_p, squeeze_dims=[1]))
예제 #22
0
    def get_prediction_module(self, bert_model, features, is_training,
                              percent_done):
        final_hidden = bert_model.get_sequence_output()

        final_hidden_shape = modeling.get_shape_list(final_hidden,
                                                     expected_rank=3)
        batch_size = final_hidden_shape[0]
        seq_length = final_hidden_shape[1]

        answer_mask = tf.cast(features["input_mask"], tf.float32)
        answer_mask *= tf.cast(features["segment_ids"], tf.float32)
        answer_mask += tf.one_hot(0, seq_length)

        start_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)

        start_top_log_probs = tf.zeros([batch_size, self.config.beam_size])
        start_top_index = tf.zeros([batch_size, self.config.beam_size],
                                   tf.int32)
        end_top_log_probs = tf.zeros(
            [batch_size, self.config.beam_size, self.config.beam_size])
        end_top_index = tf.zeros(
            [batch_size, self.config.beam_size, self.config.beam_size],
            tf.int32)
        if self.config.joint_prediction:
            start_logits += 1000.0 * (answer_mask - 1)
            start_log_probs = tf.nn.log_softmax(start_logits)
            start_top_log_probs, start_top_index = tf.nn.top_k(
                start_log_probs, k=self.config.beam_size)

            if not is_training:
                # batch, beam, length, hidden
                end_features = tf.tile(tf.expand_dims(final_hidden, 1),
                                       [1, self.config.beam_size, 1, 1])
                # batch, beam, length
                start_index = tf.one_hot(start_top_index,
                                         depth=seq_length,
                                         axis=-1,
                                         dtype=tf.float32)
                # batch, beam, hidden
                start_features = tf.reduce_sum(
                    tf.expand_dims(final_hidden, 1) *
                    tf.expand_dims(start_index, -1),
                    axis=-2)
                # batch, beam, length, hidden
                start_features = tf.tile(tf.expand_dims(start_features, 2),
                                         [1, 1, seq_length, 1])
            else:
                start_index = tf.one_hot(features[self.name +
                                                  "_start_positions"],
                                         depth=seq_length,
                                         axis=-1,
                                         dtype=tf.float32)
                start_features = tf.reduce_sum(
                    tf.expand_dims(start_index, -1) * final_hidden, axis=1)
                start_features = tf.tile(tf.expand_dims(start_features, 1),
                                         [1, seq_length, 1])
                end_features = final_hidden

            final_repr = tf.concat([start_features, end_features], -1)
            final_repr = tf.layers.dense(final_repr,
                                         512,
                                         activation=modeling.gelu,
                                         name="qa_hidden")
            # batch, beam, length (batch, length when training)
            end_logits = tf.squeeze(tf.layers.dense(final_repr, 1),
                                    -1,
                                    name="qa_logits")
            if is_training:
                end_logits += 1000.0 * (answer_mask - 1)
            else:
                end_logits += tf.expand_dims(1000.0 * (answer_mask - 1), 1)

            if not is_training:
                end_log_probs = tf.nn.log_softmax(end_logits)
                end_top_log_probs, end_top_index = tf.nn.top_k(
                    end_log_probs, k=self.config.beam_size)
                end_logits = tf.zeros([batch_size, seq_length])
        else:
            end_logits = tf.squeeze(tf.layers.dense(final_hidden, 1), -1)
            start_logits += 1000.0 * (answer_mask - 1)
            end_logits += 1000.0 * (answer_mask - 1)

        def compute_loss(logits, positions):
            one_hot_positions = tf.one_hot(positions,
                                           depth=seq_length,
                                           dtype=tf.float32)
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
            return loss

        start_positions = features[self.name + "_start_positions"]
        end_positions = features[self.name + "_end_positions"]

        start_loss = compute_loss(start_logits, start_positions)
        end_loss = compute_loss(end_logits, end_positions)

        losses = (start_loss + end_loss) / 2.0

        answerable_logit = tf.zeros([batch_size])
        if self.config.answerable_classifier:
            final_repr = final_hidden[:, 0]
            if self.config.answerable_uses_start_logits:
                start_p = tf.nn.softmax(start_logits)
                start_feature = tf.reduce_sum(tf.expand_dims(start_p, -1) *
                                              final_hidden,
                                              axis=1)
                final_repr = tf.concat([final_repr, start_feature], -1)
                final_repr = tf.layers.dense(final_repr,
                                             512,
                                             activation=modeling.gelu)
            answerable_logit = tf.squeeze(tf.layers.dense(final_repr, 1), -1)
            answerable_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.cast(features[self.name + "_is_impossible"],
                               tf.float32),
                logits=answerable_logit)
            losses += answerable_loss * self.config.answerable_weight

        return losses, dict(
            loss=losses,
            start_logits=start_logits,
            end_logits=end_logits,
            answerable_logit=answerable_logit,
            start_positions=features[self.name + "_start_positions"],
            end_positions=features[self.name + "_end_positions"],
            start_top_log_probs=start_top_log_probs,
            start_top_index=start_top_index,
            end_top_log_probs=end_top_log_probs,
            end_top_index=end_top_index,
            eid=features[self.name + "_eid"],
        )
예제 #23
0
        x_fake, ll_fake = cond_gen(data_train.query[1], z_outer)

        test_fake_gen = cond_gen(data_test.query[1], mu_test)[0]

    disc_loss, train_disc = get_disc_loss(cmd_args, x_true, x_fake, score_func,
                                          z_outer, neg_kl_outer)
    gen_loss, train_gen = get_gen_loss(cmd_args, x_fake, ll_fake, score_func,
                                       z_outer)

    # for plot
    ph_x = tf.placeholder(tf.float32, shape=(None, 1))
    ph_y = tf.placeholder(tf.float32, shape=(None, 1))
    ph_x_plot_cond = tf.placeholder(tf.float32, shape=(1, None, 2))
    z_plot, _, _, _ = posterior(ph_x_plot_cond)
    z_plot = tf.tile(z_plot, [tf.shape(ph_x)[0], 1])
    x_plot = tf.concat([ph_x, ph_y], axis=-1)
    x_plot = tf.expand_dims(x_plot, 1)
    score_plot = score_func(x_plot, z_plot)
    x_plot_cond = tf.concat([data_plot.query[1], data_plot.target_y], axis=-1)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        model_dir = os.path.join(cmd_args.save_dir, 'model')
        if cmd_args.epoch_load >= 0:
            model_path = os.path.join(model_dir,
                                      'model-%d.ckpt' % cmd_args.epoch_load)
            saver.restore(sess, model_path)
예제 #24
0
    def _log_prob(self, data, num_samples=1):
        """Compute a lower bound on the log likelihood."""
        # Due to memory issues, we need to use num_samples=1 here
        num_samples, proposal_num_samples = 1, num_samples
        batch_size = tf.shape(data)[0]
        # Sample from the proposal and compute the weighs of the "unseen" samples.
        # We share these across the batch dimension.
        # [num_samples, K, data_size]
        proposal_samples = self.proposal.sample(num_samples * (self.K - 1))
        if not self.reparameterize_proposal_samples:
            proposal_samples = tf.stop_gradient(proposal_samples)

        # [num_samples, K]
        log_energy_proposal = tf.reshape(
            self.energy_fn(tf.reshape(proposal_samples, [-1] + self.data_dim)),
            [num_samples, self.K - 1])
        tf.summary.histogram("log_energy_proposal", log_energy_proposal)
        tf.summary.scalar("min_log_energy_proposal",
                          tf.reduce_min(log_energy_proposal))
        tf.summary.scalar("max_log_energy_proposal",
                          tf.reduce_max(log_energy_proposal))
        # [num_samples]
        proposal_lse = tf.reduce_logsumexp(log_energy_proposal, axis=1)

        # [batch_size, num_samples]
        tiled_proposal_lse = tf.tile(proposal_lse[tf.newaxis, :],
                                     [batch_size, 1])

        # Compute the weights of the observed data.
        # [batch_size, 1]
        log_energy_data = tf.reshape(self.energy_fn(data), [batch_size])
        tf.summary.histogram("log_energy_data", log_energy_data)
        tf.summary.scalar("min_log_energy_data",
                          tf.reduce_min(log_energy_data))
        tf.summary.scalar("max_log_energy_data",
                          tf.reduce_max(log_energy_data))

        # [batch_size, num_samples]
        tiled_log_energy_data = tf.tile(log_energy_data[:, tf.newaxis],
                                        [1, num_samples])

        # Add the weights of the proposal samples with the true data weights.
        # [batch_size, num_samples]
        # pylint: disable=invalid-name
        Z_hat = tf.reduce_logsumexp(tf.stack(
            [tiled_log_energy_data, tiled_proposal_lse], axis=-1),
                                    axis=-1)
        Z_hat -= tf.log(tf.to_float(self.K))
        # Perform the log-sum-exp reduction for IWAE
        # [batch_size]
        Z_hat = tf.reduce_logsumexp(Z_hat, axis=1) - tf.log(
            tf.to_float(num_samples))
        # pylint: enable=invalid-name

        try:
            # Try giving the proposal lower bound num_samples if it can use it.
            proposal_lp = self.proposal.log_prob(
                data, num_samples=proposal_num_samples)
        except TypeError:
            proposal_lp = self.proposal.log_prob(data)
        lower_bound = proposal_lp + log_energy_data - Z_hat
        return lower_bound
예제 #25
0
    def crop_proposal():
        rand_vec = lambda minval, maxval: tf.random_uniform(shape=(
            ssd_constants.NUM_CROP_PASSES, 1),
                                                            minval=minval,
                                                            maxval=maxval,
                                                            dtype=tf.float32)

        width, height = rand_vec(0.3, 1), rand_vec(0.3, 1)
        left, top = rand_vec(0, 1 - width), rand_vec(0, 1 - height)

        right = left + width
        bottom = top + height

        ltrb = tf.concat([left, top, right, bottom], axis=1)

        min_iou = tf.random_shuffle(ssd_constants.CROP_MIN_IOU_CHOICES)[0]
        ious = calc_iou_tensor(ltrb, boxes)

        # discard any bboxes whose center not in the cropped image
        xc, yc = [
            tf.tile(0.5 * (boxes[:, i + 0] + boxes[:, i + 2])[tf.newaxis, :],
                    (ssd_constants.NUM_CROP_PASSES, 1)) for i in range(2)
        ]

        masks = tf.reduce_all(tf.stack([
            tf.greater(xc, tf.tile(left, (1, num_boxes))),
            tf.less(xc, tf.tile(right, (1, num_boxes))),
            tf.greater(yc, tf.tile(top, (1, num_boxes))),
            tf.less(yc, tf.tile(bottom, (1, num_boxes))),
        ],
                                       axis=2),
                              axis=2)

        # Checks of whether a crop is valid.
        valid_aspect = tf.logical_and(tf.less(height / width, 2),
                                      tf.less(width / height, 2))
        valid_ious = tf.reduce_all(tf.greater(ious, min_iou),
                                   axis=1,
                                   keepdims=True)
        valid_masks = tf.reduce_any(masks, axis=1, keepdims=True)

        valid_all = tf.cast(
            tf.reduce_all(tf.concat([valid_aspect, valid_ious, valid_masks],
                                    axis=1),
                          axis=1), tf.int32)

        # One indexed, as zero is needed for the case of no matches.
        index = tf.range(1, 1 + ssd_constants.NUM_CROP_PASSES, dtype=tf.int32)

        # Either one-hot, or zeros if there is no valid crop.
        selection = tf.equal(tf.reduce_max(index * valid_all), index)

        use_crop = tf.reduce_any(selection)
        output_ltrb = tf.reduce_sum(tf.multiply(
            ltrb,
            tf.tile(tf.cast(selection, tf.float32)[:, tf.newaxis], (1, 4))),
                                    axis=0)
        output_masks = tf.reduce_any(tf.logical_and(
            masks, tf.tile(selection[:, tf.newaxis], (1, num_boxes))),
                                     axis=0)

        return use_crop, output_ltrb, output_masks
예제 #26
0
  def decode(self, tf_example_string_tensor):
    """Decodes serialized tensorflow example and returns a tensor dictionary.

    Args:
      tf_example_string_tensor: a string tensor holding a serialized tensorflow
        example proto.

    Returns:
      A dictionary of the following tensors.
      fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, 3]
        containing image.
      fields.InputDataFields.original_image_spatial_shape - 1D int32 tensor of
        shape [2] containing shape of the image.
      fields.InputDataFields.source_id - string tensor containing original
        image id.
      fields.InputDataFields.key - string tensor with unique sha256 hash key.
      fields.InputDataFields.filename - string tensor with original dataset
        filename.
      fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
        [None, 4] containing box corners.
      fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
        [None] containing classes for the boxes.
      fields.InputDataFields.groundtruth_weights - 1D float32 tensor of
        shape [None] indicating the weights of groundtruth boxes.
      fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
        [None] containing containing object mask area in pixel squared.
      fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
        [None] indicating if the boxes enclose a crowd.

    Optional:
      fields.InputDataFields.groundtruth_image_confidences - 1D float tensor of
        shape [None] indicating if a class is present in the image (1.0) or
        a class is not present in the image (0.0).
      fields.InputDataFields.image_additional_channels - 3D uint8 tensor of
        shape [None, None, num_additional_channels]. 1st dim is height; 2nd dim
        is width; 3rd dim is the number of additional channels.
      fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
        [None] indicating if the boxes represent `difficult` instances.
      fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
        [None] indicating if the boxes represent `group_of` instances.
      fields.InputDataFields.groundtruth_keypoints - 3D float32 tensor of
        shape [None, num_keypoints, 2] containing keypoints, where the
        coordinates of the keypoints are ordered (y, x).
      fields.InputDataFields.groundtruth_keypoint_visibilities - 2D bool
        tensor of shape [None, num_keypoints] containing keypoint visibilites.
      fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
        shape [None, None, None] containing instance masks.
      fields.InputDataFields.groundtruth_instance_mask_weights - 1D float32
        tensor of shape [None] containing weights. These are typically values
        in {0.0, 1.0} which indicate whether to consider the mask related to an
        object.
      fields.InputDataFields.groundtruth_image_classes - 1D int64 of shape
        [None] containing classes for the boxes.
      fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape
        [None * num_classes] containing flattened multiclass scores for
        groundtruth boxes.
      fields.InputDataFields.context_features - 1D float32 tensor of shape
        [context_feature_length * num_context_features]
      fields.InputDataFields.context_feature_length - int32 tensor specifying
        the length of each feature in context_features
    """
    serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
    decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
                                                    self.items_to_handlers)
    keys = decoder.list_items()
    tensors = decoder.decode(serialized_example, items=keys)
    tensor_dict = dict(zip(keys, tensors))
    is_crowd = fields.InputDataFields.groundtruth_is_crowd
    tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool)
    tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])
    tensor_dict[fields.InputDataFields.original_image_spatial_shape] = tf.shape(
        tensor_dict[fields.InputDataFields.image])[:2]

    if fields.InputDataFields.image_additional_channels in tensor_dict:
      channels = tensor_dict[fields.InputDataFields.image_additional_channels]
      channels = tf.squeeze(channels, axis=3)
      channels = tf.transpose(channels, perm=[1, 2, 0])
      tensor_dict[fields.InputDataFields.image_additional_channels] = channels

    def default_groundtruth_weights():
      return tf.ones(
          [tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]],
          dtype=tf.float32)

    tensor_dict[fields.InputDataFields.groundtruth_weights] = tf.cond(
        tf.greater(
            tf.shape(
                tensor_dict[fields.InputDataFields.groundtruth_weights])[0],
            0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
        default_groundtruth_weights)

    if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
      gt_instance_masks = tensor_dict[
          fields.InputDataFields.groundtruth_instance_masks]
      num_gt_instance_masks = tf.shape(gt_instance_masks)[0]
      gt_instance_mask_weights = tensor_dict[
          fields.InputDataFields.groundtruth_instance_mask_weights]
      num_gt_instance_mask_weights = tf.shape(gt_instance_mask_weights)[0]
      def default_groundtruth_instance_mask_weights():
        return tf.ones([num_gt_instance_masks], dtype=tf.float32)

      tensor_dict[fields.InputDataFields.groundtruth_instance_mask_weights] = (
          tf.cond(tf.greater(num_gt_instance_mask_weights, 0),
                  lambda: gt_instance_mask_weights,
                  default_groundtruth_instance_mask_weights))

    if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
      # Set all keypoints that are not labeled to NaN.
      gt_kpt_fld = fields.InputDataFields.groundtruth_keypoints
      gt_kpt_vis_fld = fields.InputDataFields.groundtruth_keypoint_visibilities
      visibilities_tiled = tf.tile(
          tf.expand_dims(tensor_dict[gt_kpt_vis_fld], -1),
          [1, 1, 2])
      tensor_dict[gt_kpt_fld] = tf.where(
          visibilities_tiled,
          tensor_dict[gt_kpt_fld],
          np.nan * tf.ones_like(tensor_dict[gt_kpt_fld]))

    if self._expand_hierarchy_labels:
      input_fields = fields.InputDataFields
      image_classes, image_confidences = self._expand_image_label_hierarchy(
          tensor_dict[input_fields.groundtruth_image_classes],
          tensor_dict[input_fields.groundtruth_image_confidences])
      tensor_dict[input_fields.groundtruth_image_classes] = image_classes
      tensor_dict[input_fields.groundtruth_image_confidences] = (
          image_confidences)

      box_fields = [
          fields.InputDataFields.groundtruth_group_of,
          fields.InputDataFields.groundtruth_is_crowd,
          fields.InputDataFields.groundtruth_difficult,
          fields.InputDataFields.groundtruth_area,
          fields.InputDataFields.groundtruth_boxes,
          fields.InputDataFields.groundtruth_weights,
      ]

      def expand_field(field_name):
        return self._expansion_box_field_labels(
            tensor_dict[input_fields.groundtruth_classes],
            tensor_dict[field_name])

      # pylint: disable=cell-var-from-loop
      for field in box_fields:
        if field in tensor_dict:
          tensor_dict[field] = tf.cond(
              tf.size(tensor_dict[field]) > 0, lambda: expand_field(field),
              lambda: tensor_dict[field])
      # pylint: enable=cell-var-from-loop

      tensor_dict[input_fields.groundtruth_classes] = (
          self._expansion_box_field_labels(
              tensor_dict[input_fields.groundtruth_classes],
              tensor_dict[input_fields.groundtruth_classes], True))

    if fields.InputDataFields.groundtruth_group_of in tensor_dict:
      group_of = fields.InputDataFields.groundtruth_group_of
      tensor_dict[group_of] = tf.cast(tensor_dict[group_of], dtype=tf.bool)

    if fields.InputDataFields.groundtruth_dp_num_points in tensor_dict:
      tensor_dict[fields.InputDataFields.groundtruth_dp_num_points] = tf.cast(
          tensor_dict[fields.InputDataFields.groundtruth_dp_num_points],
          dtype=tf.int32)
      tensor_dict[fields.InputDataFields.groundtruth_dp_part_ids] = tf.cast(
          tensor_dict[fields.InputDataFields.groundtruth_dp_part_ids],
          dtype=tf.int32)

    if fields.InputDataFields.groundtruth_track_ids in tensor_dict:
      tensor_dict[fields.InputDataFields.groundtruth_track_ids] = tf.cast(
          tensor_dict[fields.InputDataFields.groundtruth_track_ids],
          dtype=tf.int32)

    return tensor_dict
예제 #27
0
def model_fn(features, labels, mode, params):
    """Model function."""
    del labels

    # ==============================
    # Input features
    # ==============================
    # [batch_size, query_seq_len]
    query_inputs = features["query_inputs"]

    # [batch_size, num_candidates, candidate_seq_len]
    candidate_inputs = features["candidate_inputs"]

    # [batch_size, num_candidates, query_seq_len + candidate_seq_len]
    joint_inputs = features["joint_inputs"]

    # [batch_size, num_masks]
    mlm_targets = features["mlm_targets"]
    mlm_positions = features["mlm_positions"]
    mlm_mask = features["mlm_mask"]

    # ==============================
    # Create modules.
    # ==============================
    bert_module = hub.Module(
        spec=params["bert_hub_module_handle"],
        name="bert",
        tags={"train"} if mode == tf_estimator.ModeKeys.TRAIN else {},
        trainable=True)
    hub.register_module_for_export(bert_module, "bert")

    embedder_module = hub.Module(
        spec=params["embedder_hub_module_handle"],
        name="embedder",
        tags={"train"} if mode == tf_estimator.ModeKeys.TRAIN else {},
        trainable=True)
    hub.register_module_for_export(embedder_module, "embedder")

    if params["share_embedders"]:
        query_embedder_module = embedder_module
    else:
        query_embedder_module = hub.Module(
            spec=params["embedder_hub_module_handle"],
            name="embedder",
            tags={"train"} if mode == tf_estimator.ModeKeys.TRAIN else {},
            trainable=True)
        hub.register_module_for_export(embedder_module, "query_embedder")

    # ==============================
    # Retrieve.
    # ==============================
    # [batch_size, projected_size]
    query_emb = query_embedder_module(inputs=dict(
        input_ids=query_inputs.token_ids,
        input_mask=query_inputs.mask,
        segment_ids=query_inputs.segment_ids),
                                      signature="projected")

    # [batch_size * num_candidates, candidate_seq_len]
    flat_candidate_inputs, unflatten = flatten_bert_inputs(candidate_inputs)

    # [batch_size * num_candidates, projected_size]
    flat_candidate_emb = embedder_module(inputs=dict(
        input_ids=flat_candidate_inputs.token_ids,
        input_mask=flat_candidate_inputs.mask,
        segment_ids=flat_candidate_inputs.segment_ids),
                                         signature="projected")

    # [batch_size, num_candidates, projected_size]
    unflattened_candidate_emb = unflatten(flat_candidate_emb)

    # [batch_size, num_candidates]
    retrieval_score = tf.einsum("BD,BND->BN", query_emb,
                                unflattened_candidate_emb)

    # ==============================
    # Read.
    # ==============================
    # [batch_size * num_candidates, query_seq_len + candidate_seq_len]
    flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs)

    # [batch_size * num_candidates, num_masks]
    flat_mlm_positions, _ = tensor_utils.flatten(
        tf.tile(tf.expand_dims(mlm_positions, 1),
                [1, params["num_candidates"], 1]))

    batch_size, num_masks = tensor_utils.shape(mlm_targets)

    # [batch_size * num_candidates, query_seq_len + candidates_seq_len]
    flat_joint_bert_outputs = bert_module(inputs=dict(
        input_ids=flat_joint_inputs.token_ids,
        input_mask=flat_joint_inputs.mask,
        segment_ids=flat_joint_inputs.segment_ids,
        mlm_positions=flat_mlm_positions),
                                          signature="mlm",
                                          as_dict=True)

    # [batch_size, num_candidates]
    candidate_score = retrieval_score

    # [batch_size, num_candidates]
    candidate_log_probs = tf.math.log_softmax(candidate_score)

    # ==============================
    # Compute marginal log-likelihood.
    # ==============================
    # [batch_size * num_candidates, num_masks]
    flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"]

    # [batch_size, num_candidates, num_masks, vocab_size]
    mlm_logits = tf.reshape(
        flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1])
    mlm_log_probs = tf.math.log_softmax(mlm_logits)

    # [batch_size, num_candidates, num_masks]
    tiled_mlm_targets = tf.tile(tf.expand_dims(mlm_targets, 1),
                                [1, params["num_candidates"], 1])

    # [batch_size, num_candidates, num_masks, 1]
    tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1)

    # [batch_size, num_candidates, num_masks, 1]
    gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets)

    # [batch_size, num_candidates, num_masks]
    gold_log_probs = tf.squeeze(gold_log_probs, -1)

    # [batch_size, num_candidates, num_masks]
    joint_gold_log_probs = (tf.expand_dims(candidate_log_probs, -1) +
                            gold_log_probs)

    # [batch_size, num_masks]
    marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1)

    # [batch_size, num_masks]
    float_mlm_mask = tf.cast(mlm_mask, tf.float32)

    # []
    loss = -tf.div_no_nan(
        tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask),
        tf.reduce_sum(float_mlm_mask))

    # ==============================
    # Optimization
    # ==============================
    num_warmup_steps = min(10000, max(100,
                                      int(params["num_train_steps"] / 10)))
    train_op = optimization.create_optimizer(
        loss=loss,
        init_lr=params["learning_rate"],
        num_train_steps=params["num_train_steps"],
        num_warmup_steps=num_warmup_steps,
        use_tpu=params["use_tpu"])

    # ==============================
    # Evaluation
    # ==============================
    eval_metric_ops = None if params["use_tpu"] else dict()
    if mode != tf_estimator.ModeKeys.PREDICT:
        # [batch_size, num_masks]
        retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0]
        retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32)

        # []
        retrieval_utility = tf.div_no_nan(tf.reduce_sum(retrieval_utility),
                                          tf.reduce_sum(float_mlm_mask))
        add_mean_metric("retrieval_utility", retrieval_utility,
                        eval_metric_ops)

        has_timestamp = tf.cast(tf.greater(features["export_timestamp"], 0),
                                tf.float64)
        off_policy_delay_secs = (
            tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64))
        off_policy_delay_mins = off_policy_delay_secs / 60.0
        off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64)

        add_mean_metric("off_policy_delay_mins", off_policy_delay_mins,
                        eval_metric_ops)

    # Create empty predictions to avoid errors when running in prediction mode.
    predictions = dict()

    if params["use_tpu"]:
        return tf_estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                 loss=loss,
                                                 train_op=train_op,
                                                 predictions=predictions)
    else:
        if eval_metric_ops is not None:
            # Make sure the eval metrics are updated during training so that we get
            # quick feedback from tensorboard summaries when debugging locally.
            with tf.control_dependencies(
                [u for _, u in eval_metric_ops.values()]):
                loss = tf.identity(loss)
        return tf_estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops,
                                          predictions=predictions)
    def testCreate_Dual_Approx(self):
        num_layers = 3
        batch_size = 2
        action_max = 1.0
        action_tensor_center = tf.tile(
            tf.convert_to_tensor(
                np.array([1.0, 2.0, 3.0,
                          4.0]).reshape([1, 4]).astype(np.float32)), [2, 1])
        W_T_list = [
            tf.convert_to_tensor(
                np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
                          [10.0, 11.0, 12.0]]).astype(np.float32)),
            tf.convert_to_tensor(
                np.array([[-1.0], [-3.0], [-5.0]]).astype(np.float32))
        ]
        b_T_list = [
            tf.tile(
                tf.convert_to_tensor(
                    np.array([1.0, -0.5,
                              0.1]).reshape([1, 3]).astype(np.float32)),
                [2, 1]),
            tf.tile(
                tf.convert_to_tensor(
                    np.array([2.0]).reshape([1, 1]).astype(np.float32)),
                [2, 1])
        ]
        (neg_J_tilde, l_list, u_list, D_list, Nu_list, gamma_list, psi, l_ip1,
         u_ip1,
         Nu_hat_1) = dual_method.create_dual_approx(num_layers,
                                                    batch_size,
                                                    action_max,
                                                    W_T_list,
                                                    b_T_list,
                                                    action_tensor_center,
                                                    return_full_info=True)

        self.assertIsInstance(neg_J_tilde, tf.Tensor)
        self.assertEqual((2, 1), neg_J_tilde.shape)

        self.assertIsInstance(l_list, list)
        self.assertEqual(num_layers - 1, len(l_list))
        for itr, ele in enumerate(l_list):
            self.assertIsInstance(ele, tf.Tensor)
            if itr == 0:
                self.assertEqual((2, 4), ele.shape)
            elif itr == 1:
                self.assertEqual((2, 3), ele.shape)

        self.assertIsInstance(u_list, list)
        self.assertEqual(num_layers - 1, len(u_list))
        for itr, ele in enumerate(u_list):
            self.assertIsInstance(ele, tf.Tensor)
            if itr == 0:
                self.assertEqual((2, 4), ele.shape)
            elif itr == 1:
                self.assertEqual((2, 3), ele.shape)

        self.assertIsInstance(D_list, list)
        self.assertEqual(num_layers - 1, len(D_list))
        for itr, ele in enumerate(D_list):
            self.assertIsInstance(ele, tf.Tensor)
            if itr == 0:
                self.assertEqual((2, 4), ele.shape)
            elif itr == 1:
                self.assertEqual((2, 3), ele.shape)

        self.assertIsInstance(Nu_list, list)
        self.assertEqual(num_layers - 1, len(Nu_list))
        for itr, ele in enumerate(Nu_list):
            self.assertIsInstance(ele, tf.Tensor)
            if itr == 0:
                self.assertEqual((2, 3, 1), ele.shape)
            elif itr == 1:
                self.assertEqual((2, 3, 1), ele.shape)

        self.assertIsInstance(gamma_list, list)
        self.assertEqual(num_layers - 1, len(gamma_list))
        for ele in gamma_list:
            self.assertIsInstance(ele, tf.Tensor)
            self.assertEqual((2, 1), ele.shape)

        self.assertIsInstance(psi, tf.Tensor)
        self.assertEqual((2, 1), psi.shape)
        self.assertIsInstance(l_ip1, tf.Tensor)
        self.assertEqual((2, 1), l_ip1.shape)
        self.assertIsInstance(u_ip1, tf.Tensor)
        self.assertEqual((2, 1), u_ip1.shape)
        self.assertIsInstance(Nu_hat_1, tf.Tensor)
        self.assertEqual((2, 4, 1), Nu_hat_1.shape)

        neg_J_tilde_np = self.sess.run(neg_J_tilde)
        l_list_np = self.sess.run(l_list)
        u_list_np = self.sess.run(u_list)
        D_list_np = self.sess.run(D_list)
        Nu_list_np = self.sess.run(Nu_list)
        gamma_list_np = self.sess.run(gamma_list)
        psi_np = self.sess.run(psi)
        l_ip1_np = self.sess.run(l_ip1)
        u_ip1_np = self.sess.run(u_ip1)
        Nu_hat_1_np = self.sess.run(Nu_hat_1)

        self.assertArrayNear(np.array([[-508.], [-508.]]).flatten(),
                             neg_J_tilde_np.flatten(),
                             err=1e-4)
        for itr, ele in enumerate(l_list_np):
            if itr == 0:
                print(ele)
                self.assertArrayNear(np.array([[0., 0., 0., 0.],
                                               [0., 0., 0., 0.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[49., 53.5, 60.1],
                                               [49., 53.5, 60.1]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
        for itr, ele in enumerate(u_list_np):
            if itr == 0:
                self.assertArrayNear(np.array([[0., 0., 0., 0.],
                                               [0., 0., 0., 0.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[93., 105.5, 120.1],
                                               [93., 105.5, 120.1]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
        for itr, ele in enumerate(D_list_np):
            if itr == 0:
                self.assertArrayNear(np.array([[0., 0., 0., 0.],
                                               [0., 0., 0., 0.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[1., 1., 1.], [1., 1.,
                                                              1.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
        for itr, ele in enumerate(Nu_list_np):
            if itr == 0:
                self.assertArrayNear(np.array([[[0.], [0.], [0.]],
                                               [[0.], [0.], [0.]]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[[-1.], [-3.], [-5.]],
                                               [[-1.], [-3.],
                                                [-5.]]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)

        for itr, ele in enumerate(gamma_list_np):
            if itr == 0:
                self.assertArrayNear(np.array([[0.], [0.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)
            elif itr == 1:
                self.assertArrayNear(np.array([[2.], [2.]]).flatten(),
                                     ele.flatten(),
                                     err=1e-4)

        self.assertArrayNear(np.array([[-758.], [-758.]]).flatten(),
                             psi_np.flatten(),
                             err=1e-4)
        self.assertArrayNear(np.array([[0.], [0.]]).flatten(),
                             l_ip1_np.flatten(),
                             err=1e-4)
        self.assertArrayNear(np.array([[-0.], [-0.]]).flatten(),
                             u_ip1_np.flatten(),
                             err=1e-4)
        self.assertArrayNear(np.array([[[-22.], [-49.], [-76.], [-103.]],
                                       [[-22.], [-49.], [-76.],
                                        [-103.]]]).flatten(),
                             Nu_hat_1_np.flatten(),
                             err=1e-4)
예제 #29
0
def rnnt_loss(log_probs,
              labels,
              input_lengths=None,
              label_lengths=None,
              blank_index=0,
              debug=False,
              with_alignment=False):
    """
  Computes the batched forward pass of the RNN-T model.
  B: batch, T: time, U:target/labels, V: vocabulary

  :param tf.Tensor log_probs: (B, T, U+1, V) log-probabilities
  :param tf.Tensor labels: (B, U) -> [V] labels
  :param tf.Tensor input_lengths: (B,) length of input frames
  :param tf.Tensor label_lengths: (B,) length of labels
  :param int blank_index: index of the blank symbol in the vocabulary
  :param bool debug: enable verbose logging
  :param bool with_alignment: whether to generate the alignments or not.
  :return:
  with_alignment=True -> (costs, alignments)
                =False -> costs
  """
    """Pure TF implementation of the RNN-T loss."""
    shape = tf.shape(log_probs)
    n_batch = shape[0]  # B
    max_time = shape[1]  # T
    max_target = shape[2]  # U

    log_probs_tr = tf.transpose(log_probs,
                                [0, 2, 1, 3])  # (B, T, U, V) -> (B, U, T, V)
    log_probs_shifted = tf_shift_logprobs(log_probs_tr,
                                          axis=1)  # (B, U+T+1, U, V)

    num_diagonals = max_time + max_target

    labels = py_print_iteration_info("labels", labels, 0, debug=debug)

    log_probs_ta = tf.TensorArray(
        dtype=tf.float32,
        clear_after_read=False,
        size=num_diagonals,
        dynamic_size=False,
        infer_shape=False,
        element_shape=(None, None, None),  # (B, U, V)
        name="log_probs_shifted",
    )
    # (B, U+T+1, U, V) -> [(B, U, V)] * (U+T+1)
    log_probs_ta = log_probs_ta.unstack(
        tf.transpose(log_probs_shifted, [2, 0, 1, 3]))

    init_alpha_ta = tf.TensorArray(
        dtype=tf.float32,
        clear_after_read=False,
        size=num_diagonals,
        dynamic_size=False,
        infer_shape=False,
        element_shape=(
            None,
            None,
        ),  # (B, n)
        name="alpha_diagonals",
    )
    init_alpha_ta = init_alpha_ta.write(1, tf.zeros((n_batch, 1)))

    if with_alignment:
        # backtrack matrix, for each diagonal and target -> [from_u, symbol/blank]
        init_backtrack_ta = tf.TensorArray(dtype=tf.int32,
                                           size=num_diagonals - 1,
                                           dynamic_size=False,
                                           infer_shape=False,
                                           element_shape=(None, None, 2),
                                           name="backtrack")
    else:
        init_backtrack_ta = None

    def cond(n, *_):
        """We run the loop until all elements are covered by diagonals.
    """
        return tf.less(n, num_diagonals)

    def body_forward(n, alpha_ta, *args):
        """body of the while_loop, loops over the diagonals of the alpha-tensor."""
        # alpha(t-1,u) + logprobs(t-1, u)
        # alpha_blank      + lp_blank

        lp_diagonal = log_probs_ta.read(n - 2)[:, :n - 1, :]  # (B, U|n, V)
        lp_diagonal = py_print_iteration_info("lp_diagonal",
                                              lp_diagonal,
                                              n,
                                              debug=debug)

        prev_diagonal = alpha_ta.read(n - 1)[:, :n]  # (B, n-1)
        prev_diagonal = py_print_iteration_info("prev_diagonal",
                                                prev_diagonal,
                                                n,
                                                debug=debug)

        alpha_blank = prev_diagonal  # (B, N)
        alpha_blank = tf.concat(
            [alpha_blank,
             tf.tile([[tf.constant(NEG_INF)]], [n_batch, 1])],
            axis=1)
        alpha_blank = py_print_iteration_info("alpha(blank)",
                                              alpha_blank,
                                              n,
                                              debug=debug)

        # (B, U, V) -> (B, U)
        lp_blank = lp_diagonal[:, :, blank_index]  # (B, U)
        lp_blank = tf.concat(
            [lp_blank,
             tf.tile([[tf.constant(NEG_INF)]], [n_batch, 1])],
            axis=1)
        lp_blank = py_print_iteration_info("lp(blank)",
                                           lp_blank,
                                           n,
                                           debug=debug)

        # (B,N-1) ; (B,1) ->  (B, N)
        alpha_y = prev_diagonal
        alpha_y = tf.concat(
            [tf.tile([[tf.constant(NEG_INF)]], [n_batch, 1]), alpha_y], axis=1)
        alpha_y = py_print_iteration_info("alpha(y)", alpha_y, n, debug=debug)

        labels_max_len = tf.minimum(max_target - 1, n - 1)
        labels_shifted = labels[:, :labels_max_len]  # (B, U-1|n-1)
        labels_shifted = py_print_iteration_info("labels_shifted",
                                                 labels_shifted,
                                                 n,
                                                 debug=debug)
        batchs, rows = tf.meshgrid(tf.range(n_batch),
                                   tf.range(labels_max_len),
                                   indexing='ij')
        lp_y_indices = tf.stack([batchs, rows, labels_shifted],
                                axis=-1)  # (B, U-1|n-1, 3)
        lp_y_indices = py_print_iteration_info("lp_y_indices",
                                               lp_y_indices,
                                               n,
                                               debug=debug)
        lp_y = tf.gather_nd(lp_diagonal[:, :, :], lp_y_indices)  # (B, U)
        # (B, U) ; (B, 1) -> (B, U+1)
        lp_y = tf.concat(
            [tf.tile([[tf.constant(NEG_INF)]], [n_batch, 1]), lp_y], axis=1)
        lp_y = py_print_iteration_info("lp(y)", lp_y, n, debug=debug)

        cut_off = max_target
        alpha_y = tf.cond(tf.greater(n, max_target),
                          lambda: alpha_y[:, :cut_off], lambda: alpha_y)
        lp_blank = tf.cond(tf.greater(n, max_target),
                           lambda: lp_blank[:, :cut_off], lambda: lp_blank)
        alpha_blank = tf.cond(tf.greater(n, max_target),
                              lambda: alpha_blank[:, :cut_off],
                              lambda: alpha_blank)

        # all should have shape (B, n)
        blank = alpha_blank + lp_blank
        y = alpha_y + lp_y
        red_op = tf.stack([blank, y], axis=0)  # (2, B, N)
        red_op = py_print_iteration_info("red-op", red_op, n, debug=debug)
        new_diagonal = tf.math.reduce_logsumexp(red_op, axis=0)  # (B, N)

        new_diagonal = new_diagonal[:, :n]
        new_diagonal = py_print_iteration_info("new_diagonal",
                                               new_diagonal,
                                               n,
                                               debug=debug)

        if with_alignment:
            backtrack_ta = args[0]
            argmax_idx = tf.argmax([blank, y], axis=0)
            max_len_diag = tf.minimum(n, max_target)
            u_ranged = tf.tile(tf.range(max_len_diag)[None],
                               [n_batch, 1])  # (B, n|U)
            blank_tiled = tf.tile([[blank_index]], [n_batch, 1])

            stack_blank_sel = tf.stack(
                [u_ranged, tf.tile(blank_tiled, [1, max_len_diag])], axis=-1)

            b, r = tf.meshgrid(tf.range(n_batch),
                               tf.maximum(0,
                                          tf.range(max_len_diag) - 1),
                               indexing='ij')
            labels_indices = tf.stack([b, r], axis=-1)
            labels_emit_sel = tf.gather_nd(
                labels, labels_indices)  # (B, n)  labels[u-1]
            stack_emit_sel = tf.stack([u_ranged - 1, labels_emit_sel], axis=-1)
            best_sel = tf.where(
                tf.tile(tf.equal(argmax_idx, 0)[..., None], [1, 1, 2]),
                stack_blank_sel,  # blank
                stack_emit_sel  # emit
            )
            backtrack_ta = backtrack_ta.write(n - 1, best_sel)
        else:
            backtrack_ta = None
        return [n + 1, alpha_ta.write(n, new_diagonal)
                ] + ([backtrack_ta] if with_alignment else [])

    init_n = tf.constant(2)
    if with_alignment:
        final_n, alpha_out_ta, backtrack_out_ta = tf.while_loop(
            cond,
            body_forward,
            [init_n, init_alpha_ta, init_backtrack_ta],
            parallel_iterations=1,  # iterative computation
            name="rnnt")
    else:
        final_n, alpha_out_ta = tf.while_loop(
            cond,
            body_forward,
            [init_n, init_alpha_ta],
            parallel_iterations=1,  # iterative computation
            name="rnnt")
        backtrack_out_ta = None

    # p(y|x) = alpha(T,U) * blank(T,U)  (--> in log-space)

    # (B,): batch index -> diagonal index
    diag_idxs = input_lengths + label_lengths  # (B,)

    # (B,): batch index -> index within diagonal
    within_diag_idx = label_lengths

    res_ta = tf.TensorArray(
        dtype=tf.float32,
        clear_after_read=True,
        size=n_batch,
        dynamic_size=False,
        infer_shape=False,
        element_shape=(),
        name="alpha_diagonals",
    )

    def ta_read_body(i, res_loop_ta):
        """Reads from the alpha-diagonals TensorArray. We need this because of the inconsistent shapes in the TA."""
        ta_item = alpha_out_ta.read(diag_idxs[i])[i]
        return i + 1, res_loop_ta.write(i, ta_item[within_diag_idx[i]])

    final_i, a_ta = tf.while_loop(lambda i, _: tf.less(i, n_batch),
                                  ta_read_body,
                                  (tf.constant(0, tf.int32), res_ta))
    indices = tf.stack(
        [
            tf.range(n_batch),
            input_lengths - 1,  # noqa T-1
            label_lengths,  # U-1
            tf.tile([blank_index], [n_batch]),
        ],
        axis=-1)  # (B, 3)
    ll_tf = a_ta.stack() + tf.gather_nd(log_probs, indices)

    if with_alignment:
        assert backtrack_out_ta is not None
        alignments = backtrack_alignment_tf(backtrack_out_ta, input_lengths,
                                            label_lengths, blank_index)
        return ll_tf, alignments
    else:
        return ll_tf
예제 #30
0
def neural_voxel_renderer_plus(voxels,
                               rerendering,
                               light_pos,
                               size=4,
                               norm2d='batchnorm',
                               norm3d='batchnorm'):
  """Neural Voxel Renderer + keras model."""
  with tf.name_scope('Network/'):

    voxels = layers.Input(tensor=voxels)
    rerendering = layers.Input(tensor=rerendering)
    light_pos = layers.Input(tensor=light_pos)

    nf_2d = 512

    with tf.name_scope('VoxelProcessing'):
      vol0_a = layer_utils.conv_block_3d(voxels,
                                         nfilters=16,
                                         size=size,
                                         strides=2,
                                         normalization=norm3d)  # 64x64x64x16
      vol0_b = layer_utils.conv_block_3d(vol0_a,
                                         nfilters=16,
                                         size=size,
                                         strides=1,
                                         normalization=norm3d)  # 64x64x64x16
      vol1_a = layer_utils.conv_block_3d(vol0_b,
                                         nfilters=16,
                                         size=size,
                                         strides=2,
                                         normalization=norm3d)  # 32x32x32x16
      vol1_b = layer_utils.conv_block_3d(vol1_a,
                                         nfilters=32,
                                         size=size,
                                         strides=1,
                                         normalization=norm3d)  # 32x32x32x32
      vol1_c = layer_utils.conv_block_3d(vol1_b,
                                         nfilters=32,
                                         size=size,
                                         strides=1,
                                         normalization=norm3d)  # 32x32x32x32
      shortcut = vol1_c
      vol_a1 = layer_utils.residual_block_3d(vol1_c,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      vol_a2 = layer_utils.residual_block_3d(vol_a1,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      vol_a3 = layer_utils.residual_block_3d(vol_a2,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      vol_a4 = layer_utils.residual_block_3d(vol_a3,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      vol_a5 = layer_utils.residual_block_3d(vol_a4,
                                             32,
                                             strides=(1, 1, 1),
                                             normalization=norm3d)  # 32x
      encoded_vol = layers.add([shortcut, vol_a5])
      encoded_vol = layers.Reshape([32, 32, 32*32])(encoded_vol)
      encoded_vol = layers.Conv2D(nf_2d,
                                  kernel_size=1,
                                  strides=(1, 1),
                                  padding='same',
                                  kernel_initializer=initializer)(encoded_vol)
      latent_projection = layers.LeakyReLU()(encoded_vol)  # 32x32x512

    with tf.name_scope('ProjectionProcessing'):
      shortcut = latent_projection  # 32x32xnf_2d
      e1 = layer_utils.residual_block_2d(latent_projection,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      e2 = layer_utils.residual_block_2d(e1,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      e3 = layer_utils.residual_block_2d(e2,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      e4 = layer_utils.residual_block_2d(e3,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      e5 = layer_utils.residual_block_2d(e4,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      encoded_proj = layers.add([shortcut, e5])  # 32x32xnf_2d

    with tf.name_scope('LightProcessing'):
      fc_light = layers.Dense(64, kernel_initializer=initializer)(light_pos)
      light_code = layers.Dense(64, kernel_initializer=initializer)(fc_light)
      light_code = \
        layers.Lambda(lambda v: tf.tile(v[0], [1, 32*32]))([light_code])
      light_code = layers.Reshape((32, 32, 64))(light_code)  # 32x32x64

    with tf.name_scope('Merger'):
      latent_code_final = layers.concatenate([encoded_proj, light_code])
      latent_code_final = layer_utils.conv_block_2d(latent_code_final,
                                                    nfilters=nf_2d,
                                                    size=size,
                                                    strides=1,
                                                    normalization=norm3d)
      shortcut = latent_code_final
      m1 = layer_utils.residual_block_2d(latent_code_final,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      m2 = layer_utils.residual_block_2d(m1,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      m3 = layer_utils.residual_block_2d(m2,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      m4 = layer_utils.residual_block_2d(m3,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d
      m5 = layer_utils.residual_block_2d(m4,
                                         nfilters=nf_2d,
                                         strides=(1, 1),
                                         normalization=norm2d)  # 32x32xnf_2d

      latent_code_final2 = layers.add([shortcut, m5])  # 32x32xnf_2d

    with tf.name_scope('Decoder'):
      d7 = layer_utils.conv_t_block_2d(latent_code_final2,
                                       nfilters=128,
                                       size=size,
                                       strides=2,
                                       normalization=norm2d)  # 64x64x128
      d7 = layer_utils.conv_block_2d(d7,
                                     nfilters=128,
                                     size=size,
                                     strides=1,
                                     normalization=norm2d)  # 64x64x128
      d8 = layer_utils.conv_t_block_2d(d7,
                                       nfilters=64,
                                       size=size,
                                       strides=2,
                                       normalization=norm2d)  # 128x128x64
      d8 = layer_utils.conv_block_2d(d8,
                                     nfilters=64,
                                     size=size,
                                     strides=1,
                                     normalization=norm2d)  # 128x128x64
      d9 = layer_utils.conv_t_block_2d(d8,
                                       nfilters=32,
                                       size=size,
                                       strides=2,
                                       normalization=norm2d)  # 256x256x32
      d9 = layer_utils.conv_block_2d(d9,
                                     nfilters=32,
                                     size=size,
                                     strides=1,
                                     normalization=norm2d)  # 256x256x32
      rendered_image = layers.Conv2D(32,
                                     size,
                                     strides=1,
                                     padding='same',
                                     kernel_initializer=initializer,
                                     use_bias=False)(d9)  # 256x256x3

    with tf.name_scope('ImageProcessingNetwork'):
      ec1 = layer_utils.conv_block_2d(rerendering,
                                      nfilters=32,
                                      size=size,
                                      strides=1,
                                      normalization=norm2d)  # 256x
      ec2 = layer_utils.conv_block_2d(ec1,
                                      nfilters=32,
                                      size=size,
                                      strides=1,
                                      normalization=norm2d)  # 256x

    with tf.name_scope('NeuralRerenderingNetwork'):
      latent_img = layers.add([rendered_image, ec2])
      target_code = unet_3x_with_res_in_mid(latent_img, 32, norm2d=norm2d)
      out0 = layer_utils.conv_block_2d(target_code,
                                       nfilters=32,
                                       size=size,
                                       strides=1,
                                       normalization=norm2d)  # 256x
      predicted_image = layers.Conv2D(3,
                                      size,
                                      strides=1,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False)(out0)  # 256x256x3

    return tf.keras.Model(inputs=[voxels, rerendering, light_pos],
                          outputs=[predicted_image])