Exemplo n.º 1
0
    def test_fixed_seed_dim_sample_quadratic(self):
        """Ensures that passing `seed` produces the same loss."""
        dims = 10
        initial_dist = quadratic_helper.ConstantDistribution(
            np.random.normal(0, 1, [dims]))

        specturm = np.linspace(1.0, 2.0, dims).astype(np.float32)
        A_dist = quadratic_helper.FixedEigenSpectrumMatrixDistribution(
            specturm)

        B_dist = quadratic_helper.ConstantDistribution(
            tf.ones([dims], dtype=tf.float32))
        C_dist = quadratic_helper.ConstantDistribution(1.0)

        lossmod = quadratic_helper.QuadraticBasedTask(
            dims=dims,
            initial_dist=initial_dist,
            A_dist=A_dist,
            B_dist=B_dist,
            C_dist=C_dist,
            seed=23,
        )
        s = lossmod.initial_params()
        loss = lossmod.call_split(s, None)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            loss1 = sess.run(loss)
            sess.run(tf.global_variables_initializer())
            loss2 = sess.run(loss)
            self.assertNear(loss1, loss2, 1e-9)
Exemplo n.º 2
0
def get_quadratic_family(cfg, seed=None):
    """Get a task for the given cfg.

  Args:
    cfg: config specifying the model generated by `sample_quadratic_family_cfg`.
    seed: optional int Seed used to generate the instance of a given task. Note
      this is not the seed used to generate the cfg, but just an instance of the
      given cfg.

  Returns:
    base.BaseTask for the given config.
  """

    dims = cfg["dims"]
    # pylint: disable=invalid-name
    A_dist = _get_distribution_over_matrix(dims, cfg["A_dist"])

    initial_dist = _get_vector_dist(dims, cfg["initial_dist"])

    B_dist = quadratic_helper.ConstantDistribution(tf.zeros([dims]))
    C_dist = quadratic_helper.ConstantDistribution(0.)

    output_fn = _get_output_fn(cfg["output_fn"], cfg["loss_scale"])
    seed = cfg["seed"]

    if cfg["noise"]:
        A_noise_dist = _get_distribution_over_matrix(dims,
                                                     cfg["noise"]["A_noise"])
    else:
        A_noise_dist = None

    return quadratic_helper.QuadraticBasedTask(dims=dims,
                                               initial_dist=initial_dist,
                                               A_dist=A_dist,
                                               A_noise_dist=A_noise_dist,
                                               B_dist=B_dist,
                                               C_dist=C_dist,
                                               seed=seed,
                                               output_fn=output_fn)