Exemple #1
0
    def _fn(batch):
        """Build the loss."""
        shapes = [(8, 8), (16, 16), (32, 32)]

        def encoder_fn(net):
            """Encoder for VAE."""
            net = snt.nets.ConvNet2D(enc_units,
                                     kernel_shapes=[(3, 3)],
                                     strides=[2, 2, 2],
                                     paddings=[snt.SAME],
                                     activation=activation_fn,
                                     activate_final=True)(net)

            flat_dims = int(np.prod(net.shape.as_list()[1:]))
            net = tf.reshape(net, [-1, flat_dims])
            net = snt.Linear(2 * n_z)(net)
            return generative_utils.LogStddevNormal(net)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            """Decoder for VAE."""
            net = snt.Linear(4 * 4 * 32)(net)
            net = tf.reshape(net, [-1, 4, 4, 32])
            net = snt.nets.ConvNet2DTranspose(dec_units,
                                              shapes,
                                              kernel_shapes=[(3, 3)],
                                              strides=[2, 2, 2],
                                              paddings=[snt.SAME],
                                              activation=activation_fn,
                                              activate_final=True)(net)

            outchannel = batch["image"].shape.as_list()[3]
            net = snt.Conv2D(2 * outchannel, kernel_shape=(1, 1))(net)
            net = tf.clip_by_value(net, -10, 10)

            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")

        zshape = tf.stack([tf.shape(batch["image"])[0], 2 * n_z])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        input_image = (batch["image"] - 0.5) * 2

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, input_image)

        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
Exemple #2
0
def get_mlp_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
  act_fn = utils.get_activation(cfg["activation"])
  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}
  # cfg["dataset"] contains (dname, extra_info)

  dataset = utils.get_image_dataset(cfg["dataset"])

  def _fn(batch):
    image = utils.maybe_center(cfg["center_data"], batch["image"])
    hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]]
    net = snt.BatchFlatten()(image)
    mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
    logits = mod(net)
    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)
    return tf.reduce_mean(loss_vec)

  return base.DatasetModelTask(lambda: snt.Module(_fn), dataset)
Exemple #3
0
    def testInitialStateInModule(self):
        # Check that scopes play nicely with initial states created inside modules.
        batch_size = 6

        def module_build():
            core = snt.DeepRNN([snt.LSTM(4), snt.LSTM(5)])
            initial_state1 = core.initial_state(batch_size,
                                                dtype=tf.float32,
                                                trainable=True)
            initial_state2 = core.initial_state(batch_size + 1,
                                                dtype=tf.float32,
                                                trainable=True)

            return initial_state1, initial_state2

        initial_state_module = snt.Module(module_build)
        initial_state = initial_state_module()

        self.evaluate(tf.global_variables_initializer())
        initial_state_value = self.evaluate(initial_state)
        self.assertEqual(initial_state_value[0][0][0].shape, (batch_size, 4))
        self.assertEqual(initial_state_value[1][0][0].shape,
                         (batch_size + 1, 4))
        self.assertEqual(initial_state_value[0][0][1].shape, (batch_size, 4))
        self.assertEqual(initial_state_value[1][0][1].shape,
                         (batch_size + 1, 4))
        self.assertEqual(initial_state_value[0][1][0].shape, (batch_size, 5))
        self.assertEqual(initial_state_value[1][1][0].shape,
                         (batch_size + 1, 5))
        self.assertEqual(initial_state_value[0][1][1].shape, (batch_size, 5))
        self.assertEqual(initial_state_value[1][1][1].shape,
                         (batch_size + 1, 5))
