Exemple #1
0
def random_pad_to_aspect_ratio(image,
                               aspect_ratio=1.0,
                               min_padded_size_ratio=(1.0, 1.0),
                               max_padded_size_ratio=(2.0, 2.0),
                               seed=None):
    """Randomly zero pads an image to the specified aspect ratio.

  Pads the image so that the resulting image will have the specified aspect
  ratio without scaling less than the min_padded_size_ratio or more than the
  max_padded_size_ratio. If the min_padded_size_ratio or max_padded_size_ratio
  is lower than what is possible to maintain the aspect ratio, then this method
  will use the least padding to achieve the specified aspect ratio.

  Args:
    image: rank 3 float32 tensor contains 1 image -> [height, width, channels]
           with pixel values varying between [0, 1].
    aspect_ratio: aspect ratio of the final image.
    min_padded_size_ratio: min ratio of padded image height and width to the
                           input image's height and width.
    max_padded_size_ratio: max ratio of padded image height and width to the
                           input image's height and width.
    seed: random seed.

  Returns:
    image: image which is the same rank as input image.

  Raises:
    ValueError: If image is not a 3D tensor.
  """
    if len(image.get_shape()) != 3:
        raise ValueError('Image should be 3D tensor')

    with tf.name_scope('RandomPadToAspectRatio', values=[image]):
        image_shape = tf.shape(image)
        image_height = tf.to_float(image_shape[0])
        image_width = tf.to_float(image_shape[1])
        image_aspect_ratio = image_width / image_height
        new_aspect_ratio = tf.constant(aspect_ratio, dtype=tf.float32)
        target_height = tf.cond(image_aspect_ratio <= new_aspect_ratio,
                                lambda: image_height,
                                lambda: image_width / new_aspect_ratio)
        target_width = tf.cond(image_aspect_ratio >= new_aspect_ratio,
                               lambda: image_width,
                               lambda: image_height * new_aspect_ratio)

        min_height = tf.maximum(min_padded_size_ratio[0] * image_height,
                                target_height)
        min_width = tf.maximum(min_padded_size_ratio[1] * image_width,
                               target_width)
        max_height = tf.maximum(max_padded_size_ratio[0] * image_height,
                                target_height)
        max_width = tf.maximum(max_padded_size_ratio[1] * image_width,
                               target_width)

        max_scale = tf.minimum(max_height / target_height,
                               max_width / target_width)
        min_scale = tf.minimum(
            max_scale,
            tf.maximum(min_height / target_height, min_width / target_width))

        scale = tf.random_uniform([], min_scale, max_scale, seed=seed)

        target_height = tf.round(scale * target_height)
        target_width = tf.round(scale * target_width)

        new_image = tf.image.pad_to_bounding_box(image, 0, 0,
                                                 tf.to_int32(target_height),
                                                 tf.to_int32(target_width))

        return new_image
 def scale_values(im):
     scale = 255.0 / (hi - lo)
     offset = -lo * scale
     im = tf.to_float(im) * scale + offset
     im = tf.clip_by_value(im, 0.0, 255.0)
     return tf.cast(im, tf.uint8)
def build_model_graph(features, labels, is_training, params):
    """Builds the forward model graph."""
    use_batched_nms = (not params['use_tpu'] and params['use_batched_nms'])
    is_gpu_inference = (not is_training and use_batched_nms)
    model_outputs = {}

    if is_training and params['transpose_input']:
        if (params['backbone'].startswith('resnet')
                and params['conv0_space_to_depth_block_size'] > 0):
            features['images'] = tf.transpose(features['images'], [2, 0, 1, 3])
        else:
            features['images'] = tf.transpose(features['images'], [3, 0, 1, 2])

    batch_size, image_height, image_width, _ = (
        features['images'].get_shape().as_list())

    conv0_space_to_depth_block_size = 0
    if (is_training and (params['backbone'].startswith('resnet')
                         and params['conv0_space_to_depth_block_size'] > 0)):
        conv0_space_to_depth_block_size = params[
            'conv0_space_to_depth_block_size']
        image_height *= conv0_space_to_depth_block_size
        image_width *= conv0_space_to_depth_block_size

    if 'source_ids' not in features:
        features['source_ids'] = -1 * tf.ones([batch_size], dtype=tf.float32)

    all_anchors = anchors.Anchors(params['min_level'], params['max_level'],
                                  params['num_scales'],
                                  params['aspect_ratios'],
                                  params['anchor_scale'],
                                  (image_height, image_width))

    if 'resnet' in params['backbone']:
        with tf.variable_scope(params['backbone']):
            resnet_fn = resnet.resnet_v1(
                params['backbone'],
                conv0_kernel_size=params['conv0_kernel_size'],
                conv0_space_to_depth_block_size=conv0_space_to_depth_block_size,
                num_batch_norm_group=params['num_batch_norm_group'])
            backbone_feats = resnet_fn(
                features['images'], (params['is_training_bn'] and is_training))
    elif 'mnasnet' in params['backbone']:
        with tf.variable_scope(params['backbone']):
            _, endpoints = mnasnet_models.build_mnasnet_base(
                features['images'],
                params['backbone'],
                training=(params['is_training_bn'] and is_training),
                override_params={'use_keras': False})

            backbone_feats = {
                2: endpoints['reduction_2'],
                3: endpoints['reduction_3'],
                4: endpoints['reduction_4'],
                5: endpoints['reduction_5'],
            }
    else:
        raise ValueError('Not a valid backbone option: %s' %
                         params['backbone'])

    fpn_feats = fpn.fpn(backbone_feats, params['min_level'],
                        params['max_level'])
    model_outputs.update({
        'fpn_features': fpn_feats,
    })

    rpn_score_outputs, rpn_box_outputs = heads.rpn_head(
        fpn_feats, params['min_level'], params['max_level'],
        len(params['aspect_ratios'] * params['num_scales']))

    if is_training:
        rpn_pre_nms_topn = params['rpn_pre_nms_topn']
        rpn_post_nms_topn = params['rpn_post_nms_topn']
    else:
        rpn_pre_nms_topn = params['test_rpn_pre_nms_topn']
        rpn_post_nms_topn = params['test_rpn_post_nms_topn']

    rpn_box_scores, rpn_box_rois = roi_ops.multilevel_propose_rois(
        rpn_score_outputs,
        rpn_box_outputs,
        all_anchors,
        features['image_info'],
        rpn_pre_nms_topn,
        rpn_post_nms_topn,
        params['rpn_nms_threshold'],
        params['rpn_min_size'],
        bbox_reg_weights=None,
        use_batched_nms=use_batched_nms)
    rpn_box_rois = tf.to_float(rpn_box_rois)
    if is_training:
        rpn_box_rois = tf.stop_gradient(rpn_box_rois)
        rpn_box_scores = tf.stop_gradient(rpn_box_scores)

    if is_training:
        # Sampling
        box_targets, class_targets, rpn_box_rois, proposal_to_label_map = (
            training_ops.proposal_label_op(
                rpn_box_rois,
                labels['gt_boxes'],
                labels['gt_classes'],
                batch_size_per_im=params['batch_size_per_im'],
                fg_fraction=params['fg_fraction'],
                fg_thresh=params['fg_thresh'],
                bg_thresh_hi=params['bg_thresh_hi'],
                bg_thresh_lo=params['bg_thresh_lo']))

    # Performs multi-level RoIAlign.
    box_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
        fpn_feats,
        rpn_box_rois,
        output_size=7,
        is_gpu_inference=is_gpu_inference)

    class_outputs, box_outputs, _ = heads.box_head(
        box_roi_features,
        num_classes=params['num_classes'],
        mlp_head_dim=params['fast_rcnn_mlp_head_dim'])

    if not is_training:
        if is_gpu_inference:
            generate_detections_fn = postprocess_ops.generate_detections_gpu
        else:
            generate_detections_fn = postprocess_ops.generate_detections_tpu

        detections = generate_detections_fn(
            class_outputs, box_outputs, rpn_box_rois, features['image_info'],
            params['test_rpn_post_nms_topn'],
            params['test_detections_per_image'], params['test_nms'],
            params['bbox_reg_weights'])

        model_outputs.update({
            'num_detections': detections[0],
            'detection_boxes': detections[1],
            'detection_classes': detections[2],
            'detection_scores': detections[3],
        })
    else:
        encoded_box_targets = training_ops.encode_box_targets(
            rpn_box_rois, box_targets, class_targets,
            params['bbox_reg_weights'])
        model_outputs.update({
            'rpn_score_outputs': rpn_score_outputs,
            'rpn_box_outputs': rpn_box_outputs,
            'class_outputs': class_outputs,
            'box_outputs': box_outputs,
            'class_targets': class_targets,
            'box_targets': encoded_box_targets,
            'box_rois': rpn_box_rois,
        })

    # Faster-RCNN mode.
    if not params['include_mask']:
        # Print #parameters and #FLOPs in model.
        compute_model_statistics(batch_size, is_training=is_training)

        return model_outputs

    # Mask sampling
    if not is_training:
        selected_box_rois = model_outputs['detection_boxes']
        class_indices = model_outputs['detection_classes']
        # If using GPU for inference, delay the cast until when Gather ops show up
        # since GPU inference supports float point better.
        # TODO(laigd): revisit this when newer versions of GPU libraries is
        # released.
        if not is_gpu_inference:
            class_indices = tf.to_int32(class_indices)
    else:
        (selected_class_targets, selected_box_targets, selected_box_rois,
         proposal_to_label_map) = (training_ops.select_fg_for_masks(
             class_targets,
             box_targets,
             rpn_box_rois,
             proposal_to_label_map,
             max_num_fg=int(params['batch_size_per_im'] *
                            params['fg_fraction'])))
        class_indices = tf.to_int32(selected_class_targets)

    mask_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
        fpn_feats,
        selected_box_rois,
        output_size=14,
        is_gpu_inference=is_gpu_inference)
    mask_outputs = heads.mask_head(mask_roi_features,
                                   class_indices,
                                   num_classes=params['num_classes'],
                                   mrcnn_resolution=params['mrcnn_resolution'],
                                   is_gpu_inference=is_gpu_inference)

    if is_training:
        mask_targets = training_ops.get_mask_targets(
            selected_box_rois, proposal_to_label_map, selected_box_targets,
            labels['cropped_gt_masks'], params['mrcnn_resolution'])
        model_outputs.update({
            'mask_outputs': mask_outputs,
            'mask_targets': mask_targets,
            'selected_class_targets': selected_class_targets,
        })
    else:
        model_outputs.update({
            'detection_masks': tf.nn.sigmoid(mask_outputs),
        })

    if params['num_attributes']:
        attribute_outputs = heads.attributes_head(
            roi_features=mask_roi_features,
            num_attributes=params['num_attributes'],
            mlp_head_dim=params['fast_rcnn_mlp_head_dim'],
        )

        if is_training:
            attribute_targets = tf.gather(
                labels['gt_attributes'], proposal_to_label_map,
                batch_dims=1)  # [batch, K, num_attributes]

            model_outputs.update({
                'attribute_outputs': attribute_outputs,
                'attribute_targets': attribute_targets,
            })
        else:
            model_outputs['detection_attributes'] = tf.nn.sigmoid(
                attribute_outputs)

    # Print #parameters and #FLOPs in model.
    compute_model_statistics(batch_size, is_training=is_training)

    return model_outputs
Exemple #4
0
def _apply_negative_infinity_mask(tensor, mask):
    """Where mask is true, add a large negative value."""
    tensor += tf.to_float(mask) * -INF
    tensor = tf.maximum(tensor, -INF)
    return tensor
Exemple #5
0
    def inject_latent(self, layer, inputs, target, action):
        """Inject a deterministic latent based on the target frame."""
        hparams = self.hparams
        final_filters = common_layers.shape_list(layer)[-1]
        filters = hparams.hidden_size
        kernel = (4, 4)
        layer_shape = common_layers.shape_list(layer)
        activation_fn = common_layers.belu
        if hparams.activation_fn == "relu":
            activation_fn = tf.nn.relu

        def add_bits(layer, bits):
            z_mul = tfl.dense(bits, final_filters, name="unbottleneck_mul")
            if not hparams.complex_addn:
                return layer + z_mul
            layer *= tf.nn.sigmoid(z_mul)
            z_add = tfl.dense(bits, final_filters, name="unbottleneck_add")
            layer += z_add
            return layer

        if not self.is_training:
            if hparams.full_latent_tower:
                rand = tf.random_uniform(layer_shape[:-1] +
                                         [hparams.bottleneck_bits])
                bits = 2.0 * tf.to_float(tf.less(0.5, rand)) - 1.0
            else:
                bits, _ = discretization.predict_bits_with_lstm(
                    layer,
                    hparams.latent_predictor_state_size,
                    hparams.bottleneck_bits,
                    temperature=hparams.latent_predictor_temperature)
                bits = tf.expand_dims(tf.expand_dims(bits, axis=1), axis=2)
            return add_bits(layer, bits), 0.0

        # Embed.
        frames = tf.concat(inputs + [target], axis=-1)
        x = tfl.dense(
            frames,
            filters,
            name="latent_embed",
            bias_initializer=tf.random_normal_initializer(stddev=0.01))
        x = common_attention.add_timing_signal_nd(x)

        # Add embedded action if present.
        if action is not None:
            x = common_video.inject_additional_input(x, action,
                                                     "action_enc_latent",
                                                     hparams.action_injection)

        if hparams.full_latent_tower:
            for i in range(hparams.num_compress_steps):
                with tf.variable_scope("latent_downstride%d" % i):
                    x = common_layers.make_even_size(x)
                    if i < hparams.filter_double_steps:
                        filters *= 2
                    x = common_attention.add_timing_signal_nd(x)
                    x = tfl.conv2d(x,
                                   filters,
                                   kernel,
                                   activation=activation_fn,
                                   strides=(2, 2),
                                   padding="SAME")
                    x = common_layers.layer_norm(x)
        else:
            x = common_layers.double_discriminator(x)
            x = tf.expand_dims(tf.expand_dims(x, axis=1), axis=1)

        bits, bits_clean = discretization.tanh_discrete_bottleneck(
            x, hparams.bottleneck_bits, hparams.bottleneck_noise,
            hparams.discretize_warmup_steps, hparams.mode)
        if not hparams.full_latent_tower:
            _, pred_loss = discretization.predict_bits_with_lstm(
                layer,
                hparams.latent_predictor_state_size,
                hparams.bottleneck_bits,
                target_bits=bits_clean)
            # Mix bits from latent with predicted bits on forward pass as a noise.
            if hparams.latent_rnn_max_sampling > 0.0:
                with tf.variable_scope(tf.get_variable_scope(), reuse=True):
                    bits_pred, _ = discretization.predict_bits_with_lstm(
                        layer,
                        hparams.latent_predictor_state_size,
                        hparams.bottleneck_bits,
                        temperature=hparams.latent_predictor_temperature)
                    bits_pred = tf.expand_dims(tf.expand_dims(bits_pred,
                                                              axis=1),
                                               axis=2)
                # Be bits_pred on the forward pass but bits on the backward one.
                bits_pred = bits_clean + tf.stop_gradient(bits_pred -
                                                          bits_clean)
                # Select which bits to take from pred sampling with bit_p probability.
                which_bit = tf.random_uniform(common_layers.shape_list(bits))
                bit_p = common_layers.inverse_lin_decay(
                    hparams.latent_rnn_warmup_steps)
                bit_p *= hparams.latent_rnn_max_sampling
                bits = tf.where(which_bit < bit_p, bits_pred, bits)

        res = add_bits(layer, bits)
        # During training, sometimes skip the latent to help action-conditioning.
        res_p = common_layers.inverse_lin_decay(
            hparams.latent_rnn_warmup_steps / 2)
        res_p *= hparams.latent_use_max_probability
        res_rand = tf.random_uniform([layer_shape[0]])
        res = tf.where(res_rand < res_p, res, layer)
        return res, pred_loss
