示例#1
0
 def __init__(self,
              *,
              lr_kwargs,
              layers=None,
              preprocessors=None,
              learn_std=True,
              std_value=0.1,
              mu_range=None,
              log_std_range=None,
              **kwargs):
     super(GaussianPolicy, self).__init__(**kwargs)
     self.std_value = std_value
     self.lr_kwargs = lr_kwargs
     self.lr_scheduler = get_scheduler(lr_kwargs)
     self.schedulers += (self.lr_scheduler, )
     self.learn_std = learn_std
     self.mu_range = (-2.0, 2.0) if mu_range is None else mu_range
     self.log_std_range = (-10,
                           0.3) if log_std_range is None else log_std_range
     assert not self.discrete_action_space, "Action space for the Gaussian Policy must be continuous!"
     # Placeholders
     self.lr_ph = tf_v1.placeholder("float32", shape=(), name="lr_ph")
     # Create model
     if layers is None:
         layers = DEFAULT_LAYERS
     self.layers = layers
     self.preprocessors = preprocessors
     self.base_model = NeuralNetwork(self.scope,
                                     input_shapes=[self.obs_shape],
                                     layers=self.layers,
                                     preprocessors=self.preprocessors)
     self.mu = tf_v1.keras.layers.Dense(self.action_size, activation=None)(
         self.base_model.output)
     self.mu = tf_v1.clip_by_value(self.mu, *self.mu_range)
     if self.learn_std:
         self.log_std = tf_v1.keras.layers.Dense(self.action_size)(
             self.base_model.output)
         self.log_std = tf_v1.clip_by_value(self.log_std,
                                            *self.log_std_range)
         self.std = tf_v1.exp(self.log_std)
         self.raw_action_model = tf_v1.keras.Model(
             inputs=[self.base_model.input], outputs=[self.mu, self.std])
     else:
         self.std = tf_v1.constant([std_value] * self.action_size,
                                   dtype="float32")
         self.raw_action_model = tf_v1.keras.Model(
             inputs=[self.base_model.input], outputs=[self.mu])
     batch_size = tf_v1.shape(self.mu)[0]
     norm_dist = tfd.Normal(loc=tf_v1.zeros(self.action_size),
                            scale=tf_v1.ones(self.action_size))
     z = norm_dist.sample(batch_size)
     raw_actions = self.mu + z * self.std  # Reparameterization trick
     self.actions = tf_v1.tanh(raw_actions)
     self.deterministic_actions = tf_v1.tanh(self.mu)
     self.model = tf_v1.keras.Model(inputs=[self.base_model.input],
                                    outputs=[self.actions])
     # Loss parameters
     self._loss = None
     self.train_op = None
     # Summary parameters
     self.scalar_summaries += ("lr", )
     self.scalar_summaries_tf += ("loss", "mean_log_actions", "min_mu",
                                  "mean_mu", "max_mu", "min_std",
                                  "mean_std", "max_std")
     self.histogram_summaries_tf += ("actions", "mu", "std", "log_actions")
示例#2
0
def generator(z,
              progress,
              num_filters_fn,
              resolution_schedule,
              num_blocks=None,
              kernel_size=3,
              colors=3,
              to_rgb_activation=None,
              simple_arch=False,
              scope='progressive_gan_generator',
              reuse=None):
  """Generator network for the progressive GAN model.

  Args:
    z: A `Tensor` of latent vector. The first dimension must be batch size.
    progress: A scalar float `Tensor` of training progress.
    num_filters_fn: A function that maps `block_id` to # of filters for the
        block.
    resolution_schedule: An object of `ResolutionSchedule`.
    num_blocks: An integer of number of blocks. None means maximum number of
        blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None.
    kernel_size: An integer of convolution kernel size.
    colors: Number of output color channels. Defaults to 3.
    to_rgb_activation: Activation function applied when output rgb.
    simple_arch: Architecture variants for lower memory usage and faster speed
    scope: A string or variable scope.
    reuse: Whether to reuse `scope`. Defaults to None which means to inherit
        the reuse option of the parent scope.
  Returns:
    A `Tensor` of model output and a dictionary of model end points.
  """
  if num_blocks is None:
    num_blocks = resolution_schedule.num_resolutions

  start_h, start_w = resolution_schedule.start_resolutions
  final_h, final_w = resolution_schedule.final_resolutions

  def _conv2d(scope, x, kernel_size, filters, padding='SAME'):
    return layers.custom_conv2d(
        x=x,
        filters=filters,
        kernel_size=kernel_size,
        padding=padding,
        activation=lambda x: layers.pixel_norm(tf.nn.leaky_relu(x)),
        he_initializer_slope=0.0,
        scope=scope)

  def _to_rgb(x):
    return layers.custom_conv2d(
        x=x,
        filters=colors,
        kernel_size=1,
        padding='SAME',
        activation=to_rgb_activation,
        scope='to_rgb')

  he_init = contrib_layers.variance_scaling_initializer()

  end_points = {}

  with tf.variable_scope(scope, reuse=reuse):
    with tf.name_scope('input'):
      x = contrib_layers.flatten(z)
      end_points['latent_vector'] = x

    with tf.variable_scope(block_name(1)):
      if simple_arch:
        x_shape = tf.shape(x)
        x = tf.layers.dense(x, start_h*start_w*num_filters_fn(1),
                            kernel_initializer=he_init)
        x = tf.nn.relu(x)
        x = tf.reshape(x, [x_shape[0], start_h, start_w, num_filters_fn(1)])
      else:
        x = tf.expand_dims(tf.expand_dims(x, 1), 1)
        x = layers.pixel_norm(x)
        # Pad the 1 x 1 image to 2 * (start_h - 1) x 2 * (start_w - 1)
        # with zeros for the next conv.
        x = tf.pad(x, [[0] * 2, [start_h - 1] * 2, [start_w - 1] * 2, [0] * 2])
        # The output is start_h x start_w x num_filters_fn(1).
        x = _conv2d('conv0', x, (start_h, start_w), num_filters_fn(1), 'VALID')
        x = _conv2d('conv1', x, kernel_size, num_filters_fn(1))
      lods = [x]

    if resolution_schedule.scale_mode == 'H':
      strides = (resolution_schedule.scale_base, 1)
    else:
      strides = (resolution_schedule.scale_base,
                 resolution_schedule.scale_base)

    for block_id in range(2, num_blocks + 1):
      with tf.variable_scope(block_name(block_id)):
        if simple_arch:
          x = tf.layers.conv2d_transpose(
              x,
              num_filters_fn(block_id),
              kernel_size=kernel_size,
              strides=strides,
              padding='SAME',
              kernel_initializer=he_init)
          x = tf.nn.relu(x)
        else:
          x = resolution_schedule.upscale(x, resolution_schedule.scale_base)
          x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id))
          x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id))
        lods.append(x)

    outputs = []
    for block_id in range(1, num_blocks + 1):
      with tf.variable_scope(block_name(block_id)):
        if simple_arch:
          lod = lods[block_id - 1]
          lod = tf.layers.conv2d(
              lod,
              colors,
              kernel_size=1,
              padding='SAME',
              name='to_rgb',
              kernel_initializer=he_init)
          lod = to_rgb_activation(lod)
        else:
          lod = _to_rgb(lods[block_id - 1])
        scale = resolution_schedule.scale_factor(block_id)
        lod = resolution_schedule.upscale(lod, scale)
        end_points['upscaled_rgb_{}'.format(block_id)] = lod

        # alpha_i is used to replace lod_select. Note sum(alpha_i) is
        # garanteed to be 1.
        alpha = _generator_alpha(block_id, progress)
        end_points['alpha_{}'.format(block_id)] = alpha

        outputs.append(lod * alpha)

    predictions = tf.add_n(outputs)
    batch_size = z.shape[0].value
    predictions.set_shape([batch_size, final_h, final_w, colors])
    end_points['predictions'] = predictions

  return predictions, end_points
示例#3
0
def sample(logits):
    noise = tf.random_uniform(tf.shape(logits))
    return tf.argmax(logits - tf.log(-tf.log(noise)), 1)
示例#4
0
  def __call__(self,
               mixed_features2d,
               cell_state,
               logits2d,
               is_training=False,
               policy="learned"):
    """Builds Saccader cell.

    Args:
      mixed_features2d: 4-D Tensor of shape [batch, height, width, channels].
      cell_state: 4-D Tensor of shape [batch, height, width, 1] with cell state.
      logits2d: 4-D Tensor of shape [batch, height, width, channels].
      is_training: (Boolean) To indicate training or inference modes.
      policy: (String) 'learned': uses learned policy, 'random': uses random
        policy, or 'center': uses center look policy.
    Returns:
      logits: Model logits.
      cell_state: New cell state.
      endpoints: Dictionary with cell parameters.
    """
    batch_size = tf.shape(mixed_features2d)[0]
    _, height, width, channels = mixed_features2d.shape.as_list()
    reuse = True if self.var_list else False
    position_channels = utils.position_channels(mixed_features2d)

    variables_before = set(tf.global_variables())
    with tf.variable_scope("saccader_cell", reuse=reuse):
      # Compute 2D weights of features across space.
      features_space_logits = tf.layers.dense(
          mixed_features2d, units=1,
          use_bias=False, name="attention_weights") / tf.math.sqrt(
              float(channels))

      features_space_logits += (cell_state * -1.e5)  # Mask used locations.
      features_space_weights = utils.softmax2d(features_space_logits)

      # Compute 1D weights of features across channels.
      features_channels_logits = tf.reduce_sum(
          mixed_features2d * features_space_weights, axis=[1, 2])
      features_channels_weights = tf.nn.softmax(
          features_channels_logits, axis=1)

      # Compute location probability.
      locations_logits2d = tf.reduce_sum(
          (mixed_features2d *
           features_channels_weights[:, tf.newaxis, tf.newaxis, :]),
          axis=-1, keepdims=True)

      locations_logits2d += (cell_state * -1e5)  # Mask used locations.
      locations_prob2d = utils.softmax2d(locations_logits2d)

    variables_after = set(tf.global_variables())
    # Compute best locations.
    locations_logits = tf.reshape(
        locations_logits2d, (batch_size, -1))
    all_positions = tf.reshape(
        position_channels, [batch_size, height*width, 2])

    best_locations_labels = tf.argmax(locations_logits, axis=-1)
    best_locations = utils.batch_gather_nd(
        all_positions, best_locations_labels, axis=1)

    # Sample locations.
    if policy == "learned":
      if is_training:
        dist = tfp.distributions.Categorical(logits=locations_logits)
        locations_labels = dist.sample()
        # At training samples location from the learned distribution.
        locations = utils.batch_gather_nd(
            all_positions, locations_labels, axis=1)
        # Ensures range [-1., 1.]
        locations = tf.clip_by_value(locations, -1., 1)
        tf.logging.info("Sampling locations.")
        tf.logging.info("==================================================")
      else:
        # At inference uses the mean value for the location.
        locations = best_locations
        locations_labels = best_locations_labels
    elif policy == "random":
      # Use random policy for location.
      locations = tf.random_uniform(
          shape=(batch_size, 2),
          minval=-1.,
          maxval=1.)
      locations_labels = None
    elif policy == "center":
      # Use center look policy.
      locations = tf.zeros(
          shape=(batch_size, 2))
      locations_labels = None

    # Update cell_state.
    cell_state += utils.onehot2d(cell_state, locations)
    cell_state = tf.clip_by_value(cell_state, 0, 1)
    #########################################################################
    # Extract logits from the 2D logits.
    if self.soft_attention:
      logits = tf.reduce_sum(logits2d * locations_prob2d, axis=[1, 2])
    else:
      logits = gather_2d(logits2d, locations)
    ############################################################
    endpoints = {}
    endpoints["cell_outputs"] = {
        "locations": locations,
        "locations_labels": locations_labels,
        "best_locations": best_locations,
        "best_locations_labels": best_locations_labels,
        "locations_logits2d": locations_logits2d,
        "locations_prob2d": locations_prob2d,
        "cell_state": cell_state,
        "features_space_logits": features_space_logits,
        "features_space_weights": features_space_weights,
        "features_channels_logits": features_channels_logits,
        "features_channels_weights": features_channels_weights,
        "locations_logits": locations_logits,
        "all_positions": all_positions,
    }
    if not reuse:
      self.collect_variables(list(variables_after - variables_before))

    return logits, cell_state, endpoints