Exemple #4
0
def main(unused_argv):
    inputs = tf.random_uniform(shape=[10, 32, 32, 3])
    targets = tf.random_uniform(shape=[10, 10])

    # The line below takes custom_build and wraps it to construct a sonnet Module.
    module_with_build_args = snt.Module(custom_build, name='simple_net')

    train_model_outputs = module_with_build_args(inputs,
                                                 is_training=True,
                                                 keep_prob=tf.constant(0.5))
    test_model_outputs = module_with_build_args(inputs,
                                                is_training=False,
                                                keep_prob=tf.constant(1.0))
    loss = tf.nn.l2_loss(targets - train_model_outputs)
    # Ensure the moving averages for the BatchNorm modules are updated.
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_step = tf.train.GradientDescentOptimizer(
            learning_rate=1e-3).minimize(loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for _ in xrange(100):
            sess.run(train_step)
        # Check that evaluating train_model_outputs twice returns the same value.
        train_outputs, train_outputs_2 = sess.run(
            [train_model_outputs, train_model_outputs])
        assert (train_outputs == train_outputs_2).all()
        # Check that there is indeed a difference between train_model_outputs and
        # test_model_outputs.
        train_outputs, test_outputs = sess.run(
            [train_model_outputs, test_model_outputs])
        assert (train_outputs != test_outputs).any()
Exemple #5
0
def get_nvp_family(cfg):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_nvp_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    datasets = utils.get_image_dataset(cfg["dataset"])
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])

    def _build(batch):
        dist = distribution_with_nvp_bijectors(
            batch["image"],
            num_bijectors=cfg["num_bijectors"],
            layers=cfg["hidden_units"],
            activation=act_fn,
            w_init=w_init)
        return neg_log_p(dist, batch["image"])

    base_model_fn = lambda: snt.Module(_build)

    return base.DatasetModelTask(base_model_fn, datasets)
def get_mlp_vae_family(cfg):
    """Gets a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_vae_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])
    init = {"w": w_init}

    datasets = utils.get_image_dataset(cfg["dataset"])

    def _build(batch):
        """Build the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])
        latent_size = cfg["enc_hidden_units"][-1]

        def encoder_fn(net):
            hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            outputs = mod(net)
            return generative_utils.LogStddevNormal(outputs)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            hidden_units = cfg["dec_hidden_units"] + [
                flat_img.shape.as_list()[1] * 2
            ]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            net = mod(net)
            net = tf.clip_by_value(net, -10, 10)
            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")
        zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, flat_img)
        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)

    return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
Exemple #7
0
    def test_dataset_model_task(self):
        def one_dataset_fn(scale):
            dataset = tf.data.Dataset.from_tensor_slices(
                [scale * tf.ones([10, 2])])
            return dataset.repeat()

        all_datasets = datasets.Datasets(one_dataset_fn(1), one_dataset_fn(2),
                                         one_dataset_fn(3), one_dataset_fn(4))

        def fn(inp):
            out = snt.Linear(10, initializers={"w":
                                               tf.initializers.ones()})(inp)
            loss = tf.reduce_mean(out)
            return loss

        task = base.DatasetModelTask(lambda: snt.Module(fn), all_datasets)

        param_dict = task.initial_params()

        self.assertLen(param_dict, 2)

        with self.test_session():
            train_loss = task.call_split(param_dict, datasets.Split.TRAIN)
            self.assertNear(train_loss.eval(), 2.0, 1e-8)
            test_loss = task.call_split(param_dict, datasets.Split.TEST)
            self.assertNear(test_loss.eval(), 8.0, 1e-8)
            grads = task.gradients(train_loss, param_dict)
            np_grad = grads["BaseModel/fn/linear/w"].eval()
            self.assertNear(np_grad[0, 0], 0.1, 1e-5)
    def test_stack_with_snt_activation(self, activation_fn):
        conv = snt.Conv2D(output_channels=5, kernel_shape=3, padding='VALID')
        linear = snt.Linear(23)
        module = snt.Sequential([
            conv,
            snt.Module(activation_fn),
            snt.BatchFlatten(),
            linear,
        ])

        network = ibp.VerifiableModelWrapper(module)
        network(self._inputs)

        v_layers = auto_verifier.VerifiableLayerBuilder(network).build_layers()

        self.assertLen(v_layers, 3)

        self.assertIsInstance(v_layers[0], layers.Conv)
        self.assertIs(conv, v_layers[0].module)
        self.assertIsInstance(v_layers[0].input_node, ibp.ModelInputWrapper)

        self.assertIsInstance(v_layers[1], layers.Activation)
        self.assertEqual(activation_fn.__name__, v_layers[1].activation)
        self.assertIs(v_layers[0].output_node, v_layers[1].input_node)

        self.assertIsInstance(v_layers[2], layers.Linear)
        self.assertIs(linear, v_layers[2].module)

        self.assertIs(v_layers[2].output_node, network.output_module)
