def get_mlp_vae_family(cfg): """Gets a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_vae_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} datasets = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Build the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) latent_size = cfg["enc_hidden_units"][-1] def encoder_fn(net): hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(net) return generative_utils.LogStddevNormal(outputs) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(net): hidden_units = cfg["dec_hidden_units"] + [ flat_img.shape.as_list()[1] * 2 ] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) net = mod(net) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, flat_img) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def get_mlp_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} # cfg["dataset"] contains (dname, extra_info) dataset = utils.get_image_dataset(cfg["dataset"]) def _fn(batch): image = utils.maybe_center(cfg["center_data"], batch["image"]) hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]] net = snt.BatchFlatten()(image) mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) logits = mod(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_fn), dataset)
def _build(batch): """Builds the sonnet module.""" # Shape is [batch size, sequence length] inp = batch["text"] # Clip the vocab to be at most vocab_size. inp = tf.minimum(inp, tf.to_int64(tf.reshape(cfg["vocab_size"] - 1, [1, 1]))) embed = snt.Embed(vocab_size=cfg["vocab_size"], embed_dim=cfg["embed_dim"]) embedded_chars = embed(inp) rnn = utils.get_rnn_core(cfg["core"]) batch_size = inp.shape.as_list()[0] state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"]) outputs, state = tf.nn.dynamic_rnn(rnn, embedded_chars, initial_state=state) w_init = utils.get_initializer(cfg["w_init"]) pred_logits = snt.BatchApply( snt.Linear(cfg["vocab_size"], initializers={"w": w_init}))( outputs[:, :-1]) actual_output_tokens = inp[:, 1:] flat_s = [pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2]] flat_pred_logits = tf.reshape(pred_logits, flat_s) flat_actual_tokens = tf.reshape(actual_output_tokens, [flat_s[0]]) loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=flat_actual_tokens, logits=flat_pred_logits) return tf.reduce_mean(loss_vec)
def get_nvp_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_nvp_family_cfg`. Returns: base.BaseTask for the given config. """ datasets = utils.get_image_dataset(cfg["dataset"]) act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) def _build(batch): dist = distribution_with_nvp_bijectors( batch["image"], num_bijectors=cfg["num_bijectors"], layers=cfg["hidden_units"], activation=act_fn, w_init=w_init) return neg_log_p(dist, batch["image"]) base_model_fn = lambda: snt.Module(_build) return base.DatasetModelTask(base_model_fn, datasets)
def get_mlp_ae_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_ae_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} datasets = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Builds the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) if cfg["output_type"] in ["tanh", "linear_center"]: flat_img = flat_img * 2.0 - 1.0 hidden_units = cfg["hidden_units"] + [flat_img.shape.as_list()[1]] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(flat_img) if cfg["output_type"] == "sigmoid": outputs = tf.nn.sigmoid(outputs) elif cfg["output_type"] == "tanh": outputs = tf.tanh(outputs) elif cfg["output_type"] in ["linear", "linear_center"]: # nothing to be done to the outputs pass else: raise ValueError("Invalid output_type [%s]." % cfg["output_type"]) reduce_fn = getattr(tf, cfg["reduction_type"]) if cfg["loss_type"] == "l2": loss_vec = reduce_fn(tf.square(outputs - flat_img), axis=1) elif cfg["loss_type"] == "l1": loss_vec = reduce_fn(tf.abs(outputs - flat_img), axis=1) else: raise ValueError("Unsupported loss_type [%s]." % cfg["reduction_type"]) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def get_conv_pooling_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_conv_pooling_family_cfg`. Returns: A task for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} hidden_units = cfg["hidden_units"] dataset = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Builds the sonnet module.""" image = utils.maybe_center(cfg["center_data"], batch["image"]) net = snt.nets.ConvNet2D( hidden_units, kernel_shapes=[(3, 3)], strides=cfg["strides"], paddings=cfg["padding"], activation=act_fn, use_bias=cfg["use_bias"], initializers=init, activate_final=True)( image) if cfg["pool_type"] == "mean": net = tf.reduce_mean(net, axis=[1, 2]) elif cfg["pool_type"] == "max": net = tf.reduce_max(net, axis=[1, 2]) elif cfg["pool_type"] == "squared_mean": net = tf.reduce_mean(net**2, axis=[1, 2]) logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
def test_sample_get_initializer(self): rng = np.random.RandomState(123) sampled_init = [] num = 3000 for _ in range(num): init_name, args = utils.sample_initializer(rng) sampled_init.append(init_name) # smoke test to ensure graph builds out = utils.get_initializer((init_name, args))((10, 10)) self.assertIsInstance(out, tf.Tensor) uniques, counts = np.unique(sampled_init, return_counts=True) counts_map = {str(u): c for u, c in zip(uniques, counts)} # 13 is the total sum of unnormalized probs amount_per_n = num / float(13) self.assertNear(counts_map["he_normal"], amount_per_n * 2, 40) self.assertNear(counts_map["orthogonal"], amount_per_n, 40) self.assertNear(counts_map["glorot_normal"], amount_per_n * 2, 40)
def get_rnn_text_classification_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_rnn_text_classification_family_cfg`. Returns: base.BaseTask for the given config. """ w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} def _build(batch): """Build the sonnet module. Args: batch: A dictionary with keys "label", "label_onehot", and "text" mapping to tensors. The "text" consists of int tokens. These tokens are truncated to the length of the vocab before performing an embedding lookup. Returns: Loss of the batch. """ vocab_size = cfg["vocab_size"] max_token = cfg["dataset"][1]["max_token"] if max_token: vocab_size = min(max_token, vocab_size) # Clip the max token to be vocab_size-1. tokens = tf.minimum( tf.to_int32(batch["text"]), tf.to_int32(tf.reshape(vocab_size - 1, [1, 1]))) embed = snt.Embed(vocab_size=vocab_size, embed_dim=cfg["embed_dim"]) embedded_tokens = embed(tokens) rnn = utils.get_rnn_core(cfg["core"]) batch_size = tokens.shape.as_list()[0] state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"]) outputs, _ = tf.nn.dynamic_rnn(rnn, embedded_tokens, initial_state=state) if cfg["loss_compute"] == "last": rnn_output = outputs[:, -1] # grab the last output elif cfg["loss_compute"] == "avg": rnn_output = tf.reduce_mean(outputs, 1) # average over length elif cfg["loss_compute"] == "max": rnn_output = tf.reduce_max(outputs, 1) else: raise ValueError("Not supported loss_compute [%s]" % cfg["loss_compute"]) logits = snt.Linear( batch["label_onehot"].shape[1], initializers=init)( rnn_output) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) datasets = utils.get_text_dataset(cfg["dataset"]) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)