예제 #1
0
    def test_dataset_model_task(self):
        def one_dataset_fn(scale):
            dataset = tf.data.Dataset.from_tensor_slices(
                [scale * tf.ones([10, 2])])
            return dataset.repeat()

        all_datasets = datasets.Datasets(one_dataset_fn(1), one_dataset_fn(2),
                                         one_dataset_fn(3), one_dataset_fn(4))

        def fn(inp):
            out = snt.Linear(10, initializers={"w":
                                               tf.initializers.ones()})(inp)
            loss = tf.reduce_mean(out)
            return loss

        task = base.DatasetModelTask(lambda: snt.Module(fn), all_datasets)

        param_dict = task.initial_params()

        self.assertLen(param_dict, 2)

        with self.test_session():
            train_loss = task.call_split(param_dict, datasets.Split.TRAIN)
            self.assertNear(train_loss.eval(), 2.0, 1e-8)
            test_loss = task.call_split(param_dict, datasets.Split.TEST)
            self.assertNear(test_loss.eval(), 8.0, 1e-8)
            grads = task.gradients(train_loss, param_dict)
            np_grad = grads["BaseModel/fn/linear/w"].eval()
            self.assertNear(np_grad[0, 0], 0.1, 1e-5)
예제 #2
0
def _make_just_train(dataset, just_train):
    """Converts a datasets object to maybe use just the training dataset."""
    if just_train:
        return datasets.Datasets(dataset.train, dataset.train, dataset.train,
                                 dataset.train)
    else:
        return dataset