Exemple #6
0
def normalize(data):
    data['image'] = tf.to_float(data['image']) / 255.
    return data
Exemple #7
0
def generate_detections_per_image_op(cls_outputs,
                                     box_outputs,
                                     anchor_boxes,
                                     image_id,
                                     image_info,
                                     num_detections=100,
                                     pre_nms_num_detections=1000,
                                     nms_threshold=0.3,
                                     bbox_reg_weights=(10., 10., 5., 5.)):
    """Generates detections with model outputs and anchors.

  Args:
    cls_outputs: a Tensor with shape [N, num_classes], which stacks class
      logit outputs on all feature levels. The N is the number of total anchors
      on all levels. The num_classes is the number of classes predicted by the
      model. Note that the cls_outputs should be the output of softmax().
    box_outputs: a Tensor with shape [N, num_classes*4], which stacks
      box regression outputs on all feature levels. The N is the number of total
      anchors on all levels.
    anchor_boxes: a Tensor with shape [N, 4], which stacks anchors on all
      feature levels. The N is the number of total anchors on all levels.
    image_id: an integer number to specify the image id.
    image_info: a tensor of shape [5] which encodes the input image's [height,
      width, scale, original_height, original_width]
    num_detections: Number of detections after NMS.
    pre_nms_num_detections: Number of candidates before NMS.
    nms_threshold: a float number to specify the threshold of NMS.
    bbox_reg_weights: a list of 4 float scalars, which are default weights on
      (dx, dy, dw, dh) for normalizing bbox regression targets.
  Returns:
    detections: detection results in a tensor with each row representing
      [image_id, ymin, xmin, ymax, xmax, score, class]
  """
    num_boxes, num_classes = cls_outputs.get_shape().as_list()

    # Removes background class scores.
    cls_outputs = cls_outputs[:, 1:num_classes]
    top_k_scores, top_k_indices_with_classes = tf.nn.top_k(
        tf.reshape(cls_outputs, [-1]), k=pre_nms_num_detections, sorted=True)
    classes = tf.mod(top_k_indices_with_classes, num_classes - 1)
    top_k_indices = tf.floordiv(top_k_indices_with_classes, num_classes - 1)

    anchor_boxes = tf.gather(anchor_boxes, top_k_indices)
    box_outputs = tf.reshape(box_outputs,
                             [num_boxes, num_classes, 4])[:, 1:num_classes, :]
    box_outputs = tf.gather_nd(box_outputs,
                               tf.stack([top_k_indices, classes], axis=1))

    # Applies bounding box regression to anchors.
    boxes = box_utils.batch_decode_box_outputs_op(
        tf.expand_dims(anchor_boxes, axis=0),
        tf.expand_dims(box_outputs, axis=0), bbox_reg_weights)[0]
    boxes = box_utils.clip_boxes(tf.expand_dims(boxes, axis=0),
                                 tf.expand_dims(image_info[:2], axis=0))[0]

    classes = tf.tile(tf.reshape(classes, [1, pre_nms_num_detections]),
                      [num_classes - 1, 1])
    scores = tf.tile(tf.reshape(top_k_scores, [1, pre_nms_num_detections]),
                     [num_classes - 1, 1])
    boxes = tf.tile(tf.reshape(boxes, [1, pre_nms_num_detections, 4]),
                    [num_classes - 1, 1, 1])

    class_bitmask = tf.tile(
        tf.reshape(tf.range(num_classes - 1), [num_classes - 1, 1]),
        [1, pre_nms_num_detections])
    scores = tf.where(tf.equal(classes, class_bitmask), scores,
                      tf.zeros_like(scores))
    scores = tf.where(tf.greater(scores, 0.05), scores, tf.zeros_like(scores))
    # Reshape classes to be compartible with the top_k function.
    classes = tf.reshape(classes, [num_classes - 1, pre_nms_num_detections, 1])
    scores, sorted_tensors = box_utils.top_k(scores,
                                             k=pre_nms_num_detections,
                                             tensors=[boxes, classes])
    boxes = sorted_tensors[0]
    classes = tf.reshape(sorted_tensors[1],
                         [num_classes - 1, pre_nms_num_detections])

    idx, num_valid = non_max_suppression.non_max_suppression_padded(
        scores,
        boxes,
        max_output_size=num_detections,
        iou_threshold=nms_threshold,
        level=0)

    post_nms_boxes = non_max_suppression.gather_boxes_by_indices(
        boxes, num_detections, idx, num_valid)
    post_nms_scores = non_max_suppression.gather_scores_by_indices(
        scores, num_detections, idx, num_valid)

    # Sorts all results.
    sorted_scores, sorted_indices = tf.nn.top_k(tf.to_float(
        tf.reshape(post_nms_scores, [-1])),
                                                k=num_detections,
                                                sorted=True)
    post_nms_boxes = tf.gather(tf.reshape(post_nms_boxes, [-1, 4]),
                               sorted_indices)
    classes = tf.batch_gather(classes, idx)
    post_nms_classes = tf.gather(tf.reshape(classes, [-1]), sorted_indices) + 1

    if isinstance(image_id, int):
        image_id = tf.constant(image_id)
    image_id = tf.reshape(image_id, [])
    detections_result = tf.stack([
        tf.to_float(tf.fill(tf.shape(sorted_scores), image_id)),
        post_nms_boxes[:, 0],
        post_nms_boxes[:, 1],
        post_nms_boxes[:, 2],
        post_nms_boxes[:, 3],
        sorted_scores,
        tf.to_float(post_nms_classes),
    ],
                                 axis=1)
    return detections_result