Exemple #9
0
    def __init__(self,
                 edge_model_fn=None,
                 left_node_model_fn=None,
                 right_node_model_fn=None,
                 global_model_fn=None,
                 name="Bipartite_graph_independent"):
        """Initializes the BipartiteGraphIndependent module.

    Args:
      edge_model_fn: A callable that returns an edge model function. The
        callable must return a Sonnet module (or equivalent). If passed `None`,
        will pass through inputs (the default).
      node_model_fn: A callable that returns a node model function. The callable
        must return a Sonnet module (or equivalent). If passed `None`, will pass
        through inputs (the default).
      global_model_fn: A callable that returns a global model function. The
        callable must return a Sonnet module (or equivalent). If passed `None`,
        will pass through inputs (the default).
      name: The module name.
    """
        super(BipartiteGraphIndependent, self).__init__(name=name)

        with self._enter_variable_scope():
            # The use of snt.Module below is to ensure the ops and variables that
            # result from the edge/node/global_model_fns are scoped analogous to how
            # the Edge/Node/GlobalBlock classes do.
            if edge_model_fn is None:
                self._edge_model = lambda x: x
            else:
                self._edge_model = snt.Module(lambda x: edge_model_fn()(x),
                                              name="edge_model")  # pylint: disable=unnecessary-lambda
            if left_node_model_fn is None:
                self._left_node_model = lambda x: x
            else:
                self._left_node_model = snt.Module(
                    lambda x: left_node_model_fn()(x), name="left_node_model")  # pylint: disable=unnecessary-lambda
            if right_node_model_fn is None:
                self._right_node_model = lambda x: x
            else:
                self._right_node_model = snt.Module(
                    lambda x: right_node_model_fn()(x),
                    name="right_node_model")  # pylint: disable=unnecessary-lambda
            if global_model_fn is None:
                self._global_model = lambda x: x
            else:
                self._global_model = snt.Module(lambda x: global_model_fn()(x),
                                                name="global_model")  # pylint: disable=unnecessary-lambda
Exemple #10
0
def get_loss_fn(num_bijectors, layers):
    def _fn(batch):
        dist = maf.dist_with_maf_bijectors(batch["image"],
                                           num_bijectors=num_bijectors,
                                           layers=layers)
        return maf.neg_log_p(dist, batch["image"])

    return lambda: snt.Module(_fn)
Exemple #11
0
  def _build(batch):
    """Build the sonnet module."""
    net = snt.BatchFlatten()(batch["image"])
    # shift to be zero mean
    net = (net - 0.5) * 2

    n_inp = net.shape.as_list()[1]

    def encoder_fn(x):
      mlp_encoding = snt.nets.MLP(
          name="mlp_encoder",
          output_sizes=enc_units + [2 * n_z],
          activation=activation)
      return generative_utils.LogStddevNormal(mlp_encoding(x))

    encoder = snt.Module(encoder_fn, name="encoder")

    def decoder_fn(x):
      mlp_decoding = snt.nets.MLP(
          name="mlp_decoder",
          output_sizes=dec_units + [2 * n_inp],
          activation=activation)
      net = mlp_decoding(x)
      net = tf.clip_by_value(net, -10, 10)
      return generative_utils.QuantizedNormal(mu_log_sigma=net)

    decoder = snt.Module(decoder_fn, name="decoder")

    zshape = tf.stack([tf.shape(net)[0], 2 * n_z])

    prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

    log_p_x, kl_term = generative_utils.log_prob_elbo_components(
        encoder, decoder, prior, net)

    elbo = log_p_x - kl_term

    metrics = {
        "kl_term": tf.reduce_mean(kl_term),
        "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
        "log_p_x": tf.reduce_mean(log_p_x),
        "elbo": tf.reduce_mean(elbo),
        "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
    }

    return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