示例#5
0
    def _parse_function(*args):
        """Parses the tf example."""
        serialized_example = args[-1]

        context_feature_names = {
            dataset_descriptor.image_id: tf.FixedLenFeature([], tf.string),
        }
        sequence_feature_names = {}
        if flags.use_ref_exp:
            context_feature_names[REF_EXP_ID] = tf.FixedLenFeature([],
                                                                   tf.string)

        if flags.use_labels:
            if dataset_descriptor.has_candidate:
                context_feature_names[
                    SELECTED_CANDIDATE_ID] = tf.FixedLenFeature([], tf.int64)
                sequence_feature_names[
                    ELEMENTS_MASK_ID] = tf.FixedLenSequenceFeature([],
                                                                   tf.string)
            else:
                context_feature_names[
                    dataset_descriptor.label_id] = tf.FixedLenFeature(
                        [], tf.string)

        if dataset_descriptor.has_elements_boxes:
            sequence_feature_names[
                dataset_descriptor.
                elements_box_id] = tf.FixedLenSequenceFeature([4],
                                                              dtype=tf.float32)
        if flags.use_elements_texts:
            sequence_feature_names[
                dataset_descriptor.
                elements_text_id] = tf.FixedLenSequenceFeature([],
                                                               dtype=tf.string)
        if flags.use_elements_neighbors:
            sequence_feature_names[
                ELEMENTS_NEIGHBORS_ID] = tf.FixedLenSequenceFeature(
                    [], dtype=tf.string)
        if flags.use_elements_ref_match:
            sequence_feature_names[
                ELEMENTS_REF_MATCH_ID] = tf.FixedLenSequenceFeature(
                    [], dtype=tf.string)

        if flags.use_groundtruth_box:
            context_feature_names[GROUNDTRUTH_XMIN_ID] = tf.FixedLenFeature(
                [], tf.float32)
            context_feature_names[GROUNDTRUTH_XMAX_ID] = tf.FixedLenFeature(
                [], tf.float32)
            context_feature_names[GROUNDTRUTH_YMIN_ID] = tf.FixedLenFeature(
                [], tf.float32)
            context_feature_names[GROUNDTRUTH_YMAX_ID] = tf.FixedLenFeature(
                [], tf.float32)

        context_features, sequence_features = tf.parse_single_sequence_example(
            serialized_example,
            context_features=context_feature_names,
            sequence_features=sequence_feature_names,
        )

        features.update(context_features)
        features.update(sequence_features)

        if flags.use_elements_texts:
            features[ELEMENTS_TEXT_ID] = features.pop(
                dataset_descriptor.elements_text_id)
        if dataset_descriptor.has_elements_boxes:
            features[ELEMENTS_BOX_ID] = features.pop(
                dataset_descriptor.elements_box_id)

        image = features.pop(dataset_descriptor.image_id)
        image = tf.image.decode_image(image, channels=3)

        image = tf.cast(image, tf.float32)
        mean_pixel = tf.reshape(
            feature_extractor.mean_pixel(flags.model_variant), [1, 1, 3])

        features[IMAGE_PAD_WEIGHTS_ID] = tf.ones_like(image[:, :, 0:1])
        features[IMAGE_PAD_WEIGHTS_ID] = resize_im(
            features[IMAGE_PAD_WEIGHTS_ID], flags.image_size, 0, 1)
        features[IMAGE_PAD_WEIGHTS_ID] = tf.squeeze(
            features[IMAGE_PAD_WEIGHTS_ID], 2)

        if dataset_descriptor.has_elements_boxes:
            image = resize_im(image, flags.image_size, mean_pixel, 3, features)
        else:
            image = resize_im(image, flags.image_size, mean_pixel, 3)

        if flags.use_labels:
            if dataset_descriptor.has_candidate:
                features[ELEMENTS_MASK_ID] = tf.map_fn(
                    process_label,
                    features.pop(ELEMENTS_MASK_ID),
                    parallel_iterations=128,
                    dtype=tf.int32,
                    name="mask_map")
                features[LABEL_ID] = tf.gather_nd(
                    features[ELEMENTS_MASK_ID],
                    [features[SELECTED_CANDIDATE_ID]])
            else:
                label = features.pop(dataset_descriptor.label_id)
                label = process_label(label)
                features[LABEL_ID] = label

        if flags.use_elements_texts:
            features[ELEMENTS_EXIST_ID] = tf.ones_like(
                features[ELEMENTS_TEXT_ID], dtype=tf.int32)
        elif dataset_descriptor.has_elements_boxes:
            features[ELEMENTS_EXIST_ID] = tf.ones(tf.shape(
                features[ELEMENTS_BOX_ID])[:1],
                                                  dtype=tf.int32)

        if flags.use_elements_neighbors:
            features[ELEMENTS_NEIGHBORS_ID] = convert_string_neighbors(
                features[ELEMENTS_NEIGHBORS_ID])

        features[IMAGE_ID] = image

        return features