Exemple #8
0
def imagenet_inputs(batch_size,
                    image_size,
                    num_readers=1,
                    num_preprocess_threads=4):
    """Loads a batch of imagenet inputs.

    Used as a replacement for inception.image_processing.inputs in
    tensorflow/models in order to get around the use of hard-coded flags in the
    image_processing module.

    Args:
      batch_size: int, batch size.
      image_size: int. The images will be resized bilinearly to shape
          [image_size, image_size].
      num_readers: int, number of preprocessing threads per tower.  Must be a
          multiple of 4.
      num_preprocess_threads: int, number of parallel readers.

    Returns:
      4-D tensor of images of shape [batch_size, image_size, image_size, 3], with
      values in [0, 1].

    Raises:
      IOError: If ImageNet data files cannot be found.
      ValueError: If `num_preprocess_threads is not a multiple of 4 or
          `num_readers` is less than 1.
    """
    imagenet = imagenet_data.ImagenetData('train')

    with tf.name_scope('batch_processing'):
        data_files = imagenet.data_files()
        if data_files is None:
            raise IOError('No ImageNet data files found')

        # Create filename_queue.
        filename_queue = tf.train.string_input_producer(data_files,
                                                        shuffle=True,
                                                        capacity=16)

        if num_preprocess_threads % 4:
            raise ValueError('Please make num_preprocess_threads a multiple '
                             'of 4 (%d %% 4 != 0).' % num_preprocess_threads)

        if num_readers < 1:
            raise ValueError('Please make num_readers at least 1')

        # Approximate number of examples per shard.
        examples_per_shard = 1024
        # Size the random shuffle queue to balance between good global
        # mixing (more examples) and memory use (fewer examples).
        # 1 image uses 299*299*3*4 bytes = 1MB
        # The default input_queue_memory_factor is 16 implying a shuffling queue
        # size: examples_per_shard * 16 * 1MB = 17.6GB
        input_queue_memory_factor = 16
        min_queue_examples = examples_per_shard * input_queue_memory_factor
        examples_queue = tf.RandomShuffleQueue(
            capacity=min_queue_examples + 3 * batch_size,
            min_after_dequeue=min_queue_examples,
            dtypes=[tf.string])

        # Create multiple readers to populate the queue of examples.
        enqueue_ops = []
        for _ in range(num_readers):
            reader = imagenet.reader()
            _, value = reader.read(filename_queue)
            enqueue_ops.append(examples_queue.enqueue([value]))

        tf.train.queue_runner.add_queue_runner(
            tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
        example_serialized = examples_queue.dequeue()

        images_and_labels = []
        for _ in range(num_preprocess_threads):
            # Parse a serialized Example proto to extract the image and metadata.
            image_buffer, label_index, _, _ = _parse_example_proto(
                example_serialized)
            image = tf.image.decode_jpeg(image_buffer, channels=3)

            # pylint: disable=protected-access
            image = _aspect_preserving_resize(image, image_size + 2)
            image = _central_crop([image], image_size, image_size)[0]
            # pylint: enable=protected-access
            image.set_shape([image_size, image_size, 3])
            image = tf.to_float(image) / 255.0

            images_and_labels.append([image, label_index])

        images, label_index_batch = tf.train.batch_join(
            images_and_labels,
            batch_size=batch_size,
            capacity=2 * num_preprocess_threads * batch_size)

        images = tf.reshape(images,
                            shape=[batch_size, image_size, image_size, 3])

        # Display the training images in the visualizer.
        tf.summary.image('images', images)

        return images, tf.reshape(label_index_batch, [batch_size])
    def _build_model(self):
        """
        Builds the Tensorflow graph.
        """

        # Placeholders for our input
        # Our input are 4 RGB frames of shape 160, 160 each
        self.X_pl = tf.placeholder(shape=[None, 84, 84, 4],
                                   dtype=tf.uint8,
                                   name="X")
        # The TD target value
        self.y_pl = tf.placeholder(shape=[None], dtype=tf.float32, name="y")
        # Integer id of which action was selected
        self.actions_pl = tf.placeholder(shape=[None],
                                         dtype=tf.int32,
                                         name="actions")

        X = tf.to_float(self.X_pl) / 255.0
        batch_size = tf.shape(self.X_pl)[0]

        # Three convolutional layers
        conv1 = tf.contrib.layers.conv2d(X, 32, 8, 4, activation_fn=tf.nn.relu)
        conv2 = tf.contrib.layers.conv2d(conv1,
                                         64,
                                         4,
                                         2,
                                         activation_fn=tf.nn.relu)
        conv3 = tf.contrib.layers.conv2d(conv2,
                                         64,
                                         3,
                                         1,
                                         activation_fn=tf.nn.relu)

        # Fully connected layers
        flattened = tf.contrib.layers.flatten(conv3)
        fc1 = tf.contrib.layers.fully_connected(flattened, 512)
        self.predictions = tf.contrib.layers.fully_connected(
            fc1, len(VALID_ACTIONS))

        # Get the predictions for the chosen actions only
        gather_indices = tf.range(batch_size) * tf.shape(
            self.predictions)[1] + self.actions_pl
        self.action_predictions = tf.gather(tf.reshape(self.predictions, [-1]),
                                            gather_indices)

        # Calculate the loss
        self.losses = tf.squared_difference(self.y_pl, self.action_predictions)
        self.loss = tf.reduce_mean(self.losses)

        # Optimizer Parameters from original paper
        self.optimizer = tf.train.RMSPropOptimizer(0.00025, 0.99, 0.0, 1e-6)
        self.train_op = self.optimizer.minimize(
            self.loss, global_step=tf.contrib.framework.get_global_step())

        # Summaries for Tensorboard
        self.summaries = tf.summary.merge([
            tf.summary.scalar("loss", self.loss),
            tf.summary.histogram("loss_hist", self.losses),
            tf.summary.histogram("q_values_hist", self.predictions),
            tf.summary.scalar("max_q_value", tf.reduce_max(self.predictions))
        ])
Exemple #10
0
def style_image_inputs(style_dataset_file,
                       batch_size=None,
                       image_size=None,
                       square_crop=False,
                       shuffle=True):
    """Loads a batch of random style image given the path of tfrecord dataset.

    Args:
      style_dataset_file: str, path to the tfrecord dataset of style files.
          The dataset is produced via the create_style_dataset.py script and is
          made of Example protobufs with the following features:
          * 'image_raw': byte encoding of the JPEG string of the style image.
          * 'label': integer identifier of the style image in [0, N - 1], where
                N is the number of examples in the dataset.
          * 'vgg_16/<LAYER_NAME>': Gram matrix at layer <LAYER_NAME> of the VGG-16
                network (<LAYER_NAME> in {conv,pool}{1,2,3,4,5}) for the style
                image.
      batch_size: int. If provided, batches style images. Defaults to None.
      image_size: int. The images will be resized bilinearly so that the smallest
          side has size image_size. Defaults to None.
      square_crop: bool. If True, square-crops to [image_size, image_size].
          Defaults to False.
      shuffle: bool, whether to shuffle style files at random. Defaults to True.

    Returns:
      If batch_size is defined, a 4-D tensor of shape [batch_size, ?, ?, 3] with
      values in [0, 1] for the style image, and 1-D tensor for the style label.

    Raises:
      ValueError: if center cropping is requested but no image size is provided,
          or if batch size is specified but center-cropping is not requested.
    """
    vgg_layers = [
        'vgg_16/conv1', 'vgg_16/pool1', 'vgg_16/conv2', 'vgg_16/pool2',
        'vgg_16/conv3', 'vgg_16/pool3', 'vgg_16/conv4', 'vgg_16/pool4',
        'vgg_16/conv5', 'vgg_16/pool5'
    ]

    if square_crop and image_size is None:
        raise ValueError('center-cropping requires specifying the image size.')
    if batch_size is not None and not square_crop:
        raise ValueError('batching requires center-cropping.')

    with tf.name_scope('style_image_processing'):
        filename_queue = tf.train.string_input_producer([style_dataset_file],
                                                        shuffle=False,
                                                        capacity=1,
                                                        name='filename_queue')
        if shuffle:
            examples_queue = tf.RandomShuffleQueue(
                capacity=64,
                min_after_dequeue=32,
                dtypes=[tf.string],
                name='random_examples_queue')
        else:
            examples_queue = tf.FIFOQueue(capacity=64,
                                          dtypes=[tf.string],
                                          name='fifo_examples_queue')
        reader = tf.TFRecordReader()
        _, value = reader.read(filename_queue)
        enqueue_ops = [examples_queue.enqueue([value])]
        tf.train.queue_runner.add_queue_runner(
            tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
        example_serialized = examples_queue.dequeue()
        features = tf.parse_single_example(
            example_serialized,
            features={
                'label': tf.FixedLenFeature([], tf.int64),
                'image_raw': tf.FixedLenFeature([], tf.string),
                'vgg_16/conv1': tf.FixedLenFeature([64, 64], tf.float32),
                'vgg_16/pool1': tf.FixedLenFeature([64, 64], tf.float32),
                'vgg_16/conv2': tf.FixedLenFeature([128, 128], tf.float32),
                'vgg_16/pool2': tf.FixedLenFeature([128, 128], tf.float32),
                'vgg_16/conv3': tf.FixedLenFeature([256, 256], tf.float32),
                'vgg_16/pool3': tf.FixedLenFeature([256, 256], tf.float32),
                'vgg_16/conv4': tf.FixedLenFeature([512, 512], tf.float32),
                'vgg_16/pool4': tf.FixedLenFeature([512, 512], tf.float32),
                'vgg_16/conv5': tf.FixedLenFeature([512, 512], tf.float32),
                'vgg_16/pool5': tf.FixedLenFeature([512, 512], tf.float32)
            })
        image = tf.image.decode_jpeg(features['image_raw'])
        label = features['label']
        gram_matrices = [features[vgg_layer] for vgg_layer in vgg_layers]
        image.set_shape([None, None, 3])

        if image_size:
            if square_crop:
                image = _aspect_preserving_resize(image, image_size + 2)
                image = _central_crop([image], image_size, image_size)[0]
                image.set_shape([image_size, image_size, 3])
            else:
                image = _aspect_preserving_resize(image, image_size)

        image = tf.to_float(image) / 255.0

        if batch_size is None:
            image = tf.expand_dims(image, 0)
        else:
            image_label_gram_matrices = tf.train.batch([image, label] +
                                                       gram_matrices,
                                                       batch_size=batch_size)
            image, label = image_label_gram_matrices[:2]
            gram_matrices = image_label_gram_matrices[2:]

        gram_matrices = dict(
            (vgg_layer, gram_matrix)
            for vgg_layer, gram_matrix in zip(vgg_layers, gram_matrices))
        return image, label, gram_matrices
Exemple #11
0
def arbitrary_style_image_inputs(style_dataset_file,
                                 batch_size=None,
                                 image_size=None,
                                 center_crop=True,
                                 shuffle=True,
                                 augment_style_images=False,
                                 random_style_image_size=False,
                                 min_rand_image_size=128,
                                 max_rand_image_size=300):
    """Loads a batch of random style image given the path of tfrecord dataset.

    This method does not return pre-compute Gram matrices for the images like
    style_image_inputs. But it can provide data augmentation. If
    augment_style_images is equal to True, then style images will randomly
    modified (eg. changes in brightness, hue or saturation) for data
    augmentation. If random_style_image_size is set to True then all images
    in one batch will be resized to a random size.
    Args:
      style_dataset_file: str, path to the tfrecord dataset of style files.
      batch_size: int. If provided, batches style images. Defaults to None.
      image_size: int. The images will be resized bilinearly so that the smallest
          side has size image_size. Defaults to None.
      center_crop: bool. If True, center-crops to [image_size, image_size].
          Defaults to False.
      shuffle: bool, whether to shuffle style files at random. Defaults to False.
      augment_style_images: bool. Wheather to augment style images or not.
      random_style_image_size: bool. If this value is True, then all the style
          images in one batch will be resized to a random size between
          min_rand_image_size and max_rand_image_size.
      min_rand_image_size: int. If random_style_image_size is True, this value
          specifies the minimum image size.
      max_rand_image_size: int. If random_style_image_size is True, this value
          specifies the maximum image size.

    Returns:
      4-D tensor of shape [1, ?, ?, 3] with values in [0, 1] for the style
      image (with random changes for data augmentation if
      augment_style_image_size is set to true), and 0-D tensor for the style
      label, 4-D tensor of shape [1, ?, ?, 3] with values in [0, 1] for the style
      image without random changes for data augmentation.

    Raises:
      ValueError: if center cropping is requested but no image size is provided,
          or if batch size is specified but center-cropping or
          augment-style-images is not requested,
          or if both augment-style-images and center-cropping are requested.
    """
    if center_crop and image_size is None:
        raise ValueError('center-cropping requires specifying the image size.')
    if center_crop and augment_style_images:
        raise ValueError(
            'When augment_style_images is true images will be randomly cropped.'
        )
    if batch_size is not None and not center_crop and not augment_style_images:
        raise ValueError(
            'batching requires same image sizes (Set center-cropping or '
            'augment_style_images to true)')

    with tf.name_scope('style_image_processing'):
        # Force all input processing onto CPU in order to reserve the GPU for the
        # forward inference and back-propagation.
        with tf.device('/cpu:0'):
            filename_queue = tf.train.string_input_producer(
                [style_dataset_file],
                shuffle=False,
                capacity=1,
                name='filename_queue')
            if shuffle:
                examples_queue = tf.RandomShuffleQueue(
                    capacity=64,
                    min_after_dequeue=32,
                    dtypes=[tf.string],
                    name='random_examples_queue')
            else:
                examples_queue = tf.FIFOQueue(capacity=64,
                                              dtypes=[tf.string],
                                              name='fifo_examples_queue')
            reader = tf.TFRecordReader()
            _, value = reader.read(filename_queue)
            enqueue_ops = [examples_queue.enqueue([value])]
            tf.train.queue_runner.add_queue_runner(
                tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
            example_serialized = examples_queue.dequeue()
            features = tf.parse_single_example(
                example_serialized,
                features={
                    'label': tf.FixedLenFeature([], tf.int64),
                    'image_raw': tf.FixedLenFeature([], tf.string)
                })
            image = tf.image.decode_jpeg(features['image_raw'])
            image.set_shape([None, None, 3])
            label = features['label']

            if image_size is not None:
                image_channels = image.shape[2].value
                if augment_style_images:
                    image_orig = image
                    image = tf.image.random_brightness(image, max_delta=0.8)
                    image = tf.image.random_saturation(image,
                                                       lower=0.5,
                                                       upper=1.5)
                    image = tf.image.random_hue(image, max_delta=0.2)
                    image = tf.image.random_flip_left_right(image)
                    image = tf.image.random_flip_up_down(image)
                    random_larger_image_size = tf.random_uniform(
                        [],
                        minval=image_size + 2,
                        maxval=image_size + 200,
                        dtype=tf.int32)
                    image = _aspect_preserving_resize(
                        image, random_larger_image_size)
                    image = tf.random_crop(
                        image, size=[image_size, image_size, image_channels])
                    image.set_shape([image_size, image_size, image_channels])

                    image_orig = _aspect_preserving_resize(
                        image_orig, image_size + 2)
                    image_orig = _central_crop([image_orig], image_size,
                                               image_size)[0]
                    image_orig.set_shape([image_size, image_size, 3])
                elif center_crop:
                    image = _aspect_preserving_resize(image, image_size + 2)
                    image = _central_crop([image], image_size, image_size)[0]
                    image.set_shape([image_size, image_size, image_channels])
                    image_orig = image
                else:
                    image = _aspect_preserving_resize(image, image_size)
                    image_orig = image

            image = tf.to_float(image) / 255.0
            image_orig = tf.to_float(image_orig) / 255.0

            if batch_size is None:
                image = tf.expand_dims(image, 0)
            else:
                [image, image_orig,
                 label] = tf.train.batch([image, image_orig, label],
                                         batch_size=batch_size)

            if random_style_image_size:
                # Selects a random size for the style images and resizes all the images
                # in the batch to that size.
                image = _aspect_preserving_resize(
                    image,
                    tf.random_uniform([],
                                      minval=min_rand_image_size,
                                      maxval=max_rand_image_size,
                                      dtype=tf.int32))

            return image, label, image_orig
    def body(self,
             features,
             decode_step=None,
             cache=None,
             decoding_stats=None,
             add_summary=True):
        encoder_output = None
        extra_losses = []
        padding_bias = None
        if not self.hparams.fast_decode:
            decode_step = None
        if "inputs" in features:
            inputs = features["inputs"]
            # remove the last two dimensions that are always 1.
            inputs = tf.reshape(
                inputs,
                utils.shape_list(inputs)[:2] + [self.hidden_size])
            # Padding bias only used for seq2seq models.
            padding_bias = utils.embedding_to_padding(inputs)
            # Mask random positions
            shape = utils.shape_list(inputs)
            if self.hparams.input_dropout:
                inputs = tf.where(
                    tf.random.uniform(shape) < self.hparams.input_dropout,
                    tf.zeros_like(inputs), inputs)
            if self.hparams.add_timing_signal:
                inputs += utils.get_timing_signal_1d(self.hparams.max_length,
                                                     self.hidden_size)
            if cache is not None and -1 in cache:
                encoder_output = cache[-1]
            else:
                encoder_output = utils.transformer_encoder_layers(
                    inputs=inputs,
                    num_layers=self.num_encoder_layers,
                    hparams=self.hparams,
                    losses=extra_losses,
                    name="encoder",
                    token_bias=features.get("token_bias_inputs"),
                    padding_bias=padding_bias)
            if cache is not None and -1 not in cache:
                cache[-1] = encoder_output
        targets = tf.to_int32(features["targets"])
        # remove the last two dimensions that are always 1.
        targets = tf.reshape(targets, utils.shape_list(targets)[:2])
        # Clamp targets to max_target_length
        targets = targets[:, :self.hparams.max_target_length]
        if self.is_decode:
            targets = self.process_partial_targets_decoding(targets)
        decoder_input = self.prepare_decoder(targets)

        decoder_output = utils.transformer_decoder_layers(
            inputs=decoder_input,
            num_layers=self.num_decoder_layers,
            hparams=self.hparams,
            encoder_output=encoder_output,
            decode_step=decode_step,
            losses=extra_losses,
            cache=cache,
            name="decoder",
            decoding_stats=decoding_stats,
            token_bias_inputs=features.get("token_bias_inputs"),
            token_bias_targets=features.get("token_bias_targets"),
            padding_bias=padding_bias)
        logits = self.produce_output(decoder_output)

        # Return logits as-is in decoding mode
        if self.is_decode:
            return logits

        # Add cross entropy loss
        one_hot_targets = tf.one_hot(tf.cast(targets, dtype=tf.int32),
                                     self.vocab_size)
        x_entropy = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=one_hot_targets, logits=logits)
        weights = tf.to_float(tf.not_equal(targets, 0))
        loss = tf.reduce_sum(x_entropy * weights) / tf.reduce_sum(weights)
        if add_summary:
            tf.summary.scalar("losses/weight", tf.reduce_sum(weights))
            tf.summary.scalar("losses/x_entropy",
                              tf.reduce_sum(x_entropy * weights))

        loss_dict = {"training": loss}
        if extra_losses:
            loss_dict["extra_loss"] = tf.add_n(extra_losses)
        # hack for T2T metrics
        logits = tf.reshape(
            logits,
            utils.shape_list(logits)[:2] + [1, 1] +
            utils.shape_list(logits)[-1:])
        return logits, loss_dict
Exemple #13
0
def trilerp_gather(vol, inds, bad_inds=None):
    """Trilinear interpolation dense gather from volume at query inds."""

    inds_b = inds[Ellipsis, 0]
    inds_x = inds[Ellipsis, 1]
    inds_y = inds[Ellipsis, 2]
    inds_z = inds[Ellipsis, 3]

    inds_x_0 = tf.floor(inds_x)
    inds_x_1 = inds_x_0 + 1
    inds_y_0 = tf.floor(inds_y)
    inds_y_1 = inds_y_0 + 1
    inds_z_0 = tf.floor(inds_z)
    inds_z_1 = inds_z_0 + 1

    # store invalid indices to implement correct out-of-bounds conditions
    invalid_x = tf.logical_or(
        tf.less(inds_x_0, 0.0),
        tf.greater(inds_x_1, tf.to_float(tf.shape(vol)[2] - 1)))
    invalid_y = tf.logical_or(
        tf.less(inds_y_0, 0.0),
        tf.greater(inds_y_1, tf.to_float(tf.shape(vol)[1] - 1)))
    invalid_z = tf.logical_or(
        tf.less(inds_z_0, 0.0),
        tf.greater(inds_z_1, tf.to_float(tf.shape(vol)[3] - 1)))
    if bad_inds is not None:
        invalid_inds = tf.logical_or(
            tf.logical_or(tf.logical_or(invalid_x, invalid_y), invalid_z),
            bad_inds)
    else:
        invalid_inds = tf.logical_or(tf.logical_or(invalid_x, invalid_y),
                                     invalid_z)

    inds_x_0 = tf.clip_by_value(inds_x_0, 0.0,
                                tf.to_float(tf.shape(vol)[2] - 2))
    inds_x_1 = tf.clip_by_value(inds_x_1, 0.0,
                                tf.to_float(tf.shape(vol)[2] - 1))
    inds_y_0 = tf.clip_by_value(inds_y_0, 0.0,
                                tf.to_float(tf.shape(vol)[1] - 2))
    inds_y_1 = tf.clip_by_value(inds_y_1, 0.0,
                                tf.to_float(tf.shape(vol)[1] - 1))
    inds_z_0 = tf.clip_by_value(inds_z_0, 0.0,
                                tf.to_float(tf.shape(vol)[3] - 2))
    inds_z_1 = tf.clip_by_value(inds_z_1, 0.0,
                                tf.to_float(tf.shape(vol)[3] - 1))

    # compute interp weights
    w_x_0 = 1.0 - (inds_x - inds_x_0)
    w_x_1 = 1.0 - w_x_0
    w_y_0 = 1.0 - (inds_y - inds_y_0)
    w_y_1 = 1.0 - w_y_0
    w_z_0 = 1.0 - (inds_z - inds_z_0)
    w_z_1 = 1.0 - w_z_0

    w_0_0_0 = w_y_0 * w_x_0 * w_z_0
    w_1_0_0 = w_y_1 * w_x_0 * w_z_0
    w_0_1_0 = w_y_0 * w_x_1 * w_z_0
    w_0_0_1 = w_y_0 * w_x_0 * w_z_1
    w_1_1_0 = w_y_1 * w_x_1 * w_z_0
    w_0_1_1 = w_y_0 * w_x_1 * w_z_1
    w_1_0_1 = w_y_1 * w_x_0 * w_z_1
    w_1_1_1 = w_y_1 * w_x_1 * w_z_1

    # gather for interp
    inds_0_0_0 = tf.to_int32(
        tf.stack([inds_b, inds_y_0, inds_x_0, inds_z_0], axis=-1))
    inds_1_0_0 = tf.to_int32(
        tf.stack([inds_b, inds_y_1, inds_x_0, inds_z_0], axis=-1))
    inds_0_1_0 = tf.to_int32(
        tf.stack([inds_b, inds_y_0, inds_x_1, inds_z_0], axis=-1))
    inds_0_0_1 = tf.to_int32(
        tf.stack([inds_b, inds_y_0, inds_x_0, inds_z_1], axis=-1))
    inds_1_1_0 = tf.to_int32(
        tf.stack([inds_b, inds_y_1, inds_x_1, inds_z_0], axis=-1))
    inds_0_1_1 = tf.to_int32(
        tf.stack([inds_b, inds_y_0, inds_x_1, inds_z_1], axis=-1))
    inds_1_0_1 = tf.to_int32(
        tf.stack([inds_b, inds_y_1, inds_x_0, inds_z_1], axis=-1))
    inds_1_1_1 = tf.to_int32(
        tf.stack([inds_b, inds_y_1, inds_x_1, inds_z_1], axis=-1))

    vol_0_0_0 = tf.gather_nd(vol, inds_0_0_0) * w_0_0_0[Ellipsis, tf.newaxis]
    vol_1_0_0 = tf.gather_nd(vol, inds_1_0_0) * w_1_0_0[Ellipsis, tf.newaxis]
    vol_0_1_0 = tf.gather_nd(vol, inds_0_1_0) * w_0_1_0[Ellipsis, tf.newaxis]
    vol_0_0_1 = tf.gather_nd(vol, inds_0_0_1) * w_0_0_1[Ellipsis, tf.newaxis]
    vol_1_1_0 = tf.gather_nd(vol, inds_1_1_0) * w_1_1_0[Ellipsis, tf.newaxis]
    vol_0_1_1 = tf.gather_nd(vol, inds_0_1_1) * w_0_1_1[Ellipsis, tf.newaxis]
    vol_1_0_1 = tf.gather_nd(vol, inds_1_0_1) * w_1_0_1[Ellipsis, tf.newaxis]
    vol_1_1_1 = tf.gather_nd(vol, inds_1_1_1) * w_1_1_1[Ellipsis, tf.newaxis]

    out_vol = vol_0_0_0 + vol_1_0_0 + vol_0_1_0 + vol_0_0_1 + \
              vol_1_1_0 + vol_0_1_1 + vol_1_0_1 + vol_1_1_1

    # boundary conditions for invalid indices
    invalid_inds = tf.tile(invalid_inds[:, :, :, :, tf.newaxis],
                           [1, 1, 1, 1, tf.shape(vol)[4]])
    out_vol = tf.where(invalid_inds, tf.zeros_like(out_vol), out_vol)

    return out_vol
Exemple #14
0
def bilerp_gather(img, inds):
    """Bilinear interpolation dense gather from image at query inds."""

    inds_b, _, _, = tf.meshgrid(tf.range(tf.shape(img)[0]),
                                tf.range(tf.shape(img)[1]),
                                tf.range(tf.shape(img)[2]),
                                indexing='ij')

    inds_b = tf.to_float(inds_b)
    inds_x = inds[Ellipsis, 0]
    inds_y = inds[Ellipsis, 1]

    inds_x_0 = tf.floor(inds_x)
    inds_x_1 = inds_x_0 + 1
    inds_y_0 = tf.floor(inds_y)
    inds_y_1 = inds_y_0 + 1

    # store invalid indices to implement correct out-of-bounds conditions
    invalid_x = tf.logical_or(
        tf.less(inds_x_0, 0.0),
        tf.greater(inds_x_1, tf.to_float(tf.shape(img)[2] - 1)))
    invalid_y = tf.logical_or(
        tf.less(inds_y_0, 0.0),
        tf.greater(inds_y_1, tf.to_float(tf.shape(img)[1] - 1)))
    invalid_inds = tf.logical_or(invalid_x, invalid_y)

    inds_x_0 = tf.clip_by_value(inds_x_0, 0.0,
                                tf.to_float(tf.shape(img)[2] - 2))
    inds_x_1 = tf.clip_by_value(inds_x_1, 0.0,
                                tf.to_float(tf.shape(img)[2] - 1))
    inds_y_0 = tf.clip_by_value(inds_y_0, 0.0,
                                tf.to_float(tf.shape(img)[1] - 2))
    inds_y_1 = tf.clip_by_value(inds_y_1, 0.0,
                                tf.to_float(tf.shape(img)[1] - 1))

    # compute interp weights
    w_x_0 = 1.0 - (inds_x - inds_x_0)
    w_x_1 = 1.0 - w_x_0
    w_y_0 = 1.0 - (inds_y - inds_y_0)
    w_y_1 = 1.0 - w_y_0

    w_0_0 = w_y_0 * w_x_0
    w_1_0 = w_y_1 * w_x_0
    w_0_1 = w_y_0 * w_x_1
    w_1_1 = w_y_1 * w_x_1

    # gather for interp
    inds_0_0 = tf.to_int32(tf.stack([inds_b, inds_y_0, inds_x_0], axis=-1))
    inds_1_0 = tf.to_int32(tf.stack([inds_b, inds_y_1, inds_x_0], axis=-1))
    inds_0_1 = tf.to_int32(tf.stack([inds_b, inds_y_0, inds_x_1], axis=-1))
    inds_1_1 = tf.to_int32(tf.stack([inds_b, inds_y_1, inds_x_1], axis=-1))

    img_0_0 = tf.gather_nd(img, inds_0_0) * w_0_0[Ellipsis, tf.newaxis]
    img_1_0 = tf.gather_nd(img, inds_1_0) * w_1_0[Ellipsis, tf.newaxis]
    img_0_1 = tf.gather_nd(img, inds_0_1) * w_0_1[Ellipsis, tf.newaxis]
    img_1_1 = tf.gather_nd(img, inds_1_1) * w_1_1[Ellipsis, tf.newaxis]

    out_img = img_0_0 + img_1_0 + img_0_1 + img_1_1

    # boundary conditions for invalid indices
    invalid_inds = tf.tile(invalid_inds[:, :, :, tf.newaxis],
                           [1, 1, 1, tf.shape(img)[3]])

    out_img = tf.where(invalid_inds, tf.zeros_like(out_img), out_img)

    return out_img
def encode_decode_task(features, hparams, train, attention_weights=None):
    """Model core graph for the one-shot action.

  Args:
    features: a dictionary contains "inputs" that is a tensor in shape of
        [batch_size, num_tokens], "verb_id_seq" that is in shape of
        [batch_size, num_actions], "object_spans" and "param_span" tensor
        in shape of [batch_size, num_actions, 2]. 0 is used as padding or
        non-existent values.
    hparams: the general hyperparameters for the model.
    train: the train mode.
    attention_weights: the dict to keep attention weights for analysis.
  Returns:
    loss_dict: the losses for training.
    prediction_dict: the predictions for action tuples.
    areas: the area encodings of the task.
    scope: the embedding scope.
  """
    del train
    input_embeddings, scope = common_embed.embed_tokens(
        features["task"], hparams.task_vocab_size, hparams.hidden_size,
        hparams)
    with tf.variable_scope("encode_decode", reuse=tf.AUTO_REUSE):
        encoder_nonpadding = tf.minimum(tf.to_float(features["task"]), 1.0)
        input_embeddings = tf.multiply(tf.expand_dims(encoder_nonpadding, 2),
                                       input_embeddings)
        encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
            transformer.transformer_prepare_encoder(input_embeddings,
                                                    None,
                                                    hparams,
                                                    features=None))
        encoder_input = tf.nn.dropout(encoder_input,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        if hparams.instruction_encoder == "transformer":
            encoder_output = transformer.transformer_encoder(
                encoder_input,
                self_attention_bias,
                hparams,
                save_weights_to=attention_weights,
                make_image_summary=not common_layers.is_xla_compiled())
        else:
            raise ValueError("Unsupported instruction encoder %s" %
                             (hparams.instruction_encoder))
        span_rep = hparams.get("span_rep", "area")
        area_encodings, area_starts, area_ends = area_utils.compute_sum_image(
            encoder_output, max_area_width=hparams.max_span)
        current_shape = tf.shape(area_encodings)
        if span_rep == "area":
            area_encodings, _, _ = area_utils.compute_sum_image(
                encoder_output, max_area_width=hparams.max_span)
        elif span_rep == "basic":
            area_encodings = area_utils.compute_alternative_span_rep(
                encoder_output,
                input_embeddings,
                max_area_width=hparams.max_span,
                hidden_size=hparams.hidden_size,
                advanced=False)
        elif span_rep == "coref":
            area_encodings = area_utils.compute_alternative_span_rep(
                encoder_output,
                input_embeddings,
                max_area_width=hparams.max_span,
                hidden_size=hparams.hidden_size,
                advanced=True)
        else:
            raise ValueError("xyz")
        areas = {}
        areas["encodings"] = area_encodings
        areas["starts"] = area_starts
        areas["ends"] = area_ends
        with tf.control_dependencies([
                tf.print("encoder_output", tf.shape(encoder_output)),
                tf.assert_equal(current_shape,
                                tf.shape(area_encodings),
                                summarize=100)
        ]):
            paddings = tf.cast(tf.less(self_attention_bias, -1), tf.int32)
        padding_sum, _, _ = area_utils.compute_sum_image(
            tf.expand_dims(tf.squeeze(paddings, [1, 2]), 2),
            max_area_width=hparams.max_span)
        num_areas = common_layers.shape_list(area_encodings)[1]
        area_paddings = tf.reshape(tf.minimum(tf.to_float(padding_sum), 1.0),
                                   [-1, num_areas])
        areas["bias"] = area_paddings
        decoder_nonpadding = tf.to_float(
            tf.greater(features["verb_refs"][:, :, 1],
                       features["verb_refs"][:, :, 0]))
        if hparams.instruction_encoder == "lstm":
            hparams_decoder = copy.copy(hparams)
            hparams_decoder.set_hparam("pos", "none")
        else:
            hparams_decoder = hparams
        decoder_input, decoder_self_attention_bias = _prepare_decoder_input(
            area_encodings,
            decoder_nonpadding,
            features,
            hparams_decoder,
            embed_scope=scope)
        decoder_input = tf.nn.dropout(decoder_input,
                                      keep_prob=1.0 -
                                      hparams.layer_prepostprocess_dropout)
        if hparams.instruction_decoder == "transformer":
            decoder_output = transformer.transformer_decoder(
                decoder_input=decoder_input,
                encoder_output=encoder_output,
                decoder_self_attention_bias=decoder_self_attention_bias,
                encoder_decoder_attention_bias=encoder_decoder_attention_bias,
                hparams=hparams_decoder)
        else:
            raise ValueError("Unsupported instruction encoder %s" %
                             (hparams.instruction_encoder))
        return decoder_output, decoder_nonpadding, areas, scope
    def __init__(self, dim_h, tag_size, pos_size, chunk_size, vocab_size,
                 embeddings, args):

        dim_emb = 100
        beta1, beta2 = 0.9, 0.999
        dim_d = 2 * dim_h  # value of d

        self.dropout = tf.placeholder(tf.float32, name='dropout')
        self.learning_rate = tf.placeholder(tf.float32, name='learning_rate')
        self.batch_len = tf.placeholder(tf.int32, name='batch_len')
        self.batch_size = tf.placeholder(tf.int32, name='batch_size')
        self.enc_inputs = tf.placeholder(tf.int32, [None, None],
                                         name='enc_inputs')  # size * len
        self.enc_inputs_reverse = tf.placeholder(tf.int32, [None, None],
                                                 name='enc_inputs_reverse')
        self.next_enc_inputs = tf.placeholder(
            tf.int32, [None, None], name='next_enc_inputs')  # size * len
        self.next_enc_inputs_reverse = tf.placeholder(
            tf.int32, [None, None], name='next_enc_inputs_reverse')
        self.inputs_pos = tf.placeholder(tf.int32, [None, None],
                                         name='inputs_pos')
        self.inputs_pos_reverse = tf.placeholder(tf.int32, [None, None],
                                                 name='inputs_pos_reverse')
        self.inputs_chunk = tf.placeholder(tf.int32, [None, None],
                                           name='inputs_chunk')
        self.inputs_chunk_reverse = tf.placeholder(tf.int32, [None, None],
                                                   name='inputs_chunk_reverse')
        self.inputs_case = tf.placeholder(tf.int32, [None, None],
                                          name='inputs_case')
        self.inputs_case_reverse = tf.placeholder(tf.int32, [None, None],
                                                  name='inputs_case_reverse')
        self.inputs_num = tf.placeholder(tf.int32, [None, None],
                                         name='inputs_num')
        self.inputs_num_reverse = tf.placeholder(tf.int32, [None, None],
                                                 name='inputs_num_reverse')

        self.enc_inputs_char = tf.placeholder(
            tf.int32, [None, None, None], name='enc_inputs_char')  # size * len

        self.weights = tf.placeholder(tf.float32, [None, None], name='weights')
        self.tlm_weights = tf.placeholder(tf.float32, [None, None],
                                          name='tlm_weights')
        self.targets = tf.placeholder(tf.int32, [None, None], name='targets')
        self.targets_reverse = tf.placeholder(tf.int32, [None, None],
                                              name='targets_reverse')
        self.tlm_targets = tf.placeholder(tf.int32, [None, None],
                                          name='tlm_targets')
        self.tlm_targets_reverse = tf.placeholder(tf.int32, [None, None],
                                                  name='tlm_targets_reverse')
        self.tlm_targets_pos = tf.placeholder(tf.int32, [None, None],
                                              name='tlm_targets_pos')
        self.tlm_targets_pos_reverse = tf.placeholder(
            tf.int32, [None, None], name='tlm_targets_pos_reverse')

        self.perturb = tf.placeholder(tf.float32, [None, None], name='perturb')

        embedding_global = embeddings[0]
        embedding_char_global = embeddings[1]
        batch_size = args.batch_size

        embedding_model = tf.get_variable('embedding',
                                          initializer=embedding_global.astype(
                                              np.float32))

        embedding_model_char = tf.get_variable(
            'embedding_char',
            initializer=embedding_char_global.astype(np.float32))

        def delta(v):
            return tf.norm(v, ord=1)

        inputs = tf.nn.embedding_lookup(embedding_model, self.enc_inputs)
        inputs = tf.cast(inputs, tf.float32)

        next_inputs = tf.nn.embedding_lookup(embedding_model,
                                             self.next_enc_inputs)
        next_inputs = tf.cast(next_inputs, tf.float32)
        # but use self.next_enc_inputs as targets in LM

        inputs_reverse = tf.nn.embedding_lookup(embedding_model,
                                                self.enc_inputs_reverse)
        inputs_reverse = tf.cast(inputs_reverse, tf.float32)

        next_inputs_reverse = tf.nn.embedding_lookup(
            embedding_model, self.next_enc_inputs_reverse)
        next_inputs_reverse = tf.cast(next_inputs_reverse, tf.float32)

        inputs_char = tf.nn.embedding_lookup(embedding_model_char,
                                             self.enc_inputs_char)
        inputs_char = tf.cast(inputs_char, tf.float32)
        ''' Implementing TLM 
        - Tag embeddings are L dimensional one-hot vectors // why not just random initialization
        - GRU (paper uses LSTM) language model on the tag sequences
        '''

        with tf.variable_scope('tlm_projection'):
            proj_tlm_W = tf.get_variable(
                'tlm_W', [dim_h, pos_size + tag_size],
                dtype=tf.float32)  # tag_size+vocab_size
            proj_tlm_b = tf.get_variable(
                'tlm_b', [pos_size + tag_size],
                dtype=tf.float32)  # tag_size+vocab_size
            proj_tlm_W_reverse = tf.get_variable(
                'tlm_W_reverse', [dim_h, pos_size + tag_size],
                dtype=tf.float32)  # tag_size+vocab_size
            proj_tlm_b_reverse = tf.get_variable(
                'tlm_b_reverse', [pos_size + tag_size],
                dtype=tf.float32)  # tag_size+vocab_size

        y_onehot_tlm = tf.one_hot(self.targets, depth=tag_size) + self.perturb
        y_onehot_tlm_reverse = tf.one_hot(self.targets_reverse,
                                          depth=tag_size) + self.perturb

        inputs_pos_onehot = tf.one_hot(self.inputs_pos, depth=pos_size)
        inputs_pos_onehot_reverse = tf.one_hot(self.inputs_pos_reverse,
                                               depth=pos_size)
        inputs_chunk_onehot = tf.one_hot(self.inputs_chunk, depth=chunk_size)
        inputs_chunk_onehot_reverse = tf.one_hot(self.inputs_chunk_reverse,
                                                 depth=chunk_size)
        inputs_case_onehot = tf.one_hot(self.inputs_case,
                                        depth=2)  # changed from 4 to 2
        inputs_case_onehot_reverse = tf.one_hot(self.inputs_case_reverse,
                                                depth=2)
        inputs_num_onehot = tf.one_hot(self.inputs_num, depth=2)
        inputs_num_onehot_reverse = tf.one_hot(self.inputs_num_reverse,
                                               depth=2)

        # self.output_0_shape = tf.shape(inputs)
        # self.output_1_shape = tf.shape(y_onehot_tlm)
        # self.output_2_shape = tf.shape(inputs_pos_onehot)
        # self.output_3_shape = tf.shape(inputs_chunk_onehot)

        #         with tf.variable_scope('tlm'):
        #             cell_gru = create_cell(dim_h, self.dropout) # lstm actually
        #             # initial_state_gru = cell_gru.zero_state(batch_size, dtype=tf.float32)
        #             outputs_tlm, _ = tf.nn.dynamic_rnn(cell_gru,
        #                                                tf.concat([inputs,y_onehot_tlm,inputs_pos_onehot], axis=-1), # [inputs,y_onehot_tlm]
        #                                                dtype=tf.float32, scope='tlm')
        #             outputs_tlm = tf.nn.dropout(outputs_tlm, self.dropout)
        #             outputs_tlm = tf.reshape(outputs_tlm, [-1, dim_h])

        #             self.logits_tlm_tmp = tf.matmul(outputs_tlm, proj_tlm_W) + proj_tlm_b
        #             self.logits_tlm = self.logits_tlm_tmp[:,pos_size:] # FIX!!!!!!!!!
        #             # self.logits_nextword = self.logits_tlm_tmp[:,:vocab_size]
        #             self.logits_pos = self.logits_tlm_tmp[:,:pos_size]

        #             self.probs_tlm = tf.nn.softmax(self.logits_tlm)
        #             # self.probs_nextword = tf.nn.softmax(self.logits_nextword)
        #             self.probs_pos = tf.nn.softmax(self.logits_pos)

        #             loss_pretrain_tlm = tf.nn.sparse_softmax_cross_entropy_with_logits(
        #                labels=tf.reshape(self.tlm_targets, [-1]),
        #                logits=self.logits_tlm)
        #             loss_pretrain_tlm *= tf.reshape(self.tlm_weights, [-1])
        # #             loss_pretrain_nextword = tf.nn.sparse_softmax_cross_entropy_with_logits(
        # #                labels=tf.reshape(self.next_enc_inputs, [-1]),
        # #                logits=self.logits_nextword)
        # #             loss_pretrain_nextword *= tf.reshape(self.weights, [-1])
        #             loss_pretrain_pos = tf.nn.sparse_softmax_cross_entropy_with_logits(
        #                labels=tf.reshape(self.tlm_targets_pos, [-1]),
        #                logits=self.logits_pos)
        #             loss_pretrain_pos *= tf.reshape(self.tlm_weights, [-1])

        #         with tf.variable_scope('tlm_reverse'):
        #             cell_gru_reverse = create_cell(dim_h, self.dropout)
        #             outputs_tlm_reverse, _ = tf.nn.dynamic_rnn(cell_gru_reverse,
        #                                                tf.concat([inputs_reverse,y_onehot_tlm_reverse,inputs_pos_onehot_reverse], axis=-1), # [inputs,y_onehot_tlm]
        #                                                dtype=tf.float32, scope='tlm_reverse')
        #             outputs_tlm_reverse = tf.nn.dropout(outputs_tlm_reverse, self.dropout)
        #             outputs_tlm_reverse = tf.reshape(outputs_tlm_reverse, [-1, dim_h])

        #             self.logits_tlm_tmp_reverse = tf.matmul(outputs_tlm_reverse, proj_tlm_W_reverse) + proj_tlm_b_reverse
        #             self.logits_tlm_reverse = self.logits_tlm_tmp_reverse[:,pos_size:] # FIX!!!!!!!!!
        #             # self.logits_nextword_reverse = self.logits_tlm_tmp_reverse[:,:vocab_size]
        #             self.logits_pos_reverse = self.logits_tlm_tmp_reverse[:,:pos_size]

        #             self.probs_tlm_reverse = tf.nn.softmax(self.logits_tlm_reverse)
        #             # self.probs_nextword_reverse = tf.nn.softmax(self.logits_nextword_reverse)
        #             self.probs_pos_reverse = tf.nn.softmax(self.logits_pos_reverse)

        #             loss_pretrain_tlm_reverse = tf.nn.sparse_softmax_cross_entropy_with_logits(
        #                labels=tf.reshape(self.tlm_targets_reverse, [-1]),
        #                logits=self.logits_tlm_reverse)
        #             loss_pretrain_tlm_reverse *= tf.reshape(self.tlm_weights, [-1])
        # #             loss_pretrain_nextword_reverse = tf.nn.sparse_softmax_cross_entropy_with_logits(
        # #                labels=tf.reshape(self.next_enc_inputs_reverse, [-1]),
        # #                logits=self.logits_nextword_reverse)
        # #             loss_pretrain_nextword_reverse *= tf.reshape(self.weights, [-1])
        #             loss_pretrain_pos_reverse = tf.nn.sparse_softmax_cross_entropy_with_logits(
        #                labels=tf.reshape(self.tlm_targets_pos_reverse, [-1]),
        #                logits=self.logits_pos_reverse)
        #             loss_pretrain_pos_reverse *= tf.reshape(self.tlm_weights, [-1])

        #         #     #self.tlm_tot_loss_0 = tf.reduce_sum(loss_pretrain_nextword)
        #         #     self.tlm_tot_loss_1 = tf.reduce_sum(loss_pretrain_tlm)
        #         #     self.tlm_tot_loss_2 = tf.reduce_sum(loss_pretrain_pos)
        #         #     self.tlm_tot_loss = self.tlm_tot_loss_1 + self.tlm_tot_loss_2 #+ self.tlm_tot_loss_2 #
        #         #     self.tlm_sent_loss_1 = self.tlm_tot_loss_1 / tf.to_float(self.batch_size)
        #         #     self.tlm_sent_loss_2 = self.tlm_tot_loss_2 / tf.to_float(self.batch_size)
        #         #     self.tlm_sent_loss = self.tlm_tot_loss / tf.to_float(self.batch_size)

        #         #     #self.tlm_tot_loss_0_reverse = tf.reduce_sum(loss_pretrain_nextword_reverse)
        #         #     self.tlm_tot_loss_1_reverse = tf.reduce_sum(loss_pretrain_tlm_reverse)
        #         #     self.tlm_tot_loss_2_reverse = tf.reduce_sum(loss_pretrain_pos_reverse)
        #         #     self.tlm_tot_loss_reverse = self.tlm_tot_loss_1_reverse + self.tlm_tot_loss_2_reverse #+ self.tlm_tot_loss_2_reverse
        #         #     self.tlm_sent_loss_1_reverse = self.tlm_tot_loss_1_reverse / tf.to_float(self.batch_size)
        #         #     self.tlm_sent_loss_2_reverse = self.tlm_tot_loss_2_reverse / tf.to_float(self.batch_size)
        #         #     self.tlm_sent_loss_reverse = self.tlm_tot_loss_reverse / tf.to_float(self.batch_size)

        #         # self.tlm_train_loss_1 = self.tlm_sent_loss_1+self.tlm_sent_loss_1_reverse
        #         # self.tlm_train_loss_2 = self.tlm_sent_loss_2+self.tlm_sent_loss_2_reverse

        #         # tlm_param = retrive_var(['tlm_projection','tlm','tlm_reverse'])
        #         # self.optimizer_tlm_1 = tf.train.AdamOptimizer(self.learning_rate,
        #         #     beta1, beta2).minimize(self.tlm_train_loss_1, var_list=tlm_param)
        #         # self.optimizer_tlm_2 = tf.train.AdamOptimizer(self.learning_rate,
        #         #     beta1, beta2).minimize(self.tlm_train_loss_2, var_list=tlm_param)
        ''' Implementing A_phi
        - An RNN that returns a vector at each position of x
        - We can interpret this vector as prob distn over output labels at that position
        - We first try an architecture of BiLSTM for A_phi
        '''

        with tf.variable_scope('phi_projection'):
            proj_W = tf.get_variable('W', [2 * dim_h, tag_size],
                                     dtype=tf.float32)  # 2 because of BiLSTM
            proj_b = tf.get_variable('b', [tag_size], dtype=tf.float32)

        with tf.variable_scope('phi'):
            cell_fw = create_cell(dim_h, self.dropout)
            cell_bw = create_cell(dim_h, self.dropout)
            initial_state_fw = cell_fw.zero_state(batch_size, dtype=tf.float32)
            initial_state_bw = cell_bw.zero_state(batch_size, dtype=tf.float32)

            logits_cnn = cnn(inputs_char[0], 'phi')  # batch size 1
            logits_cnn = tf.cast(logits_cnn, tf.float32)
            logits_cnn = tf.expand_dims(logits_cnn, 0)
            self.shape0 = tf.shape(logits_cnn)  # [20,64]

            outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                cell_fw,
                cell_bw,
                tf.concat([
                    inputs, logits_cnn, inputs_pos_onehot, inputs_chunk_onehot,
                    inputs_case_onehot, inputs_num_onehot
                ],
                          axis=-1),  #inputs_pos_onehot
                initial_state_fw=initial_state_fw,
                initial_state_bw=initial_state_bw,
                dtype=tf.float32,
                scope='phi')

            outputs = tf.concat(outputs, axis=-1)
            outputs = tf.nn.dropout(outputs, self.dropout)
            outputs = tf.reshape(outputs, [-1, 2 * dim_h])
            outputs = tf.cast(outputs, tf.float32)

            self.shape1 = tf.shape(outputs)  # [20,256]

        # affine transformation to get logits
        self.phi_logits = tf.matmul(
            tf.concat([outputs], axis=-1),
            proj_W) + proj_b  # shape is (batch_size*batch_length, 28)
        self.phi_probs = tf.nn.softmax(
            self.phi_logits)  # changed from sigmoid to softmax
        # But the thing is some of the logits do not count - we need to deal with it

        phi_probs_for_input = tf.reshape(
            self.phi_probs, [self.batch_size, self.batch_len, tag_size])
        phi_probs_for_input_reverse = tf.reshape(
            self.phi_probs[::-1, :],
            [self.batch_size, self.batch_len, tag_size])
        ''' Implementing energy function '''

        with tf.variable_scope('energy_function'):
            energy_U = tf.get_variable('energy_U', [tag_size, dim_d + 50],
                                       dtype=tf.float32)
            energy_W = tf.get_variable('energy_W', [tag_size, tag_size],
                                       dtype=tf.float32)

        # with tf.variable_scope('energy_feature_proj'):
        #     energy_proj_W = tf.get_variable('energy_proj_W', [2*dim_h, dim_d], dtype=tf.float32) # 2 because of BiLSTM
        #     energy_proj_b = tf.get_variable('energy_proj_b', [dim_d], dtype=tf.float32)

        with tf.variable_scope('energy_feature'):
            cell_fw = create_cell(dim_h, self.dropout)
            cell_bw = create_cell(dim_h, self.dropout)
            initial_state_fw = cell_fw.zero_state(batch_size, dtype=tf.float32)
            initial_state_bw = cell_bw.zero_state(batch_size, dtype=tf.float32)
            outputs, _ = tf.nn.bidirectional_dynamic_rnn(
                cell_fw,
                cell_bw,
                inputs,
                initial_state_fw=initial_state_fw,
                initial_state_bw=initial_state_bw,
                dtype=tf.float32,
                scope='energy_feature')

            outputs = tf.concat(outputs, axis=-1)
            outputs = tf.nn.dropout(outputs, self.dropout)
            outputs = tf.reshape(outputs, [-1, 2 * dim_h])
            outputs = tf.cast(outputs, tf.float32)

        with tf.variable_scope('energy_feature_pos'):
            cell_fw_pos = create_cell(25, self.dropout)
            cell_bw_pos = create_cell(25, self.dropout)
            initial_state_fw_pos = cell_fw_pos.zero_state(batch_size,
                                                          dtype=tf.float32)
            initial_state_bw_pos = cell_bw_pos.zero_state(batch_size,
                                                          dtype=tf.float32)
            outputs_pos, _ = tf.nn.bidirectional_dynamic_rnn(
                cell_fw_pos,
                cell_bw_pos,
                inputs_pos_onehot,
                initial_state_fw=initial_state_fw_pos,
                initial_state_bw=initial_state_bw_pos,
                dtype=tf.float32,
                scope='energy_feature_pos')

            outputs_pos = tf.concat(outputs_pos, axis=-1)
            outputs_pos = tf.nn.dropout(outputs_pos, self.dropout)
            outputs_pos = tf.reshape(outputs_pos, [-1, 2 * 25])
            outputs_pos = tf.cast(outputs_pos, tf.float32)

        # shape is (batch_size(2)*batch_length, 100)
        energy_feature_vec = tf.concat(
            [outputs, outputs_pos],
            axis=-1)  #tf.matmul(outputs, energy_proj_W) + energy_proj_b

        # concat with pos feature vec
        # fix energy_U etc dimension

        def energy_result(self, x, y, y_unscale_logits, x_nextword_onehot,
                          x_nextword_onehot_reverse, nextpos_onehot,
                          nextpos_onehot_reverse):

            M0 = tf.matmul(energy_U, tf.transpose(energy_feature_vec))
            tmp0 = tf.multiply(y, tf.transpose(M0))  # elt-wise
            energy_first_part = tf.reduce_sum(tmp0)

            #y_prime = tf.manip.roll(y, shift=1, axis=0)
            #y_prime = tf.concat([[tf.zeros([tag_size])], y_prime[1:]], axis=0) # check y has 28 as last dim

            y_prime = y[:-1]
            tmp1 = tf.multiply(tf.matmul(y_prime, energy_W),
                               y[1:])  # first y is tricky
            energy_second_part = tf.reduce_sum(tmp1)
            old_return = -(energy_first_part + energy_second_part)

            return old_return

        # should be the same function as above, but written here for convenience
        def energy_result_gold(self, x, y, y_unscale_logits, x_nextword_onehot,
                               x_nextword_onehot_reverse, nextpos_onehot,
                               nextpos_onehot_reverse):

            # note that energy_feature_vec will be looped around twice with batch_size 2
            M0 = tf.matmul(energy_U, tf.transpose(energy_feature_vec))
            tmp0 = tf.multiply(y, tf.transpose(M0))  # elt-wise
            energy_first_part = tf.reduce_sum(tmp0)

            #y_prime = tf.manip.roll(y, shift=1, axis=0)
            #y_prime = tf.concat([[tf.zeros([tag_size])], y_prime[1:]], axis=0) # check y has 28 as last dim

            y_prime = y[:-1]
            tmp1 = tf.multiply(tf.matmul(y_prime, energy_W),
                               y[1:])  # first y is tricky
            energy_second_part = tf.reduce_sum(tmp1)
            old_return = -(energy_first_part + energy_second_part)

            return old_return

        ''' Implementing phi and theta '''

        y_onehot = tf.one_hot(self.targets, depth=tag_size)
        y_onehot = tf.reshape(y_onehot, [-1, tag_size])
        tmp_delta_0 = tf.reduce_sum(self.phi_probs - y_onehot, axis=-1)
        tmp_delta_0 *= tf.reshape(self.weights, [-1])

        x_nextword_onehot = tf.one_hot(self.next_enc_inputs, depth=vocab_size)
        x_nextword_onehot = tf.reshape(x_nextword_onehot, [-1, vocab_size])

        x_nextword_onehot_reverse = tf.one_hot(self.next_enc_inputs_reverse,
                                               depth=vocab_size)
        x_nextword_onehot_reverse = tf.reshape(x_nextword_onehot_reverse,
                                               [-1, vocab_size])

        nextpos_onehot = tf.one_hot(self.tlm_targets_pos, depth=pos_size)
        nextpos_onehot = tf.reshape(nextpos_onehot, [-1, pos_size])

        nextpos_onehot_reverse = tf.one_hot(self.tlm_targets_pos_reverse,
                                            depth=pos_size)
        nextpos_onehot_reverse = tf.reshape(nextpos_onehot_reverse,
                                            [-1, pos_size])

        extra_reg_term = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.reshape(self.targets, [-1]), logits=self.phi_logits)
        extra_reg_term *= tf.reshape(self.weights, [-1])
        extra_reg_term = tf.reduce_sum(extra_reg_term) / tf.to_float(
            self.batch_size)

        # self.loss_phi *= tf.reshape(self.weights, [-1])
        # something like this
        loss_phi = delta(tmp_delta_0) - energy_result(
            self, inputs, self.phi_probs, self.phi_logits, x_nextword_onehot,
            x_nextword_onehot_reverse, nextpos_onehot, nextpos_onehot_reverse
        )  #+ energy_result_gold(self, inputs, y_onehot, y_onehot, x_nextword_onehot, x_nextword_onehot_reverse, nextpos_onehot, nextpos_onehot_reverse)
        loss_phi = -loss_phi
        self.loss_phi = extra_reg_term  #loss_phi + 0.5 * extra_reg_term #tf.maximum(loss_phi, 0.0) + 0.5 * extra_reg_term

        lambda_new = 1.0
        new_theta_term = lambda_new * (- energy_result(self, inputs, self.phi_probs, self.phi_logits, x_nextword_onehot, x_nextword_onehot_reverse, nextpos_onehot, nextpos_onehot_reverse) \
                                       + energy_result_gold(self, inputs, y_onehot, y_onehot, x_nextword_onehot, x_nextword_onehot_reverse, nextpos_onehot, nextpos_onehot_reverse))
        new_theta_term = tf.maximum(new_theta_term, -1.0)

        loss_theta = delta(tmp_delta_0) - energy_result(self, inputs, self.phi_probs, self.phi_logits, x_nextword_onehot, x_nextword_onehot_reverse, nextpos_onehot, nextpos_onehot_reverse) \
            + energy_result_gold(self, inputs, y_onehot, y_onehot, x_nextword_onehot, x_nextword_onehot_reverse, nextpos_onehot, nextpos_onehot_reverse)
        # + 0.0001 * retrive_var_regularize(['energy_function','energy_feature_proj','energy_feature']) # regularization
        loss_theta = tf.maximum(loss_theta, -1.0)
        self.loss_theta = loss_theta + new_theta_term
        ''' Optimization '''

        phi = retrive_var(['phi_projection', 'phi', 'embedding_char'])
        theta = retrive_var(
            ['energy_function', 'energy_feature',
             'energy_feature_pos'])  #,'tlm_projection','tlm','tlm_reverse'])
        self.optimizer_phi = tf.train.AdamOptimizer(
            self.learning_rate, beta1, beta2).minimize(self.loss_phi,
                                                       var_list=phi)
        self.optimizer_theta = tf.train.AdamOptimizer(
            self.learning_rate, beta1, beta2).minimize(self.loss_theta,
                                                       var_list=theta)

        psi = retrive_var(['phi_projection', 'phi'])
        self.loss_psi = energy_result(self, inputs, self.phi_probs,
                                      self.phi_logits, x_nextword_onehot,
                                      x_nextword_onehot_reverse,
                                      nextpos_onehot, nextpos_onehot_reverse)
        self.optimizer_psi = tf.train.AdamOptimizer(
            self.learning_rate, beta1, beta2).minimize(self.loss_psi,
                                                       var_list=psi)

        self.saver = tf.train.Saver()
