def test_dataset_model_task(self): def one_dataset_fn(scale): dataset = tf.data.Dataset.from_tensor_slices( [scale * tf.ones([10, 2])]) return dataset.repeat() all_datasets = datasets.Datasets(one_dataset_fn(1), one_dataset_fn(2), one_dataset_fn(3), one_dataset_fn(4)) def fn(inp): out = snt.Linear(10, initializers={"w": tf.initializers.ones()})(inp) loss = tf.reduce_mean(out) return loss task = base.DatasetModelTask(lambda: snt.Module(fn), all_datasets) param_dict = task.initial_params() self.assertLen(param_dict, 2) with self.test_session(): train_loss = task.call_split(param_dict, datasets.Split.TRAIN) self.assertNear(train_loss.eval(), 2.0, 1e-8) test_loss = task.call_split(param_dict, datasets.Split.TEST) self.assertNear(test_loss.eval(), 8.0, 1e-8) grads = task.gradients(train_loss, param_dict) np_grad = grads["BaseModel/fn/linear/w"].eval() self.assertNear(np_grad[0, 0], 0.1, 1e-5)
def get_mlp_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} # cfg["dataset"] contains (dname, extra_info) dataset = utils.get_image_dataset(cfg["dataset"]) def _fn(batch): image = utils.maybe_center(cfg["center_data"], batch["image"]) hidden_units = cfg["layer_sizes"] + [batch["label_onehot"].shape[1]] net = snt.BatchFlatten()(image) mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) logits = mod(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_fn), dataset)
def get_nvp_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_nvp_family_cfg`. Returns: base.BaseTask for the given config. """ datasets = utils.get_image_dataset(cfg["dataset"]) act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) def _build(batch): dist = distribution_with_nvp_bijectors( batch["image"], num_bijectors=cfg["num_bijectors"], layers=cfg["hidden_units"], activation=act_fn, w_init=w_init) return neg_log_p(dist, batch["image"]) base_model_fn = lambda: snt.Module(_build) return base.DatasetModelTask(base_model_fn, datasets)
def get_mlp_vae_family(cfg): """Gets a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_vae_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} datasets = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Build the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) latent_size = cfg["enc_hidden_units"][-1] def encoder_fn(net): hidden_units = cfg["enc_hidden_units"][:-1] + [latent_size * 2] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(net) return generative_utils.LogStddevNormal(outputs) encoder = snt.Module(encoder_fn, name="encoder") def decoder_fn(net): hidden_units = cfg["dec_hidden_units"] + [ flat_img.shape.as_list()[1] * 2 ] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) net = mod(net) net = tf.clip_by_value(net, -10, 10) return generative_utils.QuantizedNormal(mu_log_sigma=net) decoder = snt.Module(decoder_fn, name="decoder") zshape = tf.stack([tf.shape(flat_img)[0], 2 * latent_size]) prior = generative_utils.LogStddevNormal(tf.zeros(shape=zshape)) log_p_x, kl_term = generative_utils.log_prob_elbo_components( encoder, decoder, prior, flat_img) elbo = log_p_x - kl_term metrics = { "kl_term": tf.reduce_mean(kl_term), "log_kl_term": tf.log(tf.reduce_mean(kl_term)), "log_p_x": tf.reduce_mean(log_p_x), "elbo": tf.reduce_mean(elbo), "log_neg_log_p_x": tf.log(-tf.reduce_mean(elbo)) } return base.LossAndAux(-tf.reduce_mean(elbo), metrics) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def _(): base_model_fn = rnn_classification( lambda: snt.VanillaRNN(64, activation=tf.nn.relu), embed_dim=64, aggregate_method="avg") dataset = imdb_subword(128, 32) return base.DatasetModelTask(base_model_fn, dataset)
def convnetloss1(): base_model_fn = ce_pool_loss([32, 64, 128, 128], tf.nn.relu, use_batch_norm=True) dataset = datasets.get_image_datasets( "sun397_32x32", batch_size=128, shuffle_buffer=5000) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = ce_pool_loss([32, 64, 128, 128], tf.nn.relu, use_batch_norm=True) dataset = datasets.get_image_datasets( "coil100_32x32", batch_size=128, shuffle_buffer=5000, num_per_valid=800) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = ce_pool_loss([32, 32, 32, 64, 64], tf.nn.relu, use_batch_norm=True) dataset = datasets.get_image_datasets( "food101_32x32", batch_size=128, shuffle_buffer=5000) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = three_layer_conv_vae_loss_fn([64, 128, 256], [256, 128, 64], 128, tf.nn.relu) dataset = datasets.get_image_datasets("cifar10", batch_size=128) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = ce_flatten_loss([32, 64, 64], tf.nn.relu, []) dataset = datasets.get_image_datasets("colorectal_histology_32x32", batch_size=128, shuffle_buffer=5000, num_per_valid=700) return base.DatasetModelTask(base_model_fn, dataset)
def _(): # pylint: disable=missing-docstring init = {} init["w"] = contrib_layers.variance_scaling_initializer() base_model_fn = ce_flatten_loss([32, 64, 128], tf.nn.tanh, [64, 32], initializers=init) dataset = datasets.get_image_datasets("cifar100", batch_size=64) return base.DatasetModelTask(base_model_fn, dataset)
def _(): init = {} init["w"] = tf.initializers.he_normal() base_model_fn = ce_flatten_loss([32, 64, 128], tf.nn.tanh, [64, 32], initializers=init) dataset = datasets.get_image_datasets("cifar10", batch_size=8) return base.DatasetModelTask(base_model_fn, dataset)
def copy_fn(c): base_model_fn = sequence_to_sequence_rnn(lambda: _rnn_mod_map[c[0]](c[1])) return base.DatasetModelTask( base_model_fn, datasets.copy_sequence(c[2], sequence_length=c[3], num_separator=1, num_tokens=c[4]))
def _make_task(cfg): loss_type, dataset_name = cfg dataset = datasets.get_image_datasets( dataset_name, batch_size=128, shuffle_buffer=5000) num_classes = dataset.train.output_shapes["label_onehot"].as_list()[1] base_model_fn = fc_loss_fn([128, 128, 128, num_classes], loss_type, tf.nn.relu) return base.DatasetModelTask(base_model_fn, dataset)
def _(): # pylint: disable=missing-docstring init = {} init["w"] = tf.initializers.he_normal() base_model_fn = ce_pool_loss([32, 64, 128], tf.nn.tanh, initializers=init, pool="max") dataset = datasets.get_image_datasets("cifar10", batch_size=64) return base.DatasetModelTask(base_model_fn, dataset)
def get_mlp_ae_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_mlp_ae_family_cfg`. Returns: base.BaseTask for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} datasets = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Builds the sonnet module.""" flat_img = snt.BatchFlatten()(batch["image"]) if cfg["output_type"] in ["tanh", "linear_center"]: flat_img = flat_img * 2.0 - 1.0 hidden_units = cfg["hidden_units"] + [flat_img.shape.as_list()[1]] mod = snt.nets.MLP(hidden_units, activation=act_fn, initializers=init) outputs = mod(flat_img) if cfg["output_type"] == "sigmoid": outputs = tf.nn.sigmoid(outputs) elif cfg["output_type"] == "tanh": outputs = tf.tanh(outputs) elif cfg["output_type"] in ["linear", "linear_center"]: # nothing to be done to the outputs pass else: raise ValueError("Invalid output_type [%s]." % cfg["output_type"]) reduce_fn = getattr(tf, cfg["reduction_type"]) if cfg["loss_type"] == "l2": loss_vec = reduce_fn(tf.square(outputs - flat_img), axis=1) elif cfg["loss_type"] == "l1": loss_vec = reduce_fn(tf.abs(outputs - flat_img), axis=1) else: raise ValueError("Unsupported loss_type [%s]." % cfg["reduction_type"]) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), datasets)
def get_conv_pooling_family(cfg): """Get a task for the given cfg. Args: cfg: config specifying the model generated by `sample_conv_pooling_family_cfg`. Returns: A task for the given config. """ act_fn = utils.get_activation(cfg["activation"]) w_init = utils.get_initializer(cfg["w_init"]) init = {"w": w_init} hidden_units = cfg["hidden_units"] dataset = utils.get_image_dataset(cfg["dataset"]) def _build(batch): """Builds the sonnet module.""" image = utils.maybe_center(cfg["center_data"], batch["image"]) net = snt.nets.ConvNet2D( hidden_units, kernel_shapes=[(3, 3)], strides=cfg["strides"], paddings=cfg["padding"], activation=act_fn, use_bias=cfg["use_bias"], initializers=init, activate_final=True)( image) if cfg["pool_type"] == "mean": net = tf.reduce_mean(net, axis=[1, 2]) elif cfg["pool_type"] == "max": net = tf.reduce_max(net, axis=[1, 2]) elif cfg["pool_type"] == "squared_mean": net = tf.reduce_mean(net**2, axis=[1, 2]) logits = snt.Linear(batch["label_onehot"].shape[1], initializers=init)(net) loss_vec = tf.nn.softmax_cross_entropy_with_logits_v2( labels=batch["label_onehot"], logits=logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
def _build_lm_task(cfg, dataset): """Builds a language modeling task.""" def _build(batch): """Builds the sonnet module.""" # Shape is [batch size, sequence length] inp = batch["text"] # Clip the vocab to be at most vocab_size. inp = tf.minimum( inp, tf.to_int64(tf.reshape(cfg["vocab_size"] - 1, [1, 1]))) embed = snt.Embed(vocab_size=cfg["vocab_size"], embed_dim=cfg["embed_dim"]) embedded_chars = embed(inp) rnn = utils.get_rnn_core(cfg["core"]) batch_size = inp.shape.as_list()[0] state = rnn.initial_state(batch_size, trainable=cfg["trainable_init"]) outputs, state = tf.nn.dynamic_rnn(rnn, embedded_chars, initial_state=state) w_init = utils.get_initializer(cfg["w_init"]) pred_logits = snt.BatchApply( snt.Linear(cfg["vocab_size"], initializers={"w": w_init}))(outputs[:, :-1]) actual_output_tokens = inp[:, 1:] flat_s = [ pred_logits.shape[0] * pred_logits.shape[1], pred_logits.shape[2] ] flat_pred_logits = tf.reshape(pred_logits, flat_s) flat_actual_tokens = tf.reshape(actual_output_tokens, [flat_s[0]]) loss_vec = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=flat_actual_tokens, logits=flat_pred_logits) return tf.reduce_mean(loss_vec) return base.DatasetModelTask(lambda: snt.Module(_build), dataset)
def _(): base_model_fn = fc_vae_loss_fn([128, 64], [64, 128], 32, tf.nn.relu) dataset = datasets.get_image_datasets("mnist", batch_size=64) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = _fc_batch_norm_loss_fn([64, 64, 64, 64, 64, 10], tf.nn.relu) return base.DatasetModelTask( base_model_fn, datasets.get_image_datasets("cifar10", batch_size=128))
def _(): base_model_fn = _fc_layer_norm_loss_fn([128, 128, 128, 10], tf.tanh) return base.DatasetModelTask( base_model_fn, datasets.get_image_datasets("cifar10", batch_size=128))
def _(): base_model_fn = _fc_dropout_loss_fn([128, 128, 10], tf.nn.relu, keep_probs=0.2) return base.DatasetModelTask( base_model_fn, datasets.get_image_datasets("cifar10", batch_size=128))
def _(): base_model_fn = get_loss_fn(9, layers=(128, 128)) dataset = datasets.get_image_datasets("mnist", batch_size=64) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = get_loss_fn(3, (1024, 1024)) dataset = datasets.get_image_datasets("cifar10", batch_size=64) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = get_loss_fn(2, (2048, 2048)) dataset = datasets.get_image_datasets("mnist", batch_size=64) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = fc_ae_loss_fn([32, 32, 32], tf.nn.relu) dataset = datasets.get_image_datasets("mnist", batch_size=128) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = fc_vae_loss_fn([128, 64], [64, 128], 32, tf.nn.relu) dataset = datasets.get_image_datasets("food101_32x32", batch_size=256, shuffle_buffer=5000) return base.DatasetModelTask(base_model_fn, dataset)
def associative_fn(c): base_model_fn = sequence_to_sequence_rnn(lambda: _rnn_mod_map[c[0]](c[1])) return base.DatasetModelTask( base_model_fn, datasets.associative_sequence(c[2], num_pairs=c[3], num_tokens=c[4]))
def _(): base_model_fn = fc_vae_loss_fn([128], [128], 32, tf.nn.relu) dataset = datasets.get_image_datasets("cifar10", batch_size=128) return base.DatasetModelTask(base_model_fn, dataset)
def _(): base_model_fn = teacher_force_language_modeling(lambda: snt.GRU(256), embed_dim=64) dataset = lm1b_byte(128, 128) return base.DatasetModelTask(base_model_fn, dataset)