示例#6
0
def model_fn(features, labels, mode, params, config):
    """Builds the acoustic model."""
    del config
    hparams = params

    length = features.length
    spec = features.spec

    is_training = mode == tf.estimator.ModeKeys.TRAIN

    if is_training:
        onset_labels = labels.onsets
        offset_labels = labels.offsets
        velocity_labels = labels.velocities
        frame_labels = labels.labels
        frame_label_weights = labels.label_weights

    if hparams.stop_activation_gradient and not hparams.activation_loss:
        raise ValueError(
            'If stop_activation_gradient is true, activation_loss must be true.')

    losses = {}
    with slim.arg_scope([slim.batch_norm, slim.dropout], is_training=is_training):
        with tf.variable_scope('onsets'):
            onset_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.onset_lstm_units,
                lengths=length,
                is_training=is_training)
            onset_probs = slim.fully_connected(
                onset_outputs,
                constants.MIDI_PITCHES,
                activation_fn=tf.sigmoid,
                scope='onset_probs')

            # onset_probs_flat is used during inference.
            onset_probs_flat = flatten_maybe_padded_sequences(onset_probs, length)
            if is_training:
                onset_labels_flat = flatten_maybe_padded_sequences(onset_labels, length)
                onset_losses = tf_utils.log_loss(onset_labels_flat, onset_probs_flat)
                tf.losses.add_loss(tf.reduce_mean(onset_losses))
                losses['onset'] = onset_losses
        with tf.variable_scope('offsets'):
            offset_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.offset_lstm_units,
                lengths=length,
                is_training=is_training)
            offset_probs = slim.fully_connected(
                offset_outputs,
                constants.MIDI_PITCHES,
                activation_fn=tf.sigmoid,
                scope='offset_probs')

            # offset_probs_flat is used during inference.
            offset_probs_flat = flatten_maybe_padded_sequences(offset_probs, length)
            if is_training:
                offset_labels_flat = flatten_maybe_padded_sequences(
                    offset_labels, length)
                offset_losses = tf_utils.log_loss(offset_labels_flat, offset_probs_flat)
                tf.losses.add_loss(tf.reduce_mean(offset_losses))
                losses['offset'] = offset_losses
        with tf.variable_scope('velocity'):
            velocity_outputs = acoustic_model(
                spec,
                hparams,
                lstm_units=hparams.velocity_lstm_units,
                lengths=length,
                is_training=is_training)
            velocity_values = slim.fully_connected(
                velocity_outputs,
                constants.MIDI_PITCHES,
                activation_fn=None,
                scope='onset_velocities')

            velocity_values_flat = flatten_maybe_padded_sequences(
                velocity_values, length)
            if is_training:
                velocity_labels_flat = flatten_maybe_padded_sequences(
                    velocity_labels, length)
                velocity_loss = tf.reduce_sum(
                    onset_labels_flat *
                    tf.square(velocity_labels_flat - velocity_values_flat),
                    axis=1)
                tf.losses.add_loss(tf.reduce_mean(velocity_loss))
                losses['velocity'] = velocity_loss

        with tf.variable_scope('frame'):
            if not hparams.share_conv_features:
                # TODO(eriche): this is broken when hparams.frame_lstm_units > 0
                activation_outputs = acoustic_model(
                    spec,
                    hparams,
                    lstm_units=hparams.frame_lstm_units,
                    lengths=length,
                    is_training=is_training)
                activation_probs = slim.fully_connected(
                    activation_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')
            else:
                activation_probs = slim.fully_connected(
                    onset_outputs,
                    constants.MIDI_PITCHES,
                    activation_fn=tf.sigmoid,
                    scope='activation_probs')

            probs = []
            if hparams.stop_onset_gradient:
                probs.append(tf.stop_gradient(onset_probs))
            else:
                probs.append(onset_probs)

            if hparams.stop_activation_gradient:
                probs.append(tf.stop_gradient(activation_probs))
            else:
                probs.append(activation_probs)

            if hparams.stop_offset_gradient:
                probs.append(tf.stop_gradient(offset_probs))
            else:
                probs.append(offset_probs)

            combined_probs = tf.concat(probs, 2)

            if hparams.combined_lstm_units > 0:
                outputs = lstm_layer(
                    combined_probs,
                    hparams.batch_size,
                    hparams.combined_lstm_units,
                    lengths=length if hparams.use_lengths else None,
                    stack_size=hparams.combined_rnn_stack_size,
                    use_cudnn=hparams.use_cudnn,
                    is_training=is_training,
                    bidirectional=hparams.bidirectional)
            else:
                outputs = combined_probs

            frame_probs = slim.fully_connected(
                outputs,
                constants.MIDI_PITCHES,
                activation_fn=tf.sigmoid,
                scope='frame_probs')

        frame_probs_flat = flatten_maybe_padded_sequences(frame_probs, length)

        if is_training:
            frame_labels_flat = flatten_maybe_padded_sequences(frame_labels, length)
            frame_label_weights_flat = flatten_maybe_padded_sequences(
                frame_label_weights, length)
            if hparams.weight_frame_and_activation_loss:
                frame_loss_weights = frame_label_weights_flat
            else:
                frame_loss_weights = None
            frame_losses = tf_utils.log_loss(
                frame_labels_flat, frame_probs_flat, weights=frame_loss_weights)
            tf.losses.add_loss(tf.reduce_mean(frame_losses))
            losses['frame'] = frame_losses

            if hparams.activation_loss:
                if hparams.weight_frame_and_activation_loss:
                    activation_loss_weights = frame_label_weights
                else:
                    activation_loss_weights = None
                activation_losses = tf_utils.log_loss(
                    frame_labels_flat,
                    flatten_maybe_padded_sequences(activation_probs, length),
                    weights=activation_loss_weights)
                tf.losses.add_loss(tf.reduce_mean(activation_losses))
                losses['activation'] = activation_losses

    frame_predictions = frame_probs_flat > hparams.predict_frame_threshold
    onset_predictions = onset_probs_flat > hparams.predict_onset_threshold
    offset_predictions = offset_probs_flat > hparams.predict_offset_threshold

    frame_predictions = tf.expand_dims(frame_predictions, axis=0)
    onset_predictions = tf.expand_dims(onset_predictions, axis=0)
    offset_predictions = tf.expand_dims(offset_predictions, axis=0)
    velocity_values = tf.expand_dims(velocity_values_flat, axis=0)

    metrics_values = metrics.define_metrics(
        frame_probs=frame_probs,
        onset_probs=onset_probs,
        frame_predictions=frame_predictions,
        onset_predictions=onset_predictions,
        offset_predictions=offset_predictions,
        velocity_values=velocity_values,
        length=features.length,
        sequence_label=labels.note_sequence,
        frame_labels=labels.labels,
        sequence_id=features.sequence_id,
        hparams=hparams)

    for label, loss_collection in losses.items():
        loss_label = 'losses/' + label
        metrics_values[loss_label] = loss_collection

    def predict_sequence():
        """Convert frame predictions into a sequence (TF)."""

        def _predict(frame_probs, onset_probs, frame_predictions, onset_predictions,
                     offset_predictions, velocity_values):
            """Convert frame predictions into a sequence (Python)."""
            sequence = infer_util.predict_sequence(
                frame_probs=frame_probs,
                onset_probs=onset_probs,
                frame_predictions=frame_predictions,
                onset_predictions=onset_predictions,
                offset_predictions=offset_predictions,
                velocity_values=velocity_values,
                hparams=hparams,
                min_pitch=constants.MIN_MIDI_PITCH)
            return sequence.SerializeToString()

        sequence = tf.py_func(
            _predict,
            inp=[
                frame_probs[0],
                onset_probs[0],
                frame_predictions[0],
                onset_predictions[0],
                offset_predictions[0],
                velocity_values[0],
            ],
            Tout=tf.string,
            stateful=False)
        sequence.set_shape([])
        return tf.expand_dims(sequence, axis=0)

    predictions = {
        'frame_probs': tf.expand_dims(frame_probs_flat, axis=0),
        'frame_predictions': frame_predictions,
        'onset_predictions': onset_predictions,
        'offset_predictions': offset_predictions,
        'velocity_values': velocity_values,
        'sequence_predictions': predict_sequence(),
        # Include some features and labels in output because Estimator 'predict'
        # API does not give access to them.
        'sequence_ids': features.sequence_id,
        'sequence_labels': labels.note_sequence,
        'frame_labels': labels.labels,
    }
    for k, v in metrics_values.items():
        predictions[k] = tf.stack(v)

    metric_ops = {k: tf.metrics.mean(v) for k, v in metrics_values.items()}

    train_op = None
    loss = None
    if is_training:
        # Creates a pianoroll labels in red and probs in green [minibatch, 88]
        images = {}
        onset_pianorolls = tf.concat([
            onset_labels[:, :, :, tf.newaxis], onset_probs[:, :, :, tf.newaxis],
            tf.zeros(tf.shape(onset_labels))[:, :, :, tf.newaxis]
        ],
            axis=3)
        images['OnsetPianorolls'] = onset_pianorolls
        offset_pianorolls = tf.concat([
            offset_labels[:, :, :, tf.newaxis], offset_probs[:, :, :, tf.newaxis],
            tf.zeros(tf.shape(offset_labels))[:, :, :, tf.newaxis]
        ],
            axis=3)
        images['OffsetPianorolls'] = offset_pianorolls
        activation_pianorolls = tf.concat([
            frame_labels[:, :, :, tf.newaxis], frame_probs[:, :, :, tf.newaxis],
            tf.zeros(tf.shape(frame_labels))[:, :, :, tf.newaxis]
        ],
            axis=3)
        images['ActivationPianorolls'] = activation_pianorolls
        for name, image in images.items():
            tf.summary.image(name, image)

        loss = tf.losses.get_total_loss()
        tf.summary.scalar('loss', loss)
        for label, loss_collection in losses.items():
            loss_label = 'losses/' + label
            tf.summary.scalar(loss_label, tf.reduce_mean(loss_collection))

        train_op = contrib_layers.optimize_loss(
            name='training',
            loss=loss,
            global_step=tf.train.get_or_create_global_step(),
            learning_rate=hparams.learning_rate,
            learning_rate_decay_fn=functools.partial(
                tf.train.exponential_decay,
                decay_steps=hparams.decay_steps,
                decay_rate=hparams.decay_rate,
                staircase=True),
            clip_gradients=hparams.clip_norm,
            optimizer='Adam')

    return tf.estimator.EstimatorSpec(
        mode=mode, predictions=predictions, loss=loss, train_op=train_op,
        eval_metric_ops=metric_ops)
 def build_action_sampling(self):
     mean, logstd = self.parameters
     self.sample_ac = mean + tf.exp(logstd) * tf.random_normal(
         tf.shape(mean), 0, 1)
示例#8
0
def make_global_local_transformer_side_inputs_from_example_ids(
        long_example_ids: tf.Tensor,
        global_example_ids: tf.Tensor,
        sentence_ids: tf.Tensor,
        local_radius: int,
        relative_pos_max_distance: int,
        use_hard_g2l_mask: bool = False,
        use_hard_l2g_mask: bool = False,
        name: Optional[Text] = None) -> GlobalLocalTransformerSideInputs:
    """Makes side input tensors based on the given example and sentence ids.

  When packing examples (e.g. for pre-training), each example must have a
  unique id for `long_example_ids`/`global_example_ids`, and padding must
  also have a unique id distinct from all the example ids.

  When not packing examples, there will simply be two unique ids: one for
  example tokens, and another for padding.  Note that in this case, the classic
  BERT `input_mask` is a valid special case of `long_example_ids`.

  The other arguments have the same interpretation as in
  `make_global_local_transformer_side_inputs`.

  Args:
    long_example_ids: <int32>[batch_size, long_seq_len] Tensor of example ids of
      different packed examples.
    global_example_ids: <int32>[batch_size, global_seq_len] Tensor of example
      ids of different packed examples.
    sentence_ids: <int32>[batch_size, long_seq_len] Tensor of ids indicating
      which sentence each token belongs to. For this dataset, "sentence" refers
      to real natural language sentence, not a BERT "sentence" from the "next
      sentence prediction" task.
    local_radius: How many tokens to the left/right for input tokens to locally
      self-attend to. For example, a value of 1 would allow each token to only
      attend to 1 token to the left and 1 token to the right of it.
    relative_pos_max_distance: Maximum distance to use for relative position
      representations. All larger distances will be clipped to this value. Use 0
      to skip relative position representations entirely.
    use_hard_g2l_mask: If True, global tokens only attend to tokens of the
      corresponding sentences in the long input. If False, global tokens attend
      to all sentences within the corresponding global example.
    use_hard_l2g_mask: If True, long tokens only attend to tokens of the
      corresponding global tokens. If False, long tokens attend to all the
      global tokens within the corresponding global example.
    name: A name for the operation (optional).

  Returns:
    A `GlobalLocalTransformerSideInputs` with all relevant tensors set.
  """
    with tf.name_scope(name or 'make_global_local_transformer_side_inputs'):
        long_example_ids = tf.convert_to_tensor(long_example_ids)
        global_example_ids = tf.convert_to_tensor(global_example_ids)
        sentence_ids = tf.convert_to_tensor(sentence_ids)

        long_seq_len = tensor_utils.get_shape_list(long_example_ids)[1]
        global_seq_len = tensor_utils.get_shape_list(global_example_ids)[1]

        l2l_att_mask = feature_utils.make_local_segmented_att_mask(
            long_example_ids, local_radius)
        g2g_att_mask = feature_utils.make_segmented_att_mask(
            global_example_ids)

        l2g_att_mask = tf.cast(
            tf.equal(long_example_ids[:, :, tf.newaxis],
                     global_example_ids[:, tf.newaxis, :]), tf.int32)
        g2l_att_mask = tf.transpose(l2g_att_mask, perm=[0, 2, 1])

        if use_hard_g2l_mask:
            # Have each global token attend to just one sentence instead of having
            # it attend to all the sentences within a global example.
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            hard_g2l_att_mask = tf.cast(
                tf.equal(global_range[tf.newaxis, :, tf.newaxis],
                         sentence_ids[:, tf.newaxis, :]), tf.int32)
            g2l_att_mask *= hard_g2l_att_mask

        if use_hard_l2g_mask:
            # Have each long token attend to just the corresponding global token
            # instead of having it attend to all the global tokens within a
            # global example.
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            hard_l2g_att_mask = tf.cast(
                tf.equal(sentence_ids[:, :, tf.newaxis],
                         global_range[tf.newaxis, tf.newaxis, :]), tf.int32)
            l2g_att_mask *= hard_l2g_att_mask

        batch_size = tf.shape(long_example_ids)[0]

        l2l_relative_att_ids = None
        g2g_relative_att_ids = None
        l2g_relative_att_ids = None
        g2l_relative_att_ids = None

        if relative_pos_max_distance > 0:
            relative_pos_generator = feature_utils.RelativePositionGenerator(
                relative_pos_max_distance)
            l2l_relative_att_ids = relative_pos_generator.make_local_relative_att_ids(
                seq_len=long_seq_len,
                local_radius=local_radius,
                batch_size=batch_size)
            g2g_relative_att_ids = relative_pos_generator.make_relative_att_ids(
                seq_len=global_seq_len, batch_size=batch_size)
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            l2g_relative_att_ids = tf.cast(
                tf.equal(sentence_ids[:, :, tf.newaxis],
                         global_range[tf.newaxis, tf.newaxis, :]), tf.int32)
            g2l_relative_att_ids = tf.transpose(l2g_relative_att_ids,
                                                perm=[0, 2, 1])

            # For fused attention, l2l and l2g share the same relative vocabulary, as
            # do g2g and g2l, so we add an offset for l2g and g2l so their original
            # 0/1 ids don't collide with l2l and g2g relative position ids.
            l2g_relative_att_ids += relative_pos_generator.relative_vocab_size
            g2l_relative_att_ids += relative_pos_generator.relative_vocab_size

        return GlobalLocalTransformerSideInputs(
            l2l_att_mask=l2l_att_mask,
            g2g_att_mask=g2g_att_mask,
            l2g_att_mask=l2g_att_mask,
            g2l_att_mask=g2l_att_mask,
            l2l_relative_att_ids=l2l_relative_att_ids,
            g2g_relative_att_ids=g2g_relative_att_ids,
            l2g_relative_att_ids=l2g_relative_att_ids,
            g2l_relative_att_ids=g2l_relative_att_ids)
示例#9
0
 def image_summary_or_default_string(summary_name, image):
   """Returns image summaries for non-padded elements."""
   return tf.cond(
       tf.equal(tf.size(tf.shape(image)), 4),
       lambda: tf.summary.image(summary_name, image),
       lambda: tf.constant(''))
示例#10
0
def _generate_detections_v2(boxes,
                            scores,
                            max_total_size=100,
                            nms_iou_threshold=0.3,
                            score_threshold=0.05,
                            pre_nms_num_boxes=5000):
    """Generate the final detections given the model outputs.

  This uses classes unrolling with while loop based NMS, could be parralled
  at batch dimension.

  Args:
    boxes: a tensor with shape [batch_size, N, num_classes, 4] or [batch_size,
      N, 1, 4], which box predictions on all feature levels. The N is the number
      of total anchors on all levels.
    scores: a tensor with shape [batch_size, N, num_classes], which stacks class
      probability on all feature levels. The N is the number of total anchors on
      all levels. The num_classes is the number of classes predicted by the
      model. Note that the class_outputs here is the raw score.
    max_total_size: a scalar representing maximum number of boxes retained over
      all classes.
    nms_iou_threshold: a float representing the threshold for deciding whether
      boxes overlap too much with respect to IOU.
    score_threshold: a float representing the threshold for deciding when to
      remove boxes based on score.
    pre_nms_num_boxes: an int number of top candidate detections per class
      before NMS.

  Returns:
    nmsed_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
      representing top detected boxes in [y1, x1, y2, x2].
    nmsed_scores: `float` Tensor of shape [batch_size, max_total_size]
      representing sorted confidence scores for detected boxes. The values are
      between [0, 1].
    nmsed_classes: `int` Tensor of shape [batch_size, max_total_size]
      representing classes for detected boxes.
    valid_detections: `int` Tensor of shape [batch_size] only the top
      `valid_detections` boxes are valid detections.
  """
    with tf.name_scope('generate_detections'):
        # Normalizes maximum box cooridinates to 1.
        normalizer = tf.reduce_max(boxes)
        boxes /= normalizer

        nmsed_boxes = []
        nmsed_classes = []
        nmsed_scores = []
        valid_detections = []
        num_classes_for_boxes = tf.shape(boxes)[2]
        total_anchors = tf.shape(scores)[1]
        num_classes = scores.get_shape().as_list()[2]
        # Selects top pre_nms_num scores and indices before NMS.
        scores, indices = _select_top_k_scores(
            scores, tf.minimum(total_anchors, pre_nms_num_boxes))
        for i in range(num_classes):
            boxes_i = boxes[:, :, tf.minimum(num_classes_for_boxes - 1, i), :]
            scores_i = scores[:, :, i]
            # Obtains pre_nms_num_boxes before running NMS.
            boxes_i = tf.gather(boxes_i,
                                indices[:, :, i],
                                batch_dims=1,
                                axis=1)

            # Filter out scores.
            boxes_i, scores_i = box_utils.filter_boxes_by_scores(
                boxes_i, scores_i, min_score_threshold=score_threshold)

            (nmsed_scores_i,
             nmsed_boxes_i) = nms.sorted_non_max_suppression_padded(
                 tf.cast(scores_i, tf.float32),
                 tf.cast(boxes_i, tf.float32),
                 max_total_size,
                 iou_threshold=nms_iou_threshold)
            nmsed_classes_i = tf.fill(tf.shape(nmsed_scores_i), i)
            nmsed_boxes.append(nmsed_boxes_i)
            nmsed_scores.append(nmsed_scores_i)
            nmsed_classes.append(nmsed_classes_i)
    nmsed_boxes = tf.concat(nmsed_boxes, axis=1)
    nmsed_scores = tf.concat(nmsed_scores, axis=1)
    nmsed_classes = tf.concat(nmsed_classes, axis=1)
    nmsed_scores, indices = tf.nn.top_k(nmsed_scores,
                                        k=max_total_size,
                                        sorted=True)
    nmsed_boxes = tf.gather(nmsed_boxes, indices, batch_dims=1, axis=1)
    nmsed_classes = tf.gather(nmsed_classes, indices, batch_dims=1)
    valid_detections = tf.reduce_sum(input_tensor=tf.cast(
        tf.greater(nmsed_scores, -1), tf.int32),
                                     axis=1)
    # De-normalizes box cooridinates.
    nmsed_boxes *= normalizer
    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
示例#11
0
    def __call__(self, box_outputs, class_outputs, anchor_boxes, image_shape):
        """Generate final detections.

    Args:
      box_outputs: a tensor of shape of [batch_size, K, num_classes * 4]
        representing the class-specific box coordinates relative to anchors.
      class_outputs: a tensor of shape of [batch_size, K, num_classes]
        representing the class logits before applying score activiation.
      anchor_boxes: a tensor of shape of [batch_size, K, 4] representing the
        corresponding anchor boxes w.r.t `box_outputs`.
      image_shape: a tensor of shape of [batch_size, 2] storing the image height
        and width w.r.t. the scaled image, i.e. the same image space as
        `box_outputs` and `anchor_boxes`.

    Returns:
      nmsed_boxes: `float` Tensor of shape [batch_size, max_total_size, 4]
        representing top detected boxes in [y1, x1, y2, x2].
      nmsed_scores: `float` Tensor of shape [batch_size, max_total_size]
        representing sorted confidence scores for detected boxes. The values are
        between [0, 1].
      nmsed_classes: `int` Tensor of shape [batch_size, max_total_size]
        representing classes for detected boxes.
      valid_detections: `int` Tensor of shape [batch_size] only the top
        `valid_detections` boxes are valid detections.
    """
        class_outputs = tf.nn.softmax(class_outputs, axis=-1)

        # Removes the background class.
        class_outputs_shape = tf.shape(class_outputs)
        num_locations = class_outputs_shape[1]
        num_classes = class_outputs_shape[-1]
        num_detections = num_locations * (num_classes - 1)

        class_outputs = tf.slice(class_outputs, [0, 0, 1], [-1, -1, -1])
        box_outputs = tf.reshape(box_outputs,
                                 [-1, num_locations, num_classes, 4])
        box_outputs = tf.slice(box_outputs, [0, 0, 1, 0], [-1, -1, -1, -1])
        anchor_boxes = tf.tile(tf.expand_dims(anchor_boxes, axis=2),
                               [1, 1, num_classes - 1, 1])
        box_outputs = tf.reshape(box_outputs, [-1, num_detections, 4])
        anchor_boxes = tf.reshape(anchor_boxes, [-1, num_detections, 4])

        # Box decoding.
        decoded_boxes = box_utils.decode_boxes(box_outputs,
                                               anchor_boxes,
                                               weights=[10.0, 10.0, 5.0, 5.0])

        # Box clipping
        decoded_boxes = box_utils.clip_boxes(decoded_boxes, image_shape)

        decoded_boxes = tf.reshape(decoded_boxes,
                                   [-1, num_locations, num_classes - 1, 4])

        if not self._apply_nms:
            return {
                'raw_boxes': decoded_boxes,
                'raw_scores': class_outputs,
            }

        nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections = (
            self._generate_detections(decoded_boxes, class_outputs))

        # Adds 1 to offset the background class which has index 0.
        nmsed_classes += 1

        return {
            'num_detections': valid_detections,
            'detection_boxes': nmsed_boxes,
            'detection_classes': nmsed_classes,
            'detection_scores': nmsed_scores,
        }
示例#12
0
def _generate_detections_per_image(boxes,
                                   scores,
                                   max_total_size=100,
                                   nms_iou_threshold=0.3,
                                   score_threshold=0.05,
                                   pre_nms_num_boxes=5000):
    """Generate the final detections per image given the model outputs.

  Args:
    boxes: a tensor with shape [N, num_classes, 4] or [N, 1, 4], which box
      predictions on all feature levels. The N is the number of total anchors on
      all levels.
    scores: a tensor with shape [N, num_classes], which stacks class probability
      on all feature levels. The N is the number of total anchors on all levels.
      The num_classes is the number of classes predicted by the model. Note that
      the class_outputs here is the raw score.
    max_total_size: a scalar representing maximum number of boxes retained over
      all classes.
    nms_iou_threshold: a float representing the threshold for deciding whether
      boxes overlap too much with respect to IOU.
    score_threshold: a float representing the threshold for deciding when to
      remove boxes based on score.
    pre_nms_num_boxes: an int number of top candidate detections per class
      before NMS.

  Returns:
    nmsed_boxes: `float` Tensor of shape [max_total_size, 4] representing top
      detected boxes in [y1, x1, y2, x2].
    nmsed_scores: `float` Tensor of shape [max_total_size] representing sorted
      confidence scores for detected boxes. The values are between [0, 1].
    nmsed_classes: `int` Tensor of shape [max_total_size] representing classes
      for detected boxes.
    valid_detections: `int` Tensor of shape [1] only the top `valid_detections`
      boxes are valid detections.
  """
    nmsed_boxes = []
    nmsed_scores = []
    nmsed_classes = []
    num_classes_for_box = boxes.get_shape().as_list()[1]
    num_classes = scores.get_shape().as_list()[1]
    for i in range(num_classes):
        boxes_i = boxes[:, min(num_classes_for_box - 1, i)]
        scores_i = scores[:, i]

        # Obtains pre_nms_num_boxes before running NMS.
        scores_i, indices = tf.nn.top_k(scores_i,
                                        k=tf.minimum(
                                            tf.shape(scores_i)[-1],
                                            pre_nms_num_boxes))
        boxes_i = tf.gather(boxes_i, indices)

        (nmsed_indices_i,
         nmsed_num_valid_i) = tf.image.non_max_suppression_padded(
             tf.cast(boxes_i, tf.float32),
             tf.cast(scores_i, tf.float32),
             max_total_size,
             iou_threshold=nms_iou_threshold,
             score_threshold=score_threshold,
             pad_to_max_output_size=True,
             name='nms_detections_' + str(i))
        nmsed_boxes_i = tf.gather(boxes_i, nmsed_indices_i)
        nmsed_scores_i = tf.gather(scores_i, nmsed_indices_i)
        # Sets scores of invalid boxes to -1.
        nmsed_scores_i = tf.where(
            tf.less(tf.range(max_total_size), [nmsed_num_valid_i]),
            nmsed_scores_i, -tf.ones_like(nmsed_scores_i))
        nmsed_classes_i = tf.fill([max_total_size], i)
        nmsed_boxes.append(nmsed_boxes_i)
        nmsed_scores.append(nmsed_scores_i)
        nmsed_classes.append(nmsed_classes_i)

    # Concats results from all classes and sort them.
    nmsed_boxes = tf.concat(nmsed_boxes, axis=0)
    nmsed_scores = tf.concat(nmsed_scores, axis=0)
    nmsed_classes = tf.concat(nmsed_classes, axis=0)
    nmsed_scores, indices = tf.nn.top_k(nmsed_scores,
                                        k=max_total_size,
                                        sorted=True)
    nmsed_boxes = tf.gather(nmsed_boxes, indices)
    nmsed_classes = tf.gather(nmsed_classes, indices)
    valid_detections = tf.reduce_sum(
        tf.cast(tf.greater(nmsed_scores, -1), tf.int32))
    return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]
        is_real_example = None
        if "is_real_example" in features:
            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)
        else:
            is_real_example = tf.ones(tf.shape(label_ids), dtype=tf.float32)

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        (total_loss, per_example_loss, probabilities, logits, predictions) = \
            create_model(
                albert_config,
                is_training,
                input_ids,
                input_mask,
                segment_ids,
                label_ids,
                num_labels,
                use_one_hot_embeddings,
                task_name,
                hub_module
            )

        tvars = tf.trainable_variables()
        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = modeling.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:

            train_op = optimization.create_optimizer(total_loss, learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps, use_tpu,
                                                     optimizer)

            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       loss=total_loss,
                                                       train_op=train_op,
                                                       scaffold_fn=scaffold_fn)
        elif mode == tf.estimator.ModeKeys.EVAL:
            if task_name not in ["sts-b", "cola", "nlpcc_dbqa"]:

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)
                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)
                    return {
                        "eval_accuracy": accuracy,
                        "eval_loss": loss,
                    }
            elif task_name == "sts-b":

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    """Compute Pearson correlations for STS-B."""
                    # Display labels and predictions
                    concat1 = contrib_metrics.streaming_concat(logits)
                    concat2 = contrib_metrics.streaming_concat(label_ids)

                    # Compute Pearson correlation
                    pearson = contrib_metrics.streaming_pearson_correlation(
                        logits, label_ids, weights=is_real_example)

                    # Compute MSE
                    # mse = tf.metrics.mean(per_example_loss)
                    mse = tf.metrics.mean_squared_error(
                        label_ids, logits, weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "pred": concat1,
                        "label_ids": concat2,
                        "pearson": pearson,
                        "MSE": mse,
                        "eval_loss": loss,
                    }
            elif task_name == "cola":

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):
                    """Compute Matthew's correlations for STS-B."""
                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    # https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
                    tp, tp_op = tf.metrics.true_positives(
                        predictions, label_ids, weights=is_real_example)
                    tn, tn_op = tf.metrics.true_negatives(
                        predictions, label_ids, weights=is_real_example)
                    fp, fp_op = tf.metrics.false_positives(
                        predictions, label_ids, weights=is_real_example)
                    fn, fn_op = tf.metrics.false_negatives(
                        predictions, label_ids, weights=is_real_example)

                    # Compute Matthew's correlation
                    mcc = tf.div_no_nan(
                        tp * tn - fp * fn,
                        tf.pow((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn),
                               0.5))

                    # Compute accuracy
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "matthew_corr":
                        (mcc, tf.group(tp_op, tn_op, fp_op, fn_op)),
                        "eval_accuracy": accuracy,
                        "eval_loss": loss,
                    }

            elif task_name == "nlpcc_dbqa":

                def metric_fn(per_example_loss, label_ids, logits,
                              is_real_example):

                    predictions = tf.argmax(logits,
                                            axis=-1,
                                            output_type=tf.int32)
                    precision, precision_update_op = tf.metrics.precision(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example,
                        name="precision")
                    recall, recall_update_op = tf.metrics.recall(
                        labels=label_ids,
                        predictions=predictions,
                        weights=is_real_example,
                        name='recall')

                    f1_score, f1_update_op = tf.metrics.mean(
                        (2 * (precision + 1e-7) *
                         (recall + 1e-7)) / (precision + recall + 2e-7),
                        name='f1_score')

                    # Compute accuracy
                    accuracy = tf.metrics.accuracy(labels=label_ids,
                                                   predictions=predictions,
                                                   weights=is_real_example)

                    loss = tf.metrics.mean(values=per_example_loss,
                                           weights=is_real_example)

                    return {
                        "precision": (precision, precision_update_op),
                        "recall": (recall, recall_update_op),
                        "f1_score": (f1_score, f1_update_op),
                        "eval_accuracy": accuracy,
                        "eval_loss": loss,
                    }

            eval_metrics = (metric_fn, [
                per_example_loss, label_ids, logits, is_real_example
            ])
            output_spec = contrib_tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            output_spec = contrib_tpu.TPUEstimatorSpec(mode=mode,
                                                       predictions={
                                                           "probabilities":
                                                           probabilities,
                                                           "predictions":
                                                           predictions
                                                       },
                                                       scaffold_fn=scaffold_fn)
        return output_spec