Exemple #12
0
 def __init__(self, params, shrink_factor=1, name='ResBlock'):
     super(ResBlock, self).__init__(name=name)
     self.shrink_factor = shrink_factor
     with self._enter_variable_scope():
         self.layers, self.layer_names = layer_factory(params)
         self._cores = Struct(out=snt.Module(self._custom_build),
                              preprocess=snt.BatchNorm(decay_rate=0.99,
                                                       scale=True,
                                                       fused=True))
Exemple #13
0
def conv_ae_loss_fn(enc_units, dec_units, num_latents, activation_fn):
    """Convolutional autoencoder loss module helper.

  This creates a callable that returns a sonnet module for the loss.

  Args:
    enc_units: list list of integers containing the encoder convnet number of
      units
    dec_units: list list of integers containing the decoder convnet number of
      units
    num_latents: int size of the middle layer of the autoencoder
    activation_fn: callable activation function used in the convnet

  Returns:
    callable that returns a sonnet module representing the loss
  """
    def _fn(batch):
        """Make the loss from the given batch."""
        net = batch["image"]
        net = snt.nets.ConvNet2D(enc_units,
                                 kernel_shapes=[(3, 3)],
                                 strides=[2, 2],
                                 paddings=[snt.SAME],
                                 activation=activation_fn,
                                 activate_final=True)(batch["image"])

        flat_dims = int(np.prod(net.shape.as_list()[1:]))
        net = tf.reshape(net, [-1, flat_dims])
        net = snt.Linear(num_latents)(net)

        if batch["image"].shape.as_list()[1] == 28:
            net = snt.Linear(7 * 7 * 32)(net)
            net = tf.reshape(net, [-1, 7, 7, 32])
            shapes = [(14, 14), (28, 28)]
        elif batch["image"].shape.as_list()[1] == 32:
            net = snt.Linear(8 * 8 * 32)(net)
            net = tf.reshape(net, [-1, 8, 8, 32])
            shapes = [(16, 16), (32, 32)]
        else:
            raise ValueError("Only 28x28, or 32x32 supported")

        net = snt.nets.ConvNet2DTranspose(dec_units,
                                          shapes,
                                          kernel_shapes=[(3, 3)],
                                          strides=[2, 2],
                                          paddings=[snt.SAME],
                                          activation=activation_fn,
                                          activate_final=True)(net)

        outchannel = batch["image"].shape.as_list()[3]
        net = snt.Conv2D(outchannel, kernel_shape=(1, 1))(net)

        loss_vec = tf.reduce_mean(
            tf.square(batch["image"] - tf.nn.sigmoid(net)), [1, 2, 3])
        return tf.reduce_mean(loss_vec)

    return lambda: snt.Module(_fn)
Exemple #14
0
 def test_count_parameters_on_module(self):
     module = snt.Module()
     # Weights of a 2D convolution with 2 filters..
     module.conv = snt.Conv2D(output_channels=2,
                              kernel_shape=3,
                              name="conv")
     module.conv(tf.ones(
         (2, 5, 5, 3)))  # 3 * 3*3 * 2 + 2 (bias) = 56 parameters
     self.assertEqual(56, parameter_overview.count_parameters(module))
