예제 #1
0
    def test_dataset_model_task(self):
        def one_dataset_fn(scale):
            dataset = tf.data.Dataset.from_tensor_slices(
                [scale * tf.ones([10, 2])])
            return dataset.repeat()

        all_datasets = datasets.Datasets(one_dataset_fn(1), one_dataset_fn(2),
                                         one_dataset_fn(3), one_dataset_fn(4))

        def fn(inp):
            out = snt.Linear(10, initializers={"w":
                                               tf.initializers.ones()})(inp)
            loss = tf.reduce_mean(out)
            return loss

        task = base.DatasetModelTask(lambda: snt.Module(fn), all_datasets)

        param_dict = task.initial_params()

        self.assertLen(param_dict, 2)

        with self.test_session():
            train_loss = task.call_split(param_dict, datasets.Split.TRAIN)
            self.assertNear(train_loss.eval(), 2.0, 1e-8)
            test_loss = task.call_split(param_dict, datasets.Split.TEST)
            self.assertNear(test_loss.eval(), 8.0, 1e-8)
            grads = task.gradients(train_loss, param_dict)
            np_grad = grads["BaseModel/fn/linear/w"].eval()
            self.assertNear(np_grad[0, 0], 0.1, 1e-5)
예제 #2
0
def get_mlp_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
  act_fn = utils.get_activation(cfg["activation"])
  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}
  # cfg["dataset"] contains (dname, extra_info)

  dataset = utils.get_image_dataset(cfg["dataset"])

  def _fn(batch):
    image = utils.maybe_center(cfg["center_data"], batch["image"])
    hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]]
    net = snt.BatchFlatten()(image)
    mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
    logits = mod(net)
    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)
    return tf.reduce_mean(loss_vec)

  return base.DatasetModelTask(lambda: snt.Module(_fn), dataset)
예제 #3
0
def get_nvp_family(cfg):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_nvp_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    datasets = utils.get_image_dataset(cfg["dataset"])
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])

    def _build(batch):
        dist = distribution_with_nvp_bijectors(
            batch["image"],
            num_bijectors=cfg["num_bijectors"],
            layers=cfg["hidden_units"],
            activation=act_fn,
            w_init=w_init)
        return neg_log_p(dist, batch["image"])

    base_model_fn = lambda: snt.Module(_build)

    return base.DatasetModelTask(base_model_fn, datasets)
예제 #4
0
def get_mlp_vae_family(cfg):
    """Gets a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_vae_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])
    init = {"w": w_init}

    datasets = utils.get_image_dataset(cfg["dataset"])

    def _build(batch):
        """Build the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])
        latent_size = cfg["enc_hidden_units"][-1]

        def encoder_fn(net):
            hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            outputs = mod(net)
            return generative_utils.LogStddevNormal(outputs)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            hidden_units = cfg["dec_hidden_units"] + [
                flat_img.shape.as_list()[1] * 2
            ]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            net = mod(net)
            net = tf.clip_by_value(net, -10, 10)
            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")
        zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, flat_img)
        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)

    return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
예제 #5
0
def _():
    base_model_fn = rnn_classification(
        lambda: snt.VanillaRNN(64, activation=tf.nn.relu),
        embed_dim=64,
        aggregate_method="avg")
    dataset = imdb_subword(128, 32)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #6
0
def convnetloss1():
  base_model_fn = ce_pool_loss([32, 64, 128, 128],
                               tf.nn.relu,
                               use_batch_norm=True)
  dataset = datasets.get_image_datasets(
      "sun397_32x32", batch_size=128, shuffle_buffer=5000)
  return base.DatasetModelTask(base_model_fn, dataset)
예제 #7
0
def _():
  base_model_fn = ce_pool_loss([32, 64, 128, 128],
                               tf.nn.relu,
                               use_batch_norm=True)
  dataset = datasets.get_image_datasets(
      "coil100_32x32", batch_size=128, shuffle_buffer=5000, num_per_valid=800)
  return base.DatasetModelTask(base_model_fn, dataset)
예제 #8
0
def _():
  base_model_fn = ce_pool_loss([32, 32, 32, 64, 64],
                               tf.nn.relu,
                               use_batch_norm=True)
  dataset = datasets.get_image_datasets(
      "food101_32x32", batch_size=128, shuffle_buffer=5000)
  return base.DatasetModelTask(base_model_fn, dataset)
예제 #9
0
def _():
    base_model_fn = three_layer_conv_vae_loss_fn([64, 128, 256],
                                                 [256, 128, 64], 128,
                                                 tf.nn.relu)

    dataset = datasets.get_image_datasets("cifar10", batch_size=128)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #10