示例#14
0
 def _to_constant_shape(tensor):
     tensor = tensor[:length]
     tensor = tf.pad(tensor, [(0, length - tf.shape(tensor)[0])])
     return tf.reshape(tensor, [length])
示例#15
0
def test_compress(args):
    """Compresses an image."""

    # Load input image and add batch dimension.
    x = read_png(args.input_file)
    x = tf.expand_dims(x, 0)
    x.set_shape([1, None, None, 3])
    x_shape = tf.shape(x)

    net = DyTFC(192)
    net.build(x)

    sess = tf.Session()
    latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
    tf.train.Saver().restore(sess, save_path=latest)
    print(
        sess.run(
            tf.reduce_sum(tf.log(net.y_likelihoods), axis=(0, 1, 2)) /
            (-np.log(2) * net.num_pixels)))
    return

    #vnames = ['gaussian_conditional/quantized_cdf:0', 'gaussian_conditional/cdf_length:0']
    #old_cb_weights = net.conditional_bottleneck.get_weights()

    #print(old_cb_weights)
    #net.set_active(192)
    #net.build(x)
    #sess.run(tf.variables_initializer(get_uninitialized_variables(sess)))
    #sess.run(tf.variables_initializer([net.vst[name] for name in vnames]))
    #net.conditional_bottleneck.set_weights(old_cb_weights)

    #
    #tf.train.Saver().save(sess,"./sort128/model.ckpt")

    tensors = [
        net.string, net.side_string, net.x_shape[1:-1], net.y_shape[1:-1],
        net.z_shape[1:-1]
    ]

    arrays = sess.run(tensors)

    # Write a binary file with the shape information and the compressed string.
    packed = tfc.PackedTensors()
    packed.pack(tensors, arrays)
    with open(args.output_file, "wb") as f:
        f.write(packed.string)

    # If requested, transform the quantized image back and measure performance.
    if args.verbose:
        eval_bpp, mse, psnr, msssim, num_pixels = sess.run(
            [net.eval_bpp, net.mse, net.psnr, net.msssim, net.num_pixels])

        # The actual bits per pixel including overhead.
        bpp = len(packed.string) * 8 / num_pixels

        print("Mean squared error: {:0.4f}".format(mse))
        print("PSNR (dB): {:0.2f}".format(psnr))
        print("Multiscale SSIM: {:0.4f}".format(msssim))
        print("Multiscale SSIM (dB): {:0.2f}".format(-10 *
                                                     np.log10(1 - msssim)))
        print("Information content in bpp: {:0.4f}".format(eval_bpp))
        print("Actual bits per pixel: {:0.4f}".format(bpp))