Exemple #15
0
    def test_count_parameters_empty(self):
        module = snt.Module()
        snt.allow_empty_variables(module)

        # No variables.
        self.assertEqual(0, parameter_overview.count_parameters(module))

        # Single variable.
        module.var = tf.Variable([0, 1])
        self.assertEqual(2, parameter_overview.count_parameters(module))
    def _build(batch):
        """Build the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])
        latent_size = cfg["enc_hidden_units"][-1]

        def encoder_fn(net):
            hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            outputs = mod(net)
            return generative_utils.LogStddevNormal(outputs)

        encoder = snt.Module(encoder_fn, name="encoder")

        def decoder_fn(net):
            hidden_units = cfg["dec_hidden_units"] + [
                flat_img.shape.as_list()[1] * 2
            ]
            mod = snt.nets.MLP(hidden_units,
                               activation=act_fn,
                               initializers=init)
            net = mod(net)
            net = tf.clip_by_value(net, -10, 10)
            return generative_utils.QuantizedNormal(mu_log_sigma=net)

        decoder = snt.Module(decoder_fn, name="decoder")
        zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size])
        prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape))

        log_p_x, kl_term = generative_utils.log_prob_elbo_components(
            encoder, decoder, prior, flat_img)
        elbo = log_p_x - kl_term

        metrics = {
            "kl_term": tf.reduce_mean(kl_term),
            "log_kl_term": tf.log(tf.reduce_mean(kl_term)),
            "log_p_x": tf.reduce_mean(log_p_x),
            "elbo": tf.reduce_mean(elbo),
            "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo))
        }

        return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
Exemple #17
0
def save_policy(policy_network: snt.Sequential,
                input_shape: Tuple[int, int],
                save_path: str) -> None:
    @tf.function(input_signature=[tf.TensorSpec(input_shape)])
    def inference(x):
        return policy_network(x)

    to_save = snt.Module()
    to_save.inference = inference
    to_save.all_variables = list(policy_network.variables)
    tf.saved_model.save(to_save, save_path)
Exemple #18
0
    def test_get_parameter_overview_empty(self):
        module = snt.Module()
        snt.allow_empty_variables(module)

        # No variables.
        self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
                         parameter_overview.get_parameter_overview(module))

        module.conv = snt.Conv2D(output_channels=2, kernel_shape=3)
        # Variables not yet created (happens in the first forward pass).
        self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
                         parameter_overview.get_parameter_overview(module))
Exemple #19
0
 def test_get_parameter_overview_on_module(self):
   module = snt.Module()
   # Weights of a 2D convolution with 2 filters..
   module.conv = snt.Conv2D(output_channels=2, kernel_shape=3, name="conv")
   module.conv(tf.ones((2, 5, 5, 3)))  # 3 * 3^2 * 2 = 56 parameters
   for v in module.variables:
     v.assign(tf.ones_like(v))
   self.assertEqual(
       SNT_CONV2D_PARAMETER_OVERVIEW,
       parameter_overview.get_parameter_overview(module, include_stats=False))
   self.assertEqual(SNT_CONV2D_PARAMETER_OVERVIEW_WITH_STATS,
                    parameter_overview.get_parameter_overview(module))
Exemple #20
0
    def _edge_model(self):
        common_embedding_module = snt.Module(
            partial(common_embedding,
                    num_types=self._num_edge_types,
                    type_embedding_dim=self._type_embedding_dim))

        return snt.Sequential([
            common_embedding_module,
            snt.nets.MLP([self._latent_size] * self._num_layers,
                         activate_final=True),
            snt.LayerNorm()
        ])
Exemple #21
0
 def _node_model(self):
     node_embedding_module = snt.Module(
         partial(node_embedding,
                 num_types=self._num_node_types,
                 type_embedding_dim=self._type_embedding_dim,
                 attr_encoders=self._attr_embedders,
                 attr_embedding_dim=self._attr_embedding_dim))
     return snt.Sequential([
         node_embedding_module,
         snt.nets.MLP([self._latent_size] * self._num_layers,
                      activate_final=True),
         snt.LayerNorm()
     ])
def ce_pool_loss(
    hidden_units,
    activation_fn,
    initializers=None,
    pool="avg",
    use_batch_norm=False,
):
  """Helper function to make a sonnet loss.

  This creates a cross entropy loss, pooling last layer conv net.
  Args:
    hidden_units: list list of hidden unit sizes
    activation_fn: callable activation function used in the convnet
    initializers: optional dict dictionary of initalizers used to initialize the
      convnet weights.
    pool: str the type of pooling. Supported values are max or avg.
    use_batch_norm: boolean to use batch norm or not in the convnet

  Returns:
    callable that returns a sonnet module representing the loss.
  """
  if not initializers:
    initializers = {}

  def _fn(batch):
    """Make the loss."""
    net = snt.nets.ConvNet2D(
        hidden_units,
        kernel_shapes=[(3, 3)],
        strides=[2] + [1] * (len(hidden_units) - 1),
        paddings=[snt.SAME],
        activation=activation_fn,
        initializers=initializers,
        use_batch_norm=use_batch_norm,
        activate_final=True)(
            batch["image"], is_training=True)
    # average pool
    if pool == "avg":
      net = tf.reduce_mean(net, [1, 2])
    elif pool == "max":
      net = tf.reduce_max(net, [1, 2])
    else:
      raise ValueError("pool type not supported")

    num_classes = batch["label_onehot"].shape.as_list()[1]
    logits = snt.Linear(num_classes)(net)
    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)
    return tf.reduce_mean(loss_vec)

  return lambda: snt.Module(_fn)
Exemple #23
0
def classification_probe(features, labels, n_classes, labeled=None):
    """Classification probe with stopped gradient on features."""
    def _classification_probe(features):
        logits = snt.Linear(n_classes)(tf.stop_gradient(features))
        xe = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                            labels=labels)
        if labeled is not None:
            xe = xe * tf.cast(labeled, tf.float32)
        xe = tf.reduce_mean(xe)
        acc = tf.reduce_mean(
            tf.cast(tf.equal(tf.argmax(logits, axis=1), labels), tf.float32))
        return xe, acc

    return snt.Module(_classification_probe)(features)
Exemple #24
0
def rnn_classification(core_fn,
                       vocab_size=10000,
                       embed_dim=64,
                       aggregate_method="last"):
    """Helper for RNN based text classification tasks.

  Args:
    core_fn: callable callable that returns a sonnet RNN core
    vocab_size: int number of words to use for the embedding table. All index
      higher than this will be clipped
    embed_dim: int size of the embedding dim
    aggregate_method: str how to aggregate the sequence of features. If 'last'
      grab the last hidden features. If 'avg' compute the average over the full
      sequence.

  Returns:
    a callable that returns a sonnet module representing the loss.
  """
    def _build(batch):
        """Build the loss sonnet module."""
        # TODO(lmetz) these are dense updates.... so keeping this small for now.
        tokens = tf.minimum(batch["text"],
                            tf.to_int64(tf.reshape(vocab_size - 1, [1, 1])))
        embed = snt.Embed(vocab_size=vocab_size, embed_dim=embed_dim)
        embedded_tokens = embed(tokens)
        rnn = core_fn()
        bs = tokens.shape.as_list()[0]

        state = rnn.initial_state(bs, trainable=True)

        outputs, state = tf.nn.dynamic_rnn(rnn,
                                           embedded_tokens,
                                           initial_state=state)
        if aggregate_method == "last":
            last_output = outputs[:, -1]  # grab the last output
        elif aggregate_method == "avg":
            last_output = tf.reduce_mean(outputs, [1])  # average over length
        else:
            raise ValueError("not supported aggregate_method")

        logits = snt.Linear(2)(last_output)

        loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=batch["label_onehot"], logits=logits)

        return tf.reduce_mean(loss_vec)

    return lambda: snt.Module(_build)
Exemple #25
0
def sequence_to_sequence_rnn(core_fn):
    """A RNN based model for sequence to sequence prediction.

  This module takes a batch of data containing:
  * input: a [batch_size, seq_lenth, feature] onehot tensor.
  * output : a [batch_size, seq_lenth, feature] onehot tensor.
  * loss_mask: a [batch_size, seq_lenth] tensor.

  The input sequence encoded is passed it through a RNN, then a linear layer to
  the prediction dimension specified by the output. A cross entropy loss is then
  done comparing the predicted output with the actual outputs. A weighted
  average is then done using weights specified by the loss_mask.

  Args:
    core_fn: A fn that returns a sonnet RNNCore.

  Returns:
    A Callable that returns a snt.Module.
  """
    def _build(batch):
        """Build the sonnet module."""
        rnn = core_fn()

        initial_state = rnn.initial_state(batch["input"].shape[0])
        outputs, _ = tf.nn.dynamic_rnn(rnn,
                                       batch["input"],
                                       initial_state=initial_state,
                                       dtype=tf.float32,
                                       time_major=False)

        pred_logits = snt.BatchApply(snt.Linear(
            batch["output"].shape[2]))(outputs)

        flat_shape = [
            pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2]
        ]
        flat_pred_logits = tf.reshape(pred_logits, flat_shape)
        flat_actual_tokens = tf.reshape(batch["output"], flat_shape)
        flat_mask = tf.reshape(batch["loss_mask"], [flat_shape[0]])

        loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
            labels=flat_actual_tokens, logits=flat_pred_logits)
        total_loss = tf.reduce_sum(flat_mask * loss_vec)
        mean_loss = total_loss / tf.reduce_sum(flat_mask)

        return mean_loss

    return lambda: snt.Module(_build)
def get_mlp_ae_family(cfg):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_mlp_ae_family_cfg`.

  Returns:
    base.BaseTask for the given config.
  """
    act_fn = utils.get_activation(cfg["activation"])
    w_init = utils.get_initializer(cfg["w_init"])
    init = {"w": w_init}

    datasets = utils.get_image_dataset(cfg["dataset"])

    def _build(batch):
        """Builds the sonnet module."""
        flat_img = snt.BatchFlatten()(batch["image"])

        if cfg["output_type"] in ["tanh", "linear_center"]:
            flat_img = flat_img * 2.0 - 1.0

        hidden_units = cfg["hidden_units"] + [flat_img.shape.as_list()[1]]
        mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init)
        outputs = mod(flat_img)

        if cfg["output_type"] == "sigmoid":
            outputs = tf.nn.sigmoid(outputs)
        elif cfg["output_type"] == "tanh":
            outputs = tf.tanh(outputs)
        elif cfg["output_type"] in ["linear", "linear_center"]:
            # nothing to be done to the outputs
            pass
        else:
            raise ValueError("Invalid output_type [%s]." % cfg["output_type"])

        reduce_fn = getattr(tf, cfg["reduction_type"])
        if cfg["loss_type"] == "l2":
            loss_vec = reduce_fn(tf.square(outputs - flat_img), axis=1)
        elif cfg["loss_type"] == "l1":
            loss_vec = reduce_fn(tf.abs(outputs - flat_img), axis=1)
        else:
            raise ValueError("Unsupported loss_type [%s]." %
                             cfg["reduction_type"])

        return tf.reduce_mean(loss_vec)

    return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def get_conv_pooling_family(cfg):
  """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by
      `sample_conv_pooling_family_cfg`.

  Returns:
    A task for the given config.
  """

  act_fn = utils.get_activation(cfg["activation"])
  w_init = utils.get_initializer(cfg["w_init"])
  init = {"w": w_init}
  hidden_units = cfg["hidden_units"]

  dataset = utils.get_image_dataset(cfg["dataset"])

  def _build(batch):
    """Builds the sonnet module."""
    image = utils.maybe_center(cfg["center_data"], batch["image"])

    net = snt.nets.ConvNet2D(
        hidden_units,
        kernel_shapes=[(3, 3)],
        strides=cfg["strides"],
        paddings=cfg["padding"],
        activation=act_fn,
        use_bias=cfg["use_bias"],
        initializers=init,
        activate_final=True)(
            image)
    if cfg["pool_type"] == "mean":
      net = tf.reduce_mean(net, axis=[1, 2])
    elif cfg["pool_type"] == "max":
      net = tf.reduce_max(net, axis=[1, 2])
    elif cfg["pool_type"] == "squared_mean":
      net = tf.reduce_mean(net**2, axis=[1, 2])

    logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net)

    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)

    return tf.reduce_mean(loss_vec)

  return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
