def _fn(batch): """Build the loss.""" shapes = [(8, 8), (16, 16), (32, 32)] def encoder_fn(net): """Encoder for VAE.""" net = snt.nets.ConvNet2D(enc_units, kernel_shapes=[(3, 3)], strides=[2, 2, 2], paddings=[snt.SAME], activation=activation_fn, activate_final=True)(net) flat_dims = int(np.prod(net.shape.as_list()[1:])) net = tf.reshape(net, [-1, flat_dims]) net = snt.Linear(2 * n_z)(net) return generative_utils.LogStddevNormal(net) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(net): """Decoder for VAE.""" net = snt.Linear(4 * 4 * 32)(net) net = tf.reshape(net, [-1, 4, 4, 32]) net = snt.nets.ConvNet2DTranspose(dec_units, shapes, kernel_shapes=[(3, 3)], strides=[2, 2, 2], paddings=[snt.SAME], activation=activation_fn, activate_final=True)(net) outchannel = batch["image"].shape.as_list()[3] net = snt.Conv2D(2 * outchannel, kernel_shape=(1, 1))(net) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(batch["image"])[0], 2 * n_z]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) input_image = (batch["image"] - 0.5) * 2 log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, input_image) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
def get_mlp_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} # cfg["dataset"] contains (dname, extra_info) dataset = utils.get_image_dataset(cfg["dataset"]) def _fn(batch): image = utils.maybe_center(cfg["center_data"], batch["image"]) hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]] net = snt.BatchFlatten()(image) mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) logits = mod(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_fn), dataset)
def testInitialStateInModule(self): # Check that scopes play nicely with initial states created inside modules. batch_size = 6 def module_build(): core = snt.DeepRNN([snt.LSTM(4), snt.LSTM(5)]) initial_state1 = core.initial_state(batch_size, dtype=tf.float32, trainable=True) initial_state2 = core.initial_state(batch_size + 1, dtype=tf.float32, trainable=True) return initial_state1, initial_state2 initial_state_module = snt.Module(module_build) initial_state = initial_state_module() self.evaluate(tf.global_variables_initializer()) initial_state_value = self.evaluate(initial_state) self.assertEqual(initial_state_value[0][0][0].shape, (batch_size, 4)) self.assertEqual(initial_state_value[1][0][0].shape, (batch_size + 1, 4)) self.assertEqual(initial_state_value[0][0][1].shape, (batch_size, 4)) self.assertEqual(initial_state_value[1][0][1].shape, (batch_size + 1, 4)) self.assertEqual(initial_state_value[0][1][0].shape, (batch_size, 5)) self.assertEqual(initial_state_value[1][1][0].shape, (batch_size + 1, 5)) self.assertEqual(initial_state_value[0][1][1].shape, (batch_size, 5)) self.assertEqual(initial_state_value[1][1][1].shape, (batch_size + 1, 5))
def main(unused_argv): inputs = tf.random_uniform(shape=[10, 32, 32, 3]) targets = tf.random_uniform(shape=[10, 10]) # The line below takes custom_build and wraps it to construct a sonnet Module. module_with_build_args = snt.Module(custom_build, name='simple_net') train_model_outputs = module_with_build_args(inputs, is_training=True, keep_prob=tf.constant(0.5)) test_model_outputs = module_with_build_args(inputs, is_training=False, keep_prob=tf.constant(1.0)) loss = tf.nn.l2_loss(targets - train_model_outputs) # Ensure the moving averages for the BatchNorm modules are updated. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_step = tf.train.GradientDescentOptimizer( learning_rate=1e-3).minimize(loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for _ in xrange(100): sess.run(train_step) # Check that evaluating train_model_outputs twice returns the same value. train_outputs, train_outputs_2 = sess.run( [train_model_outputs, train_model_outputs]) assert (train_outputs == train_outputs_2).all() # Check that there is indeed a difference between train_model_outputs and # test_model_outputs. train_outputs, test_outputs = sess.run( [train_model_outputs, test_model_outputs]) assert (train_outputs != test_outputs).any()
def get_nvp_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_nvp_family_cfg`. Returns: base.BaseTask for the given config. """ datasets = utils.get_image_dataset(cfg["dataset"]) act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) def _build(batch): dist = distribution_with_nvp_bijectors( batch["image"], num_bijectors=cfg["num_bijectors"], layers=cfg["hidden_units"], activation=act_fn, w_init=w_init) return neg_log_p(dist, batch["image"]) base_model_fn = lambda: snt.Module(_build) return base.DatasetModelTask(base_model_fn, datasets)
def get_mlp_vae_family(cfg): """Gets a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_vae_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} datasets = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Build the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) latent_size = cfg["enc_hidden_units"][-1] def encoder_fn(net): hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(net) return generative_utils.LogStddevNormal(outputs) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(net): hidden_units = cfg["dec_hidden_units"] + [ flat_img.shape.as_list()[1] * 2 ] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) net = mod(net) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, flat_img) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def test_dataset_model_task(self): def one_dataset_fn(scale): dataset = tf.data.Dataset.from_tensor_slices( [scale * tf.ones([10, 2])]) return dataset.repeat() all_datasets = datasets.Datasets(one_dataset_fn(1), one_dataset_fn(2), one_dataset_fn(3), one_dataset_fn(4)) def fn(inp): out = snt.Linear(10, initializers={"w": tf.initializers.ones()})(inp) loss = tf.reduce_mean(out) return loss task = base.DatasetModelTask(lambda: snt.Module(fn), all_datasets) param_dict = task.initial_params() self.assertLen(param_dict, 2) with self.test_session(): train_loss = task.call_split(param_dict, datasets.Split.TRAIN) self.assertNear(train_loss.eval(), 2.0, 1e-8) test_loss = task.call_split(param_dict, datasets.Split.TEST) self.assertNear(test_loss.eval(), 8.0, 1e-8) grads = task.gradients(train_loss, param_dict) np_grad = grads["BaseModel/fn/linear/w"].eval() self.assertNear(np_grad[0, 0], 0.1, 1e-5)
def test_stack_with_snt_activation(self, activation_fn): conv = snt.Conv2D(output_channels=5, kernel_shape=3, padding='VALID') linear = snt.Linear(23) module = snt.Sequential([ conv, snt.Module(activation_fn), snt.BatchFlatten(), linear, ]) network = ibp.VerifiableModelWrapper(module) network(self._inputs) v_layers = auto_verifier.VerifiableLayerBuilder(network).build_layers() self.assertLen(v_layers, 3) self.assertIsInstance(v_layers[0], layers.Conv) self.assertIs(conv, v_layers[0].module) self.assertIsInstance(v_layers[0].input_node, ibp.ModelInputWrapper) self.assertIsInstance(v_layers[1], layers.Activation) self.assertEqual(activation_fn.__name__, v_layers[1].activation) self.assertIs(v_layers[0].output_node, v_layers[1].input_node) self.assertIsInstance(v_layers[2], layers.Linear) self.assertIs(linear, v_layers[2].module) self.assertIs(v_layers[2].output_node, network.output_module)
def __init__(self, edge_model_fn=None, left_node_model_fn=None, right_node_model_fn=None, global_model_fn=None, name="Bipartite_graph_independent"): """Initializes the BipartiteGraphIndependent module. Args: edge_model_fn: A callable that returns an edge model function. The callable must return a Sonnet module (or equivalent). If passed `None`, will pass through inputs (the default). node_model_fn: A callable that returns a node model function. The callable must return a Sonnet module (or equivalent). If passed `None`, will pass through inputs (the default). global_model_fn: A callable that returns a global model function. The callable must return a Sonnet module (or equivalent). If passed `None`, will pass through inputs (the default). name: The module name. """ super(BipartiteGraphIndependent, self).__init__(name=name) with self._enter_variable_scope(): # The use of snt.Module below is to ensure the ops and variables that # result from the edge/node/global_model_fns are scoped analogous to how # the Edge/Node/GlobalBlock classes do. if edge_model_fn is None: self._edge_model = lambda x: x else: self._edge_model = snt.Module(lambda x: edge_model_fn()(x), name="edge_model") # pylint: disable=unnecessary-lambda if left_node_model_fn is None: self._left_node_model = lambda x: x else: self._left_node_model = snt.Module( lambda x: left_node_model_fn()(x), name="left_node_model") # pylint: disable=unnecessary-lambda if right_node_model_fn is None: self._right_node_model = lambda x: x else: self._right_node_model = snt.Module( lambda x: right_node_model_fn()(x), name="right_node_model") # pylint: disable=unnecessary-lambda if global_model_fn is None: self._global_model = lambda x: x else: self._global_model = snt.Module(lambda x: global_model_fn()(x), name="global_model") # pylint: disable=unnecessary-lambda
def get_loss_fn(num_bijectors, layers): def _fn(batch): dist = maf.dist_with_maf_bijectors(batch["image"], num_bijectors=num_bijectors, layers=layers) return maf.neg_log_p(dist, batch["image"]) return lambda: snt.Module(_fn)
def _build(batch): """Build the sonnet module.""" net = snt.BatchFlatten()(batch["image"]) # shift to be zero mean net = (net - 0.5) * 2 n_inp = net.shape.as_list()[1] def encoder_fn(x): mlp_encoding = snt.nets.MLP( name="mlp_encoder", output_sizes=enc_units + [2 * n_z], activation=activation) return generative_utils.LogStddevNormal(mlp_encoding(x)) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(x): mlp_decoding = snt.nets.MLP( name="mlp_decoder", output_sizes=dec_units + [2 * n_inp], activation=activation) net = mlp_decoding(x) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(net)[0], 2 * n_z]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, net) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
def __init__(self, params, shrink_factor=1, name='ResBlock'): super(ResBlock, self).__init__(name=name) self.shrink_factor = shrink_factor with self._enter_variable_scope(): self.layers, self.layer_names = layer_factory(params) self._cores = Struct(out=snt.Module(self._custom_build), preprocess=snt.BatchNorm(decay_rate=0.99, scale=True, fused=True))
def conv_ae_loss_fn(enc_units, dec_units, num_latents, activation_fn): """Convolutional autoencoder loss module helper. This creates a callable that returns a sonnet module for the loss. Args: enc_units: list list of integers containing the encoder convnet number of units dec_units: list list of integers containing the decoder convnet number of units num_latents: int size of the middle layer of the autoencoder activation_fn: callable activation function used in the convnet Returns: callable that returns a sonnet module representing the loss """ def _fn(batch): """Make the loss from the given batch.""" net = batch["image"] net = snt.nets.ConvNet2D(enc_units, kernel_shapes=[(3, 3)], strides=[2, 2], paddings=[snt.SAME], activation=activation_fn, activate_final=True)(batch["image"]) flat_dims = int(np.prod(net.shape.as_list()[1:])) net = tf.reshape(net, [-1, flat_dims]) net = snt.Linear(num_latents)(net) if batch["image"].shape.as_list()[1] == 28: net = snt.Linear(7 * 7 * 32)(net) net = tf.reshape(net, [-1, 7, 7, 32]) shapes = [(14, 14), (28, 28)] elif batch["image"].shape.as_list()[1] == 32: net = snt.Linear(8 * 8 * 32)(net) net = tf.reshape(net, [-1, 8, 8, 32]) shapes = [(16, 16), (32, 32)] else: raise ValueError("Only 28x28, or 32x32 supported") net = snt.nets.ConvNet2DTranspose(dec_units, shapes, kernel_shapes=[(3, 3)], strides=[2, 2], paddings=[snt.SAME], activation=activation_fn, activate_final=True)(net) outchannel = batch["image"].shape.as_list()[3] net = snt.Conv2D(outchannel, kernel_shape=(1, 1))(net) loss_vec = tf.reduce_mean( tf.square(batch["image"] - tf.nn.sigmoid(net)), [1, 2, 3]) return tf.reduce_mean(loss_vec) return lambda: snt.Module(_fn)
def test_count_parameters_on_module(self): module = snt.Module() # Weights of a 2D convolution with 2 filters.. module.conv = snt.Conv2D(output_channels=2, kernel_shape=3, name="conv") module.conv(tf.ones( (2, 5, 5, 3))) # 3 * 3*3 * 2 + 2 (bias) = 56 parameters self.assertEqual(56, parameter_overview.count_parameters(module))
def test_count_parameters_empty(self): module = snt.Module() snt.allow_empty_variables(module) # No variables. self.assertEqual(0, parameter_overview.count_parameters(module)) # Single variable. module.var = tf.Variable([0, 1]) self.assertEqual(2, parameter_overview.count_parameters(module))
def _build(batch): """Build the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) latent_size = cfg["enc_hidden_units"][-1] def encoder_fn(net): hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(net) return generative_utils.LogStddevNormal(outputs) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(net): hidden_units = cfg["dec_hidden_units"] + [ flat_img.shape.as_list()[1] * 2 ] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) net = mod(net) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, flat_img) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics)
def save_policy(policy_network: snt.Sequential, input_shape: Tuple[int, int], save_path: str) -> None: @tf.function(input_signature=[tf.TensorSpec(input_shape)]) def inference(x): return policy_network(x) to_save = snt.Module() to_save.inference = inference to_save.all_variables = list(policy_network.variables) tf.saved_model.save(to_save, save_path)
def test_get_parameter_overview_empty(self): module = snt.Module() snt.allow_empty_variables(module) # No variables. self.assertEqual(EMPTY_PARAMETER_OVERVIEW, parameter_overview.get_parameter_overview(module)) module.conv = snt.Conv2D(output_channels=2, kernel_shape=3) # Variables not yet created (happens in the first forward pass). self.assertEqual(EMPTY_PARAMETER_OVERVIEW, parameter_overview.get_parameter_overview(module))
def test_get_parameter_overview_on_module(self): module = snt.Module() # Weights of a 2D convolution with 2 filters.. module.conv = snt.Conv2D(output_channels=2, kernel_shape=3, name="conv") module.conv(tf.ones((2, 5, 5, 3))) # 3 * 3^2 * 2 = 56 parameters for v in module.variables: v.assign(tf.ones_like(v)) self.assertEqual( SNT_CONV2D_PARAMETER_OVERVIEW, parameter_overview.get_parameter_overview(module, include_stats=False)) self.assertEqual(SNT_CONV2D_PARAMETER_OVERVIEW_WITH_STATS, parameter_overview.get_parameter_overview(module))
def _edge_model(self): common_embedding_module = snt.Module( partial(common_embedding, num_types=self._num_edge_types, type_embedding_dim=self._type_embedding_dim)) return snt.Sequential([ common_embedding_module, snt.nets.MLP([self._latent_size] * self._num_layers, activate_final=True), snt.LayerNorm() ])
def _node_model(self): node_embedding_module = snt.Module( partial(node_embedding, num_types=self._num_node_types, type_embedding_dim=self._type_embedding_dim, attr_encoders=self._attr_embedders, attr_embedding_dim=self._attr_embedding_dim)) return snt.Sequential([ node_embedding_module, snt.nets.MLP([self._latent_size] * self._num_layers, activate_final=True), snt.LayerNorm() ])
def ce_pool_loss( hidden_units, activation_fn, initializers=None, pool="avg", use_batch_norm=False, ): """Helper function to make a sonnet loss. This creates a cross entropy loss, pooling last layer conv net. Args: hidden_units: list list of hidden unit sizes activation_fn: callable activation function used in the convnet initializers: optional dict dictionary of initalizers used to initialize the convnet weights. pool: str the type of pooling. Supported values are max or avg. use_batch_norm: boolean to use batch norm or not in the convnet Returns: callable that returns a sonnet module representing the loss. """ if not initializers: initializers = {} def _fn(batch): """Make the loss.""" net = snt.nets.ConvNet2D( hidden_units, kernel_shapes=[(3, 3)], strides=[2] + [1] * (len(hidden_units) - 1), paddings=[snt.SAME], activation=activation_fn, initializers=initializers, use_batch_norm=use_batch_norm, activate_final=True)( batch["image"], is_training=True) # average pool if pool == "avg": net = tf.reduce_mean(net, [1, 2]) elif pool == "max": net = tf.reduce_max(net, [1, 2]) else: raise ValueError("pool type not supported") num_classes = batch["label_onehot"].shape.as_list()[1] logits = snt.Linear(num_classes)(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return lambda: snt.Module(_fn)
def classification_probe(features, labels, n_classes, labeled=None): """Classification probe with stopped gradient on features.""" def _classification_probe(features): logits = snt.Linear(n_classes)(tf.stop_gradient(features)) xe = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels) if labeled is not None: xe = xe * tf.cast(labeled, tf.float32) xe = tf.reduce_mean(xe) acc = tf.reduce_mean( tf.cast(tf.equal(tf.argmax(logits, axis=1), labels), tf.float32)) return xe, acc return snt.Module(_classification_probe)(features)
def rnn_classification(core_fn, vocab_size=10000, embed_dim=64, aggregate_method="last"): """Helper for RNN based text classification tasks. Args: core_fn: callable callable that returns a sonnet RNN core vocab_size: int number of words to use for the embedding table. All index higher than this will be clipped embed_dim: int size of the embedding dim aggregate_method: str how to aggregate the sequence of features. If 'last' grab the last hidden features. If 'avg' compute the average over the full sequence. Returns: a callable that returns a sonnet module representing the loss. """ def _build(batch): """Build the loss sonnet module.""" # TODO(lmetz) these are dense updates.... so keeping this small for now. tokens = tf.minimum(batch["text"], tf.to_int64(tf.reshape(vocab_size - 1, [1, 1]))) embed = snt.Embed(vocab_size=vocab_size, embed_dim=embed_dim) embedded_tokens = embed(tokens) rnn = core_fn() bs = tokens.shape.as_list()[0] state = rnn.initial_state(bs, trainable=True) outputs, state = tf.nn.dynamic_rnn(rnn, embedded_tokens, initial_state=state) if aggregate_method == "last": last_output = outputs[:, -1] # grab the last output elif aggregate_method == "avg": last_output = tf.reduce_mean(outputs, [1]) # average over length else: raise ValueError("not supported aggregate_method") logits = snt.Linear(2)(last_output) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return lambda: snt.Module(_build)
def sequence_to_sequence_rnn(core_fn): """A RNN based model for sequence to sequence prediction. This module takes a batch of data containing: * input: a [batch_size, seq_lenth, feature] onehot tensor. * output : a [batch_size, seq_lenth, feature] onehot tensor. * loss_mask: a [batch_size, seq_lenth] tensor. The input sequence encoded is passed it through a RNN, then a linear layer to the prediction dimension specified by the output. A cross entropy loss is then done comparing the predicted output with the actual outputs. A weighted average is then done using weights specified by the loss_mask. Args: core_fn: A fn that returns a sonnet RNNCore. Returns: A Callable that returns a snt.Module. """ def _build(batch): """Build the sonnet module.""" rnn = core_fn() initial_state = rnn.initial_state(batch["input"].shape[0]) outputs, _ = tf.nn.dynamic_rnn(rnn, batch["input"], initial_state=initial_state, dtype=tf.float32, time_major=False) pred_logits = snt.BatchApply(snt.Linear( batch["output"].shape[2]))(outputs) flat_shape = [ pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2] ] flat_pred_logits = tf.reshape(pred_logits, flat_shape) flat_actual_tokens = tf.reshape(batch["output"], flat_shape) flat_mask = tf.reshape(batch["loss_mask"], [flat_shape[0]]) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=flat_actual_tokens, logits=flat_pred_logits) total_loss = tf.reduce_sum(flat_mask * loss_vec) mean_loss = total_loss / tf.reduce_sum(flat_mask) return mean_loss return lambda: snt.Module(_build)
def get_mlp_ae_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_ae_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} datasets = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Builds the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) if cfg["output_type"] in ["tanh", "linear_center"]: flat_img = flat_img * 2.0 - 1.0 hidden_units = cfg["hidden_units"] + [flat_img.shape.as_list()[1]] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(flat_img) if cfg["output_type"] == "sigmoid": outputs = tf.nn.sigmoid(outputs) elif cfg["output_type"] == "tanh": outputs = tf.tanh(outputs) elif cfg["output_type"] in ["linear", "linear_center"]: # nothing to be done to the outputs pass else: raise ValueError("Invalid output_type [%s]." % cfg["output_type"]) reduce_fn = getattr(tf, cfg["reduction_type"]) if cfg["loss_type"] == "l2": loss_vec = reduce_fn(tf.square(outputs - flat_img), axis=1) elif cfg["loss_type"] == "l1": loss_vec = reduce_fn(tf.abs(outputs - flat_img), axis=1) else: raise ValueError("Unsupported loss_type [%s]." % cfg["reduction_type"]) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def get_conv_pooling_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_conv_pooling_family_cfg`. Returns: A task for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} hidden_units = cfg["hidden_units"] dataset = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Builds the sonnet module.""" image = utils.maybe_center(cfg["center_data"], batch["image"]) net = snt.nets.ConvNet2D( hidden_units, kernel_shapes=[(3, 3)], strides=cfg["strides"], paddings=cfg["padding"], activation=act_fn, use_bias=cfg["use_bias"], initializers=init, activate_final=True)( image) if cfg["pool_type"] == "mean": net = tf.reduce_mean(net, axis=[1, 2]) elif cfg["pool_type"] == "max": net = tf.reduce_max(net, axis=[1, 2]) elif cfg["pool_type"] == "squared_mean": net = tf.reduce_mean(net**2, axis=[1, 2]) logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
def get_loss_fn(num_bijectors, layers): """Helper for constructing a NVP based loss. Args: num_bijectors: int the number of bijectors to use. layers: list list with number of units per layer for each bijector. Returns: callable that returns a sonnet module representing the loss. """ def _fn(batch): dist = nvp.distribution_with_nvp_bijectors(batch["image"], num_bijectors=num_bijectors, layers=layers) return nvp.neg_log_p(dist, batch["image"]) return lambda: snt.Module(_fn)
def ce_flatten_loss(hidden_units, activation_fn, hidden_layers, initializers=None): """Helper function to make a sonnet loss. This creates a cross entropy loss, conv net where the last conv layer is flattened and run through an MLP instead of pooled. Args: hidden_units: list list of hidden unit sizes activation_fn: callable activation function used in the convnet hidden_layers: list hidden layers of the classification MLP initializers: optional dict dictionary of initalizers used to initialize the convnet weights Returns: callable that returns a sonnet module representing the loss """ if not initializers: initializers = {} def _fn(batch): """Build the loss.""" net = snt.nets.ConvNet2D( hidden_units, kernel_shapes=[(3, 3)], strides=[2] + [1] * (len(hidden_units) - 1), paddings=[snt.SAME], activation=activation_fn, initializers=initializers, activate_final=True)( batch["image"]) lastdims = int(np.prod(net.shape.as_list()[1:])) net = tf.reshape(net, [-1, lastdims]) for s in hidden_layers: net = activation_fn(snt.Linear(s, initializers=initializers)(net)) num_classes = batch["label_onehot"].shape.as_list()[1] logits = snt.Linear(num_classes, initializers=initializers)(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return lambda: snt.Module(_fn)
def teacher_force_language_modeling(core_fn, embed_dim=32): """Helper for teacher forced language modeling. Args: core_fn: callable callable that returns a sonnet RNN core. embed_dim: int size of the embedding table. Returns: callable that returns a sonnet module representing the loss. """ def _fn(batch): """Compute the loss from the given batch.""" # Shape is [bs, seq len, features] inp = batch["text"] mask = batch["mask"] embed = snt.Embed(vocab_size=256, embed_dim=embed_dim) embedded_chars = embed(inp) rnn = core_fn() bs = inp.shape.as_list()[0] state = rnn.initial_state(bs, trainable=True) outputs, state = tf.nn.dynamic_rnn(rnn, embedded_chars, initial_state=state) pred_logits = snt.BatchApply(snt.Linear(256))(outputs[:, :-1]) actual_tokens = inp[:, 1:] flat_s = [ pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2] ] f_pred_logits = tf.reshape(pred_logits, flat_s) f_actual_tokens = tf.reshape(actual_tokens, [flat_s[0]]) f_mask = tf.reshape(mask[:, 1:], [flat_s[0]]) loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=f_actual_tokens, logits=f_pred_logits) total_loss = tf.reduce_sum(f_mask * loss_vec) mean_loss = total_loss / tf.reduce_sum(f_mask) return mean_loss return lambda: snt.Module(_fn)