コード例 #1
0
def get_mlp_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
  act_fn = utils.get_activation(cfg["activation"])
  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}
  # cfg["dataset"] contains (dname, extra_info)

  dataset = utils.get_image_dataset(cfg["dataset"])

  def _fn(batch):
    image = utils.maybe_center(cfg["center_data"], batch["image"])
    hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]]
    net = snt.BatchFlatten()(image)
    mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
    logits = mod(net)
    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)
    return tf.reduce_mean(loss_vec)

  return base.DatasetModelTask(lambda: snt.Module(_fn), dataset)
コード例 #2
0
def get_mlp_vae_family(cfg):
    """Gets a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_vae_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])
    init = {"w": w_init}

    datasets = utils.get_image_dataset(cfg["dataset"])

    def _build(batch):
        """Build the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])
        latent_size = cfg["enc_hidden_units"][-1]

        def encoder_fn(net):
            hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            outputs = mod(net)
            return generative_utils.LogStddevNormal(outputs)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            hidden_units = cfg["dec_hidden_units"] + [
                flat_img.shape.as_list()[1] * 2
            ]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            net = mod(net)
            net = tf.clip_by_value(net, -10, 10)
            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")
        zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, flat_img)
        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)

    return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
コード例 #3
0
ファイル: nvp.py プロジェクト: MitchellTesla/google-research
def get_nvp_family(cfg):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_nvp_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    datasets = utils.get_image_dataset(cfg["dataset"])
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])

    def _build(batch):
        dist = distribution_with_nvp_bijectors(
            batch["image"],
            num_bijectors=cfg["num_bijectors"],
            layers=cfg["hidden_units"],
            activation=act_fn,
            w_init=w_init)
        return neg_log_p(dist, batch["image"])

    base_model_fn = lambda: snt.Module(_build)

    return base.DatasetModelTask(base_model_fn, datasets)
コード例 #4
0
def _get_fully_connected(cfg):
    """Get a fully connected problem from the given config."""
    return (problem_spec.Spec(
        pg.FullyConnected, (cfg["n_features"], cfg["n_classes"]), {
            "hidden_sizes": tuple(cfg["hidden_sizes"]),
            "activation": utils.get_activation(cfg["activation"]),
        }), losg_datasets.random_mlp(cfg["n_features"],
                                     cfg["n_samples"]), cfg["bs"])
コード例 #5
0
def get_mlp_ae_family(cfg):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_ae_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])
    init = {"w": w_init}

    datasets = utils.get_image_dataset(cfg["dataset"])

    def _build(batch):
        """Builds the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])

        if cfg["output_type"] in ["tanh", "linear_center"]:
            flat_img = flat_img * 2.0 - 1.0

        hidden_units = cfg["hidden_units"] + [flat_img.shape.as_list()[1]]
        mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
        outputs = mod(flat_img)

        if cfg["output_type"] == "sigmoid":
            outputs = tf.nn.sigmoid(outputs)
        elif cfg["output_type"] == "tanh":
            outputs = tf.tanh(outputs)
        elif cfg["output_type"] in ["linear", "linear_center"]:
            # nothing to be done to the outputs
            pass
        else:
            raise ValueError("Invalid output_type [%s]." % cfg["output_type"])

        reduce_fn = getattr(tf, cfg["reduction_type"])
        if cfg["loss_type"] == "l2":
            loss_vec = reduce_fn(tf.square(outputs - flat_img), axis=1)
        elif cfg["loss_type"] == "l1":
            loss_vec = reduce_fn(tf.abs(outputs - flat_img), axis=1)
        else:
            raise ValueError("Unsupported loss_type [%s]." %
                             cfg["reduction_type"])

        return tf.reduce_mean(loss_vec)

    return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
コード例 #6
0
def get_conv_pooling_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by
      `sample_conv_pooling_family_cfg`.

  Returns:
    A task for the given config.
  """

  act_fn = utils.get_activation(cfg["activation"])
  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}
  hidden_units = cfg["hidden_units"]

  dataset = utils.get_image_dataset(cfg["dataset"])

  def _build(batch):
    """Builds the sonnet module."""
    image = utils.maybe_center(cfg["center_data"], batch["image"])

    net = snt.nets.ConvNet2D(
        hidden_units,
        kernel_shapes=[(3, 3)],
        strides=cfg["strides"],
        paddings=cfg["padding"],
        activation=act_fn,
        use_bias=cfg["use_bias"],
        initializers=init,
        activate_final=True)(
            image)
    if cfg["pool_type"] == "mean":
      net = tf.reduce_mean(net, axis=[1, 2])
    elif cfg["pool_type"] == "max":
      net = tf.reduce_max(net, axis=[1, 2])
    elif cfg["pool_type"] == "squared_mean":
      net = tf.reduce_mean(net**2, axis=[1, 2])

    logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net)

    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)

    return tf.reduce_mean(loss_vec)

  return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
コード例 #7
0
    def test_sample_get_activation(self):
        rng = np.random.RandomState(123)
        sampled_acts = []
        num = 4000
        for _ in range(num):
            aname = utils.sample_activation(rng)
            sampled_acts.append(aname)
            # smoke test to ensure graph builds
            out = utils.get_activation(aname)(tf.constant(1.0))
            self.assertIsInstance(out, tf.Tensor)

        uniques, counts = np.unique(sampled_acts, return_counts=True)
        counts_map = {str(u): c for u, c in zip(uniques, counts)}
        # 16 is the total sum of unnormalized probs
        amount_per_n = num / float(16)
        self.assertNear(counts_map["relu"], amount_per_n * 6, 40)
        self.assertNear(counts_map["tanh"], amount_per_n * 3, 40)
        self.assertNear(counts_map["swish"], amount_per_n, 40)