Exemple #28
0
def get_loss_fn(num_bijectors, layers):
    """Helper for constructing a NVP based loss.

  Args:
    num_bijectors: int the number of bijectors to use.
    layers: list list with number of units per layer for each bijector.

  Returns:
    callable that returns a sonnet module representing the loss.
  """
    def _fn(batch):
        dist = nvp.distribution_with_nvp_bijectors(batch["image"],
                                                   num_bijectors=num_bijectors,
                                                   layers=layers)
        return nvp.neg_log_p(dist, batch["image"])

    return lambda: snt.Module(_fn)
def ce_flatten_loss(hidden_units,
                    activation_fn,
                    hidden_layers,
                    initializers=None):
  """Helper function to make a sonnet loss.

  This creates a cross entropy loss, conv net where the last conv layer is
  flattened and run through an MLP instead of pooled.

  Args:
    hidden_units: list list of hidden unit sizes
    activation_fn: callable activation function used in the convnet
    hidden_layers: list hidden layers of the classification MLP
    initializers: optional dict dictionary of initalizers used to initialize the
      convnet weights

  Returns:
    callable that returns a sonnet module representing the loss
  """
  if not initializers:
    initializers = {}

  def _fn(batch):
    """Build the loss."""
    net = snt.nets.ConvNet2D(
        hidden_units,
        kernel_shapes=[(3, 3)],
        strides=[2] + [1] * (len(hidden_units) - 1),
        paddings=[snt.SAME],
        activation=activation_fn,
        initializers=initializers,
        activate_final=True)(
            batch["image"])

    lastdims = int(np.prod(net.shape.as_list()[1:]))
    net = tf.reshape(net, [-1, lastdims])
    for s in hidden_layers:
      net = activation_fn(snt.Linear(s, initializers=initializers)(net))

    num_classes = batch["label_onehot"].shape.as_list()[1]
    logits = snt.Linear(num_classes, initializers=initializers)(net)
    loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2(
        labels=batch["label_onehot"], logits=logits)
    return tf.reduce_mean(loss_vec)

  return lambda: snt.Module(_fn)