0
def _():
    base_model_fn = ce_flatten_loss([32, 64, 64], tf.nn.relu, [])
    dataset = datasets.get_image_datasets("colorectal_histology_32x32",
                                          batch_size=128,
                                          shuffle_buffer=5000,
                                          num_per_valid=700)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #11
0
def _():  # pylint: disable=missing-docstring
  init = {}
  init["w"] = contrib_layers.variance_scaling_initializer()
  base_model_fn = ce_flatten_loss([32, 64, 128],
                                  tf.nn.tanh, [64, 32],
                                  initializers=init)
  dataset = datasets.get_image_datasets("cifar100", batch_size=64)
  return base.DatasetModelTask(base_model_fn, dataset)
예제 #12
0
def _():
  init = {}
  init["w"] = tf.initializers.he_normal()
  base_model_fn = ce_flatten_loss([32, 64, 128],
                                  tf.nn.tanh, [64, 32],
                                  initializers=init)
  dataset = datasets.get_image_datasets("cifar10", batch_size=8)
  return base.DatasetModelTask(base_model_fn, dataset)
예제 #13
0
def copy_fn(c):
    base_model_fn = sequence_to_sequence_rnn(lambda: _rnn_mod_map[c[0]](c[1]))
    return base.DatasetModelTask(
        base_model_fn,
        datasets.copy_sequence(c[2],
                               sequence_length=c[3],
                               num_separator=1,
                               num_tokens=c[4]))
예제 #14
0
def _make_task(cfg):
  loss_type, dataset_name = cfg
  dataset = datasets.get_image_datasets(
      dataset_name, batch_size=128, shuffle_buffer=5000)
  num_classes = dataset.train.output_shapes["label_onehot"].as_list()[1]
  base_model_fn = fc_loss_fn([128, 128, 128, num_classes], loss_type,
                             tf.nn.relu)
  return base.DatasetModelTask(base_model_fn, dataset)
예제 #15
0
def _():  # pylint: disable=missing-docstring
  init = {}
  init["w"] = tf.initializers.he_normal()
  base_model_fn = ce_pool_loss([32, 64, 128],
                               tf.nn.tanh,
                               initializers=init,
                               pool="max")
  dataset = datasets.get_image_datasets("cifar10", batch_size=64)
  return base.DatasetModelTask(base_model_fn, dataset)
예제 #16
0
def get_mlp_ae_family(cfg):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_ae_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])
    init = {"w": w_init}

    datasets = utils.get_image_dataset(cfg["dataset"])

    def _build(batch):
        """Builds the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])

        if cfg["output_type"] in ["tanh", "linear_center"]:
            flat_img = flat_img * 2.0 - 1.0

        hidden_units = cfg["hidden_units"] + [flat_img.shape.as_list()[1]]
        mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
        outputs = mod(flat_img)

        if cfg["output_type"] == "sigmoid":
            outputs = tf.nn.sigmoid(outputs)
        elif cfg["output_type"] == "tanh":
            outputs = tf.tanh(outputs)
        elif cfg["output_type"] in ["linear", "linear_center"]:
            # nothing to be done to the outputs
            pass
        else:
            raise ValueError("Invalid output_type [%s]." % cfg["output_type"])

        reduce_fn = getattr(tf, cfg["reduction_type"])
        if cfg["loss_type"] == "l2":
            loss_vec = reduce_fn(tf.square(outputs - flat_img), axis=1)
        elif cfg["loss_type"] == "l1":
            loss_vec = reduce_fn(tf.abs(outputs - flat_img), axis=1)
        else:
            raise ValueError("Unsupported loss_type [%s]." %
                             cfg["reduction_type"])

        return tf.reduce_mean(loss_vec)

    return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
예제 #17
0
def get_conv_pooling_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by
      `sample_conv_pooling_family_cfg`.

  Returns:
    A task for the given config.
  """

  act_fn = utils.get_activation(cfg["activation"])
  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}
  hidden_units = cfg["hidden_units"]

  dataset = utils.get_image_dataset(cfg["dataset"])

  def _build(batch):
    """Builds the sonnet module."""
    image = utils.maybe_center(cfg["center_data"], batch["image"])

    net = snt.nets.ConvNet2D(
        hidden_units,
        kernel_shapes=[(3, 3)],
        strides=cfg["strides"],
        paddings=cfg["padding"],
        activation=act_fn,
        use_bias=cfg["use_bias"],
        initializers=init,
        activate_final=True)(
            image)
    if cfg["pool_type"] == "mean":
      net = tf.reduce_mean(net, axis=[1, 2])
    elif cfg["pool_type"] == "max":
      net = tf.reduce_max(net, axis=[1, 2])
    elif cfg["pool_type"] == "squared_mean":
      net = tf.reduce_mean(net**2, axis=[1, 2])

    logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net)

    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)

    return tf.reduce_mean(loss_vec)

  return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
