Example #1
0
def scheduled_sample_count(ground_truth_x, generated_x, batch_size,
                           scheduled_sample_var):
    """Sample batch with specified mix of groundtruth and generated data points.

  Args:
    ground_truth_x: tensor of ground-truth data points.
    generated_x: tensor of generated data points.
    batch_size: batch size
    scheduled_sample_var: number of ground-truth examples to include in batch.

  Returns:
    New batch with num_ground_truth sampled from ground_truth_x and the rest
    from generated_x.
  """
    num_ground_truth = scheduled_sample_var
    idx = tf.random_shuffle(tf.range(batch_size))
    ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
    generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size))

    ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
    generated_examps = tf.gather(generated_x, generated_idx)

    output = tf.dynamic_stitch([ground_truth_idx, generated_idx],
                               [ground_truth_examps, generated_examps])
    # if batch size is known set it.
    if isinstance(batch_size, int):
        output.set_shape([batch_size] + common.shape_list(output)[1:])
    return output
Example #2
0
def lstm_cell(inputs,
              state,
              num_units,
              use_peepholes=False,
              cell_clip=0.0,
              initializer=None,
              num_proj=None,
              num_unit_shards=None,
              num_proj_shards=None,
              reuse=None,
              name=None):
    """Full LSTM cell."""
    input_shape = common.shape_list(inputs)
    cell = tf.nn.rnn_cell.LSTMCell(num_units,
                                   use_peepholes=use_peepholes,
                                   cell_clip=cell_clip,
                                   initializer=initializer,
                                   num_proj=num_proj,
                                   num_unit_shards=num_unit_shards,
                                   num_proj_shards=num_proj_shards,
                                   reuse=reuse,
                                   name=name,
                                   state_is_tuple=False)
    if state is None:
        state = cell.zero_state(input_shape[0], tf.float32)
    outputs, new_state = cell(inputs, state)
    return outputs, new_state
Example #3
0
def inject_additional_input(layer, inputs, name, mode="concat"):
    """Injects the additional input into the layer.

  Args:
    layer: layer that the input should be injected to.
    inputs: inputs to be injected.
    name: TF scope name.
    mode: how the infor should be added to the layer: "concat" concats as
      additional channels. "multiplicative" broadcasts inputs and multiply them
      to the channels. "multi_additive" broadcasts inputs and multiply and add
      to the channels.

  Returns:
    updated layer.

  Raises:
    ValueError: in case of unknown mode.
  """
    inputs = common.to_float(inputs)
    layer_shape = common.shape_list(layer)
    input_shape = common.shape_list(inputs)
    zeros_mask = tf.zeros(layer_shape, dtype=tf.float32)
    if mode == "concat":
        emb = encode_to_shape(inputs, layer_shape, name)
        layer = tf.concat(values=[layer, emb], axis=-1)
    elif mode == "multiplicative":
        filters = layer_shape[-1]
        input_reshaped = tf.reshape(inputs, [-1, 1, 1, input_shape[-1]])
        input_mask = tf.layers.dense(input_reshaped, filters, name=name)
        input_broad = input_mask + zeros_mask
        layer *= input_broad
    elif mode == "multi_additive":
        filters = layer_shape[-1]
        input_reshaped = tf.reshape(inputs, [-1, 1, 1, input_shape[-1]])
        input_mul = tf.layers.dense(input_reshaped,
                                    filters,
                                    name=name + "_mul")
        layer *= tf.nn.sigmoid(input_mul)
        input_add = tf.layers.dense(input_reshaped,
                                    filters,
                                    name=name + "_add")
        layer += input_add
    else:
        raise ValueError("Unknown injection mode: %s" % mode)

    return layer
Example #4
0
def basic_lstm(inputs, state, num_units, name=None):
    """Basic LSTM."""
    input_shape = common.shape_list(inputs)
    # reuse parameters across time-steps.
    cell = tf.nn.rnn_cell.BasicLSTMCell(num_units,
                                        name=name,
                                        reuse=tf.AUTO_REUSE)
    if state is None:
        state = cell.zero_state(input_shape[0], tf.float32)
    outputs, new_state = cell(inputs, state)
    return outputs, new_state
