Exemple #1
0
def input_tensors_to_model_input(input_tensors,
                                 hparams,
                                 is_training,
                                 num_classes=constants.MIDI_PITCHES):
    """Processes an InputTensor into FeatureTensors and LabelTensors."""
    length = tf.cast(input_tensors.length, tf.int32)
    labels = tf.reshape(input_tensors.labels, (-1, num_classes))
    label_weights = tf.reshape(input_tensors.label_weights, (-1, num_classes))
    onsets = tf.reshape(input_tensors.onsets, (-1, num_classes))
    offsets = tf.reshape(input_tensors.offsets, (-1, num_classes))
    velocities = tf.reshape(input_tensors.velocities, (-1, num_classes))
    spec = tf.reshape(input_tensors.spec, (-1, hparams_frame_size(hparams)))

    # Slice specs and labels tensors so they are no longer than truncated_length.
    hparams_truncated_length = tf.cast(
        hparams.truncated_length_secs * hparams_frames_per_second(hparams),
        tf.int32)
    if hparams.truncated_length_secs:
        truncated_length = tf.reduce_min([hparams_truncated_length, length])
    else:
        truncated_length = length

    if is_training:
        truncated_note_sequence = tf.constant(0)
    else:
        truncated_note_sequence = truncate_note_sequence_op(
            input_tensors.note_sequence, truncated_length, hparams)

    # If max_expected_train_example_len is set, ensure that all examples are
    # padded to this length. This results in a fixed shape that can work on TPUs.
    if hparams.max_expected_train_example_len and is_training:
        # In this case, final_length is a constant.
        if hparams.truncated_length_secs:
            assert_op = tf.assert_equal(hparams.max_expected_train_example_len,
                                        hparams_truncated_length)
            with tf.control_dependencies([assert_op]):
                final_length = hparams.max_expected_train_example_len
        else:
            final_length = hparams.max_expected_train_example_len
    else:
        # In this case, it is min(hparams.truncated_length, length)
        final_length = truncated_length

    spec_delta = tf.shape(spec)[0] - final_length
    spec = tf.case([(spec_delta < 0,
                     lambda: tf.pad(spec, tf.stack([(0, -spec_delta),
                                                    (0, 0)]))),
                    (spec_delta > 0, lambda: spec[0:-spec_delta])],
                   default=lambda: spec)
    labels_delta = tf.shape(labels)[0] - final_length
    labels = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(labels, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: labels[0:-labels_delta])],
        default=lambda: labels)
    label_weights = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(label_weights, tf.stack([(0, -labels_delta),
                                                  (0, 0)]))),
         (labels_delta > 0, lambda: label_weights[0:-labels_delta])],
        default=lambda: label_weights)
    onsets = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(onsets, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: onsets[0:-labels_delta])],
        default=lambda: onsets)
    offsets = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(offsets, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: offsets[0:-labels_delta])],
        default=lambda: offsets)
    velocities = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(velocities, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: velocities[0:-labels_delta])],
        default=lambda: velocities)

    features = FeatureTensors(spec=tf.reshape(
        spec, (final_length, hparams_frame_size(hparams), 1)),
                              length=truncated_length,
                              sequence_id=tf.constant(0)
                              if is_training else input_tensors.sequence_id)
    labels = LabelTensors(
        labels=tf.reshape(labels, (final_length, num_classes)),
        label_weights=tf.reshape(label_weights, (final_length, num_classes)),
        onsets=tf.reshape(onsets, (final_length, num_classes)),
        offsets=tf.reshape(offsets, (final_length, num_classes)),
        velocities=tf.reshape(velocities, (final_length, num_classes)),
        note_sequence=truncated_note_sequence)

    if hparams.drum_data_map:
        labels_dict = labels._asdict()
        for k in ('labels', 'onsets', 'offsets'):
            labels_dict[k] = drum_mappings.map_pianoroll(
                labels_dict[k],
                mapping_name=hparams.drum_data_map,
                reduce_mode='any',
                min_pitch=constants.MIN_MIDI_PITCH)
        for k in ('label_weights', 'velocities'):
            labels_dict[k] = drum_mappings.map_pianoroll(
                labels_dict[k],
                mapping_name=hparams.drum_data_map,
                reduce_mode='max',
                min_pitch=constants.MIN_MIDI_PITCH)
        if labels_dict['note_sequence'].dtype == tf.string:
            labels_dict['note_sequence'] = tf.py_func(
                functools.partial(drum_mappings.map_sequences,
                                  mapping_name=hparams.drum_data_map),
                [labels_dict['note_sequence']],
                tf.string,
                name='get_drum_sequences',
                stateful=False)
            labels_dict['note_sequence'].set_shape(())
        labels = LabelTensors(**labels_dict)

    return features, labels
Exemple #2
0
def psnr(x_result, x_true, name='psnr'):
    with tf.name_scope(name):
        maxval = tf.reduce_max(x_true) - tf.reduce_min(x_true)
        mse = tf.reduce_mean((x_result - x_true)**2)
        return 20 * log10(maxval) - 10 * log10(mse)
 x1 = tf.matmul(net_dict["LLRa{0}".format(i)], W_odd2even)
 x2 = tf.add(x0, x1)
 x2 = tf.transpose(x2, [0, 2, 1])
 x2 = tf.reshape(x2, [batch_size, neurons_per_odd_layer * Z])
 x2 = tf.matmul(x2, Lift_Matrix1[0].transpose())
 x2 = tf.reshape(x2, [batch_size, neurons_per_odd_layer, Z])
 x2 = tf.transpose(x2, [0, 2, 1])
 x_tile = tf.tile(x2, multiples=[1, 1, neurons_per_odd_layer])
 W_input_reshape = tf.reshape(W_even2odd.transpose(), [-1])
 #check node update
 x_tile_mul = tf.multiply(x_tile, W_input_reshape)
 x2_1 = tf.reshape(
     x_tile_mul,
     [batch_size, Z, neurons_per_odd_layer, neurons_per_odd_layer])
 x2_abs = tf.add(tf.abs(x2_1), 10000 * (1 - tf.to_float(tf.abs(x2_1) > 0)))
 x3 = tf.reduce_min(x2_abs, axis=3)
 x2_2 = -x2_1
 x4 = tf.add(
     tf.zeros(
         (batch_size, Z, neurons_per_odd_layer, neurons_per_odd_layer)),
     1 - 2 * tf.to_float(x2_2 < 0))
 x4_prod = -tf.reduce_prod(x4, axis=3)
 x_output_0 = tf.multiply(x3, tf.sign(x4_prod))
 x_output_0 = tf.transpose(x_output_0, [0, 2, 1])
 x_output_0 = tf.reshape(x_output_0,
                         [batch_size, Z * neurons_per_odd_layer])
 x_output_0 = tf.matmul(x_output_0, Lift_Matrix2[0])
 x_output_0 = tf.reshape(x_output_0, [batch_size, neurons_per_odd_layer, Z])
 x_output_0 = tf.transpose(x_output_0, [0, 2, 1])
 # add learnable parameters
 x_output_1 = tf.add(
Exemple #4
0
def main(argv):

    del argv  # unused

    if FLAGS.checkpoint_dir is None:
        raise ValueError("`checkpoint_dir` must be defined")
    if FLAGS.data_dir is None:
        raise ValueError("`data_dir` must be defined")
    if FLAGS.output_dir is None:
        raise ValueError("`output_dir` must be defined")

    # Set up placeholders
    ref_image = tf.placeholder(dtype=tf.float32,
                               shape=[None, height, width, 3])
    ref_depth = tf.placeholder(dtype=tf.float32, shape=[None, height, width])
    intrinsics = tf.placeholder(dtype=tf.float32, shape=[None, 3, 3])
    ref_pose = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4])
    src_images = tf.placeholder(dtype=tf.float32,
                                shape=[None, height, width, 3])
    src_poses = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4, 1])
    env_pose = tf.placeholder(dtype=tf.float32, shape=[None, 4, 4])

    # Set up model
    model = MLV()

    # We use the true depth bounds for testing
    # Adjust to estimated bounds for your dataset
    min_depth = tf.reduce_min(ref_depth)
    max_depth = tf.reduce_max(ref_depth)

    # Set up graph
    mpi_planes = pj.inv_depths(min_depth, max_depth, num_planes)

    pred = model.infer_mpi(src_images, ref_image, ref_pose, src_poses,
                           intrinsics, mpi_planes)
    rgba_layers = pred["rgba_layers"]

    lightvols, lightvol_centers, \
    lightvol_side_lengths, \
    cube_rel_shapes, \
    cube_nest_inds = model.predict_lighting_vol(rgba_layers, mpi_planes,
                                                intrinsics, cube_res,
                                                scale_factors,
                                                depth_clip=depth_clip)
    lightvols_out = nets.cube_net_multires(lightvols, cube_rel_shapes,
                                           cube_nest_inds)
    output_envmap, _ = model.render_envmap(lightvols_out, lightvol_centers,
                                           lightvol_side_lengths,
                                           cube_rel_shapes, cube_nest_inds,
                                           ref_pose, env_pose, theta_res,
                                           phi_res, r_res)

    if not os.path.exists(FLAGS.output_dir):
        os.mkdir(FLAGS.output_dir)

    input_files = sorted(
        [f for f in os.listdir(FLAGS.data_dir) if f.endswith(".npz")])
    print("found {:05d} input files".format(len(input_files)))

    with tf.Session() as sess:
        saver = tf.train.Saver()
        saver.restore(sess, os.path.join(FLAGS.checkpoint_dir, "model.ckpt"))

        for i in range(0, len(input_files)):
            print("running example:", i)

            # Load inputs
            batch = np.load(FLAGS.data_dir + input_files[i])

            output_envmap_eval, = sess.run(
                [output_envmap],
                feed_dict={
                    ref_image: batch["ref_image"],
                    ref_depth: batch["ref_depth"],
                    intrinsics: batch["intrinsics"],
                    ref_pose: batch["ref_pose"],
                    src_images: batch["src_images"],
                    src_poses: batch["src_poses"],
                    env_pose: batch["env_pose"]
                })

            # Write environment map image
            plt.imsave(os.path.join(FLAGS.output_dir, "{:05d}.png".format(i)),
                       output_envmap_eval[0, :, :, :3])
