コード例 #1
0
def _get_fully_connected(cfg):
    """Get a fully connected problem from the given config."""
    return (problem_spec.Spec(
        pg.FullyConnected, (cfg["n_features"], cfg["n_classes"]), {
            "hidden_sizes": tuple(cfg["hidden_sizes"]),
            "activation": utils.get_activation(cfg["activation"]),
        }), losg_datasets.random_mlp(cfg["n_features"],
                                     cfg["n_samples"]), cfg["bs"])
コード例 #2
0
  def testRandomMlp(self):
    n_samples = 4
    n_features = 3
    n_layers = 2
    width = 10

    ds = datasets.random_mlp(
        n_features, n_samples, n_layers=n_layers, width=width, random_seed=200)
    self.assertEqual(n_samples, ds.size)
    self.assertLen(ds.labels, n_samples)
    self.assertLen(ds.data[0], n_features)
    self.assertTrue(all([x in [0, 1] for x in ds.labels]))  # binary
コード例 #3
0
def _get_outward_snake(cfg):
    """Get an outward snake problem from the given config."""
    return (problem_spec.Spec(pg.OutwardSnake, (cfg["dim"], ), {}),
            losg_datasets.random_mlp(cfg["dim"], cfg["n_samples"]), cfg["bs"])
コード例 #4
0
def _get_dependency_chain(cfg):
    """Get a dependency chain problem from the given config."""
    return (problem_spec.Spec(pg.DependencyChain, (cfg["dim"], ), {}),
            losg_datasets.random_mlp(cfg["dim"], cfg["n_samples"]), cfg["bs"])