Exemple #17
0
def main(argv):
    tf.logging.set_verbosity(tf.logging.INFO)
    trainer_lib.set_random_seed(FLAGS.random_seed)
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
    t2t_trainer.maybe_log_registry_and_exit()

    if FLAGS.cloud_mlengine:
        cloud_mlengine.launch()
        return

    if FLAGS.generate_data:
        t2t_trainer.generate_data()

    if cloud_mlengine.job_dir():
        FLAGS.output_dir = cloud_mlengine.job_dir()

    if argv:
        t2t_trainer.set_hparams_from_args(argv[1:])

    if FLAGS.surrogate_attack:
        tf.logging.warn("Performing surrogate model attack.")
        sur_hparams = create_surrogate_hparams()
        trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem)

    hparams = t2t_trainer.create_hparams()
    trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

    attack_params = create_attack_params()
    attack_params.add_hparam(attack_params.epsilon_name, 0.0)

    if FLAGS.surrogate_attack:
        sur_config = create_surrogate_run_config(sur_hparams)
    config = t2t_trainer.create_run_config(hparams)
    params = {
        "batch_size": hparams.batch_size,
        "use_tpu": FLAGS.use_tpu,
    }

    # add "_rev" as a hack to avoid image standardization
    problem = registry.problem(FLAGS.problem + "_rev")

    inputs, labels, features = prepare_data(problem, hparams, params, config)

    sess = tf.Session()

    if FLAGS.surrogate_attack:
        sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn(
            FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu)
        sur_ch_model = adv_attack_utils.T2TAttackModel(sur_model_fn,
                                                       features,
                                                       params,
                                                       sur_config,
                                                       scope="surrogate")
        # Dummy call to construct graph
        sur_ch_model.get_probs(inputs)

        checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
        tf.train.init_from_checkpoint(
            tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
        sess.run(tf.global_variables_initializer())

    other_vars = set(tf.global_variables())

    model_fn = t2t_model.T2TModel.make_estimator_model_fn(FLAGS.model, hparams)
    ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params,
                                               config)

    acc_mask = None
    probs = ch_model.get_probs(inputs)
    if FLAGS.ignore_incorrect:
        preds = tf.argmax(probs, -1, output_type=labels.dtype)
        preds = tf.reshape(preds, labels.shape)
        acc_mask = tf.to_float(tf.equal(labels, preds))
    one_hot_labels = tf.one_hot(labels, probs.shape[-1])

    if FLAGS.surrogate_attack:
        attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess)
    else:
        attack = create_attack(attack_params.attack)(ch_model, sess=sess)

    new_vars = set(tf.global_variables()) - other_vars

    # Restore weights
    saver = tf.train.Saver(new_vars)
    checkpoint_path = os.path.expanduser(FLAGS.output_dir)
    saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

    # reuse variables
    tf.get_variable_scope().reuse_variables()

    def compute_accuracy(x, l, mask):
        """Compute model accuracy."""
        preds = ch_model.get_probs(x)
        preds = tf.squeeze(preds)
        preds = tf.argmax(preds, -1, output_type=l.dtype)

        _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)

        if FLAGS.surrogate_attack:
            preds = sur_ch_model.get_probs(x)
            preds = tf.squeeze(preds)
            preds = tf.argmax(preds, -1, output_type=l.dtype)
            acc_update_op = tf.tuple(
                (acc_update_op, tf.metrics.accuracy(l, preds,
                                                    weights=mask)[1]))

        sess.run(tf.initialize_local_variables())
        for i in range(FLAGS.eval_steps):
            tf.logging.info("\tEvaluating batch [%d / %d]" %
                            (i + 1, FLAGS.eval_steps))
            acc = sess.run(acc_update_op)
        if FLAGS.surrogate_attack:
            tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
        else:
            tf.logging.info("\tFinal acc: %.4f" % acc)
        return acc

    epsilon_acc_pairs = []
    for epsilon in attack_params.attack_epsilons:
        tf.logging.info("Attacking @ eps=%.4f" % epsilon)
        attack_params.set_hparam(attack_params.epsilon_name, epsilon)
        adv_x = attack.generate(inputs,
                                y=one_hot_labels,
                                **attack_params.values())
        acc = compute_accuracy(adv_x, labels, acc_mask)
        epsilon_acc_pairs.append((epsilon, acc))

    for epsilon, acc in epsilon_acc_pairs:
        if FLAGS.surrogate_attack:
            tf.logging.info("Accuracy @ eps=%.4f: (%.4f, %.4f)" %
                            (epsilon, acc[0], acc[1]))
        else:
            tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