Exemple #5
0
def rnn(cell,
        inputs,
        sequence_length=None,
        initial_state=None,
        ff_keep_prob=1.,
        recur_keep_prob=1.,
        enforce_dropout=False,
        dtype=tf.float32,
        scope=None):
    """ """

    inputs = tf.transpose(inputs, [1, 0, 2])  # (B,T,D) => (T,B,D)

    parallel_iterations = 32
    if sequence_length is not None:
        sequence_length = tf.to_int32(sequence_length)

    with tf.variable_scope(scope or 'RNN') as varscope:
        #if varscope.caching_device is None:
        #  varscope.set_caching_device(lambda op: op.device)
        input_shape = tf.shape(inputs)
        time_steps, batch_size, _ = tf.unstack(input_shape, 3)
        const_time_steps, const_batch_size, const_depth = inputs.get_shape(
        ).as_list()

        if initial_state is not None:
            state = initial_state
        else:
            if not dtype:
                raise ValueError(
                    'If no initial_state is provided, dtype must be.')
            state = cell.zero_state(batch_size, dtype)

        zero_output = tf.zeros(tf.stack([batch_size, cell.output_size]),
                               inputs.dtype)
        if sequence_length is not None:
            min_sequence_length = tf.reduce_min(sequence_length)
            max_sequence_length = tf.reduce_max(sequence_length)

        time = tf.constant(0, dtype=tf.int32, name='time')

        output_ta = tf.TensorArray(dtype=inputs.dtype,
                                   size=time_steps,
                                   tensor_array_name='dynamic_rnn_output')

        input_ta = tf.TensorArray(dtype=inputs.dtype,
                                  size=time_steps,
                                  tensor_array_name='dynamic_rnn_input')

        if ff_keep_prob < 1:
            noise_shape = tf.stack([1, batch_size, const_depth])
            if enforce_dropout is not None:
                inputs = tf.layers.dropout(inputs,
                                           1 - ff_keep_prob,
                                           noise_shape=noise_shape,
                                           training=enforce_dropout)
            else:
                inputs = tf.nn.dropout(inputs,
                                       ff_keep_prob,
                                       noise_shape=noise_shape)

        if recur_keep_prob < 1:
            ones = tf.ones(tf.stack([batch_size, cell.output_size]))
            if enforce_dropout is not None:
                state_dropout = tf.layers.dropout(ones,
                                                  1 - recur_keep_prob,
                                                  training=enforce_dropout)
            else:
                state_dropout = tf.nn.dropout(ones, recur_keep_prob)
            state_dropout = tf.concat(
                [ones] * (cell.state_size // cell.output_size - 1) +
                [state_dropout], 1)
        else:
            state_dropout = 1

        input_ta = input_ta.unstack(inputs)

        #-----------------------------------------------------------
        def _time_step(time, state, output_ta_t):
            """ """

            input_t = input_ta.read(time)

            #- - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            def _empty_update():
                return zero_output, state

            #- - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            def _call_cell():
                return cell(input_t, state * state_dropout)

            #- - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            def _maybe_copy_some_through():
                new_output, new_state = _call_cell()

                return tf.cond(
                    time < min_sequence_length, lambda:
                    (new_output, new_state), lambda: (tf.where(
                        time >= sequence_length, zero_output, new_output
                    ), tf.where(time >= sequence_length, state, new_state)))

            #- - - - - - - - - - - - - - - - - - - - - - - - - - - - -

            if sequence_length is not None:
                output, new_state = tf.cond(time >= max_sequence_length,
                                            _empty_update,
                                            _maybe_copy_some_through)
            else:
                (output, new_state) = _call_cell()

            output_ta_t = output_ta_t.write(time, output)

            return (time + 1, new_state, output_ta_t)

        #-----------------------------------------------------------

        _, final_state, output_final_ta = tf.while_loop(
            cond=lambda time, _1, _2: time < time_steps,
            body=_time_step,
            loop_vars=(time, state, output_ta),
            parallel_iterations=parallel_iterations)

        final_outputs = output_final_ta.stack()

        outputs = tf.transpose(final_outputs, [1, 0, 2])  # (T,B,D) => (B,T,D)
        return outputs, final_state
Exemple #6
0
    def model_fn(features, labels, mode, params=None):
        """Build model and optimizer."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Check training mode.
        if FLAGS.train_mode == 'pretrain':
            num_transforms = 2
            if FLAGS.fine_tune_after_block > -1:
                raise ValueError(
                    'Does not support layer freezing during pretraining,'
                    'should set fine_tune_after_block<=-1 for safety.')
        elif FLAGS.train_mode == 'finetune':
            num_transforms = 1
        else:
            raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))

        # Split channels, and optionally apply extra batched augmentation.
        features_list = tf.split(features,
                                 num_or_size_splits=num_transforms,
                                 axis=-1)
        if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
            features_list = data_util.batch_random_blur(
                features_list, FLAGS.image_size, FLAGS.image_size)
        features = tf.concat(features_list,
                             0)  # (num_transforms * bsz, h, w, c)

        # Base network forward pass.
        with tf.variable_scope('base_model'):
            if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
                # Finetune just supervised (linear) head will not update BN stats.
                model_train_mode = False
            else:
                # Pretrain or finetuen anything else will update BN stats.
                model_train_mode = is_training
            hiddens = model(features, is_training=model_train_mode)

        # Add head and loss.
        if FLAGS.train_mode == 'pretrain':
            tpu_context = params['context'] if 'context' in params else None
            hiddens_proj = model_util.projection_head(hiddens, is_training)
            contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
                hiddens_proj,
                hidden_norm=FLAGS.hidden_norm,
                temperature=FLAGS.temperature,
                tpu_context=tpu_context if is_training else None)
            logits_sup = tf.zeros([params['batch_size'], num_classes])
        else:
            contrast_loss = tf.zeros([])
            logits_con = tf.zeros([params['batch_size'], 10])
            labels_con = tf.zeros([params['batch_size'], 10])
            logits_sup = model_util.supervised_head(hiddens, num_classes,
                                                    is_training)
            obj_lib.add_supervised_loss(labels=labels['labels'],
                                        logits=logits_sup,
                                        weights=labels['mask'])

        # Add weight decay to loss, for non-LARS optimizers.
        model_util.add_weight_decay(adjust_per_optimizer=True)
        loss = tf.losses.get_total_loss()

        if FLAGS.train_mode == 'pretrain':
            variables_to_train = tf.trainable_variables()
        else:
            collection_prefix = 'trainable_variables_inblock_'
            variables_to_train = []
            for j in range(FLAGS.fine_tune_after_block + 1, 6):
                variables_to_train += tf.get_collection(collection_prefix +
                                                        str(j))
            assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

        tf.logging.info(
            '===============Variables to train (begin)===============')
        tf.logging.info(variables_to_train)
        tf.logging.info(
            '================Variables to train (end)================')

        learning_rate = model_util.learning_rate_schedule(
            FLAGS.learning_rate, num_train_examples)

        if is_training:
            if FLAGS.train_summary_steps > 0:
                # Compute stats for the summary.
                prob_con = tf.nn.softmax(logits_con)
                entropy_con = -tf.reduce_mean(
                    tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))

                summary_writer = tf2.summary.create_file_writer(
                    FLAGS.model_dir)
                # TODO(iamtingchen): remove this control_dependencies in the future.
                with tf.control_dependencies([summary_writer.init()]):
                    with summary_writer.as_default():
                        should_record = tf.math.equal(
                            tf.math.floormod(tf.train.get_global_step(),
                                             FLAGS.train_summary_steps), 0)
                        with tf2.summary.record_if(should_record):
                            contrast_acc = tf.equal(
                                tf.argmax(labels_con, 1),
                                tf.argmax(logits_con, axis=1))
                            contrast_acc = tf.reduce_mean(
                                tf.cast(contrast_acc, tf.float32))
                            label_acc = tf.equal(
                                tf.argmax(labels['labels'], 1),
                                tf.argmax(logits_sup, axis=1))
                            label_acc = tf.reduce_mean(
                                tf.cast(label_acc, tf.float32))
                            tf2.summary.scalar('train_contrast_loss',
                                               contrast_loss,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_contrast_acc',
                                               contrast_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_label_accuracy',
                                               label_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('contrast_entropy',
                                               entropy_con,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('learning_rate',
                                               learning_rate,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_mean',
                                               tf.reduce_mean(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_max',
                                               tf.reduce_max(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_min',
                                               tf.reduce_min(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('num_labels',
                                               tf.reduce_mean(
                                                   tf.reduce_sum(
                                                       labels['labels'], -1)),
                                               step=tf.train.get_global_step())

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum,
                                                       use_nesterov=True)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)
            elif FLAGS.optimizer == 'lars':
                optimizer = LARSOptimizer(
                    learning_rate,
                    momentum=FLAGS.momentum,
                    weight_decay=FLAGS.weight_decay,
                    exclude_from_weight_decay=['batch_normalization', 'bias'])
            else:
                raise ValueError('Unknown optimizer {}'.format(
                    FLAGS.optimizer))

            if FLAGS.use_tpu:
                optimizer = tf.tpu.CrossShardOptimizer(optimizer)

            control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if FLAGS.train_summary_steps > 0:
                control_deps.extend(tf.summary.all_v2_summary_ops())
            with tf.control_dependencies(control_deps):
                train_op = optimizer.minimize(
                    loss,
                    global_step=tf.train.get_or_create_global_step(),
                    var_list=variables_to_train)

            if FLAGS.checkpoint:

                def scaffold_fn():
                    """Scaffold function to restore non-logits vars from checkpoint."""
                    tf.train.init_from_checkpoint(
                        FLAGS.checkpoint, {
                            v.op.name: v.op.name
                            for v in tf.global_variables(FLAGS.variable_schema)
                        })

                    if FLAGS.zero_init_logits_layer:
                        # Init op that initializes output layer parameters to zeros.
                        output_layer_parameters = [
                            var for var in tf.trainable_variables()
                            if var.name.startswith('head_supervised')
                        ]
                        tf.logging.info(
                            'Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
                        with tf.control_dependencies(
                            [tf.global_variables_initializer()]):
                            init_op = tf.group([
                                tf.assign(x, tf.zeros_like(x))
                                for x in output_layer_parameters
                            ])
                        return tf.train.Scaffold(init_op=init_op)
                    else:
                        return tf.train.Scaffold()
            else:
                scaffold_fn = None

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     train_op=train_op,
                                                     loss=loss,
                                                     scaffold_fn=scaffold_fn)
        else:

            def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
                          **kws):
                """Inner metric function."""
                metrics = {
                    k: tf.metrics.mean(v, weights=mask)
                    for k, v in kws.items()
                }
                metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_sup, 1),
                    tf.argmax(logits_sup, axis=1),
                    weights=mask)
                metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
                metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_con, 1),
                    tf.argmax(logits_con, axis=1),
                    weights=mask)
                metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)
                return metrics

            metrics = {
                'logits_sup':
                logits_sup,
                'labels_sup':
                labels['labels'],
                'logits_con':
                logits_con,
                'labels_con':
                labels_con,
                'mask':
                labels['mask'],
                'contrast_loss':
                tf.fill((params['batch_size'], ), contrast_loss),
                'regularization_loss':
                tf.fill((params['batch_size'], ),
                        tf.losses.get_regularization_loss()),
            }

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=loss,
                                                     eval_metrics=(metric_fn,
                                                                   metrics),
                                                     scaffold_fn=None)
def dot_product_area_attention(q,
                               k,
                               v,
                               bias,
                               dropout_rate=0.0,
                               image_shapes=None,
                               name=None,
                               attention_image_summary=None,
                               save_weights_to=None,
                               dropout_broadcast_dims=None,
                               max_area_width=1,
                               max_area_height=1,
                               memory_height=1,
                               area_key_mode="mean",
                               area_value_mode="sum",
                               top_k_areas=0,
                               area_temperature=1.0,
                               training=True):
    """Dot-product area attention.

  Args:
    q: Tensor with shape [..., length_q, depth_k].
    k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must
      match with q.
    v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
      match with q.
    bias: bias Tensor (see attention_bias())
    dropout_rate: a float.
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    name: an optional string
    attention_image_summary: the callback for making image summary of attention.
    save_weights_to: an optional dictionary to capture attention weights
      for visualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    dropout_broadcast_dims: an optional list of integers less than rank of q.
      Specifies in which dimensions to broadcast the dropout decisions.
    max_area_width: the max width allowed for an area.
    max_area_height: the max height allowed for an area.
    memory_height: the height of the memory.
    area_key_mode: the mode for computing area keys, which can be "mean",
      "concat", "sum", "sample_concat", and "sample_sum".
    area_value_mode: the mode for computing area values, which can be either
      "mean", or "sum".
    top_k_areas: Use the top key areas for attention.
    area_temperature: the temperature for attention softmax.
    training: indicating if it is in the training mode.
  Returns:
    Tensor with shape [..., length_q, depth_v].
  """

    tf.logging.info(
        "dot_product_area_attention: "
        "area_h=%d, area_w=%d, mem_h=%d, "
        "area_key_mode=%s, area_value_mode=%s, "
        "area_temperature=%f", max_area_height, max_area_width, memory_height,
        area_key_mode, area_value_mode, area_temperature)
    with tf.variable_scope(name,
                           default_name="dot_product_area_attention",
                           values=[q, k, v]) as scope:
        mem_shape = common_layers.shape_list(k)
        batch_size = mem_shape[0]
        head_size = mem_shape[1]
        length = mem_shape[2]
        depth = mem_shape[3]
        k_area = compute_area_key(tf.reshape(k, [-1, length, depth]),
                                  max_area_width=max_area_width,
                                  max_area_height=max_area_height,
                                  height=memory_height,
                                  mode=area_key_mode,
                                  training=training)
        if area_value_mode == "mean":
            v_area, _, _, _, _ = compute_area_features(
                tf.reshape(v, [-1, length, depth]),
                max_area_width=max_area_width,
                max_area_height=max_area_height,
                height=memory_height)
        elif area_value_mode == "max":
            v_area, _, _ = basic_pool(tf.reshape(v, [-1, length, depth]),
                                      max_area_width=max_area_width,
                                      max_area_height=max_area_height,
                                      height=memory_height,
                                      fn=tf.reduce_max)
        elif area_value_mode == "sum":
            _, _, v_area, _, _ = compute_area_features(
                tf.reshape(v, [-1, length, depth]),
                max_area_width=max_area_width,
                max_area_height=max_area_height,
                height=memory_height)
        else:
            raise ValueError("Unsupported area value mode=%s" %
                             area_value_mode)
        k = tf.reshape(k_area, [batch_size, head_size, -1, depth])
        v = tf.reshape(v_area, [batch_size, head_size, -1, depth])
        logits = tf.matmul(q, k,
                           transpose_b=True)  # [..., length_q, length_kv]
        if bias is not None:
            bias = common_layers.cast_like(bias, logits)
            with tf.name_scope("compute_area_att_bias", values=[bias]):
                bias_shape = common_layers.shape_list(bias)
                mem_length = bias_shape[-1]
                bias_values = tf.reshape(tf.to_float(tf.less(bias, -1)),
                                         [-1, mem_length, 1])
                _, _, padding_sum, _, _ = compute_area_features(
                    bias_values,
                    max_area_width=max_area_width,
                    max_area_height=max_area_height,
                    height=memory_height)
                bias = tf.where(tf.cast(tf.to_int32(padding_sum), tf.bool),
                                tf.fill(tf.shape(padding_sum), -np.inf),
                                tf.zeros_like(padding_sum, dtype=tf.float32))
                bias = tf.reshape(
                    bias, [bias_shape[0], bias_shape[1], bias_shape[2], -1])
            logits += bias
        logits = logits / area_temperature
        weights = tf.nn.softmax(logits, name="attention_weights")
        if top_k_areas > 0:
            tf.logging.info("area_attention top_k_areas=%d", top_k_areas)
            top_k = tf.minimum(
                common_layers.shape_list(weights)[-1], top_k_areas)
            top_weights, _ = tf.nn.top_k(weights, k=top_k)
            min_values = tf.reduce_min(top_weights, -1, keepdims=True)
            weights = tf.where(tf.greater_equal(weights, min_values), weights,
                               tf.zeros_like(weights))
            weights = tf.div(weights, tf.reduce_sum(weights, -1,
                                                    keepdims=True))
        if save_weights_to is not None:
            save_weights_to[scope.name] = weights
            save_weights_to[scope.name + "/logits"] = logits
        # Drop out attention links for each head.
        weights = common_layers.dropout_with_broadcast_dims(
            weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims)
        if common_layers.should_generate_summaries(
        ) and attention_image_summary:
            attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
def discretize_in_bins(x):
  """Discretize a vector in two bins."""
  return tf.histogram_fixed_width_bins(
      x, [tf.reduce_min(x), tf.reduce_max(x)], nbins=2)
Exemple #9
0
def make_density_image_summary(num_pts, bounds, model):
    x = tf.range(bounds[0],
                 bounds[1],
                 delta=(bounds[1] - bounds[0]) / float(num_pts))
    X, Y = tf.meshgrid(x, x)
    XY = tf.reshape(tf.stack([X, Y], axis=-1), [num_pts**2, 2])
    tf_viridis = lambda x: tf.py_func(cm.get_cmap('viridis'), [x],
                                      [tf.float64])
    if FLAGS.model == "aem":
        log_p_hat, log_q = model.log_p(
            XY,
            num_importance_samples=FLAGS.num_importance_samples,
            summarize=False)
        density_p = tf.reshape(tf.exp(log_p_hat), [num_pts, num_pts])
        density_p = (density_p - tf.reduce_min(density_p)) / (
            tf.reduce_max(density_p) - tf.reduce_min(density_p))
        density_q = tf.reshape(tf.exp(log_q), [num_pts, num_pts])
        density_q = (density_q - tf.reduce_min(density_q)) / (
            tf.reduce_max(density_q) - tf.reduce_min(density_q))

        density_p_plot = tf_viridis(density_p)
        density_q_plot = tf_viridis(density_q)
        tf.summary.image("p_density",
                         density_p_plot,
                         max_outputs=1,
                         collections=["infrequent_summaries"])
        tf.summary.image("q_density",
                         density_q_plot,
                         max_outputs=1,
                         collections=["infrequent_summaries"])
    elif FLAGS.model == "eim":
        log_p_hat = model.log_p(
            XY,
            num_importance_samples=FLAGS.num_importance_samples,
            summarize=False)
        density_p = tf.reshape(tf.exp(log_p_hat), [num_pts, num_pts])
        density_p = (density_p - tf.reduce_min(density_p)) / (
            tf.reduce_max(density_p) - tf.reduce_min(density_p))
        density_p_plot = tf_viridis(density_p)
        tf.summary.image("p_density",
                         density_p_plot,
                         max_outputs=1,
                         collections=["infrequent_summaries"])
        _, q_dist = model.arnn(XY)
        density_q = tf.reshape(tf.reduce_sum(q_dist.prob(XY), axis=-1),
                               [num_pts, num_pts])
        density_q = (density_q - tf.reduce_min(density_q)) / (
            tf.reduce_max(density_q) - tf.reduce_min(density_q))
        density_q_plot = tf_viridis(density_q)
        tf.summary.image("q_density",
                         density_q_plot,
                         max_outputs=1,
                         collections=["infrequent_summaries"])
    elif (FLAGS.model == "energy_resnet_ssm" or FLAGS.model == "aem_ssm"
          or FLAGS.model == "aem_arsm" or FLAGS.model == "gaussian_ssm"):
        log_energy = model.log_energy(XY, summarize=False)
        density_p = tf.reshape(tf.exp(log_energy), [num_pts, num_pts])
        density_p = (density_p - tf.reduce_min(density_p)) / (
            tf.reduce_max(density_p) - tf.reduce_min(density_p))
        density_p_plot = tf_viridis(density_p)
        tf.summary.image("p_density",
                         density_p_plot,
                         max_outputs=1,
                         collections=["infrequent_summaries"])
Exemple #10
0
def barrier(S0, K, B, tau, r, q, v, M, N):
    """Price a barrier option"""
    S = paths(S0, tau, r, q, v, M, N)
    l = tf.to_float(tf.greater(tf.reduce_min(S, 0), B))
    payoffs = l * tf.maximum(S[-1, :] - K, 0)
    return tf.exp(-r * tau) * tf.reduce_mean(payoffs)
def _scan_step_fn(state, example, packed_length, queue_size, spacing,
                  num_sequences, token_dtype):  # pylint: disable=g-doc-args
    """Transform function used by tf.data.experimental.scan to process an example.

  This is written as a stateless function rather than a class method because we
  trace it with AutoGraph (in order to simplify the conditional), and this way
  we don't have to worry about handling re-tracing semantics.

  Args:
    See the SequenceDatasetPacker class.

  Returns:
    The updated queue state, and either a packed example or a dummy sequence
    which will be filtered out downstream.
  """

    # Convert TensorArray tuples to lists since we'll need to replace them.
    availability, contents, top_index = state

    lengths = tf.concat([tf.shape(i) for i in example], axis=0)
    start_availability = availability.stack()
    can_fit = tf.reduce_all(tf.greater_equal(start_availability, lengths),
                            axis=1)
    any_can_fit = tf.reduce_any(can_fit, axis=0)

    # AutoGraph will convert this block to a tf.cond
    if any_can_fit:
        # This indicates where in the FFD queue rotation a given index sits
        shifted_range = (tf.range(queue_size, dtype=INDEX_DTYPE) -
                         top_index) % queue_size

        # Mark any indices which cannot accommodate the current example.
        exclusion_mask = tf.cast(tf.logical_not(can_fit),
                                 INDEX_DTYPE) * queue_size

        # Index in [0, queue_size) in which to place the sample. Note, this index
        # is the position in the actual TensorArray, not the index of the FFD queue.
        queue_index = (tf.reduce_min(shifted_range + exclusion_mask) +
                       top_index) % queue_size

        # NOTE(taylorrobie): We emit a non-empty Tensor for downstream checks.
        output_contents = -tf.ones((1, num_sequences), dtype=token_dtype)

    else:
        index_range = top_index * packed_length + tf.range(packed_length)
        output_contents = contents.gather(index_range)

        # Reset the queue state.
        availability = availability.write(
            top_index,
            packed_length * tf.ones((num_sequences, ), dtype=INDEX_DTYPE))
        empty_contents = tf.zeros((packed_length, num_sequences * 2),
                                  dtype=token_dtype)
        contents = contents.scatter(index_range, empty_contents)

        queue_index = top_index
        top_index = (top_index + 1) % queue_size

    pre_assign_availability = availability.read(queue_index)
    space_left = pre_assign_availability - lengths - spacing
    availability = availability.write(queue_index, space_left)

    # ============================================================================
    # == Update contents =========================================================
    # ============================================================================
    # Consider the following case for a seq-to-seq packing:
    #   (padding is represented as underscores)
    #
    #   Queue starting state:
    #     [1, 3, 2, 4, 6, 1, _, _, _, _, _, ...]
    #     [5, 9, _, _, _, _, _, _, _, _, _, ...]
    #
    #   Examples:
    #     [4, 2, 4], [3]
    #
    #   Desired new queue state:
    #     [1, 3, 2, 4, 6, 1, _, _, 4, 2, 4, _, _, ...]
    #     [5, 9, _, _, 3, _, _, _, _, _, _, _, _, ...]
    #
    # This could be acomplished by creating a TensorArray for each of the two
    # sequences, and scattering into the respective arrays. However TensorArray
    # writes are extremely expensive relative to other operations. So instead we
    # store the contents in a single TensorArray of shape (packed_length, 2), and
    # we pad and concatenate the examples such that they can be added in a single
    # assign:
    #
    #              [_, _, _, _, 4, 2, 4]
    #              [3, _, _, _, _, _, _]
    #                        +
    #  [1, 3, 2, 4, 6, 1, _, _, _, _, _, ...]
    #  [5, 9, _, _, _, _, _, _, _, _, _, ...]
    #
    # And in practice, the extra work of padding is neglidgable compared to
    # the gain from vectorizing the TensorArray assign. We also store a bit mask
    # denoting where sequences start which is used to compute segment and
    # position metadata:
    #
    #              [_, _, _, _, 1, _, _]
    #              [1, _, _, _, _, _, _]
    #                        +
    #  [1, _, _, _, _, _, _, _, _, _, _, ...]
    #  [1, _, _, _, _, _, _, _, _, _, _, ...]
    #
    # Both the contents and the mask are concatenated in the same TensorArray
    # for performance.

    start_index = packed_length - pre_assign_availability
    end_index = start_index + lengths
    leftmost = tf.reduce_min(start_index, axis=0)
    rightmost = tf.reduce_max(end_index, axis=0)
    delta = rightmost - leftmost
    pad_indices = [
        tf.stack((start_index[i] - leftmost, rightmost - end_index[i]))
        for i in range(num_sequences)
    ]

    padded_examples = [
        tf.pad(ex, padding[tf.newaxis, :])
        for ex, padding in zip(example, pad_indices)
    ]
    padded_examples = tf.transpose(tf.stack(padded_examples))
    mask_update = tf.one_hot(start_index - leftmost,
                             delta,
                             dtype=contents.dtype,
                             axis=0)

    content_update = tf.concat([padded_examples, mask_update], axis=1)

    index_range = (
        queue_index * packed_length +  # Offset into the right section.
        tf.range(delta, dtype=INDEX_DTYPE) + leftmost)
    contents = contents.scatter(index_range,
                                contents.gather(index_range) + content_update)

    state = (availability, contents, top_index)
    return state, (tf.logical_not(any_can_fit), output_contents)
def spherical_cubevol_resample(vol, env2ref, cube_center, side_length, n_phi,
                               n_theta, n_r):
    """Resample cube volume onto spherical coordinates centered at target point.

  Args:
    vol: [B,H,W,D,C], input volume
    env2ref: [B,4,4], relative pose transformation (transform env to ref)
    cube_center: [B,3], [x,y,z] coordinates for center of cube volume
    side_length: side length of cube
    n_phi: number of samples along vertical spherical coordinate dim
    n_theta: number of samples along horizontal spherical coordinate dim
    n_r: number of samples along radius spherical coordinate dim

  Returns:
    resampled: [B, n_phi, n_theta, n_r, C]
  """

    batch_size = tf.shape(vol)[0]
    height = tf.shape(vol)[1]

    cube_res = tf.to_float(height)

    # create spherical coordinates
    b_vals = tf.to_float(tf.range(batch_size))
    phi_vals = tf.linspace(0.0, np.pi, n_phi)
    theta_vals = tf.linspace(1.5 * np.pi, -0.5 * np.pi, n_theta)

    # compute radii to use
    x_vals = tf.linspace(-side_length / 2.0, side_length / 2.0,
                         tf.to_int32(cube_res))
    y_vals = tf.linspace(-side_length / 2.0, side_length / 2.0,
                         tf.to_int32(cube_res))
    z_vals = tf.linspace(side_length / 2.0, -side_length / 2.0,
                         tf.to_int32(cube_res))
    y_c, x_c, z_c = tf.meshgrid(y_vals, x_vals, z_vals, indexing='ij')
    x_c = x_c + cube_center[:, 0, tf.newaxis, tf.newaxis, tf.newaxis]
    y_c = y_c + cube_center[:, 1, tf.newaxis, tf.newaxis, tf.newaxis]
    z_c = z_c + cube_center[:, 2, tf.newaxis, tf.newaxis, tf.newaxis]
    cube_coords = tf.stack([x_c, y_c, z_c], axis=4)
    min_r = tf.reduce_min(
        tf.norm(cube_coords -
                env2ref[:, :3, 3][:, tf.newaxis, tf.newaxis, tf.newaxis, :],
                axis=4),
        axis=[0, 1, 2, 3])  # side_length / cube_res
    max_r = tf.reduce_max(
        tf.norm(cube_coords -
                env2ref[:, :3, 3][:, tf.newaxis, tf.newaxis, tf.newaxis, :],
                axis=4),
        axis=[0, 1, 2, 3])

    r_vals = tf.linspace(max_r, min_r, n_r)
    b, phi, theta, r = tf.meshgrid(b_vals,
                                   phi_vals,
                                   theta_vals,
                                   r_vals,
                                   indexing='ij')  # currently in env frame

    # transform spherical coordinates into cartesian
    # (currently in env frame, z points forwards)
    x = r * tf.cos(theta) * tf.sin(phi)
    z = r * tf.sin(theta) * tf.sin(phi)
    y = r * tf.cos(phi)

    # transform coordinates into ref frame
    sphere_coords = tf.stack([x, y, z, tf.ones_like(x)], axis=-1)[Ellipsis,
                                                                  tf.newaxis]
    sphere_coords_ref = tfmm(env2ref, sphere_coords)
    x = sphere_coords_ref[Ellipsis, 0, 0]
    y = sphere_coords_ref[Ellipsis, 1, 0]
    z = sphere_coords_ref[Ellipsis, 2, 0]

    # transform coordinates into vol indices
    x_inds = (x - cube_center[:, 0, tf.newaxis, tf.newaxis, tf.newaxis] +
              side_length / 2.0) * ((cube_res - 1) / side_length)
    y_inds = -(y - cube_center[:, 1, tf.newaxis, tf.newaxis, tf.newaxis] -
               side_length / 2.0) * ((cube_res - 1) / side_length)
    z_inds = -(z - cube_center[:, 2, tf.newaxis, tf.newaxis, tf.newaxis] -
               side_length / 2.0) * ((cube_res - 1) / side_length)
    sphere_coords_inds = tf.stack([b, x_inds, y_inds, z_inds], axis=-1)

    # trilinear interpolation gather from volume
    # interpolate pre-multiplied RGBAs, then un-pre-multiply
    vol_alpha = tf.clip_by_value(vol[Ellipsis, -1:], 0.0, 1.0)
    vol_channels_p = vol[Ellipsis, :-1] * vol_alpha
    vol_p = tf.concat([vol_channels_p, vol_alpha], axis=-1)

    resampled_p = sampling.trilerp_gather(vol_p, sphere_coords_inds)

    resampled_alpha = resampled_p[Ellipsis, -1:]
    resampled_channels = resampled_p[Ellipsis, :-1] / (resampled_alpha + 1e-8)
    resampled = tf.concat([resampled_channels, resampled_alpha], axis=-1)

    return resampled, r_vals
Exemple #13
0
    def _quantizable_concat(self,
                            inputs,
                            axis,
                            is_training,
                            is_quantized=True,
                            default_min=0,
                            default_max=6,
                            ema_decay=0.999,
                            scope='quantized_concat'):
        """Concat replacement with quantization option.

    Allows concat inputs to share the same min max ranges,
    from experimental/gazelle/synthetic/model/tpu/utils.py.

    Args:
      inputs: list of tensors to concatenate.
      axis: dimension along which to concatenate.
      is_training: true if the graph is a training graph.
      is_quantized: flag to enable/disable quantization.
      default_min: default min value for fake quant op.
      default_max: default max value for fake quant op.
      ema_decay: the moving average decay for the quantization variables.
      scope: Optional scope for variable_scope.

    Returns:
      Tensor resulting from concatenation of input tensors
    """
        if is_quantized:
            with tf.variable_scope(scope):
                min_var = self._quant_var('min', default_min)
                max_var = self._quant_var('max', default_max)
                if not is_training:
                    # If we are building an eval graph just use the values in the
                    # variables.
                    quant_inputs = [
                        tf.fake_quant_with_min_max_vars(t, min_var, max_var)
                        for t in inputs
                    ]
                else:
                    concat_tensors = tf.concat(inputs, axis=axis)
                    tf.logging.info(
                        'concat_tensors: {}'.format(concat_tensors))
                    # TFLite requires that 0.0 is always in the [min; max] range.
                    range_min = tf.minimum(tf.reduce_min(concat_tensors),
                                           0.0,
                                           name='SafeQuantRangeMin')
                    range_max = tf.maximum(tf.reduce_max(concat_tensors),
                                           0.0,
                                           name='SafeQuantRangeMax')
                    # Otherwise we need to keep track of the moving averages of the min
                    # and of the elements of the input tensor max.
                    min_val = moving_averages.assign_moving_average(
                        min_var, range_min, ema_decay, name='AssignMinEma')
                    max_val = moving_averages.assign_moving_average(
                        max_var, range_max, ema_decay, name='AssignMaxEma')
                    quant_inputs = [
                        tf.fake_quant_with_min_max_vars(t, min_val, max_val)
                        for t in inputs
                    ]
                outputs = tf.concat(quant_inputs, axis=axis)
        else:
            outputs = tf.concat(inputs, axis=axis)
        return outputs
Exemple #14
0
    def rnn_decoder(self,
                    encode_embed,
                    attention_states,
                    initial_state,
                    cell,
                    num_heads=1,
                    loop_function=None,
                    dtype=dtypes.float32,
                    scope=None,
                    initial_state_attention=False):
        """RNN decoder for the sequence-to-sequence model.

        """
        with tf.variable_scope(scope or "rnn_decoder"):
            batch_size = tf.shape(encode_embed[0])[0]  # Needed for reshaping.
            # cprint('batch_size: {}'.format(batch_size), 'green')  # Tensor("ranking_model/ranking_model/embedding_rnn_decoder/rnn_decoder/strided_slice_1:0", shape=(), dtype=int32)
            # cprint('batch_size.get_shape(): {}'.format(batch_size.get_shape()), 'red')  # ()
            # number of output vector in sequence
            attn_length = attention_states.get_shape()[1].value
            # the dimension size of each output vector
            attn_size = attention_states.get_shape()[2].value
            # the dimension size of state vector
            state_size = initial_state.get_shape()[1].value
            print(batch_size, attn_length, attn_size, state_size,
                  "batch_size, attn_length, attn_size, state_size")
            # To calculate W1 * h_t we use a 1-by-1 convolution, need to
            # reshape before.
            print(attention_states.get_shape(),
                  "attention_states.get_shape()")  # (?, 9, 186)
            hidden = tf.reshape(attention_states,
                                [-1, attn_length, 1, attn_size])
            hidden_features = []
            hidden_features2 = []
            v = []
            u = []
            linear_w = []
            linear_b = []
            abstract_w = []
            abstract_b = []
            abstract_layers = [
                int((attn_size + state_size) / (2 + 2 * i)) for i in xrange(2)
            ] + [1]
            # Size of query vectors for attention.
            attention_vec_size = attn_size
            head_weights = []
            for a in xrange(num_heads):
                k = self.get_variable("AttnW_%d" % a,
                                      [1, 1, attn_size, attention_vec_size])
                hidden_features.append(
                    nn_ops.conv2d(hidden, k, [1, 1, 1, 1],
                                  "SAME"))  # [B,T,1,attn_vec_size]
                k2 = self.get_variable("AttnW2_%d" % a,
                                       [1, 1, attn_size, attention_vec_size])
                hidden_features2.append(
                    nn_ops.conv2d(hidden, k2, [1, 1, 1, 1], "SAME"))
                v.append(
                    self.get_variable("AttnV_%d" % a, [attention_vec_size]))
                u.append(
                    self.get_variable("AttnU_%d" % a, [attention_vec_size]))
                head_weights.append(
                    self.get_variable("head_weight_%d" % a, [1]))
                current_layer_size = attn_size + state_size
                linear_w.append(
                    self.get_variable("linearW_%d" % a,
                                      [1, 1, current_layer_size, 1]))
                linear_b.append(self.get_variable("linearB_%d" % a, [1]))
                abstract_w.append([])
                abstract_b.append([])
                for i in xrange(len(abstract_layers)):
                    layer_size = abstract_layers[i]
                    abstract_w[a].append(
                        self.get_variable(
                            "Att_%d_layerW_%d" % (a, i),
                            [1, 1, current_layer_size, layer_size]))
                    abstract_b[a].append(
                        self.get_variable("Att_%d_layerB_%d" % (a, i),
                                          [layer_size]))
                    current_layer_size = layer_size

            def attention(query):
                """Put attention masks on hidden using hidden_features and query."""
                ds = []  # Results of attention reads will be stored here.
                aw = []  # Attention weights will be stored here
                tiled_query = tf.tile(
                    tf.reshape(query, [-1, 1, 1, state_size]),
                    [1, attn_length, 1, 1])
                print(hidden.get_shape(),
                      "hidden.get_shape()")  # (?, 9, 1, 186)
                print(tiled_query.get_shape(),
                      "tiled_query.get_shape()")  # (?, 9, 1, 186)
                concat_input = tf.concat(axis=3, values=[hidden, tiled_query])
                #concat_input = tf.concat(3, [hidden, hidden])
                for a in xrange(num_heads):
                    with tf.variable_scope("Attention_%d" % a):
                        s = None
                        if self.hparams.att_strategy == 'multi':
                            print('Attention: multiply')
                            y = linear(
                                query, attention_vec_size, True
                            )  # 第三个参数是boolean, whether to add a bias term or not.
                            y = tf.reshape(y, [-1, 1, 1, attention_vec_size])
                            # s = math_ops.reduce_sum(
                            # u[a] * math_ops.tanh(y * hidden_features[a]), [2,
                            # 3])
                            s = math_ops.reduce_sum(hidden * math_ops.tanh(y),
                                                    [2, 3])
                            # hidden_features[a] * math_ops.tanh(y), [2, 3])

                        elif self.hparams.att_strategy == 'multi_add':
                            print('Attention: multiply_add')
                            y = linear(query,
                                       attention_vec_size,
                                       True,
                                       scope='y')
                            y2 = linear(query,
                                        attention_vec_size,
                                        True,
                                        scope='y2')
                            y = tf.reshape(y, [-1, 1, 1, attention_vec_size])
                            y2 = tf.reshape(y2, [-1, 1, 1, attention_vec_size])
                            # s = math_ops.reduce_sum(
                            # u[a] * math_ops.tanh(y * hidden_features[a]), [2,
                            # 3])
                            s = math_ops.reduce_sum(hidden * math_ops.tanh(y2),
                                                    [2, 3])
                            s = s + math_ops.reduce_sum(
                                v[a] * math_ops.tanh(hidden_features[a] + y),
                                [2, 3])

                        elif self.hparams.att_strategy == 'NTN':
                            print('Attention: NTN')
                            y = linear(query, attn_size, False)
                            y = tf.tile(tf.reshape(y, [-1, 1, 1, attn_size]),
                                        [1, attn_length, 1, 1])
                            s = math_ops.reduce_sum(hidden * y,
                                                    [2, 3])  # bilnear
                            s = s + math_ops.reduce_sum(
                                nn_ops.conv2d(concat_input, linear_w[a],
                                              [1, 1, 1, 1], "SAME"),
                                [2, 3])  # linear
                            s = s + linear_b[a]  # bias
                            # print(s.get_shape())
                            # s = tf.tanh(s) #non linear

                        elif self.hparams.att_strategy == 'elu':
                            print('Attention: elu')

                            cur_input = concat_input
                            # for i in xrange(len(abstract_layers)):
                            #    cur_input = tf.contrib.layers.fully_connected(cur_input, abstract_layers[i], activation_fn=tf.nn.elu)
                            for i in xrange(len(abstract_layers)):
                                cur_input = nn_ops.conv2d(
                                    cur_input, abstract_w[a][i], [1, 1, 1, 1],
                                    "SAME")
                                cur_input = cur_input + abstract_b[a][i]
                                cur_input = tf.nn.elu(cur_input)
                            s = math_ops.reduce_sum(cur_input, [2, 3])

                        else:
                            print('Attention: add')
                            y = linear(query, attention_vec_size, True)
                            y = tf.reshape(y, [-1, 1, 1, attention_vec_size])
                            s = math_ops.reduce_sum(
                                v[a] * math_ops.tanh(hidden_features[a] + y),
                                [2, 3])

                        att = s * head_weights[a]  # nn_ops.softmax(s)
                        aw.append(att)
                        # Now calculate the attention-weighted vector d.
                        d = math_ops.reduce_sum(
                            tf.reshape(att, [-1, attn_length, 1, 1]) * hidden,
                            [1, 2])
                        ds.append(tf.reshape(d, [-1, attn_size]))
                return aw, ds

            state = initial_state
            outputs = []
            prev = None
            batch_attn_size = tf.stack([batch_size, attn_size])
            batch_attw_size = tf.stack([batch_size, attn_length])
            attns = [
                tf.zeros(batch_attn_size, dtype=dtype)
                for _ in xrange(num_heads)
            ]
            attw = [
                1.0 / attn_length * tf.ones(batch_attw_size, dtype=dtype)
                for _ in xrange(num_heads)
            ]
            for a in attns:  # Ensure the second shape of attention vectors is set.
                a.set_shape([None, attn_size])

            # Directly use previous state
            attw, attns = attention(initial_state)
            aw = math_ops.reduce_sum(attw, 0)
            output = tf.scalar_mul(1.0 / float(num_heads), aw)
            output = output - tf.reduce_min(output, 1, keep_dims=True)
            outputs.append(output)

        return outputs, state
    def step_fn(self, params, model):
        """A single step for supervised learning."""
        (train_images, train_labels, valid_images,
         valid_labels) = tf.raw_ops.InfeedDequeueTuple(
             dtypes=params.train_dtypes, shapes=params.train_shapes)

        if train_labels.dtype == tf.int32:
            train_labels = tf.one_hot(train_labels,
                                      depth=params.num_classes,
                                      dtype=tf.float32)
        if valid_labels.dtype == tf.int32:
            valid_labels = tf.one_hot(valid_labels,
                                      depth=params.num_classes,
                                      dtype=tf.float32)
        global_step = tf.train.get_or_create_global_step()

        num_replicas = tf.cast(params.num_replicas, tf.float32)

        with tf.variable_scope(MODEL_SCOPE):
            train_logits = model(train_images, training=True)

        with tf.variable_scope(SCORE_SCOPE):
            score_logits = model(train_images,
                                 training=False,
                                 return_scores=True)
            score_m = tf.tpu.cross_replica_sum(tf.reduce_sum(score_logits))
            score_m = tf.stop_gradient(score_m) / float(params.num_replicas)
            score_e = tf.exp(score_logits - score_m)
            score_z = tf.tpu.cross_replica_sum(tf.reduce_sum(score_e))
            score_probs = score_e / score_z

        # train the main model
        cross_entropy = tf.losses.softmax_cross_entropy(
            onehot_labels=train_labels,
            logits=train_logits,
            label_smoothing=params.label_smoothing,
            reduction=tf.losses.Reduction.NONE)
        cross_entropy = tf.reduce_sum(cross_entropy *
                                      tf.stop_gradient(score_probs))

        l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas,
                              tf.float32)
        weight_dec = common_utils.get_l2_loss(excluded_keywords=[SCORE_SCOPE])
        total_loss = cross_entropy + weight_dec * l2_reg_rate

        model_variables = [
            v for v in tf.trainable_variables() if MODEL_SCOPE in v.name
        ]
        train_gradients = tf.gradients(total_loss, model_variables)
        train_gradients = [
            tf.tpu.cross_replica_sum(g) for g in train_gradients
        ]
        train_gradients, grad_norm = tf.clip_by_global_norm(
            train_gradients, params.grad_bound)

        learning_rate, optimizer = common_utils.get_optimizer(params)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.cond(
            tf.math.is_finite(grad_norm), lambda: optimizer.
            apply_gradients(zip(train_gradients, model_variables),
                            global_step=global_step), tf.no_op)
        with tf.control_dependencies(update_ops + [train_op]):
            ema_train_op = common_utils.setup_ema(
                params, f'{MODEL_SCOPE}/{model.name}')

        with tf.control_dependencies([ema_train_op]):
            with tf.variable_scope(MODEL_SCOPE, reuse=True):
                valid_logits = model(valid_images, training=False)
                valid_cross_entropy = tf.losses.softmax_cross_entropy(
                    onehot_labels=valid_labels,
                    logits=valid_logits,
                    reduction=tf.losses.Reduction.MEAN) / float(
                        params.num_replicas)
                valid_gradients = tf.gradients(valid_cross_entropy,
                                               model_variables)
                valid_gradients = [
                    tf.tpu.cross_replica_sum(g) for g in valid_gradients
                ]

            dot_product = tf.add_n([
                tf.reduce_sum(g_t * g_v)
                for g_t, g_v in zip(train_gradients, valid_gradients)
            ])
            dot_product = tf.stop_gradient(dot_product)
            dot_product_avg = tf.get_variable(name='dot_product_avg',
                                              shape=[],
                                              trainable=False)
            dot_product_update = tf.assign_sub(
                dot_product_avg, 0.01 * (dot_product_avg - dot_product))
            with tf.control_dependencies([dot_product_update]):
                dot_product = tf.identity(dot_product - dot_product_avg)

        # trains the scorer.
        score_entropy = tf.reduce_sum(-score_probs * tf.math.log(score_probs))
        score_entropy = tf.tpu.cross_replica_sum(score_entropy) / float(
            valid_images.shape[0].value)
        score_variables = [
            v for v in tf.trainable_variables() if SCORE_SCOPE in v.name
        ]
        score_gradients = tf.gradients(dot_product * score_entropy,
                                       score_variables)
        score_gradients = [
            tf.tpu.cross_replica_sum(g) for g in score_gradients
        ]
        score_optimizer = tf.train.GradientDescentOptimizer(
            learning_rate=params.scorer_lr, use_locking=True)
        score_train_op = tf.cond(
            global_step < params.scorer_wait_steps, tf.no_op,
            lambda: score_optimizer.apply_gradients(
                zip(score_gradients, score_variables)))

        with tf.control_dependencies([score_train_op]):
            logs = collections.OrderedDict()
            logs['global_step'] = tf.cast(global_step, tf.float32)

            logs['model/total'] = total_loss
            logs['model/weight_decay'] = weight_dec / num_replicas
            logs['model/cross_entropy'] = cross_entropy
            logs['model/lr'] = tf.identity(learning_rate) / num_replicas
            logs['model/grad_norm'] = grad_norm / num_replicas

            logs['score/dot_product'] = dot_product / num_replicas
            logs['score/dot_product_avg'] = dot_product_avg / num_replicas
            logs['score/entropy'] = score_entropy
            logs['score/p_min'] = tf.reduce_min(score_probs) / num_replicas
            logs['score/p_max'] = tf.reduce_max(score_probs) / num_replicas

            tensors = [tf.expand_dims(t, axis=0) for t in logs.values()]
            self.step_info = {k: [tf.float32, [1]] for k in logs.keys()}
            outfeed_enqueue_op = tf.cond(
                common_utils.should_log(params),
                lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors),
                tf.no_op)
        return outfeed_enqueue_op
def _weights_for_active_seps(power_sources, power_separated):
    """Return (source,) weights for active separated signals."""
    min_power = tf.reduce_min(power_sources, axis=-1, keepdims=True)
    return tf.greater(power_separated, 0.01 * min_power)
Exemple #17
0
def system_functions(model):
    f_1 = tf.add(tf.reduce_sum(tf.square(model._x)), -model._a_const)
    f_2 = tf.add(-tf.reduce_sum(tf.square(model._x)), model._b_const)
    functionals_ = tf.reduce_min([f_1, f_2])
    return functionals_
 def ensemble_q(self, qs):
     lambda_ = self._ensemble_q_lambda
     return (lambda_ * tf.reduce_min(qs, axis=-1) +
             (1 - lambda_) * tf.reduce_max(qs, axis=-1))
Exemple #19
0
    def test_min_reduce(self):
        input = tf.placeholder(shape=(4, 32, 32, 3), dtype=tf.float32)
        output = tf.reduce_min(input, axis=3, keepdims=True)

        self._test_conversion('min_reduce', [input], [output])
Exemple #20
0
 def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
     return tf.logical_and(
         tf.reduce_min(output_size) < max_output_size,
         idx < num_boxes // _NMS_TILE_SIZE)
Exemple #21
0
def input_tensors_to_model_input(input_tensors, hparams, is_training):
    """Processes an InputTensor into FeatureTensors and LabelTensors."""
    length = tf.cast(input_tensors.length, tf.int32)
    labels = tf.reshape(input_tensors.labels, (-1, constants.MIDI_PITCHES))
    label_weights = tf.reshape(input_tensors.label_weights,
                               (-1, constants.MIDI_PITCHES))
    onsets = tf.reshape(input_tensors.onsets, (-1, constants.MIDI_PITCHES))
    offsets = tf.reshape(input_tensors.offsets, (-1, constants.MIDI_PITCHES))
    velocities = tf.reshape(input_tensors.velocities,
                            (-1, constants.MIDI_PITCHES))
    spec = tf.reshape(input_tensors.spec, (-1, hparams_frame_size(hparams)))

    # Slice specs and labels tensors so they are no longer than truncated_length.
    hparams_truncated_length = tf.cast(
        hparams.truncated_length_secs * hparams_frames_per_second(hparams),
        tf.int32)
    if hparams.truncated_length_secs:
        truncated_length = tf.reduce_min([hparams_truncated_length, length])
    else:
        truncated_length = length

    if is_training:
        truncated_note_sequence = tf.constant(0)
    else:
        truncated_note_sequence = truncate_note_sequence_op(
            input_tensors.note_sequence, truncated_length, hparams)

    # If max_expected_train_example_len is set, ensure that all examples are
    # padded to this length. This results in a fixed shape that can work on TPUs.
    if hparams.max_expected_train_example_len and is_training:
        # In this case, final_length is a constant.
        if hparams.truncated_length_secs:
            assert_op = tf.assert_equal(hparams.max_expected_train_example_len,
                                        hparams_truncated_length)
            with tf.control_dependencies([assert_op]):
                final_length = hparams.max_expected_train_example_len
        else:
            final_length = hparams.max_expected_train_example_len
    else:
        # In this case, it is min(hparams.truncated_length, length)
        final_length = truncated_length

    spec_delta = tf.shape(spec)[0] - final_length
    spec = tf.case([(spec_delta < 0,
                     lambda: tf.pad(spec, tf.stack([(0, -spec_delta),
                                                    (0, 0)]))),
                    (spec_delta > 0, lambda: spec[0:-spec_delta])],
                   default=lambda: spec)
    labels_delta = tf.shape(labels)[0] - final_length
    labels = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(labels, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: labels[0:-labels_delta])],
        default=lambda: labels)
    label_weights = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(label_weights, tf.stack([(0, -labels_delta),
                                                  (0, 0)]))),
         (labels_delta > 0, lambda: label_weights[0:-labels_delta])],
        default=lambda: label_weights)
    onsets = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(onsets, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: onsets[0:-labels_delta])],
        default=lambda: onsets)
    offsets = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(offsets, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: offsets[0:-labels_delta])],
        default=lambda: offsets)
    velocities = tf.case(
        [(labels_delta < 0,
          lambda: tf.pad(velocities, tf.stack([(0, -labels_delta), (0, 0)]))),
         (labels_delta > 0, lambda: velocities[0:-labels_delta])],
        default=lambda: velocities)

    features = FeatureTensors(spec=tf.reshape(
        spec, (final_length, hparams_frame_size(hparams), 1)),
                              length=truncated_length,
                              sequence_id=tf.constant(0)
                              if is_training else input_tensors.sequence_id)
    labels = LabelTensors(
        labels=tf.reshape(labels, (final_length, constants.MIDI_PITCHES)),
        label_weights=tf.reshape(label_weights,
                                 (final_length, constants.MIDI_PITCHES)),
        onsets=tf.reshape(onsets, (final_length, constants.MIDI_PITCHES)),
        offsets=tf.reshape(offsets, (final_length, constants.MIDI_PITCHES)),
        velocities=tf.reshape(velocities,
                              (final_length, constants.MIDI_PITCHES)),
        note_sequence=truncated_note_sequence)

    return features, labels
Exemple #22
0
def debugprint(x, name=''):
    """Small wrapper for tf.Print which prints summary statistics."""
    name += '\t' + x.name
    return tf.Print(x, [tf.reduce_min(x),
                        tf.reduce_mean(x),
                        tf.reduce_max(x)], name)
Exemple #23
0
    def _log_prob(self, data, num_samples=1):
        """Compute a lower bound on the log likelihood."""
        # Due to memory issues, we need to use num_samples=1 here
        num_samples, proposal_num_samples = 1, num_samples
        batch_size = tf.shape(data)[0]
        # Sample from the proposal and compute the weighs of the "unseen" samples.
        # We share these across the batch dimension.
        # [num_samples, K, data_size]
        proposal_samples = self.proposal.sample(num_samples * (self.K - 1))
        if not self.reparameterize_proposal_samples:
            proposal_samples = tf.stop_gradient(proposal_samples)

        # [num_samples, K]
        log_energy_proposal = tf.reshape(
            self.energy_fn(tf.reshape(proposal_samples, [-1] + self.data_dim)),
            [num_samples, self.K - 1])
        tf.summary.histogram("log_energy_proposal", log_energy_proposal)
        tf.summary.scalar("min_log_energy_proposal",
                          tf.reduce_min(log_energy_proposal))
        tf.summary.scalar("max_log_energy_proposal",
                          tf.reduce_max(log_energy_proposal))
        # [num_samples]
        proposal_lse = tf.reduce_logsumexp(log_energy_proposal, axis=1)

        # [batch_size, num_samples]
        tiled_proposal_lse = tf.tile(proposal_lse[tf.newaxis, :],
                                     [batch_size, 1])

        # Compute the weights of the observed data.
        # [batch_size, 1]
        log_energy_data = tf.reshape(self.energy_fn(data), [batch_size])
        tf.summary.histogram("log_energy_data", log_energy_data)
        tf.summary.scalar("min_log_energy_data",
                          tf.reduce_min(log_energy_data))
        tf.summary.scalar("max_log_energy_data",
                          tf.reduce_max(log_energy_data))

        # [batch_size, num_samples]
        tiled_log_energy_data = tf.tile(log_energy_data[:, tf.newaxis],
                                        [1, num_samples])

        # Add the weights of the proposal samples with the true data weights.
        # [batch_size, num_samples]
        # pylint: disable=invalid-name
        Z_hat = tf.reduce_logsumexp(tf.stack(
            [tiled_log_energy_data, tiled_proposal_lse], axis=-1),
                                    axis=-1)
        Z_hat -= tf.log(tf.to_float(self.K))
        # Perform the log-sum-exp reduction for IWAE
        # [batch_size]
        Z_hat = tf.reduce_logsumexp(Z_hat, axis=1) - tf.log(
            tf.to_float(num_samples))
        # pylint: enable=invalid-name

        try:
            # Try giving the proposal lower bound num_samples if it can use it.
            proposal_lp = self.proposal.log_prob(
                data, num_samples=proposal_num_samples)
        except TypeError:
            proposal_lp = self.proposal.log_prob(data)
        lower_bound = proposal_lp + log_energy_data - Z_hat
        return lower_bound
Exemple #24
0
    def create_id3_embedding(self, videos):
        """Embeds the given videos using the Inflated 3D Convolution network.

      Downloads the graph of the I3D from tf.hub and adds it to the graph on the
      first call.

      Args:
        videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3].
          Expected range is [-1, 1].

      Returns:
        embedding: <float32>[batch_size, embedding_size]. embedding_size depends
                   on the model used.

      Raises:
        ValueError: when a provided embedding_layer is not supported.
      """

        batch_size = 16
        module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1"

        # Making sure that we import the graph separately for
        # each different input video tensor.
        module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str(
            videos.name).replace(":", "_")

        assert_ops = [
            tf.Assert(
                tf.reduce_max(videos) <= 1.001,
                ["max value in frame is > 1", videos]),
            tf.Assert(
                tf.reduce_min(videos) >= -1.001,
                ["min value in frame is < -1", videos]),
            tf.assert_equal(tf.shape(videos)[0],
                            batch_size,
                            ["invalid frame batch size: ",
                             tf.shape(videos)],
                            summarize=6),
        ]
        with tf.control_dependencies(assert_ops):
            videos = tf.identity(videos)

        module_scope = "%s_apply_default/" % module_name

        # To check whether the module has already been loaded into the graph, we look
        # for a given tensor name. If this tensor name exists, we assume the function
        # has been called before and the graph was imported. Otherwise we import it.
        # Note: in theory, the tensor could exist, but have wrong shapes.
        # This will happen if create_id3_embedding is called with a frames_placehoder
        # of wrong size/batch size, because even though that will throw a tf.Assert
        # on graph-execution time, it will insert the tensor (with wrong shape) into
        # the graph. This is why we need the following assert.
        video_batch_size = int(videos.shape[0])
        assert video_batch_size in [batch_size, -1, None], "Invalid batch size"
        tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
        if not _is_in_graph(tensor_name):
            # i3d_model = hub.Module(module_spec, name=module_name)
            self.model(videos)

        # gets the kinetics-i3d-400-logits layer
        tensor_name = module_scope + "RGB/inception_i3d/Mean:0"
        tensor = tf.get_default_graph().get_tensor_by_name(tensor_name)

        return tensor
def rescale(image):
    """Rescale to full [0, 255] range."""
    image = tf.cast(image, tf.float32)
    image = ((image - tf.reduce_min(image)) /
             (tf.reduce_max(image) - tf.reduce_min(image)) * 255.)
    return image
def dropblock(net,
              is_training,
              keep_prob,
              dropblock_size,
              data_format='channels_first'):
    """DropBlock: a regularization method for convolutional neural networks.

  DropBlock is a form of structured dropout, where units in a contiguous
  region of a feature map are dropped together. DropBlock works better than
  dropout on convolutional layers due to the fact that activation units in
  convolutional layers are spatially correlated.
  See https://arxiv.org/pdf/1810.12890.pdf for details.

  Args:
    net: `Tensor` input tensor.
    is_training: `bool` for whether the model is training.
    keep_prob: `float` or `Tensor` keep_prob parameter of DropBlock. "None"
      means no DropBlock.
    dropblock_size: `int` size of blocks to be dropped by DropBlock.
    data_format: `str` either "channels_first" for `[batch, channels, height,
      width]` or "channels_last for `[batch, height, width, channels]`.

  Returns:
      A version of input tensor with DropBlock applied.
  Raises:
      if width and height of the input tensor are not equal.
  """

    if not is_training or keep_prob is None:
        return net

    tf.logging.info(
        'Applying DropBlock: dropblock_size {}, net.shape {}'.format(
            dropblock_size, net.shape))

    if data_format == 'channels_last':
        _, width, height, _ = net.get_shape().as_list()
    else:
        _, _, width, height = net.get_shape().as_list()
    if width != height:
        raise ValueError('Input tensor with width!=height is not supported.')

    dropblock_size = min(dropblock_size, width)
    # seed_drop_rate is the gamma parameter of DropBlcok.
    seed_drop_rate = (1.0 - keep_prob) * width**2 / dropblock_size**2 / (
        width - dropblock_size + 1)**2

    # Forces the block to be inside the feature map.
    w_i, h_i = tf.meshgrid(tf.range(width), tf.range(width))
    valid_block_center = tf.logical_and(
        tf.logical_and(w_i >= int(dropblock_size // 2),
                       w_i < width - (dropblock_size - 1) // 2),
        tf.logical_and(h_i >= int(dropblock_size // 2),
                       h_i < width - (dropblock_size - 1) // 2))

    valid_block_center = tf.expand_dims(valid_block_center, 0)
    valid_block_center = tf.expand_dims(
        valid_block_center, -1 if data_format == 'channels_last' else 0)

    randnoise = tf.random_uniform(net.shape, dtype=tf.float32)
    block_pattern = (
        1 - tf.cast(valid_block_center, dtype=tf.float32) + tf.cast(
            (1 - seed_drop_rate), dtype=tf.float32) + randnoise) >= 1
    block_pattern = tf.cast(block_pattern, dtype=tf.float32)

    if dropblock_size == width:
        block_pattern = tf.reduce_min(
            block_pattern,
            axis=[1, 2] if data_format == 'channels_last' else [2, 3],
            keepdims=True)
    else:
        if data_format == 'channels_last':
            ksize = [1, dropblock_size, dropblock_size, 1]
        else:
            ksize = [1, 1, dropblock_size, dropblock_size]
        block_pattern = -tf.nn.max_pool(
            -block_pattern,
            ksize=ksize,
            strides=[1, 1, 1, 1],
            padding='SAME',
            data_format='NHWC' if data_format == 'channels_last' else 'NCHW')

    percent_ones = tf.cast(tf.reduce_sum(
        (block_pattern)), tf.float32) / tf.cast(tf.size(block_pattern),
                                                tf.float32)

    net = net / tf.cast(percent_ones, net.dtype) * tf.cast(
        block_pattern, net.dtype)
    return net
Exemple #27
0
    def get(self):
        """ Provides input data to the graph. """
        # calculate size of each record (this lists what is contained in the db and how many bytes are occupied)
        record_bytes = 0

        encoding_bytes = 4
        kp_xyz_entries = 3 * self.num_kp
        record_bytes += encoding_bytes*kp_xyz_entries

        encoding_bytes = 4
        kp_uv_entries = 2 * self.num_kp
        record_bytes += encoding_bytes*kp_uv_entries

        kp_vis_entries = self.num_kp
        record_bytes += encoding_bytes*kp_vis_entries

        image_bytes = self.image_size[0] * self.image_size[1] * 3
        record_bytes += image_bytes

        """ READ DATA ITEMS"""
        # Start reader
        reader = tf.FixedLengthRecordReader(header_bytes=0, record_bytes=record_bytes)
        _, value = reader.read(tf.train.string_input_producer([self.path_to_db]))

        # decode to floats
        bytes_read = 0
        data_dict = dict()
        record_bytes_float32 = tf.decode_raw(value, tf.float32)

        # 1. Read keypoint xyz
        keypoint_xyz21 = tf.reshape(tf.slice(record_bytes_float32, [bytes_read//4], [kp_xyz_entries]), [self.num_kp, 3])
        bytes_read += encoding_bytes*kp_xyz_entries
        keypoint_xyz21 /= 1000.0  # scale to meters
        keypoint_xyz21 = self.convert_kp(keypoint_xyz21)

        # calculate wrist coord
        if self.use_wrist_coord:
            wrist_xyz = keypoint_xyz21[16, :] + 2.0*(keypoint_xyz21[0, :] - keypoint_xyz21[16, :])
            keypoint_xyz21 = tf.concat([tf.expand_dims(wrist_xyz, 0),
                                        keypoint_xyz21[1:, :]], 0)

        data_dict['keypoint_xyz21'] = keypoint_xyz21

        # 2. Read keypoint uv AND VIS
        keypoint_uv_vis21 = tf.reshape(tf.slice(record_bytes_float32, [bytes_read//4], [kp_uv_entries+kp_vis_entries]), [self.num_kp, 3])
        bytes_read += encoding_bytes*(kp_uv_entries+kp_vis_entries)
        keypoint_uv_vis21 = self.convert_kp(keypoint_uv_vis21)
        keypoint_uv21 = keypoint_uv_vis21[:, :2]
        keypoint_vis21 = tf.equal(keypoint_uv_vis21[:, 2], 1.0)

        # calculate wrist vis
        if self.use_wrist_coord:
            wrist_vis = tf.logical_or(keypoint_vis21[16], keypoint_vis21[0])
            keypoint_vis21 = tf.concat([tf.expand_dims(wrist_vis, 0),
                                        keypoint_vis21[1:]], 0)

            wrist_uv = keypoint_uv21[16, :] + 2.0*(keypoint_uv21[0, :] - keypoint_uv21[16, :])
            keypoint_uv21 = tf.concat([tf.expand_dims(wrist_uv, 0),
                                       keypoint_uv21[1:, :]], 0)

        data_dict['keypoint_vis21'] = keypoint_vis21

        if self.coord_uv_noise:
            noise = tf.truncated_normal([42, 2], mean=0.0, stddev=self.coord_uv_noise_sigma)
            keypoint_uv21 += noise

        data_dict['keypoint_uv21'] = keypoint_uv21

        # decode to uint8
        record_bytes_uint8 = tf.decode_raw(value, tf.uint8)

        # 4. Read image
        image = tf.reshape(tf.slice(record_bytes_uint8, [bytes_read], [image_bytes]),
                               [self.image_size[0], self.image_size[1], 3])
        image = tf.cast(image, tf.float32)
        bytes_read += image_bytes

        # subtract mean
        image = image / 255.0 - 0.5
        if self.hue_aug:
            image = tf.image.random_hue(image, self.hue_aug_max)
        data_dict['image'] = image

        """ CONSTANTS """
        # Camera intrinsics
        sx = 822.79041
        sy = 822.79041
        tx = 318.47345
        ty = 250.31296
        data_dict['cam_mat'] = tf.constant([[sx, 0.0, tx], [0.0, sy, ty], [0.0, 0.0, 1.0]])

        # Hand side: this dataset only contains left hands
        data_dict['hand_side'] = tf.one_hot(tf.constant(0, dtype=tf.int32), depth=2, on_value=1.0, off_value=0.0, dtype=tf.float32)

        assert bytes_read == record_bytes, "Doesnt add up."

        """ DEPENDENT DATA ITEMS: XYZ represenations. """
        # make coords relative to root joint
        kp_coord_xyz_root = keypoint_xyz21[0, :] # this is the palm coord
        kp_coord_xyz21_rel = keypoint_xyz21 - kp_coord_xyz_root  # relative coords in metric coords
        index_root_bone_length = tf.sqrt(tf.reduce_sum(tf.square(kp_coord_xyz21_rel[12, :] - kp_coord_xyz21_rel[11, :])))
        data_dict['keypoint_scale'] = index_root_bone_length
        data_dict['keypoint_xyz21_normed'] = kp_coord_xyz21_rel / index_root_bone_length  # normalized by length of 12->11

        # calculate local coordinates
        kp_coord_xyz21_local = bone_rel_trafo(data_dict['keypoint_xyz21_normed'])
        kp_coord_xyz21_local = tf.squeeze(kp_coord_xyz21_local)
        data_dict['keypoint_xyz21_local'] = kp_coord_xyz21_local

        # calculate viewpoint and coords in canonical coordinates
        kp_coord_xyz21_rel_can, rot_mat = canonical_trafo(data_dict['keypoint_xyz21_normed'])
        kp_coord_xyz21_rel_can, rot_mat = tf.squeeze(kp_coord_xyz21_rel_can), tf.squeeze(rot_mat)
        data_dict['keypoint_xyz21_can'] = kp_coord_xyz21_rel_can
        data_dict['rot_mat'] = tf.matrix_inverse(rot_mat)

        """ DEPENDENT DATA ITEMS: HAND CROP """
        if self.hand_crop:
            crop_center = keypoint_uv21[12, ::-1]

            # catch problem, when no valid kp available (happens almost never)
            crop_center = tf.cond(tf.reduce_all(tf.is_finite(crop_center)), lambda: crop_center,
                                  lambda: tf.constant([0.0, 0.0]))
            crop_center.set_shape([2, ])

            if self.crop_center_noise:
                noise = tf.truncated_normal([2], mean=0.0, stddev=self.crop_center_noise_sigma)
                crop_center += noise

            crop_scale_noise = tf.constant(1.0)
            if self.crop_scale_noise:
                    crop_scale_noise = tf.squeeze(tf.random_uniform([1], minval=1.0, maxval=1.2))

            if not self.use_wrist_coord:
                wrist_uv = keypoint_uv21[16, :] + 2.0*(keypoint_uv21[0, :] - keypoint_uv21[16, :])
                keypoint_uv21 = tf.concat([tf.expand_dims(wrist_uv, 0),
                                           keypoint_uv21[1:, :]], 0)

            # select visible coords only
            kp_coord_h = tf.boolean_mask(keypoint_uv21[:, 1], keypoint_vis21)
            kp_coord_w = tf.boolean_mask(keypoint_uv21[:, 0], keypoint_vis21)
            kp_coord_hw = tf.stack([kp_coord_h, kp_coord_w], 1)

            # determine size of crop (measure spatial extend of hw coords first)
            min_coord = tf.maximum(tf.reduce_min(kp_coord_hw, 0), 0.0)
            max_coord = tf.minimum(tf.reduce_max(kp_coord_hw, 0), self.image_size)

            # find out larger distance wrt the center of crop
            crop_size_best = 2*tf.maximum(max_coord - crop_center, crop_center - min_coord)
            crop_size_best = tf.reduce_max(crop_size_best)
            crop_size_best = tf.minimum(tf.maximum(crop_size_best, 50.0), 500.0)

            # catch problem, when no valid kp available
            crop_size_best = tf.cond(tf.reduce_all(tf.is_finite(crop_size_best)), lambda: crop_size_best,
                                  lambda: tf.constant(200.0))
            crop_size_best.set_shape([])

            # calculate necessary scaling
            scale = tf.cast(self.crop_size, tf.float32) / crop_size_best
            scale = tf.minimum(tf.maximum(scale, 1.0), 10.0)
            scale *= crop_scale_noise
            data_dict['crop_scale'] = scale

            if self.crop_offset_noise:
                noise = tf.truncated_normal([2], mean=0.0, stddev=self.crop_offset_noise_sigma)
                crop_center += noise

            # Crop image
            img_crop = crop_image_from_xy(tf.expand_dims(image, 0), crop_center, self.crop_size, scale)
            data_dict['image_crop'] = tf.squeeze(img_crop)

            # Modify uv21 coordinates
            crop_center_float = tf.cast(crop_center, tf.float32)
            keypoint_uv21_u = (data_dict['keypoint_uv21'][:, 0] - crop_center_float[1]) * scale + self.crop_size // 2
            keypoint_uv21_v = (data_dict['keypoint_uv21'][:, 1] - crop_center_float[0]) * scale + self.crop_size // 2
            keypoint_uv21 = tf.stack([keypoint_uv21_u, keypoint_uv21_v], 1)
            data_dict['keypoint_uv21'] = keypoint_uv21

            # Modify camera intrinsics
            scale = tf.reshape(scale, [1, ])
            scale_matrix = tf.dynamic_stitch([[0], [1], [2],
                                              [3], [4], [5],
                                              [6], [7], [8]], [scale, [0.0], [0.0],
                                                               [0.0], scale, [0.0],
                                                               [0.0], [0.0], [1.0]])
            scale_matrix = tf.reshape(scale_matrix, [3, 3])

            crop_center_float = tf.cast(crop_center, tf.float32)
            trans1 = crop_center_float[0] * scale - self.crop_size // 2
            trans2 = crop_center_float[1] * scale - self.crop_size // 2
            trans1 = tf.reshape(trans1, [1, ])
            trans2 = tf.reshape(trans2, [1, ])
            trans_matrix = tf.dynamic_stitch([[0], [1], [2],
                                              [3], [4], [5],
                                              [6], [7], [8]], [[1.0], [0.0], -trans2,
                                                               [0.0], [1.0], -trans1,
                                                               [0.0], [0.0], [1.0]])
            trans_matrix = tf.reshape(trans_matrix, [3, 3])

            data_dict['cam_mat'] = tf.matmul(trans_matrix, tf.matmul(scale_matrix, data_dict['cam_mat']))

        """ DEPENDENT DATA ITEMS: Scoremap from the SUBSET of 21 keypoints"""
        # create scoremaps from the subset of 2D annoataion
        keypoint_hw21 = tf.stack([keypoint_uv21[:, 1], keypoint_uv21[:, 0]], -1)

        scoremap_size = self.image_size
        
        if self.hand_crop:
            scoremap_size = (self.crop_size, self.crop_size)

        scoremap = self.create_multiple_gaussian_map(keypoint_hw21,
                                                     scoremap_size,
                                                     self.sigma,
                                                     valid_vec=keypoint_vis21)
        
        if self.scoremap_dropout:
            scoremap = tf.nn.dropout(scoremap, self.scoremap_dropout_prob,
                                        noise_shape=[1, 1, 21])
            scoremap *= self.scoremap_dropout_prob

        data_dict['scoremap'] = scoremap

        if self.random_crop_to_size:
            tensor_stack = tf.concat([data_dict['image'],
                                      tf.expand_dims(tf.cast(data_dict['hand_parts'], tf.float32), -1),
                                      tf.cast(data_dict['hand_mask'], tf.float32)], 2)
            s = tensor_stack.get_shape().as_list()
            tensor_stack_cropped = tf.random_crop(tensor_stack,
                                                  [self.random_crop_size, self.random_crop_size, s[2]])
            data_dict = dict()  # delete everything else because the random cropping makes the data invalid anyway
            data_dict['image'], data_dict['hand_parts'], data_dict['hand_mask'] = tensor_stack_cropped[:, :, :3],\
                                                                                  tf.cast(tensor_stack_cropped[:, :, 3], tf.int32),\
                                                                                  tf.cast(tensor_stack_cropped[:, :, 4:], tf.int32)

        names, tensors = zip(*data_dict.items())

        if self.shuffle:
            tensors = tf.train.shuffle_batch_join([tensors],
                                                  batch_size=self.batch_size,
                                                  capacity=100,
                                                  min_after_dequeue=50,
                                                  enqueue_many=False)
        else:
            tensors = tf.train.batch_join([tensors],
                                          batch_size=self.batch_size,
                                          capacity=100,
                                          enqueue_many=False)

        return dict(zip(names, tensors))
Exemple #28
0
  def build_train_graph(self,
                        inputs,
                        min_depth,
                        max_depth,
                        num_mpi_planes,
                        learning_rate=0.0002,
                        beta1=0.9,
                        vgg_model_file=None,
                        global_step=0):
    """Construct the training computation graph.

    Args:
      inputs: dictionary of tensors (see 'input_data' below) needed for training
      min_depth: minimum depth for the PSV and MPI planes
      max_depth: maximum depth for the PSV and MPI planes
      num_mpi_planes: number of MPI planes to infer
      learning_rate: learning rate
      beta1: hyperparameter for Adam
      vgg_model_file: path to vgg weights (needed when vgg loss is used)
      global_step: current optimization step
    Returns:
      A train_op to be used for training.
    """
    print("starting to build graph")
    with tf.name_scope("input_size_randomization"):
      dim_choices = tf.constant([[1, 16], [2, 32], [4, 32], [4, 64], [4, 128],
                                 [8, 32], [8, 64], [8, 128]],
                                dtype=tf.int32)
      rand_dim = tf.random_shuffle(dim_choices)[0, :]
      height_div = rand_dim[0]
      width_div = rand_dim[0]
      num_mpi_planes = rand_dim[1]
      tf.summary.scalar("num_mpi_planes", num_mpi_planes)

    with tf.name_scope("setup"):
      mpi_planes = self.inv_depths(min_depth, max_depth, num_mpi_planes)

    with tf.name_scope("input_data"):
      raw_tgt_image = inputs["tgt_image"]
      raw_ref_image = inputs["ref_image"]
      raw_src_images = inputs["src_images"]

      _, img_height, img_width, _ = raw_src_images.get_shape().as_list(
      )
      img_height = img_height // height_div
      img_width = img_width // width_div

      raw_tgt_image = tf.image.convert_image_dtype(
          raw_tgt_image, dtype=tf.float32)
      raw_ref_image = tf.image.convert_image_dtype(
          raw_ref_image, dtype=tf.float32)
      raw_src_images = tf.image.convert_image_dtype(
          raw_src_images, dtype=tf.float32)
      raw_tgt_image = tf.image.resize_area(raw_tgt_image,
                                           [img_height, img_width])
      raw_ref_image = tf.image.resize_area(raw_ref_image,
                                           [img_height, img_width])
      raw_src_images = tf.image.resize_area(raw_src_images,
                                            [img_height, img_width])

      tgt_pose = inputs["tgt_pose"]
      ref_pose = inputs["ref_pose"]
      src_poses = inputs["src_poses"]
      intrinsics = inputs["intrinsics"]

      # Scale intrinsics based on size randomization
      intrinsics = tf.concat([
          intrinsics[:, 0:1, :] / tf.to_float(width_div),
          intrinsics[:, 1:2, :] / tf.to_float(height_div), intrinsics[:, 2:3, :]
      ],
                             axis=1)
      inputs["intrinsics"] = intrinsics

      _, num_source, _, _ = src_poses.get_shape().as_list()

    with tf.name_scope("inference"):
      print("setting up MPI inference")
      num_mpi_planes = tf.shape(mpi_planes)[0]
      pred = self.infer_mpi(raw_src_images, raw_ref_image, ref_pose, src_poses,
                            intrinsics, num_mpi_planes,
                            mpi_planes)
      rgba_layers = pred["rgba_layers"]
      rgba_layers_refine = pred["rgba_layers_refine"]
      stuff_behind = pred["stuff_behind"]
      refine_input_mpi = pred["refine_input_mpi"]
      psv = pred["psv"]

    with tf.name_scope("synthesis"):
      print("setting up rendering")
      rel_pose = tf.matmul(tgt_pose, tf.matrix_inverse(ref_pose))
      output_image, output_layers = self.mpi_render_view(
          rgba_layers, rel_pose, mpi_planes, intrinsics)
      output_alpha = output_layers[Ellipsis, -1]
      output_image_refine, _ = self.mpi_render_view(
          rgba_layers_refine, rel_pose, mpi_planes, intrinsics)

    with tf.name_scope("loss"):
      print("computing losses")
      # Mask loss for pixels outside reference frustum
      loss_mask = tf.where(
          tf.equal(
              tf.reduce_min(
                  tf.abs(tf.reduce_sum(output_layers, axis=-1)),
                  axis=3,
                  keep_dims=True), 0.0),
          tf.zeros_like(output_alpha[:, :, :, 0:1]),
          tf.ones_like(output_alpha[:, :, :, 0:1]))
      loss_mask = tf.stop_gradient(loss_mask)
      tf.summary.image("loss_mask", loss_mask)

      # Helper functions for loss
      def compute_error(real, fake, mask):
        return tf.reduce_mean(mask * tf.abs(fake - real))

      # Normalized VGG loss (from
      # https://github.com/CQFIO/PhotographicImageSynthesis)

      downsample = lambda tensor, ds: tf.nn.avg_pool(tensor, [1, ds, ds, 1],
                                                     [1, ds, ds, 1], "SAME")

      def vgg_loss(raw_tgt_image, output_image, loss_mask):
        """Compute VGG loss."""

        vgg_real = build_vgg19(raw_tgt_image * 255.0, vgg_model_file)
        rescaled_output_image = (output_image + 1.)/2. * 255.0
        vgg_fake = build_vgg19(
            rescaled_output_image, vgg_model_file, reuse=True)
        p0 = compute_error(vgg_real["input"], vgg_fake["input"], loss_mask)
        p1 = compute_error(vgg_real["conv1_2"],
                           vgg_fake["conv1_2"],
                           loss_mask)/2.6
        p2 = compute_error(vgg_real["conv2_2"],
                           vgg_fake["conv2_2"],
                           downsample(loss_mask, 2))/4.8
        p3 = compute_error(vgg_real["conv3_2"],
                           vgg_fake["conv3_2"],
                           downsample(loss_mask, 4))/3.7
        p4 = compute_error(vgg_real["conv4_2"],
                           vgg_fake["conv4_2"],
                           downsample(loss_mask, 8))/5.6
        p5 = compute_error(vgg_real["conv5_2"],
                           vgg_fake["conv5_2"],
                           downsample(loss_mask, 16))*10/1.5
        total_loss = p0+p1+p2+p3+p4+p5
        return total_loss, vgg_real, vgg_fake

      vgg_loss_initial, _, _ = vgg_loss(raw_tgt_image, output_image, loss_mask)
      tf.summary.scalar("vgg_loss_initial", vgg_loss_initial)
      total_loss = vgg_loss_initial

      vgg_loss_refine, _, _ = vgg_loss(raw_tgt_image, output_image_refine,
                                       loss_mask)
      tf.summary.scalar("vgg_loss_refine", vgg_loss_refine)
      total_loss += vgg_loss_refine

    with tf.name_scope("train_op"):
      print("setting up train op")
      train_vars = [var for var in tf.trainable_variables()]
      optim = tf.train.AdamOptimizer(learning_rate, beta1)
      grads_and_vars = optim.compute_gradients(total_loss, var_list=train_vars)
      train_op = [optim.apply_gradients(grads_and_vars)]

    # Summaries
    tf.summary.scalar("total_loss", total_loss)
    # Source images
    for i in range(num_source):
      src_image = raw_src_images[:, :, :, i*3:(i+1)*3]
      tf.summary.image("src_image_%d" % i, src_image)
    # Output image
    tf.summary.image("output_image", self.deprocess_image(output_image))
    # Refined output image
    tf.summary.image("output_image_refine",
                     self.deprocess_image(output_image_refine))
    # Target image
    tf.summary.image("tgt_image", raw_tgt_image)
    # Ref image
    tf.summary.image("ref_image", raw_ref_image)
    # Predicted color and alpha layers, and PSV
    num_summ = 16  # Number of plane summaries to show in tensorboard
    for i in range(num_summ):
      ind = tf.to_int32(i * num_mpi_planes/num_summ)
      rgb = rgba_layers[:, :, :, ind, :3]
      alpha = rgba_layers[:, :, :, ind, -1:]
      ref_plane = psv[:, :, :, ind, 3:6]
      source_plane = psv[:, :, :, ind, :3]
      output_rgb = output_layers[:, :, :, ind, :3]
      tf.summary.image("rgb_layer_%d" % i, self.deprocess_image(rgb))
      tf.summary.image("alpha_layer_%d" % i, alpha)
      tf.summary.image("rgba_layer_%d" % i, self.deprocess_image(rgb * alpha))
      tf.summary.image("psv_avg_%d" % i,
                       (self.deprocess_image(0.5*ref_plane + 0.5*source_plane)))
      tf.summary.image("output_rgb_%d" % i,
                       self.deprocess_image(output_rgb))
      tf.summary.image("psv_ref_%d" % i, self.deprocess_image(ref_plane))
      tf.summary.image("psv_source_%d" % i, self.deprocess_image(source_plane))

    # Cumulative rendered images and refined MPI
    for i in range(num_summ):
      ind = tf.to_int32(i * num_mpi_planes/num_summ)
      rgb = rgba_layers_refine[:, :, :, ind, :3]
      alpha = rgba_layers_refine[:, :, :, ind, 3:]
      render = stuff_behind[:, :, :, ind, :3]
      input_colors = refine_input_mpi[:, :, :, ind, :3]
      tf.summary.image("rgb_layer_refine_%d" % i, self.deprocess_image(rgb))
      tf.summary.image("alpha_layer_refine_%d" % i, alpha)
      tf.summary.image("rgba_layer_refine_%d" % i,
                       self.deprocess_image(rgb * alpha))
      tf.summary.image("cumulative_render_%d" % i, self.deprocess_image(render))
      tf.summary.image("input_colors_refine_%d" % i,
                       self.deprocess_image(input_colors))

    return train_op
Exemple #29
0
def too_close_condition(trip, depth_threshold=0.1):
    depths = trip.depth[:3, :, :, 0]
    depthmax = tf.reduce_max(depths)
    depths = tf.where(tf.equal(depths, 0.0), depthmax * tf.ones_like(depths),
                      depths)
    return tf.greater(tf.reduce_min(depths), depth_threshold)
Exemple #30
0
def proposal_label_op(boxes,
                      gt_boxes,
                      gt_labels,
                      batch_size_per_im=512,
                      fg_fraction=0.25,
                      fg_thresh=0.5,
                      bg_thresh_hi=0.5,
                      bg_thresh_lo=0.):
    """Assigns the proposals with ground truth labels and performs subsmpling.

  Given proposal `boxes`, `gt_boxes`, and `gt_labels`, the function uses the
  following algorithm to generate the final `batch_size_per_im` RoIs.
  1. Calculates the IoU between each proposal box and each gt_boxes.
  2. Assigns each proposal box with a ground truth class and box label by
     choosing the largest overlap.
  3. Samples `batch_size_per_im` boxes from all proposal boxes, and returns
     box_targets, class_targets, and RoIs.
  The reference implementations of #1 and #2 are here: https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/json_dataset.py  # pylint: disable=line-too-long
  The reference implementation of #3 is here: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/fast_rcnn.py.  # pylint: disable=line-too-long

  Args:
    boxes: a tensor with a shape of [batch_size, N, 4]. N is the number of
      proposals before groundtruth assignment (e.g., rpn_post_nms_topn). The
      last dimension is the pixel coordinates of scaled images in
      [ymin, xmin, ymax, xmax] form.
    gt_boxes: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES, 4]. This
      tensor might have paddings with a value of -1. The coordinates of gt_boxes
      are in the pixel coordinates of the scaled image.
    gt_labels: a tensor with a shape of [batch_size, MAX_NUM_INSTANCES]. This
      tensor might have paddings with a value of -1.
    batch_size_per_im: a integer represents RoI minibatch size per image.
    fg_fraction: a float represents the target fraction of RoI minibatch that
      is labeled foreground (i.e., class > 0).
    fg_thresh: a float represents the overlap threshold for an RoI to be
      considered foreground (if >= fg_thresh).
    bg_thresh_hi: a float represents the overlap threshold for an RoI to be
      considered background (class = 0 if overlap in [LO, HI)).
    bg_thresh_lo: a float represents the overlap threshold for an RoI to be
      considered background (class = 0 if overlap in [LO, HI)).
  Returns:
    box_targets: a tensor with a shape of [batch_size, K, 4]. The tensor
      contains the ground truth pixel coordinates of the scaled images for each
      roi. K is the number of sample RoIs (e.g., batch_size_per_im).
    class_targets: a integer tensor with a shape of [batch_size, K]. The tensor
      contains the ground truth class for each roi.
    rois: a tensor with a shape of [batch_size, K, 4], representing the
      coordinates of the selected RoI.
    proposal_to_label_map: a tensor with a shape of [batch_size, K]. This tensor
      keeps the mapping between proposal to labels. proposal_to_label_map[i]
      means the index of the ground truth instance for the i-th proposal.
  """
    with tf.name_scope('proposal_label'):
        batch_size = boxes.shape[0]

        # The reference implementation intentionally includes ground truth boxes in
        # the proposals. see https://github.com/facebookresearch/Detectron/blob/master/detectron/datasets/json_dataset.py#L359.  # pylint: disable=line-too-long
        boxes = tf.concat([boxes, gt_boxes], axis=1)
        iou = box_utils.bbox_overlap(boxes, gt_boxes)

        (pre_sample_box_targets, pre_sample_class_targets, max_overlap,
         proposal_to_label_map) = _add_class_assignments(
             iou, gt_boxes, gt_labels)

        # Generates a random sample of RoIs comprising foreground and background
        # examples. reference: https://github.com/facebookresearch/Detectron/blob/master/detectron/roi_data/fast_rcnn.py#L132  # pylint: disable=line-too-long
        positives = tf.greater(max_overlap,
                               fg_thresh * tf.ones_like(max_overlap))
        negatives = tf.logical_and(
            tf.greater_equal(max_overlap,
                             bg_thresh_lo * tf.ones_like(max_overlap)),
            tf.less(max_overlap, bg_thresh_hi * tf.ones_like(max_overlap)))
        pre_sample_class_targets = tf.where(
            negatives, tf.zeros_like(pre_sample_class_targets),
            pre_sample_class_targets)
        proposal_to_label_map = tf.where(negatives,
                                         tf.zeros_like(proposal_to_label_map),
                                         proposal_to_label_map)

        # Handles ground truth paddings.
        ignore_mask = tf.less(tf.reduce_min(iou, axis=2),
                              tf.zeros_like(max_overlap))
        # indicator includes both positive and negative labels.
        # labels includes only positives labels.
        # positives = indicator & labels.
        # negatives = indicator & !labels.
        # ignore = !indicator.
        labels = positives
        pos_or_neg = tf.logical_or(positives, negatives)
        indicator = tf.logical_and(pos_or_neg, tf.logical_not(ignore_mask))

        all_samples = []
        sampler = (
            balanced_positive_negative_sampler.BalancedPositiveNegativeSampler(
                positive_fraction=fg_fraction, is_static=True))
        # Batch-unroll the sub-sampling process.
        for i in range(batch_size):
            samples = sampler.subsample(indicator[i], batch_size_per_im,
                                        labels[i])
            all_samples.append(samples)
        all_samples = tf.stack([all_samples], axis=0)[0]
        # A workaround to get the indices from the boolean tensors.
        _, samples_indices = tf.nn.top_k(tf.to_int32(all_samples),
                                         k=batch_size_per_im,
                                         sorted=True)
        # Contructs indices for gather.
        samples_indices = tf.reshape(
            samples_indices +
            tf.expand_dims(tf.range(batch_size) * tf.shape(boxes)[1], 1), [-1])
        rois = tf.reshape(
            tf.gather(tf.reshape(boxes, [-1, 4]), samples_indices),
            [batch_size, -1, 4])
        class_targets = tf.reshape(
            tf.gather(tf.reshape(pre_sample_class_targets, [-1, 1]),
                      samples_indices), [batch_size, -1])
        sample_box_targets = tf.reshape(
            tf.gather(tf.reshape(pre_sample_box_targets, [-1, 4]),
                      samples_indices), [batch_size, -1, 4])
        sample_proposal_to_label_map = tf.reshape(
            tf.gather(tf.reshape(proposal_to_label_map, [-1, 1]),
                      samples_indices), [batch_size, -1])
    return sample_box_targets, class_targets, rois, sample_proposal_to_label_map