예제 #18
0
def _build_lm_task(cfg, dataset):
    """Builds a language modeling task."""
    def _build(batch):
        """Builds the sonnet module."""
        # Shape is [batch size, sequence length]
        inp = batch["text"]

        # Clip the vocab to be at most vocab_size.
        inp = tf.minimum(
            inp, tf.to_int64(tf.reshape(cfg["vocab_size"] - 1, [1, 1])))

        embed = snt.Embed(vocab_size=cfg["vocab_size"],
                          embed_dim=cfg["embed_dim"])
        embedded_chars = embed(inp)

        rnn = utils.get_rnn_core(cfg["core"])
        batch_size = inp.shape.as_list()[0]

        state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"])

        outputs, state = tf.nn.dynamic_rnn(rnn,
                                           embedded_chars,
                                           initial_state=state)

        w_init = utils.get_initializer(cfg["w_init"])
        pred_logits = snt.BatchApply(
            snt.Linear(cfg["vocab_size"],
                       initializers={"w": w_init}))(outputs[:, :-1])
        actual_output_tokens = inp[:, 1:]

        flat_s = [
            pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2]
        ]
        flat_pred_logits = tf.reshape(pred_logits, flat_s)
        flat_actual_tokens = tf.reshape(actual_output_tokens, [flat_s[0]])

        loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=flat_actual_tokens, logits=flat_pred_logits)
        return tf.reduce_mean(loss_vec)

    return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
예제 #19
0
def _():
    base_model_fn = fc_vae_loss_fn([128, 64], [64, 128], 32, tf.nn.relu)
    dataset = datasets.get_image_datasets("mnist", batch_size=64)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #20
0
def _():
  base_model_fn = _fc_batch_norm_loss_fn([64, 64, 64, 64, 64, 10], tf.nn.relu)
  return base.DatasetModelTask(
      base_model_fn, datasets.get_image_datasets("cifar10", batch_size=128))
예제 #21
0
def _():
  base_model_fn = _fc_layer_norm_loss_fn([128, 128, 128, 10], tf.tanh)
  return base.DatasetModelTask(
      base_model_fn, datasets.get_image_datasets("cifar10", batch_size=128))
예제 #22
0
def _():
  base_model_fn = _fc_dropout_loss_fn([128, 128, 10],
                                      tf.nn.relu,
                                      keep_probs=0.2)
  return base.DatasetModelTask(
      base_model_fn, datasets.get_image_datasets("cifar10", batch_size=128))
예제 #23
0
def _():
    base_model_fn = get_loss_fn(9, layers=(128, 128))
    dataset = datasets.get_image_datasets("mnist", batch_size=64)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #24
0
def _():
    base_model_fn = get_loss_fn(3, (1024, 1024))
    dataset = datasets.get_image_datasets("cifar10", batch_size=64)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #25
0
def _():
    base_model_fn = get_loss_fn(2, (2048, 2048))
    dataset = datasets.get_image_datasets("mnist", batch_size=64)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #26
0
def _():
    base_model_fn = fc_ae_loss_fn([32, 32, 32], tf.nn.relu)
    dataset = datasets.get_image_datasets("mnist", batch_size=128)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #27
0
def _():
    base_model_fn = fc_vae_loss_fn([128, 64], [64, 128], 32, tf.nn.relu)
    dataset = datasets.get_image_datasets("food101_32x32",
                                          batch_size=256,
                                          shuffle_buffer=5000)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #28
0
def associative_fn(c):
    base_model_fn = sequence_to_sequence_rnn(lambda: _rnn_mod_map[c[0]](c[1]))
    return base.DatasetModelTask(
        base_model_fn,
        datasets.associative_sequence(c[2], num_pairs=c[3], num_tokens=c[4]))
예제 #29
0
def _():
    base_model_fn = fc_vae_loss_fn([128], [128], 32, tf.nn.relu)
    dataset = datasets.get_image_datasets("cifar10", batch_size=128)
    return base.DatasetModelTask(base_model_fn, dataset)
예제 #30
0
def _():
    base_model_fn = teacher_force_language_modeling(lambda: snt.GRU(256),
                                                    embed_dim=64)
    dataset = lm1b_byte(128, 128)
    return base.DatasetModelTask(base_model_fn, dataset)