Example #1
0
def get_mlp_vae_family(cfg):
    """Gets a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_vae_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])
    init = {"w": w_init}

    datasets = utils.get_image_dataset(cfg["dataset"])

    def _build(batch):
        """Build the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])
        latent_size = cfg["enc_hidden_units"][-1]

        def encoder_fn(net):
            hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            outputs = mod(net)
            return generative_utils.LogStddevNormal(outputs)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            hidden_units = cfg["dec_hidden_units"] + [
                flat_img.shape.as_list()[1] * 2
            ]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            net = mod(net)
            net = tf.clip_by_value(net, -10, 10)
            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")
        zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, flat_img)
        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)

    return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
Example #2
0
def get_mlp_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
  act_fn = utils.get_activation(cfg["activation"])
  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}
  # cfg["dataset"] contains (dname, extra_info)

  dataset = utils.get_image_dataset(cfg["dataset"])

  def _fn(batch):
    image = utils.maybe_center(cfg["center_data"], batch["image"])
    hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]]
    net = snt.BatchFlatten()(image)
    mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
    logits = mod(net)
    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)
    return tf.reduce_mean(loss_vec)

  return base.DatasetModelTask(lambda: snt.Module(_fn), dataset)
  def _build(batch):
    """Builds the sonnet module."""
    # Shape is [batch size, sequence length]
    inp = batch["text"]

    # Clip the vocab to be at most vocab_size.
    inp = tf.minimum(inp,
                     tf.to_int64(tf.reshape(cfg["vocab_size"] - 1, [1, 1])))

    embed = snt.Embed(vocab_size=cfg["vocab_size"], embed_dim=cfg["embed_dim"])
    embedded_chars = embed(inp)

    rnn = utils.get_rnn_core(cfg["core"])
    batch_size = inp.shape.as_list()[0]

    state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"])

    outputs, state = tf.nn.dynamic_rnn(rnn, embedded_chars, initial_state=state)

    w_init = utils.get_initializer(cfg["w_init"])
    pred_logits = snt.BatchApply(
        snt.Linear(cfg["vocab_size"], initializers={"w": w_init}))(
            outputs[:, :-1])
    actual_output_tokens = inp[:, 1:]

    flat_s = [pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2]]
    flat_pred_logits = tf.reshape(pred_logits, flat_s)
    flat_actual_tokens = tf.reshape(actual_output_tokens, [flat_s[0]])

    loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=flat_actual_tokens, logits=flat_pred_logits)
    return tf.reduce_mean(loss_vec)
Example #4
0
def get_nvp_family(cfg):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_nvp_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    datasets = utils.get_image_dataset(cfg["dataset"])
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])

    def _build(batch):
        dist = distribution_with_nvp_bijectors(
            batch["image"],
            num_bijectors=cfg["num_bijectors"],
            layers=cfg["hidden_units"],
            activation=act_fn,
            w_init=w_init)
        return neg_log_p(dist, batch["image"])

    base_model_fn = lambda: snt.Module(_build)

    return base.DatasetModelTask(base_model_fn, datasets)
