def isemhash_bottleneck(x, bottleneck_size, bottleneck_noise, discretize_warmup_steps, mode, isemhash_noise_dev=0.5, isemhash_mix_prob=0.5): """Improved semantic hashing bottleneck.""" with tf.variable_scope("isemhash_bottleneck"): x = tf.layers.dense(x, bottleneck_size, name="dense") y = common_layers.saturating_sigmoid(x) if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN: noise = tf.truncated_normal(common_layers.shape_list(x), mean=0.0, stddev=isemhash_noise_dev) y = common_layers.saturating_sigmoid(x + noise) d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y) d = 2.0 * d - 1.0 # Move from [0, 1] to [-1, 1]. if mode == tf.estimator.ModeKeys.TRAIN: # Flip some bits. noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0 d *= noise d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps, mode == tf.estimator.ModeKeys.TRAIN, max_prob=isemhash_mix_prob) return d
def bottleneck(x, hparams, filter_size, name): """Bottleneck.""" def embed1(x): if hparams.bottleneck_kind == "semhash": c = int_to_bit(x, c_size) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") return h1a + h1b elif hparams.bottleneck_kind == "gumbel-softmax": hot = tf.one_hot(x, hparams.v_size) with tf.variable_scope(name, reuse=True): return tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") def embed(x): with tf.variable_scope(name, reuse=True): h1 = embed1(x) h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") return res with tf.variable_scope(name): c_size = hparams.c_size l = tf.constant(0.0) if hparams.bottleneck_kind == "dense": c = tf.layers.dense(x, c_size, name="vcc") h1 = tf.layers.dense(c, filter_size, name="vch1") if hparams.bottleneck_kind == "semhash": c = tf.layers.dense(x, c_size, name="vcc") y_clean = common_layers.saturating_sigmoid(c) tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1])) # l = tf.reduce_mean(y_clean * (1.0 - y_clean)) if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN: dev = hparams.noise_dev noise = tf.truncated_normal(tf.shape(c), mean=0.0, stddev=dev) y = common_layers.saturating_sigmoid(c + noise) else: y = y_clean d = tf.to_float(tf.less(0.5, y)) y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2) pd *= hparams.d_mix pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 c = tf.cond(tf.less(tf.random_uniform([]), pd), lambda: y_discrete, lambda: y) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b dx = tf.to_int32(tf.stop_gradient(d)) c = bit_to_int(dx, c_size) if hparams.bottleneck_kind == "gumbel-softmax": _, hot, l = dae(x, hparams, name) c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") return res, c, l, embed
def bit_vae(x, hparams, name): with tf.variable_scope(name): bity = tf.layers.dense(x, hparams.z_size, name="bity") dev = common_layers.inverse_lin_decay(hparams.startup_steps) * 1.5 noise = tf.random_normal(tf.shape(bity), mean=0.0, stddev=dev) y = common_layers.saturating_sigmoid(bity + noise) tf.summary.histogram("bit", tf.reshape(y, [-1])) def discrete_y(): d = tf.to_float(tf.less(0.5, y)) return tf.stop_gradient(d) + y - tf.stop_gradient(y) y = tf.cond(tf.less(tf.train.get_global_step(), hparams.startup_steps), lambda: y, discrete_y) # Flatten and predict for loss. y_flat = tf.reshape(y, [-1, hparams.z_size, 1, 1]) hsize = hparams.hidden_size hparams.hidden_size = hsize // 2 emb0 = tf.get_variable("emb0", [hparams.hidden_size]) emb1 = tf.get_variable("emb1", [hparams.hidden_size]) emb0 = tf.reshape(emb0, [1, 1, 1, hparams.hidden_size]) emb1 = tf.reshape(emb0, [1, 1, 1, hparams.hidden_size]) y_emb = y_flat * emb1 + (1 - y_flat) * emb0 y_logit = decode(None, None, y_emb, None, None, hparams, "dbit") hparams.hidden_size = hsize y_pred = tf.nn.log_softmax(tf.layers.dense(y_logit, 2, name="y_pred")) y_flat = tf.reshape(y_flat, [-1]) y_pred = tf.reshape(y_pred, [-1, 2]) loss = -(y_flat * y_pred[:, 1] + (1 - y_flat) * y_pred[:, 0]) # Get the final z and return. z = tf.layers.dense(y, hparams.z_size, name="after_bit") return z, tf.reduce_mean(loss)
def isemhash_bottleneck(x, bottleneck_bits, bottleneck_noise, discretize_warmup_steps, mode, isemhash_noise_dev=0.5, isemhash_mix_prob=0.5): """Improved semantic hashing bottleneck.""" with tf.variable_scope("isemhash_bottleneck"): x = tf.layers.dense(x, bottleneck_bits, name="dense") y = common_layers.saturating_sigmoid(x) if isemhash_noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN: noise = tf.truncated_normal( common_layers.shape_list(x), mean=0.0, stddev=isemhash_noise_dev) y = common_layers.saturating_sigmoid(x + noise) d = tf.to_float(tf.less(0.5, y)) + y - tf.stop_gradient(y) d = 2.0 * d - 1.0 # Move from [0, 1] to [-1, 1]. if mode == tf.estimator.ModeKeys.TRAIN: # Flip some bits. noise = tf.random_uniform(common_layers.shape_list(x)) noise = 2.0 * tf.to_float(tf.less(bottleneck_noise, noise)) - 1.0 d *= noise d = common_layers.mix(d, 2.0 * y - 1.0, discretize_warmup_steps, mode == tf.estimator.ModeKeys.TRAIN, max_prob=isemhash_mix_prob) return d, 0.0
def testSaturatingSigmoid(self): x = np.array([-120.0, -100.0, 0.0, 100.0, 120.0], dtype=np.float32) y = common_layers.saturating_sigmoid(tf.constant(x)) res = self.evaluate(y) self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])
def discrete_bottleneck(x, hidden_size, z_size, filter_size, name, mode=None, startup_steps=50000, bottleneck_kind="dvq", num_blocks=2, num_residuals=1, reshape_method="slice", projection_tensors=None, means=None, beta=0.25, noise_dev=1., decay=0.999, discrete_mix=0.5, random_top_k=1, soft_em=False, num_samples=1, epsilon=1e-5, softmax_k=0, kl_warmup_steps=150000, ema=True, ema_count=None, ema_means=None, summary=True): """Discretization bottleneck for latent variables. Args: x: Input to the discretization bottleneck. hidden_size: Dimension of the latent state. z_size: Number of bits used to produce discrete code; discrete codes range from 1 to 2**z_size. filter_size: Filter size to be used for the embedding function. name: Name for the bottleneck scope. mode: Mode represents whether we are training or testing for bottlenecks that differ in behavior (Default: None). startup_steps: Number of steps after which latent predictor is trained (Default: 50000). bottleneck_kind: Kind of discretization bottleneck to use; one of dvq, semhash, gumbel-softmax (Default: dvq). num_blocks: Number of blocks to use for decomposed vector quantization (Default: 2). num_residuals: Number of residual units used to compute nearest neighbors (Default: 1). reshape_method: Method to reshape for DVQ (Default: slice). projection_tensors: If the reshape method is project, then these are the tensors used to project (Default: None). means: The embedding table for dvq (Default: None). beta: Beta factor for the DVQ loss (Default: 0.25). noise_dev: Stddev for noise added for semhash (Default: 0). decay: Decay factor for the exponential moving average (Default: 0.999). discrete_mix: Factor for mixing discrete and non-discrete input for semhash (Default: 0.5). random_top_k: Noisy top-k for DVQ (Default: 1). soft_em: If True then use soft EM rather than hard EM (Default: False). num_samples: Number of samples for soft EM (Default: 1). epsilon: Epsilon parameter for DVQ (Default: 1e-5). softmax_k: If > 1 then do top-k softmax (Default: 0). kl_warmup_steps: Number of steps for kl warmup (Default: 150000). ema: If True update embeddings using exponential moving averages (Default: True). ema_count: Table of counts for each embedding corresponding to how many examples in a batch it was the closest to (Default: None). ema_means: Exponentially averaged version of the embeddings (Default: None). summary: If True, then write summaries (Default: True). Returns: Embedding to pass to the decoder, discrete latent, loss, and the embedding function. Raises: ValueError: If projection_tensors is None for reshape_method project, or ema_count or ema_means is None if we are using ema, or unknown args. """ tf.logging.info("Shape of x = {}".format(common_layers.shape_list(x))) block_v_size = None if bottleneck_kind == "dvq": # Define the dvq parameters assert means is not None # Check block dimensions add up if hidden_size % num_blocks != 0: raise ValueError("num_blocks does not divide hidden size") if z_size % num_residuals != 0: raise ValueError( "num_residuals does not divide embedding table size") z_size_per_residual = int(z_size / num_residuals) if z_size_per_residual % num_blocks != 0: raise ValueError("num_blocks does not divide embedding table size") block_v_size = 2**(z_size_per_residual / num_blocks) block_v_size = int(block_v_size) # Set the reshape method corresponding to projections or slices if reshape_method == "slice": reshape_fn = partial(slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks) elif reshape_method == "project": if projection_tensors is None: raise ValueError( "Projection tensors is None for reshape_method project") reshape_fn = partial(project_hidden, projection_tensors=projection_tensors, hidden_size=hidden_size, num_blocks=num_blocks) else: raise ValueError("Unknown reshape_method") # Check if the ema settings make sense if ema: if ema_count is None: raise ValueError("ema_count is None but ema is True") if ema_means is None: raise ValueError("ema_means is None but ema is True") with tf.variable_scope(name, reuse=tf.AUTO_REUSE): l = tf.constant(0.0) if bottleneck_kind == "dense": c = tf.layers.dense(x, z_size, name="vcc") h1 = tf.layers.dense(c, filter_size, name="vch1") elif bottleneck_kind == "vae": c, l, _, _ = vae(x, z_size, "vae") h1 = tf.layers.dense(c, filter_size, name="vch1") elif bottleneck_kind == "semhash": c = tf.layers.dense(x, z_size, name="vcc") y_clean = common_layers.saturating_sigmoid(c) if summary: tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1])) if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN: noise = tf.truncated_normal(common_layers.shape_list(c), mean=0.0, stddev=noise_dev) y = common_layers.saturating_sigmoid(c + noise) else: y = y_clean d = tf.to_float(tf.less(0.5, y)) y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) pd = common_layers.inverse_exp_decay(startup_steps * 2) pd *= discrete_mix pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0 c = tf.where( tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd), y_discrete, y) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b dx = tf.to_int32(tf.stop_gradient(d)) c = bit_to_int(dx, z_size) elif bottleneck_kind == "gumbel-softmax": _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k, kl_warmup_steps, summary) c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hidden_size, name="dae_dense") elif bottleneck_kind == "dvq": x_reshaped = reshape_fn(x) x_res = x_reshaped x_means_hot = [] x_means = 0 l = 0 for i in range(num_residuals): x_means_hot_res, x_means_res, q_loss_res, e_loss_res = embedding_lookup( x_res, means[i], num_blocks, block_v_size, random_top_k, soft_em, num_samples) # Update the ema variables if ema: tf.logging.info("Using EMA with beta = {}".format(beta)) updated_ema_count_res = moving_averages.assign_moving_average( ema_count[i], tf.reduce_sum(tf.reshape( x_means_hot_res, shape=[-1, num_blocks, block_v_size]), axis=0), decay, zero_debias=False) dw = tf.matmul( tf.transpose(x_means_hot_res, perm=[1, 2, 0]), tf.transpose(x_res, perm=[1, 0, 2])) updated_ema_means_res = moving_averages.assign_moving_average( ema_means[i], dw, decay, zero_debias=False) n = tf.reduce_sum(updated_ema_count_res, axis=-1, keep_dims=True) updated_ema_count_res = ( (updated_ema_count_res + epsilon) / (n + 2**z_size * epsilon) * n) updated_ema_means_res /= tf.expand_dims( updated_ema_count_res, axis=-1) with tf.control_dependencies([e_loss_res]): update_means_res = tf.assign(means[i], updated_ema_means_res) with tf.control_dependencies([update_means_res]): l += beta * e_loss_res else: l += q_loss_res + beta * e_loss_res # Update the residuals x_res -= x_means_res x_means += x_means_res x_means_hot.append(x_means_hot_res) # Get the discrete latent representation x_means_hot = tf.stack(x_means_hot, axis=1) x_means_idx = tf.argmax(x_means_hot, axis=-1) # Get the binary representation x_means_bits = int_to_bit( x_means_idx, num_bits=int(z_size / (num_residuals * num_blocks)), base=2) shape = common_layers.shape_list(x_means_bits) new_shape = shape[:-2] new_shape[-1] = z_size x_means_bits = tf.reshape(x_means_bits, shape=new_shape) c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2) # Adjust shape of c shape_x = common_layers.shape_list(x) new_shape = shape_x[:-1] c = tf.reshape(c, new_shape) x_means = tf.reshape(x_means, shape_x) x_reshaped = tf.reshape(x_reshaped, shape_x) h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped) else: raise ValueError("Unknown discretization method.") h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") res = tf.layers.dense(tf.nn.relu(h2), hidden_size, name="vcfin") embed_fn = partial(embed, hidden_size=hidden_size, z_size=z_size, filter_size=filter_size, name=name, bottleneck_kind=bottleneck_kind, num_blocks=num_blocks, num_residuals=num_residuals, block_v_size=block_v_size, means=means) return res, c, l, embed_fn
def testSaturatingSigmoid(self): x = np.array([-120.0, -100.0, 0.0, 100.0, 120.0], dtype=np.float32) with self.test_session() as session: y = common_layers.saturating_sigmoid(tf.constant(x)) res = session.run(y) self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])
def testSaturatingSigmoid(self): x = np.array([-120.0, -100.0, 0.0, 100.0, 120.0], dtype=np.float32) y = common_layers.saturating_sigmoid(tf.constant(x)) res = self.evaluate(y) self.assertAllClose(res, [0.0, 0.0, 0.5, 1.0, 1.0])
def bottleneck(x, hparams, filter_size, name, means=None, ema_count=None, ema_means=None): """Bottleneck.""" if hparams.bottleneck_kind == "vq-vae": assert means is not None if hparams.ema: assert ema_count is not None assert ema_means is not None def embed(x): """Embedding function; must be compatible with the code later.""" with tf.variable_scope(name, reuse=tf.AUTO_REUSE): if hparams.bottleneck_kind == "semhash": c = int_to_bit(x, z_size) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b elif hparams.bottleneck_kind == "gumbel-softmax": hot = tf.one_hot(x, hparams.v_size) h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") elif hparams.bottleneck_kind == "vq-vae": if hparams.ema: means_embed = ema_means else: means_embed = means h1 = tf.gather(means_embed, x) elif hparams.bottleneck_kind == "rounding": h1 = x h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") with tf.variable_scope(name, reuse=tf.AUTO_REUSE): z_size = hparams.z_size l = tf.constant(0.0) if hparams.bottleneck_kind == "dense": c = tf.layers.dense(x, z_size, name="vcc") h1 = tf.layers.dense(c, filter_size, name="vch1") if hparams.bottleneck_kind == "vae": c, l, _, _ = vae(x, z_size, "vae") h1 = tf.layers.dense(c, filter_size, name="vch1") if hparams.bottleneck_kind == "semhash": c = tf.layers.dense(x, z_size, name="vcc") y_clean = common_layers.saturating_sigmoid(c) if _DO_SUMMARIES: tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1])) if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN: dev = hparams.noise_dev noise = tf.truncated_normal(common_layers.shape_list(c), mean=0.0, stddev=dev) y = common_layers.saturating_sigmoid(c + noise) else: y = y_clean d = tf.to_float(tf.less(0.5, y)) y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2) pd *= hparams.d_mix pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 c = tf.where(tf.less(tf.random_uniform( [common_layers.shape_list(y)[0]]), pd), y_discrete, y) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b dx = tf.to_int32(tf.stop_gradient(d)) c = bit_to_int(dx, z_size) if hparams.bottleneck_kind == "gumbel-softmax": _, hot, l = dae(x, hparams, name) c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") if hparams.bottleneck_kind == "vq-vae": x_means_hot, x_means, q_loss, e_loss = kmeans(x, means, hparams) c = tf.argmax(x_means_hot, axis=-1) # Update the ema variables if hparams.ema: tf.logging.info("Using EMA with beta = {}".format(hparams.beta)) x_means_hot_flat = tf.reshape(x_means_hot, shape=[-1, hparams.v_size]) updated_ema_count = moving_averages.assign_moving_average( ema_count, tf.reduce_sum(x_means_hot_flat, axis=0), hparams.decay, zero_debias=False) x_flat = tf.reshape(x, [-1, hparams.hidden_size]) dw = tf.matmul(x_means_hot_flat, x_flat, transpose_a=True) updated_ema_means = moving_averages.assign_moving_average( ema_means, dw, hparams.decay, zero_debias=False) n = tf.reduce_sum(updated_ema_count) updated_ema_count = ((updated_ema_count + hparams.epsilon) / (n + hparams.v_size * hparams.epsilon) * n) updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1) with tf.control_dependencies([e_loss]): update_means = tf.assign(means, updated_ema_means) with tf.control_dependencies([update_means]): l = hparams.beta * e_loss else: l = q_loss + hparams.beta * e_loss h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x) if hparams.bottleneck_kind == "rounding": h = tf.layers.dense(x, 1, name="vcc") # Make h between 0 and 1 h = tf.sigmoid(h) # Multiply by z_size to get it between [0, z_size] h *= hparams.v_size # Use the rounding bottleneck h1 = h + tf.stop_gradient(tf.round(h) - h) c = tf.squeeze(tf.round(h), axis=-1) c = tf.to_int32(c) h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") return res, c, l, embed
def bottleneck(x, hparams, filter_size, name): """Bottleneck.""" def embed(x): """Embedding function; must be compatible with the code later.""" with tf.variable_scope(name, reuse=True): if hparams.bottleneck_kind == "semhash": c = int_to_bit(x, z_size) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b elif hparams.bottleneck_kind == "gumbel-softmax": hot = tf.one_hot(x, hparams.v_size) h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") elif hparams.bottleneck_kind == "vq-vae": means = tf.get_variable(name="means", shape=[hparams.v_size, hparams.hidden_size]) h1 = tf.gather(means, x) elif hparams.bottleneck_kind == "rounding": h1 = x h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") with tf.variable_scope(name): z_size = hparams.z_size l = tf.constant(0.0) if hparams.bottleneck_kind == "dense": c = tf.layers.dense(x, z_size, name="vcc") h1 = tf.layers.dense(c, filter_size, name="vch1") if hparams.bottleneck_kind == "vae": c, l, _, _ = vae(x, z_size, "vae") h1 = tf.layers.dense(c, filter_size, name="vch1") if hparams.bottleneck_kind == "semhash": c = tf.layers.dense(x, z_size, name="vcc") y_clean = common_layers.saturating_sigmoid(c) if _DO_SUMMARIES: tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1])) if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN: dev = hparams.noise_dev noise = tf.truncated_normal(common_layers.shape_list(c), mean=0.0, stddev=dev) y = common_layers.saturating_sigmoid(c + noise) else: y = y_clean d = tf.to_float(tf.less(0.5, y)) y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2) pd *= hparams.d_mix pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 c = tf.where(tf.less(tf.random_uniform( [common_layers.shape_list(y)[0]]), pd), y_discrete, y) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b dx = tf.to_int32(tf.stop_gradient(d)) c = bit_to_int(dx, z_size) if hparams.bottleneck_kind == "gumbel-softmax": _, hot, l = dae(x, hparams, name) c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") if hparams.bottleneck_kind == "vq-vae": means = tf.get_variable(name="means", shape=[hparams.v_size, hparams.hidden_size]) x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans") h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x) c = tf.argmax(x_means_hot, axis=-1) if hparams.bottleneck_kind == "rounding": h = tf.layers.dense(x, 1, name="vcc") # Make h between 0 and 1 h = tf.sigmoid(h) # Multiply by z_size to get it between [0, z_size] h *= hparams.v_size # Use the rounding bottleneck h1 = h + tf.stop_gradient(tf.round(h) - h) c = tf.squeeze(tf.round(h), axis=-1) c = tf.to_int32(c) h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") return res, c, l, embed
def discrete_bottleneck(x, hidden_size, z_size, filter_size, name, mode=None, startup_steps=50000, bottleneck_kind="dvq", num_blocks=2, num_residuals=1, reshape_method="slice", projection_tensors=None, means=None, beta=0.25, noise_dev=1., decay=0.999, discrete_mix=0.5, random_top_k=1, soft_em=False, num_samples=1, epsilon=1e-5, softmax_k=0, kl_warmup_steps=150000, ema=True, ema_count=None, ema_means=None, summary=True): """Discretization bottleneck for latent variables. Args: x: Input to the discretization bottleneck. hidden_size: Dimension of the latent state. z_size: Number of bits used to produce discrete code; discrete codes range from 1 to 2**z_size. filter_size: Filter size to be used for the embedding function. name: Name for the bottleneck scope. mode: Mode represents whether we are training or testing for bottlenecks that differ in behavior (Default: None). startup_steps: Number of steps after which latent predictor is trained (Default: 50000). bottleneck_kind: Kind of discretization bottleneck to use; one of dvq, semhash, gumbel-softmax (Default: dvq). num_blocks: Number of blocks to use for decomposed vector quantization (Default: 2). num_residuals: Number of residual units used to compute nearest neighbors (Default: 1). reshape_method: Method to reshape for DVQ (Default: slice). projection_tensors: If the reshape method is project, then these are the tensors used to project (Default: None). means: The embedding table for dvq (Default: None). beta: Beta factor for the DVQ loss (Default: 0.25). noise_dev: Stddev for noise added for semhash (Default: 0). decay: Decay factor for the exponential moving average (Default: 0.999). discrete_mix: Factor for mixing discrete and non-discrete input for semhash (Default: 0.5). random_top_k: Noisy top-k for DVQ (Default: 1). soft_em: If True then use soft EM rather than hard EM (Default: False). num_samples: Number of samples for soft EM (Default: 1). epsilon: Epsilon parameter for DVQ (Default: 1e-5). softmax_k: If > 1 then do top-k softmax (Default: 0). kl_warmup_steps: Number of steps for kl warmup (Default: 150000). ema: If True update embeddings using exponential moving averages (Default: True). ema_count: Table of counts for each embedding corresponding to how many examples in a batch it was the closest to (Default: None). ema_means: Exponentially averaged version of the embeddings (Default: None). summary: If True, then write summaries (Default: True). Returns: Embedding to pass to the decoder, discrete latent, loss, and the embedding function. Raises: ValueError: If projection_tensors is None for reshape_method project, or ema_count or ema_means is None if we are using ema, or unknown args. """ block_v_size = None if bottleneck_kind == "dvq": # Define the dvq parameters assert means is not None # Check block dimensions add up if hidden_size % num_blocks != 0: raise ValueError("num_blocks does not divide hidden size") if z_size % num_residuals != 0: raise ValueError("num_residuals does not divide embedding table size") z_size_per_residual = int(z_size / num_residuals) if z_size_per_residual % num_blocks != 0: raise ValueError("num_blocks does not divide embedding table size") block_v_size = 2**(z_size_per_residual / num_blocks) block_v_size = int(block_v_size) # Set the reshape method corresponding to projections or slices if reshape_method == "slice": reshape_fn = partial( slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks) elif reshape_method == "project": if projection_tensors is None: raise ValueError( "Projection tensors is None for reshape_method project") reshape_fn = partial( project_hidden, projection_tensors=projection_tensors, hidden_size=hidden_size, num_blocks=num_blocks) else: raise ValueError("Unknown reshape_method") # Check if the ema settings make sense if ema: if ema_count is None: raise ValueError("ema_count is None but ema is True") if ema_means is None: raise ValueError("ema_means is None but ema is True") with tf.variable_scope(name, reuse=tf.AUTO_REUSE): l = tf.constant(0.0) if bottleneck_kind == "dense": c = tf.layers.dense(x, z_size, name="vcc") h1 = tf.layers.dense(c, filter_size, name="vch1") elif bottleneck_kind == "vae": c, l, _, _ = vae(x, z_size, "vae") h1 = tf.layers.dense(c, filter_size, name="vch1") elif bottleneck_kind == "semhash": c = tf.layers.dense(x, z_size, name="vcc") y_clean = common_layers.saturating_sigmoid(c) if summary: tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1])) if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN: noise = tf.truncated_normal( common_layers.shape_list(c), mean=0.0, stddev=noise_dev) y = common_layers.saturating_sigmoid(c + noise) else: y = y_clean d = tf.to_float(tf.less(0.5, y)) y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) pd = common_layers.inverse_exp_decay(startup_steps * 2) pd *= discrete_mix pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0 c = tf.where( tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd), y_discrete, y) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b dx = tf.to_int32(tf.stop_gradient(d)) c = bit_to_int(dx, z_size) elif bottleneck_kind == "gumbel-softmax": _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k, kl_warmup_steps, summary) c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hidden_size, name="dae_dense") elif bottleneck_kind == "dvq": x_reshaped = reshape_fn(x) x_res = x_reshaped x_means_hot = [] x_means = 0 l = 0 for i in range(num_residuals): x_means_hot_res, x_means_res, q_loss_res, e_loss_res = embedding_lookup( x_res, means[i], num_blocks, block_v_size, random_top_k, soft_em, num_samples) # Update the ema variables if ema: tf.logging.info("Using EMA with beta = {}".format(beta)) updated_ema_count_res = moving_averages.assign_moving_average( ema_count[i], tf.reduce_sum( tf.reshape( x_means_hot_res, shape=[-1, num_blocks, block_v_size]), axis=0), decay, zero_debias=False) dw = tf.matmul( tf.transpose(x_means_hot_res, perm=[1, 2, 0]), tf.transpose(x_res, perm=[1, 0, 2])) updated_ema_means_res = moving_averages.assign_moving_average( ema_means[i], dw, decay, zero_debias=False) n = tf.reduce_sum(updated_ema_count_res, axis=-1, keep_dims=True) updated_ema_count_res = ((updated_ema_count_res + epsilon) / (n + 2**z_size * epsilon) * n) # pylint: disable=g-no-augmented-assignment updated_ema_means_res = updated_ema_means_res / tf.expand_dims( updated_ema_count_res, axis=-1) # pylint: enable=g-no-augmented-assignment with tf.control_dependencies([e_loss_res]): update_means_res = tf.assign(means[i], updated_ema_means_res) with tf.control_dependencies([update_means_res]): l += beta * e_loss_res else: l += q_loss_res + beta * e_loss_res # Update the residuals x_res -= x_means_res x_means += x_means_res x_means_hot.append(x_means_hot_res) # Get the discrete latent representation x_means_hot = tf.stack(x_means_hot, axis=1) x_means_idx = tf.argmax(x_means_hot, axis=-1) # Get the binary representation x_means_bits = int_to_bit( x_means_idx, num_bits=int(z_size / (num_residuals * num_blocks)), base=2) shape = common_layers.shape_list(x_means_bits) new_shape = shape[:-2] new_shape[-1] = z_size x_means_bits = tf.reshape(x_means_bits, shape=new_shape) c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2) # Adjust shape of c shape_x = common_layers.shape_list(x) new_shape = shape_x[:-1] c = tf.reshape(c, new_shape) # If we are doing soft EM then c is x_means_hot if soft_em: c = x_means_hot new_shape.append(block_v_size) c = tf.reshape(c, new_shape) x_means = tf.reshape(x_means, shape_x) x_reshaped = tf.reshape(x_reshaped, shape_x) h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped) else: raise ValueError("Unknown discretization method.") res = h1 embed_fn = partial( embed, hidden_size=hidden_size, z_size=z_size, filter_size=filter_size, name=name, bottleneck_kind=bottleneck_kind, soft_em=soft_em, num_blocks=num_blocks, num_residuals=num_residuals, block_v_size=block_v_size, means=means) return res, c, l, embed_fn
def discrete_bottleneck(x, hidden_size, z_size, filter_size, name, mode=None, startup_steps=50000, bottleneck_kind='dvq', num_blocks=2, reshape_method='slice', projection_tensors=None, means=None, beta=0.25, noise_dev=1., decay=0.999, discrete_mix=0.5, random_top_k=1, epsilon=1e-5, softmax_k=0, kl_warmup_steps=150000, ema=True, ema_count=None, ema_means=None, summary=True, dp_strength=1.0, dp_decay=1.0, dp_alpha=0.5): """Discretization bottleneck for latent variables. Args: x: Input to the discretization bottleneck. hidden_size: Dimension of the latent state. z_size: Number of bits used to produce discrete code; discrete codes range from 1 to 2**z_size. filter_size: Filter size to be used for the embedding function. name: Name for the bottleneck scope. mode: Mode represents whether we are training or testing for bottlenecks that differ in behavior (Default: None). startup_steps: Number of steps after which latent predictor is trained (Default: 50000). bottleneck_kind: Kind of discretization bottleneck to use; one of dvq, semhash, gumbel-softmax (Default: dvq). num_blocks: Number of blocks to use for decomposed vector quantization. reshape_method: Method to reshape for DVQ (Default: slice). projection_tensors: If the reshape method is project, then these are the tensors used to project (Default: None). means: The embedding table for dvq (Default: None). beta: Beta factor for the DVQ loss (Default: 0.25). noise_dev: Stddev for noise added for semhash (Default: 0). decay: Decay factor for the exponential moving average (Default: 0.999). discrete_mix: Factor for mixing discrete and non-discrete input for semhash (Default: 0.5). random_top_k: Noisy top-k for DVQ (Default: 1). epsilon: Epsilon parameter for DVQ (Default: 1e-5). softmax_k: If > 1 then do top-k softmax (Default: 0). kl_warmup_steps: Number of steps for kl warmup (Default: 150000). ema: If True update embeddings using exponential moving averages (Default: True). ema_count: Table of counts for each embedding corresponding to how many examples in a batch it was the closest to (Default: None). ema_means: Exponentially averaged version of the embeddings (Default: None). summary: If True, then write summaries (Default: True). dp_strength: Strength of Dirichlet Process loss prior (Default: 1.0). dp_decay: Decay the dp_strength using an exponential decay using this term (Default: 1.0). dp_alpha: Alpha term (pseudo-count) in Dirichlet Process (Default: 0.5). Returns: Embedding to pass to the decoder, discrete latent, loss, and the embedding function. Raises: ValueError: If projection_tensors is None for reshape_method project, or ema_count or ema_means is None if we are using ema, or unknown args. """ block_v_size = None if bottleneck_kind == 'dvq': # Define the dvq parameters assert means is not None # Check block dimensions add up if hidden_size % num_blocks != 0: raise ValueError('num_blocks does not divide hidden size') if 2**z_size % num_blocks != 0: raise ValueError('num_blocks does not divide embedding table size') block_v_size = 2**(z_size / num_blocks) block_v_size = int(block_v_size) # Set the reshape method corresponding to projections or slices if reshape_method == 'slice': reshape_fn = partial(slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks) elif reshape_method == 'project': if projection_tensors is None: raise ValueError( 'Projection tensors is None for reshape_method project') reshape_fn = partial(project_hidden, projection_tensors=projection_tensors, hidden_size=hidden_size, num_blocks=num_blocks) else: raise ValueError('Unknown reshape_method') # Check if the ema settings make sense if ema: if ema_count is None: raise ValueError('ema_count is None but ema is True') if ema_means is None: raise ValueError('ema_means is None but ema is True') with tf.variable_scope(name, reuse=tf.AUTO_REUSE): l = tf.constant(0.0) if bottleneck_kind == 'dense': c = tf.layers.dense(x, z_size, name='vcc') h1 = tf.layers.dense(c, filter_size, name='vch1') elif bottleneck_kind == 'vae': c, l, _, _ = vae(x, z_size, 'vae') h1 = tf.layers.dense(c, filter_size, name='vch1') elif bottleneck_kind == 'semhash': c = tf.layers.dense(x, z_size, name='vcc') y_clean = common_layers.saturating_sigmoid(c) if summary: tf.summary.histogram('y_clean', tf.reshape(y_clean, [-1])) if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN: noise = tf.truncated_normal(common_layers.shape_list(c), mean=0.0, stddev=noise_dev) y = common_layers.saturating_sigmoid(c + noise) else: y = y_clean d = tf.to_float(tf.less(0.5, y)) y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) pd = common_layers.inverse_exp_decay(startup_steps * 2) pd *= discrete_mix pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0 c = tf.where( tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd), y_discrete, y) h1a = tf.layers.dense(c, filter_size, name='vch1a') h1b = tf.layers.dense(1.0 - c, filter_size, name='vch1b') h1 = h1a + h1b dx = tf.to_int32(tf.stop_gradient(d)) c = bit_to_int(dx, z_size) elif bottleneck_kind == 'gumbel-softmax': _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k, kl_warmup_steps, summary) c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hidden_size, name='dae_dense') elif bottleneck_kind == 'dvq': x_reshaped = reshape_fn(x) x_means_hot, x_means, q_loss, e_loss = embedding_lookup( x_reshaped, means, num_blocks, block_v_size, random_top_k) # Get the discrete latent represenation x_means_idx = tf.argmax(x_means_hot, axis=-1) # Get the binary representation x_means_bits = int_to_bit(x_means_idx, num_bits=int(z_size / num_blocks), base=2) shape = common_layers.shape_list(x_means_bits) new_shape = shape[:-1] new_shape[-1] = z_size x_means_bits = tf.reshape(x_means_bits, shape=new_shape) c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2) # Adjust shape of c shape_x = common_layers.shape_list(x) new_shape = shape_x[:-1] c = tf.reshape(c, new_shape) # Update the ema variables if ema: tf.logging.info('Using EMA with beta = {}'.format(beta)) updated_ema_count = moving_averages.assign_moving_average( ema_count, tf.reduce_sum(tf.reshape( x_means_hot, shape=[-1, num_blocks, block_v_size]), axis=0), decay, zero_debias=False) # Adding a term that puts a Dirichlet prior over cluster probabilities # Hopefully it'll encourage rich get richer behaviors dp_prior_loss = 0. if dp_strength > 0.0: # Decay dp_strength over time to make it less important dp_strength = tf.train.exponential_decay( dp_strength, global_step=tf.to_int32(tf.train.get_global_step()), decay_steps=20000, decay_rate=dp_decay) dp_count = ema_count + dp_alpha p = dp_count / tf.reduce_sum(dp_count, 1, keepdims=True) dp_prior_loss = tf.log(p) dp_prior_loss = -1.0 * tf.reduce_sum(dp_prior_loss) dp_prior_loss /= (num_blocks * block_v_size) x_means_hot_flat = tf.reshape( x_means_hot, shape=[-1, num_blocks, block_v_size]) dw = tf.matmul(tf.transpose(x_means_hot_flat, perm=[1, 2, 0]), tf.transpose(x_reshaped, perm=[1, 0, 2])) updated_ema_means = moving_averages.assign_moving_average( ema_means, dw, decay, zero_debias=False) n = tf.reduce_sum(updated_ema_count, axis=-1, keep_dims=True) updated_ema_count = ((updated_ema_count + epsilon) / (n + 2**z_size * epsilon) * n) updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1) with tf.control_dependencies([e_loss]): update_means = tf.assign(means, updated_ema_means) with tf.control_dependencies([update_means]): l = beta * e_loss + dp_strength * dp_prior_loss else: l = q_loss + beta * e_loss x_means = tf.reshape(x_means, shape_x) x_reshaped = tf.reshape(x_reshaped, shape_x) h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped) else: raise ValueError('Unknown discretization method.') h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name='vch2') res = tf.layers.dense(tf.nn.relu(h2), hidden_size, name='vcfin') embed_fn = partial(embed, hidden_size=hidden_size, z_size=z_size, filter_size=filter_size, name=name, bottleneck_kind=bottleneck_kind, num_blocks=num_blocks, block_v_size=block_v_size, means=means) return res, c, l, embed_fn
def discrete_bottleneck(x, hidden_size, z_size, filter_size, name, mode=None, startup_steps=50000, bottleneck_kind='dvq', num_blocks=2, reshape_method='slice', projection_tensors=None, means=None, beta=0.25, noise_dev=1., decay=0.999, discrete_mix=0.5, random_top_k=1, epsilon=1e-5, softmax_k=0, kl_warmup_steps=150000, ema=True, ema_count=None, ema_means=None, summary=True, dp_strength=1.0, dp_decay=1.0, dp_alpha=0.5): """Discretization bottleneck for latent variables. Args: x: Input to the discretization bottleneck. hidden_size: Dimension of the latent state. z_size: Number of bits used to produce discrete code; discrete codes range from 1 to 2**z_size. filter_size: Filter size to be used for the embedding function. name: Name for the bottleneck scope. mode: Mode represents whether we are training or testing for bottlenecks that differ in behavior (Default: None). startup_steps: Number of steps after which latent predictor is trained (Default: 50000). bottleneck_kind: Kind of discretization bottleneck to use; one of dvq, semhash, gumbel-softmax (Default: dvq). num_blocks: Number of blocks to use for decomposed vector quantization. reshape_method: Method to reshape for DVQ (Default: slice). projection_tensors: If the reshape method is project, then these are the tensors used to project (Default: None). means: The embedding table for dvq (Default: None). beta: Beta factor for the DVQ loss (Default: 0.25). noise_dev: Stddev for noise added for semhash (Default: 0). decay: Decay factor for the exponential moving average (Default: 0.999). discrete_mix: Factor for mixing discrete and non-discrete input for semhash (Default: 0.5). random_top_k: Noisy top-k for DVQ (Default: 1). epsilon: Epsilon parameter for DVQ (Default: 1e-5). softmax_k: If > 1 then do top-k softmax (Default: 0). kl_warmup_steps: Number of steps for kl warmup (Default: 150000). ema: If True update embeddings using exponential moving averages (Default: True). ema_count: Table of counts for each embedding corresponding to how many examples in a batch it was the closest to (Default: None). ema_means: Exponentially averaged version of the embeddings (Default: None). summary: If True, then write summaries (Default: True). dp_strength: Strength of Dirichlet Process loss prior (Default: 1.0). dp_decay: Decay the dp_strength using an exponential decay using this term (Default: 1.0). dp_alpha: Alpha term (pseudo-count) in Dirichlet Process (Default: 0.5). Returns: Embedding to pass to the decoder, discrete latent, loss, and the embedding function. Raises: ValueError: If projection_tensors is None for reshape_method project, or ema_count or ema_means is None if we are using ema, or unknown args. """ block_v_size = None if bottleneck_kind == 'dvq': # Define the dvq parameters assert means is not None # Check block dimensions add up if hidden_size % num_blocks != 0: raise ValueError('num_blocks does not divide hidden size') if 2**z_size % num_blocks != 0: raise ValueError('num_blocks does not divide embedding table size') block_v_size = 2**(z_size / num_blocks) block_v_size = int(block_v_size) # Set the reshape method corresponding to projections or slices if reshape_method == 'slice': reshape_fn = partial( slice_hidden, hidden_size=hidden_size, num_blocks=num_blocks) elif reshape_method == 'project': if projection_tensors is None: raise ValueError( 'Projection tensors is None for reshape_method project') reshape_fn = partial( project_hidden, projection_tensors=projection_tensors, hidden_size=hidden_size, num_blocks=num_blocks) else: raise ValueError('Unknown reshape_method') # Check if the ema settings make sense if ema: if ema_count is None: raise ValueError('ema_count is None but ema is True') if ema_means is None: raise ValueError('ema_means is None but ema is True') with tf.variable_scope(name, reuse=tf.AUTO_REUSE): l = tf.constant(0.0) if bottleneck_kind == 'dense': c = tf.layers.dense(x, z_size, name='vcc') h1 = tf.layers.dense(c, filter_size, name='vch1') elif bottleneck_kind == 'vae': c, l, _, _ = vae(x, z_size, 'vae') h1 = tf.layers.dense(c, filter_size, name='vch1') elif bottleneck_kind == 'semhash': c = tf.layers.dense(x, z_size, name='vcc') y_clean = common_layers.saturating_sigmoid(c) if summary: tf.summary.histogram('y_clean', tf.reshape(y_clean, [-1])) if noise_dev > 0 and mode == tf.estimator.ModeKeys.TRAIN: noise = tf.truncated_normal( common_layers.shape_list(c), mean=0.0, stddev=noise_dev) y = common_layers.saturating_sigmoid(c + noise) else: y = y_clean d = tf.to_float(tf.less(0.5, y)) y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) pd = common_layers.inverse_exp_decay(startup_steps * 2) pd *= discrete_mix pd = pd if mode == tf.estimator.ModeKeys.TRAIN else 1.0 c = tf.where( tf.less(tf.random_uniform([common_layers.shape_list(y)[0]]), pd), y_discrete, y) h1a = tf.layers.dense(c, filter_size, name='vch1a') h1b = tf.layers.dense(1.0 - c, filter_size, name='vch1b') h1 = h1a + h1b dx = tf.to_int32(tf.stop_gradient(d)) c = bit_to_int(dx, z_size) elif bottleneck_kind == 'gumbel-softmax': _, hot, l = gumbel_softmax(x, name, z_size, mode, softmax_k, kl_warmup_steps, summary) c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hidden_size, name='dae_dense') elif bottleneck_kind == 'dvq': x_reshaped = reshape_fn(x) x_means_hot, x_means, q_loss, e_loss = embedding_lookup( x_reshaped, means, num_blocks, block_v_size, random_top_k) # Get the discrete latent represenation x_means_idx = tf.argmax(x_means_hot, axis=-1) # Get the binary representation x_means_bits = int_to_bit( x_means_idx, num_bits=int(z_size / num_blocks), base=2) shape = common_layers.shape_list(x_means_bits) new_shape = shape[:-1] new_shape[-1] = z_size x_means_bits = tf.reshape(x_means_bits, shape=new_shape) c = bit_to_int(tf.to_int32(x_means_bits), num_bits=z_size, base=2) # Adjust shape of c shape_x = common_layers.shape_list(x) new_shape = shape_x[:-1] c = tf.reshape(c, new_shape) # Update the ema variables if ema: tf.logging.info('Using EMA with beta = {}'.format(beta)) updated_ema_count = moving_averages.assign_moving_average( ema_count, tf.reduce_sum( tf.reshape(x_means_hot, shape=[-1, num_blocks, block_v_size]), axis=0), decay, zero_debias=False) # Adding a term that puts a Dirichlet prior over cluster probabilities # Hopefully it'll encourage rich get richer behaviors dp_prior_loss = 0. if dp_strength > 0.0: # Decay dp_strength over time to make it less important dp_strength = tf.train.exponential_decay( dp_strength, global_step=tf.to_int32(tf.train.get_global_step()), decay_steps=20000, decay_rate=dp_decay) dp_count = ema_count + dp_alpha p = dp_count / tf.reduce_sum(dp_count, 1, keepdims=True) dp_prior_loss = tf.log(p) dp_prior_loss = -1.0 * tf.reduce_sum(dp_prior_loss) dp_prior_loss /= (num_blocks * block_v_size) x_means_hot_flat = tf.reshape( x_means_hot, shape=[-1, num_blocks, block_v_size]) dw = tf.matmul( tf.transpose(x_means_hot_flat, perm=[1, 2, 0]), tf.transpose(x_reshaped, perm=[1, 0, 2])) updated_ema_means = moving_averages.assign_moving_average( ema_means, dw, decay, zero_debias=False) n = tf.reduce_sum(updated_ema_count, axis=-1, keep_dims=True) updated_ema_count = ((updated_ema_count + epsilon) / (n + 2**z_size * epsilon) * n) updated_ema_means /= tf.expand_dims(updated_ema_count, axis=-1) with tf.control_dependencies([e_loss]): update_means = tf.assign(means, updated_ema_means) with tf.control_dependencies([update_means]): l = beta * e_loss + dp_strength * dp_prior_loss else: l = q_loss + beta * e_loss x_means = tf.reshape(x_means, shape_x) x_reshaped = tf.reshape(x_reshaped, shape_x) h1 = x_reshaped + tf.stop_gradient(x_means - x_reshaped) else: raise ValueError('Unknown discretization method.') h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name='vch2') res = tf.layers.dense(tf.nn.relu(h2), hidden_size, name='vcfin') embed_fn = partial( embed, hidden_size=hidden_size, z_size=z_size, filter_size=filter_size, name=name, bottleneck_kind=bottleneck_kind, num_blocks=num_blocks, block_v_size=block_v_size, means=means) return res, c, l, embed_fn
def bottleneck(x, hparams, filter_size, name): """Bottleneck.""" def embed(x): """Embedding function; must be compatible with the code later.""" with tf.variable_scope(name, reuse=True): if hparams.bottleneck_kind == "semhash": c = int_to_bit(x, z_size) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b elif hparams.bottleneck_kind == "gumbel-softmax": hot = tf.one_hot(x, hparams.v_size) h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") elif hparams.bottleneck_kind == "vq-vae": means = tf.get_variable( name="means", shape=[hparams.v_size, hparams.hidden_size]) h1 = tf.gather(means, x) h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") with tf.variable_scope(name): z_size = hparams.z_size l = tf.constant(0.0) if hparams.bottleneck_kind == "dense": c = tf.layers.dense(x, z_size, name="vcc") h1 = tf.layers.dense(c, filter_size, name="vch1") if hparams.bottleneck_kind == "vae": c, l, _, _ = vae(x, z_size, "vae") h1 = tf.layers.dense(c, filter_size, name="vch1") if hparams.bottleneck_kind == "semhash": c = tf.layers.dense(x, z_size, name="vcc") y_clean = common_layers.saturating_sigmoid(c) if _DO_SUMMARIES: tf.summary.histogram("y_clean", tf.reshape(y_clean, [-1])) if hparams.noise_dev > 0 and hparams.mode == tf.estimator.ModeKeys.TRAIN: dev = hparams.noise_dev noise = tf.truncated_normal(tf.shape(c), mean=0.0, stddev=dev) y = common_layers.saturating_sigmoid(c + noise) else: y = y_clean d = tf.to_float(tf.less(0.5, y)) y_discrete = tf.stop_gradient(d) + y - tf.stop_gradient(y) pd = common_layers.inverse_exp_decay(hparams.startup_steps * 2) pd *= hparams.d_mix pd = pd if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0 c = tf.cond(tf.less(tf.random_uniform([]), pd), lambda: y_discrete, lambda: y) h1a = tf.layers.dense(c, filter_size, name="vch1a") h1b = tf.layers.dense(1.0 - c, filter_size, name="vch1b") h1 = h1a + h1b dx = tf.to_int32(tf.stop_gradient(d)) c = bit_to_int(dx, z_size) if hparams.bottleneck_kind == "gumbel-softmax": _, hot, l = dae(x, hparams, name) c = tf.argmax(hot, axis=-1) h1 = tf.layers.dense(hot, hparams.hidden_size, name="dae_dense") if hparams.bottleneck_kind == "vq-vae": means = tf.get_variable( name="means", shape=[hparams.v_size, hparams.hidden_size]) x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans") h1 = x_means c = tf.argmax(x_means_hot, axis=-1) h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") return res, c, l, embed