示例#16
0
def train(train_list, val_list, debug_mode=True):
    print('Running PRLNet -Training!')
    # create folders to save trained model and results
    graph_dir = './graph'
    checkpt_dir = './model'
    ouput_dir = './output'
    exists_or_mkdir(graph_dir, need_remove=True)
    exists_or_mkdir(ouput_dir)
    exists_or_mkdir(checkpt_dir)

    # --------------------------------- load data ---------------------------------
    # data fetched at range: [-1,1]
    input_imgs, target_imgs, num = input_producer(train_list,
                                                  in_channels,
                                                  batch_size,
                                                  need_shuffle=True)
    if debug_mode:
        input_val, target_val, num_val = input_producer(val_list,
                                                        in_channels,
                                                        batch_size,
                                                        need_shuffle=False)

    pred_content, pred_detail, pred_imgs = gen_PRLNet(input_imgs,
                                                      out_channels,
                                                      is_train=True,
                                                      reuse=False)
    if debug_mode:
        _, _, pred_val = gen_PRLNet(input_val,
                                    out_channels,
                                    is_train=False,
                                    reuse=True)

    # --------------------------------- loss terms ---------------------------------
    with tf.name_scope('Loss') as loss_scp:
        target_224 = tf.image.resize_images(target_imgs,
                                            size=[224, 224],
                                            method=0,
                                            align_corners=False)
        predict_224 = tf.image.resize_images(pred_imgs,
                                             size=[224, 224],
                                             method=0,
                                             align_corners=False)
        vgg19_api = VGG19(
            "/project/dllau_uksr/mbgi222/InverseHalftoneExp6/Train_mode/vgg19.npy"
        )
        vgg_map_targets = vgg19_api.build((target_224 + 1) / 2,
                                          is_rgb=(in_channels == 3))
        vgg_map_predict = vgg19_api.build((predict_224 + 1) / 2,
                                          is_rgb=(in_channels == 3))

        content_loss = tf.losses.mean_squared_error(target_imgs, pred_content)
        vgg_loss = 2e-6 * tf.losses.mean_squared_error(vgg_map_targets,
                                                       vgg_map_predict)
        l1_loss = tf.reduce_mean(tf.abs(target_imgs - pred_imgs))
        mse_loss = tf.losses.mean_squared_error(target_imgs, pred_imgs)

        loss_op = content_loss + 2 * vgg_loss + l1_loss

    # --------------------------------- solver definition ---------------------------------
    global_step = tf.Variable(0, name='global_step', trainable=False)
    iters_per_epoch = np.floor_divide(num, batch_size)
    lr_decay = tf.train.polynomial_decay(
        learning_rate=learning_rate,
        global_step=global_step,
        decay_steps=iters_per_epoch * n_epochs,
        end_learning_rate=learning_rate / 100.0,
        power=0.9)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.name_scope('optimizer'):
        with tf.control_dependencies(update_ops):
            gen_vars = [
                var for var in tf.trainable_variables()
                if var.name.startswith("PRLNet")
            ]
            gen_optim = tf.train.AdamOptimizer(lr_decay, beta1)
            gen_grads_and_vars = gen_optim.compute_gradients(loss_op,
                                                             var_list=gen_vars)
            train_op = gen_optim.apply_gradients(gen_grads_and_vars,
                                                 global_step=global_step)

    # --------------------------------- model training ---------------------------------
    '''
    if debug_mode:
        with tf.name_scope('summarise') as sum_scope:
            tf.summary.scalar('loss', loss_op)
            tf.summary.scalar('learning rate', lr_decay)
            tf.summary.image('predicts', pred_imgs, max_outputs=9)
            summary_op = tf.summary.merge_all()
    '''

    with tf.name_scope("parameter_count"):
        num_parameters = tf.reduce_sum(
            [tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])

    # set GPU resources
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    #config.gpu_options.per_process_gpu_memory_fraction = 0.45

    saver = tf.train.Saver(max_to_keep=1)
    loss_list = []
    psnr_list = []
    with tf.Session(config=config) as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        sess.run(tf.global_variables_initializer())
        print(">>------------>>> [Training_Num] =%d" % num)
        print(">>------------>>> [Parameter_Num] =%d" %
              sess.run(num_parameters))
        '''
        if debug_mode:
            with tf.name_scope(sum_scope):
                summary_writer = tf.summary.FileWriter(graph_dir, graph=sess.graph)
        '''
        for epoch in range(0, n_epochs):
            start_time = time.time()
            epoch_loss, n_iters = 0, 0
            for step in range(0, num, batch_size):
                _, loss = sess.run([train_op, loss_op])
                epoch_loss += loss
                n_iters += 1
                # iteration information
                if n_iters % display_steps == 0:
                    tm = datetime.datetime.now().strftime(
                        '%Y-%m-%d %H:%M:%S.%f')
                    print("%s >> [%d/%d] iter: %d  loss: %4.4f" %
                          (tm, epoch, n_epochs, n_iters, loss))
                    '''
                    if debug_mode:
                        summary_str = sess.run(summary_op)
                        summary_writer.add_summary(summary_str, step)
                    '''

            # epoch information
            epoch_loss = epoch_loss / n_iters
            loss_list.append(epoch_loss)
            print(
                "[*] ----- Epoch: %d/%d | Loss: %4.4f | Time-consumed: %4.3f -----"
                % (epoch, n_epochs, epoch_loss, (time.time() - start_time)))

            if (epoch + 1) % save_epochs == 0:
                if debug_mode:
                    print("----- validating model ...")
                    mean_psnr, nn = 0, 0
                    for idx in range(0, num_val, batch_size):
                        predicts, groundtruths = sess.run(
                            [pred_val, target_val])
                        save_images_from_batch(predicts, ouput_dir, idx)
                        psnr = measure_psnr(predicts, groundtruths)
                        mean_psnr += psnr
                        nn += 1
                    psnr_list.append(mean_psnr / nn)
                    print("----- psnr:%4.4f" % (mean_psnr / nn))

                print("----- saving model  ...")
                saver.save(sess,
                           os.path.join(checkpt_dir, "model.cpkt"),
                           global_step=global_step)
                save_list(os.path.join(ouput_dir, "loss"), loss_list)
                save_list(os.path.join(ouput_dir, "psnr"), psnr_list)

        # stop data queue
        coord.request_stop()
        coord.join(threads)
        # write out the loss list
        save_list(os.path.join(ouput_dir, "loss"), loss_list)
        save_list(os.path.join(ouput_dir, "psnr"), psnr_list)
        print("Training finished!")

    return None
示例#17
0
def test_decompress(args):
    """Decompresses an image."""

    # Read the shape information and compressed string from the binary file.
    string = tf.placeholder(tf.string, [1])
    side_string = tf.placeholder(tf.string, [1])
    x_shape = tf.placeholder(tf.int32, [2])
    y_shape = tf.placeholder(tf.int32, [2])
    z_shape = tf.placeholder(tf.int32, [2])
    with open(args.input_file, "rb") as f:
        packed = tfc.PackedTensors(f.read())
    tensors = [string, side_string, x_shape, y_shape, z_shape]
    arrays = packed.unpack(tensors)

    # Instantiate model.
    synthesis_transform = SynthesisTransform(args.num_filters)
    hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters)
    entropy_bottleneck = tfc.EntropyBottleneck(dtype=tf.float32)

    # Decompress and transform the image back.
    z_shape = tf.concat([z_shape, [args.num_filters]], axis=0)
    z_hat = entropy_bottleneck.decompress(side_string,
                                          z_shape,
                                          channels=args.num_filters)
    sigma = hyper_synthesis_transform(z_hat)
    sigma = sigma[:, :y_shape[0], :y_shape[1], :]
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma,
                                                     scale_table,
                                                     dtype=tf.float32)
    y_hat_all = conditional_bottleneck.decompress(string)

    x = read_png("kodak/kodim01.png")
    x = tf.expand_dims(x, 0)
    x.set_shape([1, None, None, 3])
    x_shape = tf.shape(x)
    x *= 255

    active = 192
    y_hat = y_hat_all[:, :, :, :active]
    x_hat = synthesis_transform(y_hat)
    x_hat = tf.clip_by_value(x_hat, 0, 1)
    x_hat = tf.round(x_hat * 255)
    mse = tf.reduce_mean(tf.squared_difference(x, x_hat))
    psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255))
    msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255))

    #x_hat = x_hat[0, :x_shape[0], :x_shape[1], :]
    #op = write_png(args.output_file, x_hat)

    sess = tf.Session()
    latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
    tf.train.Saver().restore(sess, save_path=latest)
    #sess.run(op, feed_dict=dict(zip(tensors, arrays)))

    #vmse, vpsnr, vmsssim = sess.run([mse, psnr, msssim], feed_dict=dict(zip(tensors, arrays)))
    #print(vmse, vpsnr, vmsssim)

    for active in range(192, 0, -8):
        y_hat = y_hat_all[:, :, :, :active]
        x_hat = synthesis_transform(y_hat)
        x_hat = tf.clip_by_value(x_hat, 0, 1)
        x_hat = tf.round(x_hat * 255)
        mse = tf.reduce_mean(tf.squared_difference(x, x_hat))
        psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255))
        msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255))
        vmse, vpsnr, vmsssim = sess.run([mse, psnr, msssim],
                                        feed_dict=dict(zip(tensors, arrays)))
        print(active, vmse, vpsnr, vmsssim)
示例#18
0
  def sample(self, n, max_length=None, z=None, c_input=None,
             **core_sampler_kwargs):
    """Sample from decoder with an optional conditional latent vector `z`.

    Args:
      n: Scalar number of samples to return.
      max_length: (Optional) maximum total length of samples. If given, must
        match `hparams.max_seq_len`.
      z: (Optional) Latent vectors to sample from. Required if model is
        conditional. Sized `[n, z_size]`.
      c_input: (Optional) Control sequence, sized `[max_length, control_depth]`.
      **core_sampler_kwargs: (Optional) Additional keyword arguments to pass to
        core sampler.
    Returns:
      samples: Sampled sequences with concenated, possibly padded segments.
         Sized `[n, max_length, output_depth]`.
      decoder_results: The merged LstmDecodeResults from sampling.
    Raises:
      ValueError: If `z` is provided and its first dimension does not equal `n`,
        or if `c_input` is provided in re-encoder mode.
    """
    if z is not None and int(z.shape[0]) != n:
      raise ValueError(
          '`z` must have a first dimension that equals `n` when given. '
          'Got: %d vs %d' % (z.shape[0], n))
    z = tf.zeros([n, 0]) if z is None else z

    if self._hierarchical_encoder and c_input is not None:
      raise ValueError(
          'Re-encoder mode unsupported when conditioning on controls.')

    if max_length is not None:
      with tf.control_dependencies([
          tf.assert_equal(
              max_length, self._total_length,
              message='`max_length` must equal `hparams.max_seq_len` if given.')
      ]):
        max_length = tf.identity(max_length)

    if c_input is not None:
      # Reshape control sequence to hierarchy.
      c_input = tf.squeeze(
          self._reshape_to_hierarchy(tf.expand_dims(c_input, 0)),
          axis=len(self._level_lengths) - 1)

    core_max_length = self._level_lengths[-1]
    all_samples = []
    all_decode_results = []

    def base_sample_fn(embedding, hier_index):
      """Base function for sampling hierarchical decoder."""
      samples, decode_results = self._core_decoder.sample(
          n,
          max_length=core_max_length,
          z=embedding,
          c_input=c_input[hier_index] if c_input is not None else None,
          start_inputs=all_samples[-1][:, -1] if all_samples else None,
          **core_sampler_kwargs)
      all_samples.append(samples)
      all_decode_results.append(decode_results)
      if self._hierarchical_encoder:
        return self._hierarchical_encoder.level(0).encode(
            samples,
            decode_results.final_sequence_lengths)
      else:
        return tf.concat(tf.nest.flatten(decode_results.final_state), axis=-1)

    # Populate `all_sample_ids`.
    self._hierarchical_decode(z, base_sample_fn)

    all_samples = tf.concat(
        [tf.pad(s, [(0, 0), (0, core_max_length - tf.shape(s)[1]), (0, 0)])
         for s in all_samples],
        axis=1)
    return all_samples, self._merge_decode_results(all_decode_results)
    def decode(self, serialized_example):
        """Decode the serialized example.

    Args:
      serialized_example: a single serialized tf.Example string.

    Returns:
      decoded_tensors: a dictionary of tensors with the following fields:
        - image: a uint8 tensor of shape [None, None, 3].
        - source_id: a string scalar tensor.
        - height: an integer scalar tensor.
        - width: an integer scalar tensor.
        - groundtruth_classes: an int64 tensor of shape [None].
        - groundtruth_is_crowd: a bool tensor of shape [None].
        - groundtruth_area: a float32 tensor of shape [None].
        - groundtruth_boxes: a float32 tensor of shape [None, 4].
        - groundtruth_instance_masks: a float32 tensor of shape
            [None, None, None].
        - groundtruth_instance_masks_png: a string tensor of shape [None].
    """
        parsed_tensors = tf.io.parse_single_example(serialized_example,
                                                    self._keys_to_features)
        for k in parsed_tensors:
            if isinstance(parsed_tensors[k], tf.SparseTensor):
                if parsed_tensors[k].dtype == tf.string:
                    parsed_tensors[k] = tf.sparse_tensor_to_dense(
                        parsed_tensors[k], default_value='')
                else:
                    parsed_tensors[k] = tf.sparse_tensor_to_dense(
                        parsed_tensors[k], default_value=0)

        image = self._decode_image(parsed_tensors)
        boxes = self._decode_boxes(parsed_tensors)
        areas = self._decode_areas(parsed_tensors)

        decode_image_shape = tf.logical_or(
            tf.equal(parsed_tensors['image/height'], -1),
            tf.equal(parsed_tensors['image/width'], -1))
        image_shape = tf.cast(tf.shape(image), dtype=tf.int64)

        parsed_tensors['image/height'] = tf.where(
            decode_image_shape, image_shape[0], parsed_tensors['image/height'])
        parsed_tensors['image/width'] = tf.where(decode_image_shape,
                                                 image_shape[1],
                                                 parsed_tensors['image/width'])

        is_crowds = tf.cond(
            tf.greater(tf.shape(parsed_tensors['image/object/is_crowd'])[0], 0),
            lambda: tf.cast(parsed_tensors['image/object/is_crowd'], dtype=tf.bool),
            lambda: tf.zeros_like(parsed_tensors['image/object/class/label'], dtype=tf.bool))  # pylint: disable=line-too-long
        if self._regenerate_source_id:
            source_id = _get_source_id_from_encoded_image(parsed_tensors)
        else:
            source_id = tf.cond(
                tf.greater(
                    tf.strings.length(parsed_tensors['image/source_id']),
                    0), lambda: parsed_tensors['image/source_id'],
                lambda: _get_source_id_from_encoded_image(parsed_tensors))
        if self._include_mask:
            masks = self._decode_masks(parsed_tensors)

        decoded_tensors = {
            'image': image,
            'source_id': source_id,
            'height': parsed_tensors['image/height'],
            'width': parsed_tensors['image/width'],
            'groundtruth_classes': parsed_tensors['image/object/class/label'],
            'groundtruth_is_crowd': is_crowds,
            'groundtruth_area': areas,
            'groundtruth_boxes': boxes,
        }
        if self._include_mask:
            decoded_tensors.update({
                'groundtruth_instance_masks':
                masks,
                'groundtruth_instance_masks_png':
                parsed_tensors['image/object/mask'],
            })
        return decoded_tensors