Exemple #18
0
def _preprocess_zero_mean_unit_range(inputs, dtype=tf.float32):
    """Map image values from [0, 255] to [-1, 1]."""
    preprocessed_inputs = (2.0 / 255.0) * tf.to_float(inputs) - 1.0
    return tf.cast(preprocessed_inputs, dtype=dtype)
Exemple #19
0
_h0 = tf.nn.sigmoid(tf.matmul(v0, W) + hb ) # Visible layer activation
h0 = tf.nn.relu(tf.sign(_h0 - tf.random_uniform(tf.shape(_h0))))  # Gibb's Sampling
# Phase 2: Reconstruction
_v1 = tf.nn.sigmoid(tf.matmul(h0, tf.transpose(W)) + vb )  # Hidden layer activation
v1 = tf.nn.relu(tf.sign(_v1 - tf.random_uniform(tf.shape(_v1))))
h1 = tf.nn.sigmoid(tf.matmul(v1, W) + hb )
#Set RBM Training Parameters """

# Learning rate
alpha = 1.0
# Create the gradients
w_pos_grad = tf.matmul(tf.transpose(v0), h0)
w_neg_grad = tf.matmul(tf.transpose(v1), h1)

# Calculate the Contrastive Divergence to maximize
CD = (w_pos_grad - w_neg_grad) / tf.to_float(tf.shape(v0)[0])