Example #5
0
def pad_to_same_length(x, y, final_length_divisible_by=1, axis=1):
    """Pad tensors x and y on axis 1 so that they have the same length."""
    if axis not in [1, 2]:
        raise ValueError("Only axis=1 and axis=2 supported for now.")
    with tf.name_scope("pad_to_same_length", values=[x, y]):
        x_length = common.shape_list(x)[axis]
        y_length = common.shape_list(y)[axis]
        if (isinstance(x_length, int) and isinstance(y_length, int)
                and x_length == y_length and final_length_divisible_by == 1):
            return x, y
        max_length = tf.maximum(x_length, y_length)
        if final_length_divisible_by > 1:
            # Find the nearest larger-or-equal integer divisible by given number.
            max_length += final_length_divisible_by - 1
            max_length //= final_length_divisible_by
            max_length *= final_length_divisible_by
        length_diff1 = max_length - x_length
        length_diff2 = max_length - y_length

        def padding_list(length_diff, arg):
            if axis == 1:
                return [[[0, 0], [0, length_diff]],
                        tf.zeros([tf.rank(arg) - 2, 2], dtype=tf.int32)]
            return [[[0, 0], [0, 0], [0, length_diff]],
                    tf.zeros([tf.rank(arg) - 3, 2], dtype=tf.int32)]

        paddings1 = tf.concat(padding_list(length_diff1, x), axis=0)
        paddings2 = tf.concat(padding_list(length_diff2, y), axis=0)
        res_x = tf.pad(x, paddings1)
        res_y = tf.pad(y, paddings2)
        # Static shapes are the same except for axis=1.
        x_shape = x.shape.as_list()
        x_shape[axis] = None
        res_x.set_shape(x_shape)
        y_shape = y.shape.as_list()
        y_shape[axis] = None
        res_y.set_shape(y_shape)
        return res_x, res_y
Example #6
0
def tile_and_concat(image, latent, concat_latent=True):
    """Tile latent and concatenate to image across depth.

  Args:
    image: 4-D Tensor, (batch_size X height X width X channels)
    latent: 2-D Tensor, (batch_size X latent_dims)
    concat_latent: If set to False, the image is returned as is.

  Returns:
    concat_latent: 4-D Tensor, (batch_size X height X width X channels+1)
      latent tiled and concatenated to the image across the channels.
  """
    if not concat_latent:
        return image
    image_shape = common.shape_list(image)
    latent_shape = common.shape_list(latent)
    height, width = image_shape[1], image_shape[2]
    latent_dims = latent_shape[1]
    height_multiples = height // latent_dims
    pad = height - (height_multiples * latent_dims)
    latent = tf.reshape(latent, (-1, latent_dims, 1, 1))
    latent = tf.tile(latent, (1, height_multiples, width, 1))
    latent = tf.pad(latent, [[0, 0], [pad // 2, pad // 2], [0, 0], [0, 0]])
    return tf.concat([image, latent], axis=-1)
Example #7
0
def kl_divergence(mu, log_var, mu_p=0.0, log_var_p=0.0):
    """KL divergence of diagonal gaussian N(mu,exp(log_var)) and N(0,1).

  Args:
    mu: mu parameter of the distribution.
    log_var: log(var) parameter of the distribution.
    mu_p: optional mu from a learned prior distribution
    log_var_p: optional log(var) from a learned prior distribution

  Returns:
    the KL loss.
  """

    batch_size = common.shape_list(mu)[0]
    prior_distribution = tfp.distributions.Normal(
        mu_p, tf.exp(tf.multiply(0.5, log_var_p)))
    posterior_distribution = tfp.distributions.Normal(
        mu, tf.exp(tf.multiply(0.5, log_var)))
    kld = tfp.distributions.kl_divergence(posterior_distribution,
                                          prior_distribution)
    return tf.reduce_sum(kld) / common.to_float(batch_size)
Example #8
0
def conv_lstm_2d(inputs,
                 state,
                 output_channels,
                 kernel_size=5,
                 name=None,
                 spatial_dims=None):
    """2D Convolutional LSTM."""
    input_shape = common.shape_list(inputs)
    batch_size, input_channels = input_shape[0], input_shape[-1]
    if spatial_dims is None:
        input_shape = input_shape[1:]
    else:
        input_shape = spatial_dims + [input_channels]

    cell = contrib_rnn.ConvLSTMCell(2,
                                    input_shape,
                                    output_channels,
                                    [kernel_size, kernel_size],
                                    name=name)
    if state is None:
        state = cell.zero_state(batch_size, tf.float32)
    outputs, new_state = cell(inputs, state)
    return outputs, new_state