def get_mlp_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} # cfg["dataset"] contains (dname, extra_info) dataset = utils.get_image_dataset(cfg["dataset"]) def _fn(batch): image = utils.maybe_center(cfg["center_data"], batch["image"]) hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]] net = snt.BatchFlatten()(image) mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) logits = mod(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_fn), dataset)
def get_mlp_vae_family(cfg): """Gets a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_vae_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} datasets = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Build the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) latent_size = cfg["enc_hidden_units"][-1] def encoder_fn(net): hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(net) return generative_utils.LogStddevNormal(outputs) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(net): hidden_units = cfg["dec_hidden_units"] + [ flat_img.shape.as_list()[1] * 2 ] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) net = mod(net) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, flat_img) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def get_nvp_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_nvp_family_cfg`. Returns: base.BaseTask for the given config. """ datasets = utils.get_image_dataset(cfg["dataset"]) act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) def _build(batch): dist = distribution_with_nvp_bijectors( batch["image"], num_bijectors=cfg["num_bijectors"], layers=cfg["hidden_units"], activation=act_fn, w_init=w_init) return neg_log_p(dist, batch["image"]) base_model_fn = lambda: snt.Module(_build) return base.DatasetModelTask(base_model_fn, datasets)
def _get_fully_connected(cfg): """Get a fully connected problem from the given config.""" return (problem_spec.Spec( pg.FullyConnected, (cfg["n_features"], cfg["n_classes"]), { "hidden_sizes": tuple(cfg["hidden_sizes"]), "activation": utils.get_activation(cfg["activation"]), }), losg_datasets.random_mlp(cfg["n_features"], cfg["n_samples"]), cfg["bs"])
def get_mlp_ae_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_ae_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} datasets = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Builds the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) if cfg["output_type"] in ["tanh", "linear_center"]: flat_img = flat_img * 2.0 - 1.0 hidden_units = cfg["hidden_units"] + [flat_img.shape.as_list()[1]] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(flat_img) if cfg["output_type"] == "sigmoid": outputs = tf.nn.sigmoid(outputs) elif cfg["output_type"] == "tanh": outputs = tf.tanh(outputs) elif cfg["output_type"] in ["linear", "linear_center"]: # nothing to be done to the outputs pass else: raise ValueError("Invalid output_type [%s]." % cfg["output_type"]) reduce_fn = getattr(tf, cfg["reduction_type"]) if cfg["loss_type"] == "l2": loss_vec = reduce_fn(tf.square(outputs - flat_img), axis=1) elif cfg["loss_type"] == "l1": loss_vec = reduce_fn(tf.abs(outputs - flat_img), axis=1) else: raise ValueError("Unsupported loss_type [%s]." % cfg["reduction_type"]) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def get_conv_pooling_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_conv_pooling_family_cfg`. Returns: A task for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} hidden_units = cfg["hidden_units"] dataset = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Builds the sonnet module.""" image = utils.maybe_center(cfg["center_data"], batch["image"]) net = snt.nets.ConvNet2D( hidden_units, kernel_shapes=[(3, 3)], strides=cfg["strides"], paddings=cfg["padding"], activation=act_fn, use_bias=cfg["use_bias"], initializers=init, activate_final=True)( image) if cfg["pool_type"] == "mean": net = tf.reduce_mean(net, axis=[1, 2]) elif cfg["pool_type"] == "max": net = tf.reduce_max(net, axis=[1, 2]) elif cfg["pool_type"] == "squared_mean": net = tf.reduce_mean(net**2, axis=[1, 2]) logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
def test_sample_get_activation(self): rng = np.random.RandomState(123) sampled_acts = [] num = 4000 for _ in range(num): aname = utils.sample_activation(rng) sampled_acts.append(aname) # smoke test to ensure graph builds out = utils.get_activation(aname)(tf.constant(1.0)) self.assertIsInstance(out, tf.Tensor) uniques, counts = np.unique(sampled_acts, return_counts=True) counts_map = {str(u): c for u, c in zip(uniques, counts)} # 16 is the total sum of unnormalized probs amount_per_n = num / float(16) self.assertNear(counts_map["relu"], amount_per_n * 6, 40) self.assertNear(counts_map["tanh"], amount_per_n * 3, 40) self.assertNear(counts_map["swish"], amount_per_n, 40)