# Create methods to update the weights and biases
update_w = W + alpha * CD
update_vb = vb + alpha * tf.reduce_mean(v0 - v1, 0)
update_hb = hb + alpha * tf.reduce_mean(h0 - h1, 0)
# Set the error function, here we use Mean Absolute Error Function
err = v0 - v1
err_sum = tf.reduce_mean(err*err)
#""" Initialize our Variables with Zeroes using Numpy Library """

# Current weight
cur_w = np.zeros([visibleUnits, hiddenUnits])
# Current visible unit biases
cur_vb = np.zeros([visibleUnits], np.float32 )
# Current hidden unit biases
Exemple #20
0
    def build():
        """Builds the Tensorflow graph."""
        inputs, labels, lengths = None, None, None

        if mode in ('train', 'eval'):
            if isinstance(no_event_label, numbers.Number):
                label_shape = []
            else:
                label_shape = [len(no_event_label)]
            inputs, labels, lengths = magenta.common.get_padded_batch(
                sequence_example_file_paths,
                hparams.batch_size,
                input_size,
                label_shape=label_shape,
                shuffle=mode == 'train')

        elif mode == 'generate':
            inputs = tf.placeholder(tf.float32,
                                    [hparams.batch_size, None, input_size])

        if isinstance(encoder_decoder,
                      magenta.music.OneHotIndexEventSequenceEncoderDecoder):
            expanded_inputs = tf.one_hot(
                tf.cast(tf.squeeze(inputs, axis=-1), tf.int64),
                encoder_decoder.input_depth)
        else:
            expanded_inputs = inputs

        dropout_keep_prob = 1.0 if mode == 'generate' else hparams.dropout_keep_prob

        if hparams.use_cudnn:
            outputs, initial_state, final_state = make_cudnn(
                expanded_inputs,
                hparams.rnn_layer_sizes,
                hparams.batch_size,
                mode,
                dropout_keep_prob=dropout_keep_prob,
                residual_connections=hparams.residual_connections)

        else:
            cell = make_rnn_cell(
                hparams.rnn_layer_sizes,
                dropout_keep_prob=dropout_keep_prob,
                attn_length=hparams.attn_length,
                residual_connections=hparams.residual_connections)

            initial_state = cell.zero_state(hparams.batch_size, tf.float32)

            outputs, final_state = tf.nn.dynamic_rnn(
                cell,
                inputs,
                sequence_length=lengths,
                initial_state=initial_state,
                swap_memory=True)

        outputs_flat = magenta.common.flatten_maybe_padded_sequences(
            outputs, lengths)
        if isinstance(num_classes, numbers.Number):
            num_logits = num_classes
        else:
            num_logits = sum(num_classes)
        logits_flat = contrib_layers.linear(outputs_flat, num_logits)

        if mode in ('train', 'eval'):
            labels_flat = magenta.common.flatten_maybe_padded_sequences(
                labels, lengths)

            if isinstance(num_classes, numbers.Number):
                softmax_cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=labels_flat, logits=logits_flat)
                predictions_flat = tf.argmax(logits_flat, axis=1)
            else:
                logits_offsets = np.cumsum([0] + num_classes)
                softmax_cross_entropy = []
                predictions = []
                for i in range(len(num_classes)):
                    softmax_cross_entropy.append(
                        tf.nn.sparse_softmax_cross_entropy_with_logits(
                            labels=labels_flat[:, i],
                            logits=logits_flat[:, logits_offsets[i]:
                                               logits_offsets[i + 1]]))
                    predictions.append(
                        tf.argmax(
                            logits_flat[:,
                                        logits_offsets[i]:logits_offsets[i +
                                                                         1]],
                            axis=1))
                predictions_flat = tf.stack(predictions, 1)

            correct_predictions = tf.to_float(
                tf.equal(labels_flat, predictions_flat))
            event_positions = tf.to_float(
                tf.not_equal(labels_flat, no_event_label))
            no_event_positions = tf.to_float(
                tf.equal(labels_flat, no_event_label))

            # Compute the total number of time steps across all sequences in the
            # batch. For some models this will be different from the number of RNN
            # steps.
            def batch_labels_to_num_steps(batch_labels, lengths):
                num_steps = 0
                for labels, length in zip(batch_labels, lengths):
                    num_steps += encoder_decoder.labels_to_num_steps(
                        labels[:length])
                return np.float32(num_steps)

            num_steps = tf.py_func(batch_labels_to_num_steps,
                                   [labels, lengths], tf.float32)

            if mode == 'train':
                loss = tf.reduce_mean(softmax_cross_entropy)
                perplexity = tf.exp(loss)
                accuracy = tf.reduce_mean(correct_predictions)
                event_accuracy = (
                    tf.reduce_sum(correct_predictions * event_positions) /
                    tf.reduce_sum(event_positions))
                no_event_accuracy = (
                    tf.reduce_sum(correct_predictions * no_event_positions) /
                    tf.reduce_sum(no_event_positions))

                loss_per_step = tf.reduce_sum(
                    softmax_cross_entropy) / num_steps
                perplexity_per_step = tf.exp(loss_per_step)

                optimizer = tf.train.AdamOptimizer(
                    learning_rate=hparams.learning_rate)

                train_op = contrib_slim.learning.create_train_op(
                    loss, optimizer, clip_gradient_norm=hparams.clip_norm)
                tf.add_to_collection('train_op', train_op)

                vars_to_summarize = {
                    'loss': loss,
                    'metrics/perplexity': perplexity,
                    'metrics/accuracy': accuracy,
                    'metrics/event_accuracy': event_accuracy,
                    'metrics/no_event_accuracy': no_event_accuracy,
                    'metrics/loss_per_step': loss_per_step,
                    'metrics/perplexity_per_step': perplexity_per_step,
                }
            elif mode == 'eval':
                vars_to_summarize, update_ops = contrib_metrics.aggregate_metric_map(
                    {
                        'loss':
                        tf.metrics.mean(softmax_cross_entropy),
                        'metrics/accuracy':
                        tf.metrics.accuracy(labels_flat, predictions_flat),
                        'metrics/per_class_accuracy':
                        tf.metrics.mean_per_class_accuracy(
                            labels_flat, predictions_flat, num_classes),
                        'metrics/event_accuracy':
                        tf.metrics.recall(event_positions,
                                          correct_predictions),
                        'metrics/no_event_accuracy':
                        tf.metrics.recall(no_event_positions,
                                          correct_predictions),
                        'metrics/loss_per_step':
                        tf.metrics.mean(tf.reduce_sum(softmax_cross_entropy) /
                                        num_steps,
                                        weights=num_steps),
                    })
                for updates_op in update_ops.values():
                    tf.add_to_collection('eval_ops', updates_op)

                # Perplexity is just exp(loss) and doesn't need its own update op.
                vars_to_summarize['metrics/perplexity'] = tf.exp(
                    vars_to_summarize['loss'])
                vars_to_summarize['metrics/perplexity_per_step'] = tf.exp(
                    vars_to_summarize['metrics/loss_per_step'])

            for var_name, var_value in six.iteritems(vars_to_summarize):
                tf.summary.scalar(var_name, var_value)
                tf.add_to_collection(var_name, var_value)

        elif mode == 'generate':
            temperature = tf.placeholder(tf.float32, [])
            if isinstance(num_classes, numbers.Number):
                softmax_flat = tf.nn.softmax(
                    tf.div(logits_flat, tf.fill([num_classes], temperature)))
                softmax = tf.reshape(softmax_flat,
                                     [hparams.batch_size, -1, num_classes])
            else:
                logits_offsets = np.cumsum([0] + num_classes)
                softmax = []
                for i in range(len(num_classes)):
                    sm = tf.nn.softmax(
                        tf.div(
                            logits_flat[:,
                                        logits_offsets[i]:logits_offsets[i +
                                                                         1]],
                            tf.fill([num_classes[i]], temperature)))
                    sm = tf.reshape(sm,
                                    [hparams.batch_size, -1, num_classes[i]])
                    softmax.append(sm)

            tf.add_to_collection('inputs', inputs)
            tf.add_to_collection('temperature', temperature)
            tf.add_to_collection('softmax', softmax)
            # Flatten state tuples for metagraph compatibility.
            for state in tf_nest.flatten(initial_state):
                tf.add_to_collection('initial_state', state)
            for state in tf_nest.flatten(final_state):
                tf.add_to_collection('final_state', state)