Example #5
0
def get_mlp_ae_family(cfg):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_ae_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])
    init = {"w": w_init}

    datasets = utils.get_image_dataset(cfg["dataset"])

    def _build(batch):
        """Builds the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])

        if cfg["output_type"] in ["tanh", "linear_center"]:
            flat_img = flat_img * 2.0 - 1.0

        hidden_units = cfg["hidden_units"] + [flat_img.shape.as_list()[1]]
        mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
        outputs = mod(flat_img)

        if cfg["output_type"] == "sigmoid":
            outputs = tf.nn.sigmoid(outputs)
        elif cfg["output_type"] == "tanh":
            outputs = tf.tanh(outputs)
        elif cfg["output_type"] in ["linear", "linear_center"]:
            # nothing to be done to the outputs
            pass
        else:
            raise ValueError("Invalid output_type [%s]." % cfg["output_type"])

        reduce_fn = getattr(tf, cfg["reduction_type"])
        if cfg["loss_type"] == "l2":
            loss_vec = reduce_fn(tf.square(outputs - flat_img), axis=1)
        elif cfg["loss_type"] == "l1":
            loss_vec = reduce_fn(tf.abs(outputs - flat_img), axis=1)
        else:
            raise ValueError("Unsupported loss_type [%s]." %
                             cfg["reduction_type"])

        return tf.reduce_mean(loss_vec)

    return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
Example #6
0
def get_conv_pooling_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by
      `sample_conv_pooling_family_cfg`.

  Returns:
    A task for the given config.
  """

  act_fn = utils.get_activation(cfg["activation"])
  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}
  hidden_units = cfg["hidden_units"]

  dataset = utils.get_image_dataset(cfg["dataset"])

  def _build(batch):
    """Builds the sonnet module."""
    image = utils.maybe_center(cfg["center_data"], batch["image"])

    net = snt.nets.ConvNet2D(
        hidden_units,
        kernel_shapes=[(3, 3)],
        strides=cfg["strides"],
        paddings=cfg["padding"],
        activation=act_fn,
        use_bias=cfg["use_bias"],
        initializers=init,
        activate_final=True)(
            image)
    if cfg["pool_type"] == "mean":
      net = tf.reduce_mean(net, axis=[1, 2])
    elif cfg["pool_type"] == "max":
      net = tf.reduce_max(net, axis=[1, 2])
    elif cfg["pool_type"] == "squared_mean":
      net = tf.reduce_mean(net**2, axis=[1, 2])

    logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net)

    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)

    return tf.reduce_mean(loss_vec)

  return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
    def test_sample_get_initializer(self):
        rng = np.random.RandomState(123)
        sampled_init = []
        num = 3000
        for _ in range(num):
            init_name, args = utils.sample_initializer(rng)
            sampled_init.append(init_name)
            # smoke test to ensure graph builds
            out = utils.get_initializer((init_name, args))((10, 10))
            self.assertIsInstance(out, tf.Tensor)

        uniques, counts = np.unique(sampled_init, return_counts=True)
        counts_map = {str(u): c for u, c in zip(uniques, counts)}
        # 13 is the total sum of unnormalized probs
        amount_per_n = num / float(13)
        self.assertNear(counts_map["he_normal"], amount_per_n * 2, 40)
        self.assertNear(counts_map["orthogonal"], amount_per_n, 40)
        self.assertNear(counts_map["glorot_normal"], amount_per_n * 2, 40)
Example #8
0
def get_rnn_text_classification_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by
      `sample_rnn_text_classification_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """

  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}

  def _build(batch):
    """Build the sonnet module.

    Args:
      batch: A dictionary with keys "label", "label_onehot", and "text" mapping
        to tensors. The "text" consists of int tokens. These tokens are
        truncated to the length of the vocab before performing an embedding
        lookup.

    Returns:
      Loss of the batch.
    """
    vocab_size = cfg["vocab_size"]
    max_token = cfg["dataset"][1]["max_token"]
    if max_token:
      vocab_size = min(max_token, vocab_size)

    # Clip the max token to be vocab_size-1.
    tokens = tf.minimum(
        tf.to_int32(batch["text"]),
        tf.to_int32(tf.reshape(vocab_size - 1, [1, 1])))

    embed = snt.Embed(vocab_size=vocab_size, embed_dim=cfg["embed_dim"])
    embedded_tokens = embed(tokens)
    rnn = utils.get_rnn_core(cfg["core"])

    batch_size = tokens.shape.as_list()[0]

    state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"])

    outputs, _ = tf.nn.dynamic_rnn(rnn, embedded_tokens, initial_state=state)
    if cfg["loss_compute"] == "last":
      rnn_output = outputs[:, -1]  # grab the last output
    elif cfg["loss_compute"] == "avg":
      rnn_output = tf.reduce_mean(outputs, 1)  # average over length
    elif cfg["loss_compute"] == "max":
      rnn_output = tf.reduce_max(outputs, 1)
    else:
      raise ValueError("Not supported loss_compute [%s]" % cfg["loss_compute"])

    logits = snt.Linear(
        batch["label_onehot"].shape[1], initializers=init)(
            rnn_output)

    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)
    return tf.reduce_mean(loss_vec)

  datasets = utils.get_text_dataset(cfg["dataset"])
  return base.DatasetModelTask(lambda: snt.Module(_build), datasets)