示例#20
0
  def reconstruction_loss(self, x_input, x_target, x_length, z=None,
                          c_input=None):
    """Reconstruction loss calculation.

    Args:
      x_input: Batch of decoder input sequences for teacher forcing, sized
        `[batch_size, max(x_length), output_depth]`.
      x_target: Batch of expected output sequences to compute loss against,
        sized `[batch_size, max(x_length), output_depth]`.
      x_length: Length of input/output sequences, sized `[batch_size]`.
      z: (Optional) Latent vectors. Required if model is conditional. Sized
        `[n, z_size]`.
      c_input: (Optional) Batch of control sequences, sized
          `[batch_size, max(x_length), control_depth]`. Required if conditioning
          on control sequences.

    Returns:
      r_loss: The reconstruction loss for each sequence in the batch.
      metric_map: Map from metric name to tf.metrics return values for logging.
      decode_results: The LstmDecodeResults.
    """
    batch_size = int(x_input.shape[0])

    has_z = z is not None
    z = tf.zeros([batch_size, 0]) if z is None else z
    repeated_z = tf.tile(
        tf.expand_dims(z, axis=1), [1, tf.shape(x_input)[1], 1])

    has_control = c_input is not None
    if c_input is None:
      c_input = tf.zeros([batch_size, tf.shape(x_input)[1], 0])

    sampling_probability_static = tf.get_static_value(
        self._sampling_probability)
    if sampling_probability_static == 0.0:
      # Use teacher forcing.
      x_input = tf.concat([x_input, repeated_z, c_input], axis=2)
      helper = contrib_seq2seq.TrainingHelper(x_input, x_length)
    else:
      # Use scheduled sampling.
      if has_z or has_control:
        auxiliary_inputs = tf.zeros([batch_size, tf.shape(x_input)[1], 0])
        if has_z:
          auxiliary_inputs = tf.concat([auxiliary_inputs, repeated_z], axis=2)
        if has_control:
          auxiliary_inputs = tf.concat([auxiliary_inputs, c_input], axis=2)
      else:
        auxiliary_inputs = None
      helper = contrib_seq2seq.ScheduledOutputTrainingHelper(
          inputs=x_input,
          sequence_length=x_length,
          auxiliary_inputs=auxiliary_inputs,
          sampling_probability=self._sampling_probability,
          next_inputs_fn=self._sample)

    decode_results = self._decode(
        z, helper=helper, input_shape=helper.inputs.shape[2:])
    flat_x_target = flatten_maybe_padded_sequences(x_target, x_length)
    flat_rnn_output = flatten_maybe_padded_sequences(
        decode_results.rnn_output, x_length)
    r_loss, metric_map = self._flat_reconstruction_loss(
        flat_x_target, flat_rnn_output)

    # Sum loss over sequences.
    cum_x_len = tf.concat([(0,), tf.cumsum(x_length)], axis=0)
    r_losses = []
    for i in range(batch_size):
      b, e = cum_x_len[i], cum_x_len[i + 1]
      r_losses.append(tf.reduce_sum(r_loss[b:e]))
    r_loss = tf.stack(r_losses)

    return r_loss, metric_map, decode_results
def compress(args):
  """Compresses an image."""

  # Load input image and add batch dimension.
  x = read_png(args.input_file)
  x = tf.expand_dims(x, 0)
  x.set_shape([1, None, None, 3])
  x_shape = tf.shape(x)

  # Instantiate model.
  analysis_transform = AnalysisTransform(args.num_filters)
  synthesis_transform = SynthesisTransform(args.num_filters)
  hyper_analysis_transform = HyperAnalysisTransform(args.num_filters)
  hyper_synthesis_transform = HyperSynthesisTransform(args.num_filters)
  entropy_bottleneck = tfc.EntropyBottleneck()

  # Transform and compress the image.
  y = analysis_transform(x)
  y_shape = tf.shape(y)
  z = hyper_analysis_transform(abs(y))
  z_hat, z_likelihoods = entropy_bottleneck(z, training=False)
  sigma = hyper_synthesis_transform(z_hat)
  sigma = sigma[:, :y_shape[1], :y_shape[2], :]
  scale_table = np.exp(np.linspace(
      np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
  conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table)
  side_string = entropy_bottleneck.compress(z)
  string = conditional_bottleneck.compress(y)

  # Transform the quantized image back (if requested).
  y_hat, y_likelihoods = conditional_bottleneck(y, training=False)
  x_hat = synthesis_transform(y_hat)
  x_hat = x_hat[:, :x_shape[1], :x_shape[2], :]

  num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), dtype=tf.float32)

  # Total number of bits divided by number of pixels.
  eval_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) +
              tf.reduce_sum(tf.log(z_likelihoods))) / (-np.log(2) * num_pixels)

  # Bring both images back to 0..255 range.
  x *= 255
  x_hat = tf.clip_by_value(x_hat, 0, 1)
  x_hat = tf.round(x_hat * 255)

  mse = tf.reduce_mean(tf.squared_difference(x, x_hat))
  psnr = tf.squeeze(tf.image.psnr(x_hat, x, 255))
  msssim = tf.squeeze(tf.image.ssim_multiscale(x_hat, x, 255))

  with tf.Session() as sess:
    # Load the latest model checkpoint, get the compressed string and the tensor
    # shapes.
    latest = tf.train.latest_checkpoint(checkpoint_dir=args.checkpoint_dir)
    tf.train.Saver().restore(sess, save_path=latest)
    tensors = [string, side_string,
               tf.shape(x)[1:-1], tf.shape(y)[1:-1], tf.shape(z)[1:-1]]
    arrays = sess.run(tensors)

    # Write a binary file with the shape information and the compressed string.
    packed = tfc.PackedTensors()
    packed.pack(tensors, arrays)
    with open(args.output_file, "wb") as f:
      f.write(packed.string)

    # If requested, transform the quantized image back and measure performance.
    if args.verbose:
      eval_bpp, mse, psnr, msssim, num_pixels = sess.run(
          [eval_bpp, mse, psnr, msssim, num_pixels])

      # The actual bits per pixel including overhead.
      bpp = len(packed.string) * 8 / num_pixels

      print("Mean squared error: {:0.4f}".format(mse))
      print("PSNR (dB): {:0.2f}".format(psnr))
      print("Multiscale SSIM: {:0.4f}".format(msssim))
      print("Multiscale SSIM (dB): {:0.2f}".format(-10 * np.log10(1 - msssim)))
      print("Information content in bpp: {:0.4f}".format(eval_bpp))
      print("Actual bits per pixel: {:0.4f}".format(bpp))
示例#22
0

caps1_output = squash(caps1_raw, name="caps1_output")
"""# Final Capsules

## Compute the Predicted Output Vectors
"""

W_init = tf.random_normal(shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims,
                                 caps1_n_dims),
                          stddev=init_sigma,
                          dtype=tf.float32,
                          name="W_init")
W = tf.Variable(W_init, name="W")

batch_size = tf.shape(X)[0]
W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1], name="W_tiled")

caps1_output_expanded = tf.expand_dims(caps1_output,
                                       -1,
                                       name="caps1_output_expanded")
caps1_output_tile = tf.expand_dims(caps1_output_expanded,
                                   2,
                                   name="caps1_output_tile")
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1],
                             name="caps1_output_tiled")

caps2_predicted = tf.matmul(W_tiled,
                            caps1_output_tiled,
                            name="caps2_predicted")
