def test_dataset_model_task(self): def one_dataset_fn(scale): dataset = tf.data.Dataset.from_tensor_slices( [scale * tf.ones([10, 2])]) return dataset.repeat() all_datasets = datasets.Datasets(one_dataset_fn(1), one_dataset_fn(2), one_dataset_fn(3), one_dataset_fn(4)) def fn(inp): out = snt.Linear(10, initializers={"w": tf.initializers.ones()})(inp) loss = tf.reduce_mean(out) return loss task = base.DatasetModelTask(lambda: snt.Module(fn), all_datasets) param_dict = task.initial_params() self.assertLen(param_dict, 2) with self.test_session(): train_loss = task.call_split(param_dict, datasets.Split.TRAIN) self.assertNear(train_loss.eval(), 2.0, 1e-8) test_loss = task.call_split(param_dict, datasets.Split.TEST) self.assertNear(test_loss.eval(), 8.0, 1e-8) grads = task.gradients(train_loss, param_dict) np_grad = grads["BaseModel/fn/linear/w"].eval() self.assertNear(np_grad[0, 0], 0.1, 1e-5)
def _make_just_train(dataset, just_train): """Converts a datasets object to maybe use just the training dataset.""" if just_train: return datasets.Datasets(dataset.train, dataset.train, dataset.train, dataset.train) else: return dataset