Exemple #30
0
def teacher_force_language_modeling(core_fn, embed_dim=32):
    """Helper for teacher forced language modeling.

  Args:
    core_fn: callable callable that returns a sonnet RNN core.
    embed_dim: int size of the embedding table.

  Returns:
    callable that returns a sonnet module representing the loss.
  """
    def _fn(batch):
        """Compute the loss from the given batch."""
        # Shape is [bs, seq len, features]
        inp = batch["text"]
        mask = batch["mask"]
        embed = snt.Embed(vocab_size=256, embed_dim=embed_dim)
        embedded_chars = embed(inp)

        rnn = core_fn()
        bs = inp.shape.as_list()[0]

        state = rnn.initial_state(bs, trainable=True)

        outputs, state = tf.nn.dynamic_rnn(rnn,
                                           embedded_chars,
                                           initial_state=state)

        pred_logits = snt.BatchApply(snt.Linear(256))(outputs[:, :-1])
        actual_tokens = inp[:, 1:]

        flat_s = [
            pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2]
        ]
        f_pred_logits = tf.reshape(pred_logits, flat_s)
        f_actual_tokens = tf.reshape(actual_tokens, [flat_s[0]])
        f_mask = tf.reshape(mask[:, 1:], [flat_s[0]])

        loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=f_actual_tokens, logits=f_pred_logits)
        total_loss = tf.reduce_sum(f_mask * loss_vec)
        mean_loss = total_loss / tf.reduce_sum(f_mask)

        return mean_loss

    return lambda: snt.Module(_fn)