Exemple #21
0
def _grow_topk(
    i,
    alive_seq,
    alive_log_probs,
    batch_size,
    beam_size,
    symbols_to_logits_fn,
    alpha,
    vocab_size,
    eos_id,
    decode_length,
):
    r"""Inner beam seach loop.

  This function takes the current alive sequences, and grows them to topk
  sequences where k = 2*beam. We use 2*beam because, we could have beam_size
  number of sequences that might hit <EOS> and there will be no alive
  sequences to continue. With 2*beam_size, this will not happen. This relies
  on the assumption the vocab size is > beam size. If this is true, we'll
  have at least beam_size non <EOS> extensions if we extract the next top
  2*beam words.
  Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
  https://arxiv.org/abs/1609.08144.

  Args:
    i: loop index
    alive_seq: Topk sequences decoded so far.
        Shape is [batch_size, beam_size, decode_length + 1].
    alive_log_probs: probabilities of these sequences. [batch_size, beam_size]
    batch_size: Integer specifying batch size.
    beam_size: Integer specifying beam size.
    symbols_to_logits_fn: Interface to the model, to provide logits.
        Shoud take [batch_size, decoded_ids] and return [batch_size, vocab_size]
    alpha: alpha for length penalty.
    vocab_size: Size of the vocab, must equal the size of the logits returned by
        symbols_to_logits_fn
    eos_id: ID for end of sentence.
    decode_length: Maximum length for decoded sequence.
  Returns:
    Tuple of
      (Topk sequences extended by the next word,
       The log probs of these sequences,
       The scores with length penalty of these sequences,
       Flags indicating which of these sequences have finished decoding)
  """
    # Get the logits for all the possible next symbols
    flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1])

    # (batch_size * beam_size, decoded_length)
    flat_logits = symbols_to_logits_fn(i, flat_ids)
    logits = tf.reshape(flat_logits, (batch_size, beam_size, -1))

    # Convert logits to normalized log probs
    candidate_log_probs = _log_prob_from_logits(logits)

    # Multiply the probabilites by the current probabilites of the beam.
    # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
    log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)

    length_penalty = tf.pow(((5.0 + tf.to_float(i + 1)) / 6.0), alpha)

    # curr_scores has shape [batch_size, beam_size, vocab_size].
    curr_scores = log_probs / length_penalty
    # Flatten out to have shape [batch_size, beam_size * vocab_size].
    # Note that vocab size is not always known statically because the extended
    # vocab size can vary based on the input length.
    flat_curr_scores = tf.reshape(curr_scores, [batch_size, -1])

    topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, beam_size * 2)

    # Recovering the log probs because we will need to send them back
    topk_log_probs = topk_scores * length_penalty

    # Work out what beam the top probs are in.
    topk_beam_index = topk_ids // vocab_size
    topk_ids %= vocab_size  # Unflatten the ids

    # The next three steps are to create coordinates for tf.gather_nd to pull
    # out the correct seqences from id's that we need to grow.
    # We will also use the coordinates to gather the booleans of the beam items
    # that survived.
    batch_pos = _compute_batch_indices(batch_size, beam_size * 2)

    # top beams will give us the actual coordinates to do the gather.
    # stacking will create a tensor of dimension batch * beam * 2, where the
    # last dimension contains the i,j gathering coordinates.
    topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2)

    # Gather up the most probable 2*beams both for the ids and finished_in_alive
    # bools
    topk_seq = tf.gather_nd(alive_seq, topk_coordinates)

    # Append the most probable alive
    topk_ids_padded = _one_hot_tensor_3d(topk_ids, i + 1, decode_length + 1)
    topk_seq += topk_ids_padded

    topk_finished = tf.equal(topk_ids, eos_id)

    return topk_seq, topk_log_probs, topk_scores, topk_finished
Exemple #22
0
    def _build_loss(self):
        """Builds the loss tensor, to be minimized by the optimizer."""
        self.reader = reader.DataReader(
            self.data_dir,
            self.batch_size,
            self.img_height,
            self.img_width,
            SEQ_LENGTH,
            1,  # num_scales
            self.file_extension,
            self.random_scale_crop,
            reader.FLIP_RANDOM,
            self.random_color,
            self.imagenet_norm,
            self.shuffle,
            self.input_file,
            queue_size=self.queue_size,
        )

        (
            self.image_stack,
            self.image_stack_norm,
            self.seg_stack,
            self.intrinsic_mat,
            _,
        ) = self.reader.read_data()
        if self.learn_intrinsics:
            self.intrinsic_mat = None
        if self.intrinsic_mat is None and not self.learn_intrinsics:
            raise RuntimeError(
                "Could not read intrinsic matrix. Turn "
                "learn_intrinsics on to learn it instead of loading "
                "it.")
        self.export("self.image_stack", self.image_stack)

        object_masks = []
        for i in range(self.batch_size):
            object_ids = tf.unique(tf.reshape(self.seg_stack[i], [-1]))[0]
            object_masks_i = []
            for j in range(SEQ_LENGTH):
                current_seg = self.seg_stack[i, :, :, j * 3]  # (H, W)

                def process_obj_mask(obj_id):
                    """Create a mask for obj_id, skipping the background mask."""
                    mask = tf.logical_and(
                        tf.equal(current_seg, obj_id),  # pylint: disable=cell-var-from-loop
                        tf.not_equal(tf.cast(0, tf.uint8), obj_id),
                    )
                    # Leave out vert small masks, that are most often errors.
                    size = tf.reduce_sum(tf.to_int32(mask))
                    mask = tf.logical_and(mask,
                                          tf.greater(size, MIN_OBJECT_AREA))
                    if not self.boxify:
                        return mask
                    # Complete the mask to its bounding box.
                    binary_obj_masks_y = tf.reduce_any(mask,
                                                       axis=1,
                                                       keepdims=True)
                    binary_obj_masks_x = tf.reduce_any(mask,
                                                       axis=0,
                                                       keepdims=True)
                    return tf.logical_and(binary_obj_masks_y,
                                          binary_obj_masks_x)

                object_mask = tf.map_fn(  # (N, H, W)
                    process_obj_mask, object_ids, dtype=tf.bool)
                object_mask = tf.reduce_any(object_mask, axis=0)
                object_masks_i.append(object_mask)
            object_masks.append(tf.stack(object_masks_i, axis=-1))

        self.seg_stack = tf.to_float(tf.stack(object_masks, axis=0))
        tf.summary.image("Masks", self.seg_stack)

        with tf.variable_scope(DEPTH_SCOPE):
            # Organized by ...[i][scale].  Note that the order is flipped in
            # variables in build_loss() below.
            self.disp = {}
            self.depth = {}

            # Parabolic rampup of he noise over LAYER_NORM_NOISE_RAMPUP_STEPS steps.
            # We stop at 0.5 because this is the value above which the multiplicative
            # noise we use can become negative. Further experimentation is needed to
            # find if non-negativity is indeed needed.
            noise_stddev = 0.5 * tf.square(
                tf.minimum(
                    tf.to_float(self.global_step) /
                    float(LAYER_NORM_NOISE_RAMPUP_STEPS),
                    1.0,
                ))

            def _normalizer_fn(x, is_train, name="bn"):
                return randomized_layer_normalization.normalize(
                    x, is_train=is_train, name=name, stddev=noise_stddev)

            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                for i in range(SEQ_LENGTH):
                    image = self.image_stack_norm[:, :, :, 3 * i:3 * (i + 1)]
                    self.depth[
                        i] = depth_prediction_net.depth_prediction_resnet18unet(
                            image, True, self.weight_reg, _normalizer_fn)
                    self.disp[i] = 1.0 / self.depth[i]

        with tf.name_scope("compute_loss"):
            self.reconstr_loss = 0
            self.smooth_loss = 0
            self.ssim_loss = 0
            self.depth_consistency_loss = 0

            # Smoothness.
            if self.smooth_weight > 0:
                for i in range(SEQ_LENGTH):
                    disp_smoothing = self.disp[i]
                    # Perform depth normalization, dividing by the mean.
                    mean_disp = tf.reduce_mean(disp_smoothing,
                                               axis=[1, 2, 3],
                                               keep_dims=True)
                    disp_input = disp_smoothing / mean_disp
                    self.smooth_loss += _depth_smoothness(
                        disp_input, self.image_stack[:, :, :,
                                                     3 * i:3 * (i + 1)])

            self.rot_loss = 0.0
            self.trans_loss = 0.0

            def add_result_to_loss_and_summaries(endpoints, i, j):
                tf.summary.image(
                    "valid_mask%d%d" % (i, j),
                    tf.expand_dims(endpoints["depth_proximity_weight"], -1),
                )

                self.depth_consistency_loss += endpoints["depth_error"]
                self.reconstr_loss += endpoints["rgb_error"]
                self.ssim_loss += 0.5 * endpoints["ssim_error"]
                self.rot_loss += endpoints["rotation_error"]
                self.trans_loss += endpoints["translation_error"]

            self.motion_smoothing = 0.0
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                for i in range(SEQ_LENGTH - 1):
                    j = i + 1
                    depth_i = self.depth[i][:, :, :, 0]
                    depth_j = self.depth[j][:, :, :, 0]
                    image_j = self.image_stack[:, :, :, 3 * j:3 * (j + 1)]
                    image_i = self.image_stack[:, :, :, i * 3:(i + 1) * 3]
                    # We select a pair of consecutive images (and their respective
                    # predicted depth maps). Now we have the network predict a motion
                    # field that connects the two. We feed the pair of images into the
                    # network, once in forward order and then in reverse order. The
                    # results are fed into the loss calculation. The following losses
                    # are calculated:
                    # - RGB and SSIM photometric consistency.
                    # - Cycle consistency of rotations and translations for every pixel.
                    # - L1 smoothness of the disparity and the motion field.
                    # - Depth consistency
                    rot, trans, trans_res, mat = motion_prediction_net.motion_field_net(
                        images=tf.concat([image_i, image_j], axis=-1),
                        weight_reg=self.weight_reg,
                    )
                    (
                        inv_rot,
                        inv_trans,
                        inv_trans_res,
                        inv_mat,
                    ) = motion_prediction_net.motion_field_net(
                        images=tf.concat([image_j, image_i], axis=-1),
                        weight_reg=self.weight_reg,
                    )

                    if self.learn_intrinsics:
                        intrinsic_mat = 0.5 * (mat + inv_mat)
                    else:
                        intrinsic_mat = self.intrinsic_mat[:, 0, :, :]

                    def dilate(x):
                        # Dilation by n pixels is roughtly max pooling by 2 * n + 1.
                        p = self.foreground_dilation * 2 + 1
                        return tf.nn.max_pool(x, [1, p, p, 1], [1] * 4, "SAME")

                    trans += trans_res * dilate(self.seg_stack[:, :, :,
                                                               j:j + 1])
                    inv_trans += inv_trans_res * dilate(
                        self.seg_stack[:, :, :, i:i + 1])

                    tf.summary.image("trans%d%d" % (i, i + 1), trans)
                    tf.summary.image("trans%d%d" % (i + 1, i), inv_trans)

                    tf.summary.image("trans_res%d%d" % (i + 1, i),
                                     inv_trans_res)
                    tf.summary.image("trans_res%d%d" % (i, i + 1), trans_res)

                    self.motion_smoothing += _smoothness(trans)
                    self.motion_smoothing += _smoothness(inv_trans)
                    tf.summary.scalar(
                        "trans_stdev",
                        tf.sqrt(0.5 * tf.reduce_mean(
                            tf.square(trans) + tf.square(inv_trans))),
                    )

                    transformed_depth_j = transform_depth_map.using_motion_vector(
                        depth_j, trans, rot, intrinsic_mat)

                    add_result_to_loss_and_summaries(
                        consistency_losses.rgbd_and_motion_consistency_loss(
                            transformed_depth_j,
                            image_j,
                            depth_i,
                            image_i,
                            rot,
                            trans,
                            inv_rot,
                            inv_trans,
                        ),
                        i,
                        j,
                    )

                    transformed_depth_i = transform_depth_map.using_motion_vector(
                        depth_i, inv_trans, inv_rot, intrinsic_mat)

                    add_result_to_loss_and_summaries(
                        consistency_losses.rgbd_and_motion_consistency_loss(
                            transformed_depth_i,
                            image_i,
                            depth_j,
                            image_j,
                            inv_rot,
                            inv_trans,
                            rot,
                            trans,
                        ),
                        j,
                        i,
                    )

            # Build the total loss as composed of L1 reconstruction, SSIM, smoothing
            # and object size constraint loss as appropriate.
            self.reconstr_loss *= self.reconstr_weight
            self.export("self.reconstr_loss", self.reconstr_loss)
            self.total_loss = self.reconstr_loss
            if self.smooth_weight > 0:
                self.smooth_loss *= self.smooth_weight
                self.total_loss += self.smooth_loss
                self.export("self.smooth_loss", self.smooth_loss)
            if self.ssim_weight > 0:
                self.ssim_loss *= self.ssim_weight
                self.total_loss += self.ssim_loss
                self.export("self.ssim_loss", self.ssim_loss)

            if self.motion_smoothing_weight > 0:
                self.motion_smoothing *= self.motion_smoothing_weight
                self.total_loss += self.motion_smoothing
                self.export("self.motion_sm_loss", self.motion_smoothing)

            if self.depth_consistency_loss_weight:
                self.depth_consistency_loss *= self.depth_consistency_loss_weight
                self.total_loss += self.depth_consistency_loss
                self.export("self.depth_consistency_loss",
                            self.depth_consistency_loss)

            self.rot_loss *= self.rotation_consistency_weight
            self.trans_loss *= self.translation_consistency_weight
            self.export("rot_loss", self.rot_loss)
            self.export("trans_loss", self.trans_loss)

            self.total_loss += self.rot_loss
            self.total_loss += self.trans_loss

            self.export("self.total_loss", self.total_loss)
Exemple #23
0
    def _loop_cond(
        i,
        unused_alive_seq,
        alive_log_probs,
        unused_finished_seq,
        finished_scores,
        finished_in_finished,
    ):
        """Checking termination condition.

    We terminate when we decoded up to decode_length or the lowest scoring item
    in finished has a greater score that the higest prob item in alive divided
    by the max length penalty. Optionally also terminate if all alive scores
    are below lower bound.

    Args:
      i: loop index
      alive_log_probs: probabilities of the beams. [batch_size, beam_size]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]
      finished_in_finished: finished bools for each of these sequences.
        [batch_size, beam_size]

    Returns:
      True to continue the loop, False to stop.
    """
        max_length_penalty = tf.pow(((5.0 + tf.to_float(decode_length)) / 6.0),
                                    alpha)
        # The best possible score of the most likley alive sequence
        lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty

        # Now to compute the lowest score of a finished sequence in finished
        # If the sequence isn't finished, we multiply it's score by 0. since
        # scores are all -ve, taking the min will give us the score of the lowest
        # finished item.
        lowest_score_of_finished_in_finished = tf.reduce_min(
            finished_scores * tf.to_float(finished_in_finished), axis=1)
        # If none of the sequences have finished, then the min will be 0 and
        # we have to replace it by -ve INF if it is. The score of any seq in alive
        # will be much higher than -ve INF and the termination condition will not
        # be met.
        lowest_score_of_finished_in_finished = _apply_negative_infinity_mask(
            lowest_score_of_finished_in_finished,
            tf.logical_not(tf.reduce_any(finished_in_finished, 1)),
        )

        # Will terminate beam search early if bound_is_met is True.
        bound_is_met = tf.reduce_all(
            tf.greater(lowest_score_of_finished_in_finished,
                       lower_bound_alive_scores))

        # Check if all alive scores are below minimum.
        if minimum_score:
            minimum_score_log = tf.log(minimum_score)
            bound_is_met = tf.logical_or(
                bound_is_met,
                tf.reduce_all(
                    tf.less(lower_bound_alive_scores, minimum_score_log)),
            )

        return tf.logical_and(tf.less(i, decode_length),
                              tf.logical_not(bound_is_met))
