def scheduled_sample_count(ground_truth_x, generated_x, batch_size, scheduled_sample_var): """Sample batch with specified mix of groundtruth and generated data points. Args: ground_truth_x: tensor of ground-truth data points. generated_x: tensor of generated data points. batch_size: batch size scheduled_sample_var: number of ground-truth examples to include in batch. Returns: New batch with num_ground_truth sampled from ground_truth_x and the rest from generated_x. """ num_ground_truth = scheduled_sample_var idx = tf.random_shuffle(tf.range(batch_size)) ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) generated_idx = tf.gather(idx, tf.range(num_ground_truth, batch_size)) ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) generated_examps = tf.gather(generated_x, generated_idx) output = tf.dynamic_stitch([ground_truth_idx, generated_idx], [ground_truth_examps, generated_examps]) # if batch size is known set it. if isinstance(batch_size, int): output.set_shape([batch_size] + common.shape_list(output)[1:]) return output
def lstm_cell(inputs, state, num_units, use_peepholes=False, cell_clip=0.0, initializer=None, num_proj=None, num_unit_shards=None, num_proj_shards=None, reuse=None, name=None): """Full LSTM cell.""" input_shape = common.shape_list(inputs) cell = tf.nn.rnn_cell.LSTMCell(num_units, use_peepholes=use_peepholes, cell_clip=cell_clip, initializer=initializer, num_proj=num_proj, num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards, reuse=reuse, name=name, state_is_tuple=False) if state is None: state = cell.zero_state(input_shape[0], tf.float32) outputs, new_state = cell(inputs, state) return outputs, new_state
def inject_additional_input(layer, inputs, name, mode="concat"): """Injects the additional input into the layer. Args: layer: layer that the input should be injected to. inputs: inputs to be injected. name: TF scope name. mode: how the infor should be added to the layer: "concat" concats as additional channels. "multiplicative" broadcasts inputs and multiply them to the channels. "multi_additive" broadcasts inputs and multiply and add to the channels. Returns: updated layer. Raises: ValueError: in case of unknown mode. """ inputs = common.to_float(inputs) layer_shape = common.shape_list(layer) input_shape = common.shape_list(inputs) zeros_mask = tf.zeros(layer_shape, dtype=tf.float32) if mode == "concat": emb = encode_to_shape(inputs, layer_shape, name) layer = tf.concat(values=[layer, emb], axis=-1) elif mode == "multiplicative": filters = layer_shape[-1] input_reshaped = tf.reshape(inputs, [-1, 1, 1, input_shape[-1]]) input_mask = tf.layers.dense(input_reshaped, filters, name=name) input_broad = input_mask + zeros_mask layer *= input_broad elif mode == "multi_additive": filters = layer_shape[-1] input_reshaped = tf.reshape(inputs, [-1, 1, 1, input_shape[-1]]) input_mul = tf.layers.dense(input_reshaped, filters, name=name + "_mul") layer *= tf.nn.sigmoid(input_mul) input_add = tf.layers.dense(input_reshaped, filters, name=name + "_add") layer += input_add else: raise ValueError("Unknown injection mode: %s" % mode) return layer
def basic_lstm(inputs, state, num_units, name=None): """Basic LSTM.""" input_shape = common.shape_list(inputs) # reuse parameters across time-steps. cell = tf.nn.rnn_cell.BasicLSTMCell(num_units, name=name, reuse=tf.AUTO_REUSE) if state is None: state = cell.zero_state(input_shape[0], tf.float32) outputs, new_state = cell(inputs, state) return outputs, new_state
def pad_to_same_length(x, y, final_length_divisible_by=1, axis=1): """Pad tensors x and y on axis 1 so that they have the same length.""" if axis not in [1, 2]: raise ValueError("Only axis=1 and axis=2 supported for now.") with tf.name_scope("pad_to_same_length", values=[x, y]): x_length = common.shape_list(x)[axis] y_length = common.shape_list(y)[axis] if (isinstance(x_length, int) and isinstance(y_length, int) and x_length == y_length and final_length_divisible_by == 1): return x, y max_length = tf.maximum(x_length, y_length) if final_length_divisible_by > 1: # Find the nearest larger-or-equal integer divisible by given number. max_length += final_length_divisible_by - 1 max_length //= final_length_divisible_by max_length *= final_length_divisible_by length_diff1 = max_length - x_length length_diff2 = max_length - y_length def padding_list(length_diff, arg): if axis == 1: return [[[0, 0], [0, length_diff]], tf.zeros([tf.rank(arg) - 2, 2], dtype=tf.int32)] return [[[0, 0], [0, 0], [0, length_diff]], tf.zeros([tf.rank(arg) - 3, 2], dtype=tf.int32)] paddings1 = tf.concat(padding_list(length_diff1, x), axis=0) paddings2 = tf.concat(padding_list(length_diff2, y), axis=0) res_x = tf.pad(x, paddings1) res_y = tf.pad(y, paddings2) # Static shapes are the same except for axis=1. x_shape = x.shape.as_list() x_shape[axis] = None res_x.set_shape(x_shape) y_shape = y.shape.as_list() y_shape[axis] = None res_y.set_shape(y_shape) return res_x, res_y
def tile_and_concat(image, latent, concat_latent=True): """Tile latent and concatenate to image across depth. Args: image: 4-D Tensor, (batch_size X height X width X channels) latent: 2-D Tensor, (batch_size X latent_dims) concat_latent: If set to False, the image is returned as is. Returns: concat_latent: 4-D Tensor, (batch_size X height X width X channels+1) latent tiled and concatenated to the image across the channels. """ if not concat_latent: return image image_shape = common.shape_list(image) latent_shape = common.shape_list(latent) height, width = image_shape[1], image_shape[2] latent_dims = latent_shape[1] height_multiples = height // latent_dims pad = height - (height_multiples * latent_dims) latent = tf.reshape(latent, (-1, latent_dims, 1, 1)) latent = tf.tile(latent, (1, height_multiples, width, 1)) latent = tf.pad(latent, [[0, 0], [pad // 2, pad // 2], [0, 0], [0, 0]]) return tf.concat([image, latent], axis=-1)
def kl_divergence(mu, log_var, mu_p=0.0, log_var_p=0.0): """KL divergence of diagonal gaussian N(mu,exp(log_var)) and N(0,1). Args: mu: mu parameter of the distribution. log_var: log(var) parameter of the distribution. mu_p: optional mu from a learned prior distribution log_var_p: optional log(var) from a learned prior distribution Returns: the KL loss. """ batch_size = common.shape_list(mu)[0] prior_distribution = tfp.distributions.Normal( mu_p, tf.exp(tf.multiply(0.5, log_var_p))) posterior_distribution = tfp.distributions.Normal( mu, tf.exp(tf.multiply(0.5, log_var))) kld = tfp.distributions.kl_divergence(posterior_distribution, prior_distribution) return tf.reduce_sum(kld) / common.to_float(batch_size)
def conv_lstm_2d(inputs, state, output_channels, kernel_size=5, name=None, spatial_dims=None): """2D Convolutional LSTM.""" input_shape = common.shape_list(inputs) batch_size, input_channels = input_shape[0], input_shape[-1] if spatial_dims is None: input_shape = input_shape[1:] else: input_shape = spatial_dims + [input_channels] cell = contrib_rnn.ConvLSTMCell(2, input_shape, output_channels, [kernel_size, kernel_size], name=name) if state is None: state = cell.zero_state(batch_size, tf.float32) outputs, new_state = cell(inputs, state) return outputs, new_state