"""## Routing by agreement
示例#23
0
  def __call__(self,
               images,
               num_times,
               is_training=False,
               policy="learned",
               stop_gradient_after_representation=False):

    endpoints = {}
    reuse = True if self.var_list else False
    with tf.variable_scope(self.variable_scope+"/representation_network",
                           reuse=reuse):
      representation_logits, endpoints_ = self.representation_network(
          images, is_training)

    if not self.var_list_representation_network:
      self.var_list_representation_network = self.representation_network.var_list

    self.glimpse_shape = self.representation_network.receptive_field
    glimpse_size = tf.cast(self.glimpse_shape[0], dtype=tf.float32)
    image_size = tf.cast(images.shape.as_list()[1], dtype=tf.float32)
    # Ensure glimpses within image.
    location_scale = 1. - glimpse_size / image_size
    endpoints["location_scale"] = location_scale
    endpoints["representation_network"] = endpoints_
    endpoints["representation_network"]["logits"] = representation_logits
    features2d = endpoints_["features2d"]  # size [batch, 28, 28, 2048]
    logits2d = endpoints_["logits2d"]  # size [batch, 28, 28, 1001]
    what_features2d = endpoints_["features2d_lowd"]
    endpoints["logits2d"] = logits2d
    endpoints["features2d"] = features2d
    endpoints["what_features2d"] = what_features2d

    # Freeze the representation network weights.
    if stop_gradient_after_representation:
      features2d = tf.stop_gradient(features2d)
      logits2d = tf.stop_gradient(logits2d)
      what_features2d = tf.stop_gradient(what_features2d)

    # Attention network.
    variables_before = set(tf.global_variables())
    with tf.variable_scope(self.variable_scope, reuse=reuse):
      where_features2d = build_attention_network(
          features2d,
          self.config.attention_groups,
          self.config.attention_layers_per_group,
          is_training)
      endpoints["where_features2d"] = where_features2d
      # Mix what and where features.
      mixed_features2d = tf.layers.conv2d(
          tf.concat([where_features2d, what_features2d], axis=-1),
          filters=512,
          kernel_size=1,
          strides=1,
          activation=None,
          use_bias=True,
          name="mixed_features2d",
          padding="same")
    endpoints["mixed_features2d"] = mixed_features2d
    variables_after = set(tf.global_variables())
    if not self.var_list_attention_network:
      self.var_list_attention_network = list(
          variables_after - variables_before)
    # Unrolling the model in time.
    classification_logits_t = []
    locations_t = []
    best_locations_t = []
    locations_logits2d_t = []
    batch_size = tf.shape(mixed_features2d)[0]
    _, height, width, _ = mixed_features2d.shape.as_list()
    cell_state = tf.zeros((batch_size, height, width, 1), dtype=tf.float32)
    # Engineered policies.
    if policy in ["ordered_logits", "sobel_mean", "sobel_var"]:
      locations_t = engineered_policies(
          images,
          logits2d,
          utils.position_channels(logits2d) * location_scale,
          self.glimpse_shape,
          num_times,
          policy)

      best_locations_t = locations_t
      classification_logits_t = [
          gather_2d(logits2d, locations / location_scale)
          for locations in locations_t]
      # Run for 1 time to create variables (but output is unused).
      with tf.name_scope("time%d" % 0):
        with tf.variable_scope(self.variable_scope):
          self.saccader_cell(
              mixed_features2d,
              cell_state,
              logits2d,
              is_training=is_training,
              policy="random")

    # Other policies
    elif policy in ["learned", "random", "center"]:
      for t in range(num_times):
        endpoints["time%d" % t] = {}
        with tf.name_scope("time%d" % t):
          with tf.variable_scope(self.variable_scope):
            logits, cell_state, endpoints_ = self.saccader_cell(
                mixed_features2d,
                cell_state,
                logits2d,
                is_training=is_training,
                policy=policy)
          cell_outputs = endpoints_["cell_outputs"]
          endpoints["time%d" % t].update(endpoints_)
          classification_logits_t.append(logits)
          # Convert to center glimpse location on images space.
          locations_t.append(cell_outputs["locations"] * location_scale)
          best_locations_t.append(
              cell_outputs["best_locations"] * location_scale)
          locations_logits2d_t.append(cell_outputs["locations_logits2d"])
      endpoints["locations_logits2d_t"] = locations_logits2d_t
    else:
      raise ValueError(
          "policy can be either 'learned', 'random', or 'center'")

    if not self.var_list_saccader_cell:
      self.var_list_saccader_cell = self.saccader_cell.var_list
    self.collect_variables()
    endpoints["classification_logits_t"] = classification_logits_t
    logits = tf.reduce_mean(classification_logits_t, axis=0)
    return (logits, locations_t, best_locations_t, endpoints)
    def compute_trace(x, Gx, num_trace_samples=2, num_power_series_terms=2):

        u_shape = tf.shape(x)
        u_shape = tf.concat([u_shape, [num_trace_samples]], axis=0)

        # shape (batch_size, height, width, num_channel, num_sample)
        u = tf.random.normal(u_shape)

        def loop_trace_samples(n, trace_total):
            """

      :param n: loop over n samples
      :param trace_total: (batch_size, )
      :return:
      """

            # shape (batch_size, h*w*c, 1)
            u_reshaped = tf.reshape(u[..., n], (u_shape[0], -1, 1))

            def loop_series_terms(k, output_grads, trace):
                """

        :param k: loop over k terms
        :param output_grads: (batch_size, height, width, num_channel)
        :param trace: (batch_size,)
        :return:
        """
                # shape (batch_size, height, width, num_channel)
                grads = tf.gradients(Gx, x, output_grads)[0]
                # shape (batch_size, 1, h*w*c)
                grads_reshaped = tf.reshape(grads, (u_shape[0], 1, -1))


                trace = trace + tf.squeeze(tf.cond(tf.equal(k % 2, 0), lambda: 1.0, lambda: -1.0) *\
                                     tf.matmul(grads_reshaped, u_reshaped) / tf.cast(k + 1, tf.float32), axis= [1, 2])
                return k + 1, grads, trace

            _, _, trace_by_sample = tf.while_loop(
                cond=lambda k, _1, _2: k < num_power_series_terms,
                body=loop_series_terms,
                loop_vars=[
                    tf.constant(0, dtype=tf.int32), u[..., n],
                    tf.zeros(shape=u_shape[0])
                ])

            return n + 1, trace_total + trace_by_sample

        _, trace_all_samples = tf.while_loop(
            cond=lambda n, _: n < num_trace_samples,
            body=loop_trace_samples,
            loop_vars=[
                tf.constant(0, dtype=tf.int32),
                tf.zeros(shape=u_shape[0])
            ],
            shape_invariants=[tf.TensorShape(None),
                              tf.TensorShape([None])])

        # shape (batch_size, )
        trace_all_samples = trace_all_samples / num_trace_samples

        return trace_all_samples
示例#25
0
def run(master, input_dataset_class, common_module, keypoint_profiles_module,
        models_module, input_example_parser_creator, keypoint_preprocessor_3d,
        create_model_input_fn, keypoint_distance_config_override,
        embedder_fn_kwargs):
  """Runs training pipeline.

  Args:
    master: BNS name of the TensorFlow master to use.
    input_dataset_class: An input dataset class that matches input table type.
    common_module: A Python module that defines common flags and constants.
    keypoint_profiles_module: A Python module that defines keypoint profiles.
    models_module: A Python module that defines base model architectures.
    input_example_parser_creator: A function handle for creating data parser
      function. If None, uses the default parser creator.
    keypoint_preprocessor_3d: A function handle for preprocessing raw 3D
      keypoints.
    create_model_input_fn: A function handle for creating model inputs.
    keypoint_distance_config_override: A dictionary for keypoint distance
      configuration to override the defaults. Ignored if empty.
    embedder_fn_kwargs: A dictionary of additional kwargs for creating the
      embedder function.
  """
  configs = _validate_and_setup(
      common_module=common_module,
      keypoint_profiles_module=keypoint_profiles_module,
      models_module=models_module,
      keypoint_distance_config_override=keypoint_distance_config_override,
      embedder_fn_kwargs=embedder_fn_kwargs)

  g = tf.Graph()
  with g.as_default():
    with tf.device(tf.train.replica_device_setter(FLAGS.num_ps_tasks)):

      def create_inputs():
        """Creates pipeline and model inputs."""
        inputs = pipeline_utils.read_batch_from_dataset_tables(
            FLAGS.input_table,
            batch_sizes=[int(x) for x in FLAGS.batch_size],
            num_instances_per_record=2,
            shuffle=True,
            num_epochs=None,
            keypoint_names_3d=configs['keypoint_profile_3d'].keypoint_names,
            keypoint_names_2d=configs['keypoint_profile_2d'].keypoint_names,
            min_keypoint_score_2d=FLAGS.min_input_keypoint_score_2d,
            shuffle_buffer_size=FLAGS.input_shuffle_buffer_size,
            common_module=common_module,
            dataset_class=input_dataset_class,
            input_example_parser_creator=input_example_parser_creator)

        (inputs[common_module.KEY_KEYPOINTS_3D],
         keypoint_preprocessor_side_outputs_3d) = keypoint_preprocessor_3d(
             inputs[common_module.KEY_KEYPOINTS_3D],
             keypoint_profile_3d=configs['keypoint_profile_3d'],
             normalize_keypoints_3d=True)
        inputs.update(keypoint_preprocessor_side_outputs_3d)

        inputs['model_inputs'], side_inputs = create_model_input_fn(
            inputs[common_module.KEY_KEYPOINTS_2D],
            inputs[common_module.KEY_KEYPOINT_MASKS_2D],
            inputs[common_module.KEY_PREPROCESSED_KEYPOINTS_3D],
            model_input_keypoint_type=FLAGS.model_input_keypoint_type,
            normalize_keypoints_2d=True,
            keypoint_profile_2d=configs['keypoint_profile_2d'],
            keypoint_profile_3d=configs['keypoint_profile_3d'],
            azimuth_range=configs['random_projection_azimuth_range'],
            elevation_range=configs['random_projection_elevation_range'],
            roll_range=configs['random_projection_roll_range'])
        data_utils.merge_dict(side_inputs, inputs)
        return inputs

      inputs = create_inputs()
      outputs, _ = configs['embedder_fn'](inputs['model_inputs'])
      summaries = {
          'train/batch_size':
              tf.shape(outputs[common_module.KEY_EMBEDDING_MEANS])[0]
      }

      def add_triplet_loss():
        """Adds triplet loss."""
        anchor_keypoints_3d, positive_keypoints_3d = tf.unstack(
            inputs[common_module.KEY_KEYPOINTS_3D], num=2, axis=1)

        anchor_keypoint_masks_3d, positive_keypoint_masks_3d = None, None
        if FLAGS.use_inferred_keypoint_masks_for_triplet_label:
          anchor_keypoint_masks_2d, positive_keypoint_masks_2d = tf.unstack(
              inputs[common_module.KEY_PREPROCESSED_KEYPOINT_MASKS_2D],
              num=2,
              axis=1)
          anchor_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks(
              anchor_keypoint_masks_2d,
              input_keypoint_profile=configs['keypoint_profile_2d'],
              output_keypoint_profile=configs['keypoint_profile_3d'],
              enforce_surjectivity=True)
          positive_keypoint_masks_3d = keypoint_utils.transfer_keypoint_masks(
              positive_keypoint_masks_2d,
              input_keypoint_profile=configs['keypoint_profile_2d'],
              output_keypoint_profile=configs['keypoint_profile_3d'],
              enforce_surjectivity=True)

        triplet_anchor_embeddings, triplet_positive_embeddings = tf.unstack(
            pipeline_utils.stack_embeddings(outputs,
                                            configs['triplet_embedding_keys']),
            axis=1)
        if FLAGS.use_normalized_embeddings_for_triplet_loss:
          triplet_anchor_embeddings = tf.math.l2_normalize(
              triplet_anchor_embeddings, axis=-1)
          triplet_positive_embeddings = tf.math.l2_normalize(
              triplet_positive_embeddings, axis=-1)

        triplet_anchor_mining_embeddings, triplet_positive_mining_embeddings = (
            tf.unstack(
                pipeline_utils.stack_embeddings(
                    outputs, configs['triplet_mining_embedding_keys']),
                axis=1))
        if FLAGS.use_normalized_embeddings_for_triplet_mining:
          triplet_anchor_mining_embeddings = tf.math.l2_normalize(
              triplet_anchor_mining_embeddings, axis=-1)
          triplet_positive_mining_embeddings = tf.math.l2_normalize(
              triplet_positive_mining_embeddings, axis=-1)

        triplet_loss, triplet_loss_summaries = (
            loss_utils.compute_keypoint_triplet_losses(
                anchor_embeddings=triplet_anchor_embeddings,
                positive_embeddings=triplet_positive_embeddings,
                match_embeddings=triplet_positive_embeddings,
                anchor_keypoints=anchor_keypoints_3d,
                match_keypoints=positive_keypoints_3d,
                margin=FLAGS.triplet_loss_margin,
                min_negative_keypoint_distance=(
                    configs['min_negative_keypoint_distance']),
                use_semi_hard=FLAGS.use_semi_hard_triplet_negatives,
                exclude_inactive_triplet_loss=(
                    FLAGS.exclude_inactive_triplet_loss),
                anchor_keypoint_masks=anchor_keypoint_masks_3d,
                match_keypoint_masks=positive_keypoint_masks_3d,
                embedding_sample_distance_fn=(
                    configs['triplet_embedding_sample_distance_fn']),
                keypoint_distance_fn=configs['keypoint_distance_fn'],
                anchor_mining_embeddings=triplet_anchor_mining_embeddings,
                positive_mining_embeddings=triplet_positive_mining_embeddings,
                match_mining_embeddings=triplet_positive_mining_embeddings,
                summarize_percentiles=FLAGS.summarize_percentiles))
        tf.losses.add_loss(triplet_loss, loss_collection=tf.GraphKeys.LOSSES)
        summaries.update(triplet_loss_summaries)
        summaries['train/triplet_loss'] = triplet_loss

      def add_kl_regularization_loss():
        """Adds KL regularization loss."""
        kl_regularization_loss, kl_regularization_loss_summaries = (
            loss_utils.compute_kl_regularization_loss(
                outputs[common_module.KEY_EMBEDDING_MEANS],
                stddevs=outputs[common_module.KEY_EMBEDDING_STDDEVS],
                prior_stddev=FLAGS.kl_regularization_prior_stddev,
                loss_weight=FLAGS.kl_regularization_loss_weight))
        tf.losses.add_loss(
            kl_regularization_loss,
            loss_collection=tf.GraphKeys.REGULARIZATION_LOSSES)
        summaries.update(kl_regularization_loss_summaries)
        summaries['train/kl_regularization_loss'] = kl_regularization_loss

      def add_positive_pairwise_loss():
        """Adds positive pairwise loss."""
        (positive_pairwise_anchor_embeddings,
         positive_pairwise_positive_embeddings) = tf.unstack(
             pipeline_utils.stack_embeddings(
                 outputs,
                 configs['positive_pairwise_embedding_keys'],
                 common_module=common_module),
             axis=1)
        if FLAGS.use_normalized_embeddings_for_positive_pairwise_loss:
          positive_pairwise_anchor_embeddings = tf.math.l2_normalize(
              positive_pairwise_anchor_embeddings, axis=-1)
          positive_pairwise_positive_embeddings = tf.math.l2_normalize(
              positive_pairwise_positive_embeddings, axis=-1)
        positive_pairwise_loss, positive_pairwise_loss_summaries = (
            loss_utils.compute_positive_pairwise_loss(
                positive_pairwise_anchor_embeddings,
                positive_pairwise_positive_embeddings,
                loss_weight=FLAGS.positive_pairwise_loss_weight,
                distance_fn=configs[
                    'positive_pairwise_embedding_sample_distance_fn']))
        tf.losses.add_loss(
            positive_pairwise_loss, loss_collection=tf.GraphKeys.LOSSES)
        summaries.update(positive_pairwise_loss_summaries)
        summaries['train/positive_pairwise_loss'] = positive_pairwise_loss

      add_triplet_loss()
      if FLAGS.kl_regularization_loss_weight > 0.0:
        add_kl_regularization_loss()
      if FLAGS.positive_pairwise_loss_weight > 0.0:
        add_positive_pairwise_loss()
      total_loss = tf.losses.get_total_loss()
      summaries['train/total_loss'] = total_loss

      if configs['summarize_matching_sigmoid_vars']:
        # Summarize variables used in matching sigmoid.
        raw_a, a, b = distance_utils.get_sigmoid_parameters(
            name='MatchingSigmoid',
            reuse=True,
            a_range=(None, FLAGS.sigmoid_a_max))
        # TODO(liuti): Currently the variable for `raw_a` is named `a` in
        # checkpoints, and true `a` may be referred to as `a_plus` for historic
        # reasons. Consolidate the naming.
        summaries.update({
            'train/MatchingSigmoid/a': raw_a,
            'train/MatchingSigmoid/a_plus': a,
            'train/MatchingSigmoid/b': b,
        })

      if FLAGS.use_moving_average:
        pipeline_utils.add_moving_average(FLAGS.moving_average_decay)

      learning_rate = FLAGS.learning_rate
      optimizer = pipeline_utils.get_optimizer(
          FLAGS.optimizer.upper(), learning_rate=learning_rate)
      init_fn = pipeline_utils.get_init_fn(
          train_dir=FLAGS.train_log_dir,
          model_checkpoint=FLAGS.init_model_checkpoint)
      train_op = tf_slim.learning.create_train_op(
          total_loss,
          optimizer,
          clip_gradient_norm=FLAGS.gradient_clip_norm,
          summarize_gradients=FLAGS.summarize_gradients)
      saver = tf.train.Saver(
          keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
          pad_step_number=True)
      summaries['train/learning_rate'] = learning_rate

      pipeline_utils.add_summary(scalars_to_summarize=summaries)

      if FLAGS.profile_only:
        pipeline_utils.profile()
        return

      tf_slim.learning.train(
          train_op,
          logdir=FLAGS.train_log_dir,
          log_every_n_steps=FLAGS.log_every_n_steps,
          master=master,
          is_chief=FLAGS.task == 0,
          number_of_steps=FLAGS.num_steps,
          init_fn=init_fn,
          save_summaries_secs=FLAGS.save_summaries_secs,
          startup_delay_steps=FLAGS.startup_delay_steps * FLAGS.task,
          saver=saver,
          save_interval_secs=FLAGS.save_interval_secs,
          session_config=tf.ConfigProto(
              allow_soft_placement=True, log_device_placement=False))
示例#26
0
 def _mapper(dataset):
     """Computes number of contexts."""
     for k in keys_to_map:
         size = tf.shape(dataset[k])[-1]
         dataset[prefix + k] = size
     return dataset
示例#27
0
def discriminator(x,
                  progress,
                  num_filters_fn,
                  resolution_schedule,
                  num_blocks=None,
                  kernel_size=3,
                  simple_arch=False,
                  scope='progressive_gan_discriminator',
                  reuse=None):
  """Discriminator network for the progressive GAN model.

  Args:
    x: A `Tensor`of NHWC format representing images of size `resolution`.
    progress: A scalar float `Tensor` of training progress.
    num_filters_fn: A function that maps `block_id` to # of filters for the
        block.
    resolution_schedule: An object of `ResolutionSchedule`.
    num_blocks: An integer of number of blocks. None means maximum number of
        blocks, i.e. `resolution.schedule.num_resolutions`. Defaults to None.
    kernel_size: An integer of convolution kernel size.
    simple_arch: Bool, use a simple architecture.
    scope: A string or variable scope.
    reuse: Whether to reuse `scope`. Defaults to None which means to inherit
        the reuse option of the parent scope.

  Returns:
    A `Tensor` of model output and a dictionary of model end points.
  """
  he_init = contrib_layers.variance_scaling_initializer()

  if num_blocks is None:
    num_blocks = resolution_schedule.num_resolutions

  def _conv2d(scope, x, kernel_size, filters, padding='SAME'):
    return layers.custom_conv2d(
        x=x,
        filters=filters,
        kernel_size=kernel_size,
        padding=padding,
        activation=tf.nn.leaky_relu,
        he_initializer_slope=0.0,
        scope=scope)

  def _from_rgb(x, block_id):
    return _conv2d('from_rgb', x, 1, num_filters_fn(block_id))

  if resolution_schedule.scale_mode == 'H':
    strides = (resolution_schedule.scale_base, 1)
  else:
    strides = (resolution_schedule.scale_base,
               resolution_schedule.scale_base)

  end_points = {}

  with tf.variable_scope(scope, reuse=reuse):
    x0 = x
    end_points['rgb'] = x0

    lods = []
    for block_id in range(num_blocks, 0, -1):
      with tf.variable_scope(block_name(block_id)):
        scale = resolution_schedule.scale_factor(block_id)
        lod = resolution_schedule.downscale(x0, scale)
        end_points['downscaled_rgb_{}'.format(block_id)] = lod
        if simple_arch:
          lod = tf.layers.conv2d(
              lod,
              num_filters_fn(block_id),
              kernel_size=1,
              padding='SAME',
              name='from_rgb',
              kernel_initializer=he_init)
          lod = tf.nn.relu(lod)
        else:
          lod = _from_rgb(lod, block_id)
        # alpha_i is used to replace lod_select.
        alpha = _discriminator_alpha(block_id, progress)
        end_points['alpha_{}'.format(block_id)] = alpha
      lods.append((lod, alpha))

    lods_iter = iter(lods)
    x, _ = six.next(lods_iter)
    for block_id in range(num_blocks, 1, -1):
      with tf.variable_scope(block_name(block_id)):
        if simple_arch:
          x = tf.layers.conv2d(
              x,
              num_filters_fn(block_id-1),
              strides=strides,
              kernel_size=kernel_size,
              padding='SAME',
              name='conv',
              kernel_initializer=he_init)
          x = tf.nn.relu(x)
        else:
          x = _conv2d('conv0', x, kernel_size, num_filters_fn(block_id))
          x = _conv2d('conv1', x, kernel_size, num_filters_fn(block_id - 1))
          x = resolution_schedule.downscale(x, resolution_schedule.scale_base)
        lod, alpha = six.next(lods_iter)
        x = alpha * lod + (1.0 - alpha) * x

    with tf.variable_scope(block_name(1)):
      x = layers.scalar_concat(x, layers.minibatch_mean_stddev(x))
      if simple_arch:
        x = tf.reshape(x, [tf.shape(x)[0], -1])  # flatten
        x = tf.layers.dense(x, num_filters_fn(0), name='last_conv',
                            kernel_initializer=he_init)
        x = tf.reshape(x, [tf.shape(x)[0], 1, 1, num_filters_fn(0)])
        x = tf.nn.relu(x)
      else:
        x = _conv2d('conv0', x, kernel_size, num_filters_fn(1))
        x = _conv2d('conv1', x, resolution_schedule.start_resolutions,
                    num_filters_fn(0), 'VALID')
      end_points['last_conv'] = x
      if simple_arch:
        logits = tf.layers.dense(x, 1, name='logits',
                                 kernel_initializer=he_init)
      else:
        logits = layers.custom_dense(x=x, units=1, scope='logits')
      end_points['logits'] = logits

  return logits, end_points
示例#28
0
def model_function(features, labels, mode, params, embeddings):
    """A model function satisfying the tf.estimator API.

  Args:
    features: Dictionary of feature tensors with keys:
        - question_tok: <string> [batch_size, max_question_len]
        - context_tok: <string> [batch_size, max_num_context, max_context_len]
        - question_tok_len: <int32> [batch_size]
        - num_context: <int32> [batch_size]
        - context_tok_len: <int32> [batch_size]
        - question_tok_wid: <int32> [batch_size, max_question_len]
        - context_tok_wid: <int32> [batch_size, max_num_context,
          max_context_len]
         - long_answer_indices: <int32> [batch_size]
    labels: <int32> [batch_size] for answer index (-1 = NULL).
    mode: One of the keys from tf.estimator.ModeKeys.
    params: Dictionary of hyperparameters.
    embeddings: An embedding_utils.PretrainedWordEmbeddings object.

  Returns:
    estimator_spec: A tf.estimator.EstimatorSpec object.
  """
    del params  # Unused.

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Add a dummy batch dimension if we are exporting the predictor.
        features = {k: tf.expand_dims(v, 0) for k, v in features.items()}

    embedding_weights, embedding_scaffold = embeddings.get_params(
        trainable=False)

    # Features.
    question_tok_len = features["question_tok_len"]
    question_tok_wid = features["question_tok_wid"]
    context_tok_wid = features["context_tok_wid"]
    num_context = features["num_context"]
    context_tok_len = features["context_tok_len"]

    # Truncate the contexts and labels to a certain maximum length.
    context_tok_wid, num_context, context_tok_len = (
        nq_long_utils.truncate_contexts(context_token_ids=context_tok_wid,
                                        num_contexts=num_context,
                                        context_len=context_tok_len,
                                        max_contexts=FLAGS.max_contexts,
                                        max_context_len=FLAGS.max_context_len))

    non_null_context_scores = nq_long_decatt_model.build_model(
        question_tok_wid=question_tok_wid,
        question_lens=question_tok_len,
        context_tok_wid=context_tok_wid,
        context_lens=context_tok_len,
        embedding_weights=embedding_weights,
        mode=mode)

    # Mask out contexts that are padding.
    num_context_mask = tf.log(
        tf.sequence_mask(num_context,
                         tensor_utils.shape(non_null_context_scores, 1),
                         dtype=tf.float32))
    non_null_context_scores += num_context_mask

    # <float> [batch_size, 1]
    null_score = tf.zeros([tf.shape(question_tok_wid)[0], 1])

    # Offset everything by 1 to account for null context.
    # [batch_size, 1 + max_contexts]
    context_scores = tf.concat([null_score, non_null_context_scores], 1)

    if mode != tf.estimator.ModeKeys.PREDICT:
        labels = nq_long_utils.truncate_labels(labels, FLAGS.max_contexts)

        # In the data, NULL is given index -1 but this is not compatible with
        # softmax so shift by 1.
        labels = labels + 1

        # Reweight null examples.
        weights = nq_long_utils.compute_null_weights(labels, FLAGS.null_weight)

        # When computing the loss we take only the first label.
        loss_labels = labels[:, 0]

        # []
        loss = tf.losses.sparse_softmax_cross_entropy(labels=loss_labels,
                                                      logits=context_scores,
                                                      weights=weights)

        optimizer = tf.train.AdagradOptimizer(
            learning_rate=FLAGS.learning_rate)
        train_op = optimizer.minimize(loss=loss,
                                      global_step=tf.train.get_global_step())

        # <int32> [batch_size]
        eval_predictions = tf.to_int32(tf.argmax(context_scores, 1))

        non_null_match, non_null_gold, non_null_predictions = (
            nq_long_utils.compute_match_stats(eval_predictions, labels))

        precision, precision_op = (tf.metrics.mean(
            non_null_match, weights=non_null_predictions))
        recall, recall_op = (tf.metrics.mean(non_null_match,
                                             weights=non_null_gold))

        f1, f1_op = (nq_long_utils.f1_metric(precision=precision,
                                             precision_op=precision_op,
                                             recall=recall,
                                             recall_op=recall_op))

        # Bogus metric until we figure out how to connect Ming Wei's eval code.
        eval_metric_ops = {
            "precision": (precision, precision_op),
            "recall": (recall, recall_op),
            "f1": (f1, f1_op)
        }
    else:
        loss = None
        train_op = None
        eval_metric_ops = {}

    # In the export, we never predict NULL since the eval metric will compute the
    # best possible F1.
    export_long_answer_idx = tf.to_int32(tf.argmax(non_null_context_scores, 1))
    export_long_answer_score = tf.reduce_max(non_null_context_scores, 1)
    predictions = dict(idx=export_long_answer_idx,
                       score=export_long_answer_score)

    if mode == tf.estimator.ModeKeys.PREDICT:
        # Remove the dummy batch dimension if we are exporting the predictor.
        predictions = {k: tf.squeeze(v, 0) for k, v in predictions.items()}

    estimator_spec = tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        predictions=predictions,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        scaffold=embedding_scaffold)

    return estimator_spec
示例#29
0
def shuffle_network(inputs, hparams):
    """Neural Shuffle-Network with skip connections between blocks.

  Args:
    inputs: inputs to the Shuffle-Exchange network. Should be in length of power
      of 2.
    hparams: Model configuration

  Returns:
    tf.Tensor: Outputs of the Shuffle-Exchange last layer
  """
    def forward_step(state, layer_nr):
        with tf.variable_scope("forward"):
            last_state, residuals = state
            prev = residuals[layer_nr, :, :, :]
            switch = SwitchLayer("switch", hparams.dropout, hparams.mode)
            cur = switch(last_state, prev)
            return shuffle_layer(cur), residuals

    def reverse_step(state, layer_nr):
        with tf.variable_scope("reverse"):
            last_state, residuals = state
            prev = residuals[layer_nr, :, :, :]
            switch = SwitchLayer("reverse_switch", hparams.dropout,
                                 hparams.mode)
            cur = switch(last_state, prev)
            return reverse_shuffle_layer(cur), residuals

    input_shape = tf.shape(inputs)
    n_bits = tf.log(tf.cast(input_shape[1] - 1, tf.float32)) / tf.log(2.0)
    n_bits = tf.cast(n_bits, tf.int32) + 1

    queue_shape = [n_bits * 2, input_shape[0], input_shape[1], input_shape[2]]
    residuals_queue = tf.zeros(queue_shape)
    block_out = tf.tanh(inputs)

    for k in range(hparams.num_hidden_layers):
        with tf.variable_scope("benes_block_" + str(k), reuse=tf.AUTO_REUSE):
            forward_outputs, _ = tf.scan(forward_step,
                                         tf.range(0, n_bits),
                                         initializer=(block_out,
                                                      residuals_queue),
                                         parallel_iterations=1,
                                         swap_memory=True)

            forward_tensors = [
                tf.expand_dims(block_out, axis=0), forward_outputs
            ]
            forward_outputs = tf.concat(forward_tensors, axis=0)
            forward_last = forward_outputs[-1, :, :, :]

            reverse_outputs, _ = tf.scan(reverse_step,
                                         tf.range(n_bits, n_bits * 2),
                                         initializer=(forward_last,
                                                      residuals_queue),
                                         parallel_iterations=1,
                                         swap_memory=True)

            block_out = reverse_outputs[-1, :, :, :]
            residuals_queue = tf.concat([forward_outputs, reverse_outputs],
                                        axis=0)

    last_layer = SwitchLayer("last_layer", hparams.dropout, hparams.mode)
    return last_layer(block_out, residuals_queue[n_bits * 2, :, :, :])
示例#30
0
def main():

  # Build the model.
  learnable_model = learned_simulator.LearnedSimulator(
      num_dimensions=NUM_DIMENSIONS,
      connectivity_radius=0.05,
      graph_network_kwargs=dict(
          latent_size=128,
          mlp_hidden_size=128,
          mlp_num_hidden_layers=2,
          num_message_passing_steps=10,
      ),
      boundaries=DUMMY_BOUNDARIES,
      normalization_stats={"acceleration": DUMMY_STATS,
                           "velocity": DUMMY_STATS,
                           "context": DUMMY_CONTEXT_STATS,},
      num_particle_types=NUM_PARTICLE_TYPES,
      particle_type_embedding_size=16,
    )

  # Sample a batch of particle sequences with shape:
  # [TOTAL_NUM_PARTICLES, SEQUENCE_LENGTH, NUM_DIMENSIONS]
  sampled_position_sequences = [
      sample_random_position_sequence() for _ in range(BATCH_SIZE)]
  position_sequence_batch = tf.concat(sampled_position_sequences, axis=0)

  # Count how many particles are present in each element in the batch.
  # [BATCH_SIZE]
  n_particles_per_example = tf.stack(
      [tf.shape(seq)[0] for seq in sampled_position_sequences], axis=0)

  # Sample particle types.
  # [TOTAL_NUM_PARTICLES]
  particle_types = tf.random_uniform(
      [tf.shape(position_sequence_batch)[0]],
      0, NUM_PARTICLE_TYPES, dtype=tf.int32)

  # Sample global context.
  global_context = tf.random_uniform(
      [BATCH_SIZE, GLOBAL_CONTEXT_SIZE], -1., 1., dtype=tf.float32)

  # Separate input sequence from target sequence.
  # [TOTAL_NUM_PARTICLES, INPUT_SEQUENCE_LENGTH, NUM_DIMENSIONS]
  input_position_sequence = position_sequence_batch[:, :-1]
  # [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS]
  target_next_position = position_sequence_batch[:, -1]

  # Single step of inference with the model to predict next position for each
  # particle [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS].
  predicted_next_position = learnable_model(
      input_position_sequence, n_particles_per_example, global_context,
      particle_types)
  print(f"Per-particle output tensor: {predicted_next_position}")

  # Obtaining predicted and target normalized accelerations for training.
  position_sequence_noise = (
      noise_utils.get_random_walk_noise_for_position_sequence(
          input_position_sequence, noise_std_last_step=6.7e-4))

  # Both with shape [TOTAL_NUM_PARTICLES, NUM_DIMENSIONS]
  predicted_normalized_acceleration, target_normalized_acceleration = (
      learnable_model.get_predicted_and_target_normalized_accelerations(
          target_next_position, position_sequence_noise,
          input_position_sequence, n_particles_per_example, global_context,
          particle_types))
  print(f"Predicted norm. acceleration: {predicted_normalized_acceleration}")
  print(f"Target norm. acceleration: {target_normalized_acceleration}")

  with tf.train.SingularMonitoredSession() as sess:
    sess.run([predicted_next_position,
              predicted_normalized_acceleration,
              target_normalized_acceleration])