Exemple #24
0
    def render_envmap(self, cubes, cube_centers, cube_side_lengths,
                      cube_rel_shapes, cube_nest_inds, ref_pose, env_pose,
                      theta_res, phi_res, r_res):
        """Render environment map from volumetric lights.

    Args:
      cubes: input list of cubes in multiscale volume
      cube_centers: position of cube centers
      cube_side_lengths: side lengths of cubes
      cube_rel_shapes: size of "footprint" of each cube within next coarser cube
      cube_nest_inds: indices for cube "footprints"
      ref_pose: c2w pose of ref camera
      env_pose: c2w pose of environment map camera
      theta_res: resolution of theta (width) for environment map
      phi_res: resolution of phi (height) for environment map
      r_res: number of spherical shells to sample for environment map rendering

    Returns:
      An environment map at the input pose
    """
        num_scales = len(cubes)

        env_c2w = env_pose
        env2ref = tf.matmul(tf.matrix_inverse(ref_pose), env_c2w)

        # cube-->sphere resampling
        all_shells_list = []
        all_rad_list = []
        for i in range(num_scales):
            if i == num_scales - 1:
                # "finest" resolution cube, don't zero out
                cube_removed = cubes[i]
            else:
                # zero out areas covered by finer resolution cubes
                cube_shape = cubes[i].get_shape().as_list()[1]

                zm_y, zm_x, zm_z = tf.meshgrid(
                    tf.range(cube_nest_inds[i][0],
                             cube_nest_inds[i][0] + cube_rel_shapes[i]),
                    tf.range(cube_nest_inds[i][1],
                             cube_nest_inds[i][1] + cube_rel_shapes[i]),
                    tf.range(cube_nest_inds[i][2],
                             cube_nest_inds[i][2] + cube_rel_shapes[i]),
                    indexing='ij')
                inds = tf.stack([zm_y, zm_x, zm_z], axis=-1)
                updates = tf.to_float(tf.ones_like(zm_x))
                zero_mask = 1.0 - tf.scatter_nd(
                    inds, updates, shape=[cube_shape, cube_shape, cube_shape])
                cube_removed = zero_mask[tf.newaxis, :, :, :,
                                         tf.newaxis] * cubes[i]

            spheres_i, rad_i = pj.spherical_cubevol_resample(
                cube_removed, env2ref, cube_centers[i], cube_side_lengths[i],
                phi_res, theta_res, r_res)
            all_shells_list.append(spheres_i)
            all_rad_list.append(rad_i)

        all_shells = tf.concat(all_shells_list, axis=3)
        all_rad = tf.concat(all_rad_list, axis=0)
        all_shells = pj.interleave_shells(all_shells, all_rad)
        all_shells_envmap = pj.over_composite(all_shells)

        return all_shells_envmap, all_shells_list
Exemple #25
0
    def _log_prob(self, data, num_samples=1):
        """Compute a lower bound on the log likelihood."""
        # Due to memory issues, we need to use num_samples=1 here
        num_samples, proposal_num_samples = 1, num_samples
        batch_size = tf.shape(data)[0]
        # Sample from the proposal and compute the weighs of the "unseen" samples.
        # We share these across the batch dimension.
        # [num_samples, K, data_size]
        proposal_samples = self.proposal.sample(num_samples * (self.K - 1))
        if not self.reparameterize_proposal_samples:
            proposal_samples = tf.stop_gradient(proposal_samples)

        # [num_samples, K]
        log_energy_proposal = tf.reshape(
            self.energy_fn(tf.reshape(proposal_samples, [-1] + self.data_dim)),
            [num_samples, self.K - 1])
        tf.summary.histogram("log_energy_proposal", log_energy_proposal)
        tf.summary.scalar("min_log_energy_proposal",
                          tf.reduce_min(log_energy_proposal))
        tf.summary.scalar("max_log_energy_proposal",
                          tf.reduce_max(log_energy_proposal))
        # [num_samples]
        proposal_lse = tf.reduce_logsumexp(log_energy_proposal, axis=1)

        # [batch_size, num_samples]
        tiled_proposal_lse = tf.tile(proposal_lse[tf.newaxis, :],
                                     [batch_size, 1])

        # Compute the weights of the observed data.
        # [batch_size, 1]
        log_energy_data = tf.reshape(self.energy_fn(data), [batch_size])
        tf.summary.histogram("log_energy_data", log_energy_data)
        tf.summary.scalar("min_log_energy_data",
                          tf.reduce_min(log_energy_data))
        tf.summary.scalar("max_log_energy_data",
                          tf.reduce_max(log_energy_data))

        # [batch_size, num_samples]
        tiled_log_energy_data = tf.tile(log_energy_data[:, tf.newaxis],
                                        [1, num_samples])

        # Add the weights of the proposal samples with the true data weights.
        # [batch_size, num_samples]
        # pylint: disable=invalid-name
        Z_hat = tf.reduce_logsumexp(tf.stack(
            [tiled_log_energy_data, tiled_proposal_lse], axis=-1),
                                    axis=-1)
        Z_hat -= tf.log(tf.to_float(self.K))
        # Perform the log-sum-exp reduction for IWAE
        # [batch_size]
        Z_hat = tf.reduce_logsumexp(Z_hat, axis=1) - tf.log(
            tf.to_float(num_samples))
        # pylint: enable=invalid-name

        try:
            # Try giving the proposal lower bound num_samples if it can use it.
            proposal_lp = self.proposal.log_prob(
                data, num_samples=proposal_num_samples)
        except TypeError:
            proposal_lp = self.proposal.log_prob(data)
        lower_bound = proposal_lp + log_energy_data - Z_hat
        return lower_bound
Exemple #26
0
    def format_network_input(self, ref_image, psv_src_images, ref_pose,
                             psv_src_poses, planes, intrinsics):
        """Format the network input.

    Args:
      ref_image: reference source image [batch, height, width, 3]
      psv_src_images: stack of source images (excluding the ref image) [batch,
        height, width, 3*(#source)]
      ref_pose: reference camera-to-world pose (where PSV is constructed)
        [batch, 4, 4]
      psv_src_poses: input poses (camera to world) [batch, 4, 4, #source]
      planes: list of scalar depth values for each plane
      intrinsics: camera intrinsics [batch, 3, 3]

    Returns:
      net_input: [batch, height, width, #planes, (#source+1)*3]
    """

        batch_size = tf.shape(psv_src_images)[0]
        height = tf.shape(psv_src_images)[1]
        width = tf.shape(psv_src_images)[2]
        _, _, _, num_psv_source = psv_src_poses.get_shape().as_list()
        num_planes = tf.shape(planes)[0]

        filler = tf.concat(
            [tf.zeros([batch_size, 1, 3]),
             tf.ones([batch_size, 1, 1])],
            axis=2)
        intrinsics_filler = tf.stack([
            tf.to_float(height),
            tf.to_float(width),
            tf.to_float(intrinsics[0, 0, 0])
        ],
                                     axis=0)[:, tf.newaxis]

        ref_pose_c2w = ref_pose
        ref_pose_c2w = tf.concat([
            tf.concat([
                ref_pose_c2w[:, :3, 0:1], ref_pose_c2w[:, :3, 1:2],
                -1.0 * ref_pose_c2w[:, :3, 2:3], ref_pose_c2w[:, :3, 3:]
            ],
                      axis=2), filler
        ],
                                 axis=1)
        ref_pose_c2w = tf.concat([ref_pose_c2w[0, :3, :], intrinsics_filler],
                                 axis=1)

        net_input = []
        for i in range(num_psv_source):
            curr_pose_c2w = psv_src_poses[:, :, :, i]
            curr_pose_c2w = tf.concat([
                tf.concat([
                    curr_pose_c2w[:, :3, 0:1], curr_pose_c2w[:, :3, 1:2],
                    -1.0 * curr_pose_c2w[:, :3, 2:3], curr_pose_c2w[:, :3, 3:]
                ], 2), filler
            ], 1)
            curr_pose_c2w = tf.concat(
                [curr_pose_c2w[0, :3, :], intrinsics_filler], axis=1)
            curr_image = psv_src_images[:, :, :, i * 3:(i + 1) * 3]
            curr_psv = pj.make_psv_homogs(curr_image, curr_pose_c2w,
                                          ref_pose_c2w, 1.0 / planes,
                                          num_planes)
            net_input.append(curr_psv[tf.newaxis, Ellipsis])

        net_input = tf.concat(net_input, axis=4)
        ref_img_stack = tf.tile(tf.expand_dims(ref_image, 3),
                                [1, 1, 1, num_planes, 1])
        net_input = tf.concat([ref_img_stack, net_input], axis=4)
        net_input.set_shape([1, None, None, None, 3 * (num_psv_source + 1)])

        return net_input
Exemple #27
0
def accuracy(label, logits):
    """Computes accuracy from given label and logits."""
    return tf.reduce_mean(
        tf.to_float(tf.equal(label, tf.argmax(logits, axis=1))))
Exemple #28
0
    def mpi_render_view(self, input_mpi, ref_pose, tgt_pose, planes,
                        intrinsics):
        """Render a target view from MPI representation.

    Args:
      input_mpi: input MPI [batch, height, width, #planes, 4]
      ref_pose: reference camera pose [batch, 4, 4]
      tgt_pose: target pose to render from [batch, 4, 4]
      planes: list of depths for each plane
      intrinsics: camera intrinsics [batch, 3, 3]

    Returns:
      rendered view [batch, height, width, 3]
    """

        batch_size, _, _ = tgt_pose.get_shape().as_list()
        num_planes = tf.shape(planes)[0]
        height = tf.shape(input_mpi)[1]
        width = tf.shape(input_mpi)[2]

        rgba_layers = input_mpi

        # render target viewpoint
        filler = tf.concat(
            [tf.zeros([batch_size, 1, 3]),
             tf.ones([batch_size, 1, 1])],
            axis=2)
        intrinsics_filler = tf.stack(
            [tf.to_float(height),
             tf.to_float(width), intrinsics[0, 0, 0]],
            axis=0)[:, tf.newaxis]

        ref_pose_c2w = ref_pose
        ref_pose_c2w = tf.concat([
            tf.concat([
                ref_pose_c2w[:, :3, 0:1], ref_pose_c2w[:, :3, 1:2],
                -1.0 * ref_pose_c2w[:, :3, 2:3], ref_pose_c2w[:, :3, 3:]
            ],
                      axis=2), filler
        ],
                                 axis=1)
        ref_pose_c2w = tf.concat([ref_pose_c2w[0, :3, :], intrinsics_filler],
                                 axis=1)

        tgt_pose_c2w = tgt_pose
        tgt_pose_c2w = tf.concat([
            tf.concat([
                tgt_pose_c2w[:, :3, 0:1], tgt_pose_c2w[:, :3, 1:2],
                -1.0 * tgt_pose_c2w[:, :3, 2:3], tgt_pose_c2w[:, :3, 3:]
            ],
                      axis=2), filler
        ],
                                 axis=1)
        tgt_pose_c2w = tf.concat([tgt_pose_c2w[0, :3, :], intrinsics_filler],
                                 axis=1)

        rendering, alpha_acc, accum = pj.render_mpi_homogs(rgba_layers,
                                                           ref_pose_c2w,
                                                           tgt_pose_c2w,
                                                           1.0 / planes[0],
                                                           1.0 / planes[-1],
                                                           num_planes,
                                                           debug=False)

        return rendering, alpha_acc, accum
Exemple #29
0
def render_constellations(pred_points,
                          capsule_num,
                          canvas_size,
                          gt_points=None,
                          n_caps=2,
                          gt_presence=None,
                          pred_presence=None,
                          caps_presence_prob=None):
  """Renderes predicted and ground-truth points as gaussian blobs.

  Args:
    pred_points: [B, m, 2].
    capsule_num: [B, m] tensor indicating which capsule the corresponding point
      comes from. Plots from different capsules are plotted with different
      colors. Currently supported values: {0, 1, ..., 11}.
    canvas_size: tuple of ints
    gt_points: [B, k, 2]; plots ground-truth points if present.
    n_caps: integer, number of capsules.
    gt_presence: [B, k] binary tensor.
    pred_presence: [B, m] binary tensor.
    caps_presence_prob: [B, m], a tensor of presence probabilities for caps.

  Returns:
    [B, *canvas_size] tensor with plotted points
  """

  # convert coords to be in [0, side_length]
  pred_points = denormalize_coords(pred_points, canvas_size, rounded=True)

  # render predicted points
  batch_size, n_points = pred_points.shape[:2].as_list()
  capsule_num = tf.to_float(tf.one_hot(capsule_num, depth=n_caps))
  capsule_num = tf.reshape(capsule_num, [batch_size, n_points, 1, 1, n_caps, 1])

  color = tf.convert_to_tensor(_COLORS[:n_caps])
  color = tf.reshape(color, [1, 1, 1, 1, n_caps, 3]) * capsule_num
  color = tf.reduce_sum(color, -2)
  color = tf.squeeze(tf.squeeze(color, 3), 2)

  colored = render_by_scatter(canvas_size, pred_points, color, pred_presence)

  # Prepare a vertical separator between predicted and gt points.
  # Separator is composed of all supported colors and also serves as
  # a legend.
  # [b, h, w, 3]
  n_colors = _COLORS.shape[0]
  sep = tf.reshape(tf.convert_to_tensor(_COLORS), [1, 1, n_colors, 3])
  n_tiles = int(colored.shape[2]) // n_colors
  sep = snt.TileByDim([0, 1, 3], [batch_size, 3, n_tiles])(sep)
  sep = tf.reshape(sep, [batch_size, 3, n_tiles * n_colors, 3])

  pad = int(colored.shape[2]) - n_colors * n_tiles
  pad, r = pad // 2, pad % 2

  if caps_presence_prob is not None:
    n_caps = int(caps_presence_prob.shape[1])
    prob_pads = ([0, 0], [0, n_colors - n_caps])
    caps_presence_prob = tf.pad(caps_presence_prob, prob_pads)
    zeros = tf.zeros([batch_size, 3, n_colors, n_tiles, 3], dtype=tf.float32)

    shape = [batch_size, 1, n_colors, 1, 1]
    caps_presence_prob = tf.reshape(caps_presence_prob, shape)

    prob_vals = snt.MergeDims(2, 2)(caps_presence_prob + zeros)
    sep = tf.concat([sep, tf.ones_like(sep[:, :1]), prob_vals], 1)

  sep = tf.pad(sep, [(0, 0), (1, 1), (pad, pad + r), (0, 0)],
               constant_values=1.)

  # render gt points
  if gt_points is not None:
    gt_points = denormalize_coords(gt_points, canvas_size, rounded=True)

    gt_rendered = render_by_scatter(canvas_size, gt_points, colors=None,
                                    gt_presence=gt_presence)

    colored = tf.where(tf.cast(colored, bool), colored, gt_rendered)
    colored = tf.concat([gt_rendered, sep, colored], 1)

  res = tf.clip_by_value(colored, 0., 1.)
  return res
Exemple #30
0
 def target_width_fn():
     return tf.to_int32(
         tf.round(tf.to_float(orig_height) * new_aspect_ratio))