示例#1
0
  def test_jvp_zeros(self):
    def foo(x):
      def bar(y):
        return np.sin(x * y)
      return jvp(bar, (3 * x,), (2 * x,))

    jtu.check_eq(jit(foo)(0.5), foo(0.5))
示例#2
0
def test_update_params():
    params = {"a": {"b": {"c": {"d": 1}, "e": np.array(2)}, "f": np.ones(4)}}
    prior = {"a.b.c.d": dist.Delta(4), "a.f": dist.Delta(5)}
    new_params = deepcopy(params)
    with handlers.seed(rng_seed=0):
        _update_params(params, new_params, prior)
    assert params == {
        "a": {
            "b": {
                "c": {
                    "d": ParamShape(())
                },
                "e": 2
            },
            "f": ParamShape((4, ))
        }
    }
    test_util.check_eq(
        new_params,
        {
            "a": {
                "b": {
                    "c": {
                        "d": np.array(4.0)
                    },
                    "e": np.array(2)
                },
                "f": np.full((4, ), 5.0),
            }
        },
    )
示例#3
0
 def test_save_restore_checkpoints_w_float_steps(self):
   tmp_dir = self.create_tempdir().full_path
   test_object0 = {'a': np.array([0, 0, 0], np.int32),
                   'b': np.array([0, 0, 0], np.int32)}
   test_object1 = {'a': np.array([1, 2, 3], np.int32),
                   'b': np.array([1, 1, 1], np.int32)}
   test_object2 = {'a': np.array([4, 5, 6], np.int32),
                   'b': np.array([2, 2, 2], np.int32)}
   # Create leftover temporary checkpoint, which should be ignored.
   gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
   checkpoints.save_checkpoint(
       tmp_dir, test_object1, 0.0, prefix='test_', keep=1)
   self.assertIn('test_0.0', os.listdir(tmp_dir))
   new_object = checkpoints.restore_checkpoint(
       tmp_dir, test_object0, prefix='test_')
   jtu.check_eq(new_object, test_object1)
   checkpoints.save_checkpoint(
       tmp_dir, test_object1, 2.0, prefix='test_', keep=1)
   with self.assertRaises(errors.InvalidCheckpointError):
     checkpoints.save_checkpoint(
         tmp_dir, test_object2, 1.0, prefix='test_', keep=1)
   checkpoints.save_checkpoint(
       tmp_dir, test_object2, 3.0, prefix='test_', keep=2)
   self.assertIn('test_3.0', os.listdir(tmp_dir))
   self.assertIn('test_2.0', os.listdir(tmp_dir))
   jtu.check_eq(new_object, test_object1)
示例#4
0
    def test_optimized_lstm_cell_matches_regular(self):

        # Create regular LSTMCell.
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (2, 3))
        c0, h0 = nn.LSTMCell.initialize_carry(rng, (2, ), 4)
        self.assertEqual(c0.shape, (2, 4))
        self.assertEqual(h0.shape, (2, 4))
        lstm = nn.LSTMCell()
        (_, y), lstm_params = lstm.init_with_output(key2, (c0, h0), x)

        # Create OptimizedLSTMCell.
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (2, 3))
        c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2, ), 4)
        self.assertEqual(c0.shape, (2, 4))
        self.assertEqual(h0.shape, (2, 4))
        lstm_opt = nn.OptimizedLSTMCell()
        (_, y_opt), lstm_opt_params = lstm_opt.init_with_output(
            key2, (c0, h0), x)

        np.testing.assert_allclose(y, y_opt, rtol=1e-6)
        jtu.check_eq(lstm_params, lstm_opt_params)
示例#5
0
def test_update_params():
    params = {'a': {'b': {'c': {'d': 1}, 'e': np.array(2)}, 'f': np.ones(4)}}
    prior = {'a.b.c.d': dist.Delta(4), 'a.f': dist.Delta(5)}
    new_params = deepcopy(params)
    with handlers.seed(rng_seed=0):
        _update_params(params, new_params, prior)
    assert params == {
        'a': {
            'b': {
                'c': {
                    'd': ParamShape(())
                },
                'e': 2
            },
            'f': ParamShape((4, ))
        }
    }
    test_util.check_eq(
        new_params, {
            'a': {
                'b': {
                    'c': {
                        'd': np.array(4.)
                    },
                    'e': np.array(2)
                },
                'f': np.full((4, ), 5.)
            }
        })
示例#6
0
    def test_optimized_lstm_cell_matches_regular(self):

        # Create regular LSTMCell.
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (2, 3))
        c0, h0 = nn.LSTMCell.initialize_carry(rng, (2, ), 4)
        self.assertEqual(c0.shape, (2, 4))
        self.assertEqual(h0.shape, (2, 4))
        (carry, y), initial_params = nn.LSTMCell.init(key2, (c0, h0), x)
        lstm = nn.Model(nn.LSTMCell, initial_params)

        # Create OptimizedLSTMCell.
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        x = random.normal(key1, (2, 3))
        c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2, ), 4)
        self.assertEqual(c0.shape, (2, 4))
        self.assertEqual(h0.shape, (2, 4))
        (carry, y_opt), initial_params = nn.OptimizedLSTMCell.partial(
            name='LSTMCell').init(key2, (c0, h0), x)
        lstm_opt = nn.Model(nn.OptimizedLSTMCell.partial(name='LSTMCell'),
                            initial_params)

        onp.testing.assert_allclose(y, y_opt, rtol=1e-6)
        jtu.check_eq(lstm.params, lstm_opt.params)
示例#7
0
def test_fori_collect():
    def f(x):
        return {'i': x['i'] + x['j'], 'j': x['i'] - x['j']}

    a = {'i': jnp.array([0.]), 'j': jnp.array([1.])}
    expected_tree = {'i': jnp.array([[0.], [2.]])}
    actual_tree = fori_collect(1, 3, f, a, transform=lambda a: {'i': a['i']})
    check_eq(actual_tree, expected_tree)
示例#8
0
  def test_jvp_zeros(self):
    def foo(x):
      def bar(y):
        x1, y1 = core.pack((x, y))
        return np.sin(x1 * y1)
      return jvp(bar, (3 * x,), (2 * x,))

    jtu.check_eq(jit(foo)(0.5), foo(0.5))
示例#9
0
def test_fori_collect():
    def f(x):
        return {"i": x["i"] + x["j"], "j": x["i"] - x["j"]}

    a = {"i": jnp.array([0.0]), "j": jnp.array([1.0])}
    expected_tree = {"i": jnp.array([[0.0], [2.0]])}
    actual_tree = fori_collect(1, 3, f, a, transform=lambda a: {"i": a["i"]})
    check_eq(actual_tree, expected_tree)
示例#10
0
def test_iaf():
    # test for substitute logic for exposed methods `sample_posterior` and `get_transforms`
    N, dim = 3000, 3
    data = random.normal(random.PRNGKey(0), (N, dim))
    true_coefs = jnp.arange(1.0, dim + 1.0)
    logits = jnp.sum(true_coefs * data, axis=-1)
    labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))

    def model(data, labels):
        coefs = numpyro.sample("coefs",
                               dist.Normal(jnp.zeros(dim), jnp.ones(dim)))
        offset = numpyro.sample("offset", dist.Uniform(-1, 1))
        logits = offset + jnp.sum(coefs * data, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)
    guide = AutoIAFNormal(model)
    svi = SVI(model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data, labels)
    params = svi.get_params(svi_state)

    x = random.normal(random.PRNGKey(0), (dim + 1, ))
    rng_key = random.PRNGKey(1)
    actual_sample = guide.sample_posterior(rng_key, params)
    actual_output = guide._unpack_latent(guide.get_transform(params)(x))

    flows = []
    for i in range(guide.num_flows):
        if i > 0:
            flows.append(transforms.PermuteTransform(
                jnp.arange(dim + 1)[::-1]))
        arn_init, arn_apply = AutoregressiveNN(
            dim + 1,
            [dim + 1, dim + 1],
            permutation=jnp.arange(dim + 1),
            skip_connections=guide._skip_connections,
            nonlinearity=guide._nonlinearity,
        )
        arn = partial(arn_apply, params["auto_arn__{}$params".format(i)])
        flows.append(InverseAutoregressiveTransform(arn))
    flows.append(guide._unpack_latent)

    transform = transforms.ComposeTransform(flows)
    _, rng_key_sample = random.split(rng_key)
    expected_sample = transform(
        dist.Normal(jnp.zeros(dim + 1), 1).sample(rng_key_sample))
    expected_output = transform(x)
    assert_allclose(actual_sample["coefs"], expected_sample["coefs"])
    assert_allclose(
        actual_sample["offset"],
        transforms.biject_to(constraints.interval(-1, 1))(
            expected_sample["offset"]),
    )
    check_eq(actual_output, expected_output)
示例#11
0
def test_fori_collect_return_last(progbar):
    def f(x):
        x['i'] = x['i'] + 1
        return x

    tree, init_state = fori_collect(2, 4, f, {'i': 0},
                                    transform=lambda a: {'i': a['i']},
                                    return_last_val=True,
                                    progbar=progbar)
    expected_tree = {'i': jnp.array([3, 4])}
    expected_last_state = {'i': jnp.array(4)}
    check_eq(init_state, expected_last_state)
    check_eq(tree, expected_tree)
示例#12
0
 def test_save_restore_checkpoints_target_none(self):
   tmp_dir = self.create_tempdir().full_path
   test_object0 = {'a': np.array([0, 0, 0], np.int32),
                   'b': np.array([0, 0, 0], np.int32)}
   # Target pytree is a dictionary, so it's equal to a restored state_dict.
   checkpoints.save_checkpoint(tmp_dir, test_object0, 0)
   new_object = checkpoints.restore_checkpoint(tmp_dir, target=None)
   jtu.check_eq(new_object, test_object0)
   # Target pytree it's a tuple, check the expected state_dict is recovered.
   test_object1 = (np.array([0, 0, 0], np.int32),
                   np.array([1, 1, 1], np.int32))
   checkpoints.save_checkpoint(tmp_dir, test_object1, 1)
   new_object = checkpoints.restore_checkpoint(tmp_dir, target=None)
   expected_new_object = {str(k): v for k, v in enumerate(test_object1)}
   jtu.check_eq(new_object, expected_new_object)
示例#13
0
def test_dynamic_supports():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000,))

    def actual_model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        with numpyro.handlers.reparam(config={"loc": TransformReparam()}):
            loc = numpyro.sample(
                "loc",
                dist.TransformedDistribution(
                    dist.Uniform(0, 1), transforms.AffineTransform(0, alpha)
                ),
            )
        with numpyro.plate("N", len(data)):
            numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    def expected_model(data):
        alpha = numpyro.sample("alpha", dist.Uniform(0, 1))
        loc = numpyro.sample("loc", dist.Uniform(0, 1)) * alpha
        with numpyro.plate("N", len(data)):
            numpyro.sample("obs", dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)

    guide = AutoDiagonalNormal(actual_model)
    svi = SVI(actual_model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data)
    actual_opt_params = adam.get_params(svi_state.optim_state)
    actual_params = svi.get_params(svi_state)
    actual_values = guide.median(actual_params)
    actual_loss = svi.evaluate(svi_state, data)

    guide = AutoDiagonalNormal(expected_model)
    svi = SVI(expected_model, guide, adam, Trace_ELBO())
    svi_state = svi.init(rng_key_init, data)
    expected_opt_params = adam.get_params(svi_state.optim_state)
    expected_params = svi.get_params(svi_state)
    expected_values = guide.median(expected_params)
    expected_loss = svi.evaluate(svi_state, data)

    # test auto_loc, auto_scale
    check_eq(actual_opt_params, expected_opt_params)
    check_eq(actual_params, expected_params)
    # test latent values
    assert_allclose(actual_values["alpha"], expected_values["alpha"])
    assert_allclose(actual_values["loc_base"], expected_values["loc"])
    assert_allclose(actual_loss, expected_loss)
示例#14
0
    def test_overwrite_checkpoints(self):
        tmp_dir = self.create_tempdir().full_path
        test_object0 = {'a': np.array([0, 0, 0], np.int32)}
        test_object = {'a': np.array([1, 2, 3], np.int32)}

        checkpoints.save_checkpoint(tmp_dir, test_object0, 0, keep=1)
        with self.assertRaises(errors.InvalidCheckpointError):
            checkpoints.save_checkpoint(tmp_dir, test_object, 0, keep=1)
        checkpoints.save_checkpoint(tmp_dir,
                                    test_object,
                                    0,
                                    keep=1,
                                    overwrite=True)
        new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0)
        jtu.check_eq(new_object, test_object)
        checkpoints.save_checkpoint(tmp_dir,
                                    test_object0,
                                    2,
                                    keep=1,
                                    overwrite=True)
        new_object = checkpoints.restore_checkpoint(tmp_dir, test_object)
        jtu.check_eq(new_object, test_object0)
        with self.assertRaises(errors.InvalidCheckpointError):
            checkpoints.save_checkpoint(tmp_dir, test_object, 1, keep=1)
        checkpoints.save_checkpoint(tmp_dir,
                                    test_object,
                                    1,
                                    keep=1,
                                    overwrite=True)
        new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0)
        jtu.check_eq(new_object, test_object)
        os.chdir(os.path.dirname(tmp_dir))
        rel_tmp_dir = './' + os.path.basename(tmp_dir)
        checkpoints.save_checkpoint(rel_tmp_dir, test_object, 3, keep=1)
        new_object = checkpoints.restore_checkpoint(rel_tmp_dir, test_object0)
        jtu.check_eq(new_object, test_object)
        non_norm_dir_path = tmp_dir + '//'
        checkpoints.save_checkpoint(non_norm_dir_path, test_object, 4, keep=1)
        new_object = checkpoints.restore_checkpoint(non_norm_dir_path,
                                                    test_object0)
        jtu.check_eq(new_object, test_object)
def test_dynamic_supports():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def actual_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, alpha))
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    def expected_model(data):
        alpha = numpyro.sample('alpha', dist.Uniform(0, 1))
        loc = numpyro.sample('loc', dist.Uniform(0, 1)) * alpha
        numpyro.sample('obs', dist.Normal(loc, 0.1), obs=data)

    adam = optim.Adam(0.01)
    rng_key_init = random.PRNGKey(1)

    guide = AutoDiagonalNormal(actual_model)
    svi = SVI(actual_model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key_init, data)
    actual_opt_params = adam.get_params(svi_state.optim_state)
    actual_params = svi.get_params(svi_state)
    actual_values = guide.median(actual_params)
    actual_loss = svi.evaluate(svi_state, data)

    guide = AutoDiagonalNormal(expected_model)
    svi = SVI(expected_model, guide, adam, AutoContinuousELBO())
    svi_state = svi.init(rng_key_init, data)
    expected_opt_params = adam.get_params(svi_state.optim_state)
    expected_params = svi.get_params(svi_state)
    expected_values = guide.median(expected_params)
    expected_loss = svi.evaluate(svi_state, data)

    # test auto_loc, auto_scale
    check_eq(actual_opt_params, expected_opt_params)
    check_eq(actual_params, expected_params)
    # test latent values
    assert_allclose(actual_values['alpha'], expected_values['alpha'])
    assert_allclose(actual_values['loc'],
                    expected_values['alpha'] * expected_values['loc'])
    assert_allclose(actual_loss, expected_loss)
示例#16
0
def test_dynamic_supports():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def actual_model(data):
        alpha = sample('alpha', dist.Uniform(0, 1))
        loc = sample('loc', dist.Uniform(0, alpha))
        sample('obs', dist.Normal(loc, 0.1), obs=data)

    def expected_model(data):
        alpha = sample('alpha', dist.Uniform(0, 1))
        loc = sample('loc', dist.Uniform(0, 1)) * alpha
        sample('obs', dist.Normal(loc, 0.1), obs=data)

    opt_init, opt_update, get_params = optimizers.adam(0.01)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)

    guide = AutoDiagonalNormal(rng_guide, actual_model, get_params)
    svi_init, _, svi_eval = svi(actual_model, guide, elbo, opt_init,
                                opt_update, get_params)
    opt_state, constrain_fn = svi_init(rng_init, (data, ), (data, ))
    actual_params = get_params(opt_state)
    actual_base_values = constrain_fn(actual_params)
    actual_values = guide.median(opt_state)
    actual_loss = svi_eval(random.PRNGKey(1), opt_state, (data, ), (data, ))

    guide = AutoDiagonalNormal(rng_guide, expected_model, get_params)
    svi_init, _, svi_eval = svi(expected_model, guide, elbo, opt_init,
                                opt_update, get_params)
    opt_state, constrain_fn = svi_init(rng_init, (data, ), (data, ))
    expected_params = get_params(opt_state)
    expected_base_values = constrain_fn(expected_params)
    expected_values = guide.median(opt_state)
    expected_loss = svi_eval(random.PRNGKey(1), opt_state, (data, ), (data, ))

    check_eq(actual_params, expected_params)
    check_eq(actual_base_values, expected_base_values)
    assert_allclose(actual_values['alpha'], expected_values['alpha'])
    assert_allclose(actual_values['loc'],
                    expected_values['alpha'] * expected_values['loc'])
    assert_allclose(actual_loss, expected_loss)
示例#17
0
def test_jitted_update_fn():
    data = np.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1., 1.))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q", 1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0,
                               constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optim.Adam(0.05)
    svi = SVI(model, guide, adam, ELBO())
    svi_state = svi.init(random.PRNGKey(1), data)
    expected = svi.get_params(svi.update(svi_state, data)[0])

    actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
    check_eq(actual, expected)
示例#18
0
  def test_overwrite_checkpoints(self):
    tmp_dir = self.create_tempdir().full_path
    test_object0 = {'a': np.array([0, 0, 0], np.int32)}
    test_object = {'a': np.array([1, 2, 3], np.int32)}

    checkpoints.save_checkpoint(
        tmp_dir, test_object0, 0, keep=1)
    with self.assertRaises(errors.InvalidCheckpointError):
      checkpoints.save_checkpoint(
          tmp_dir, test_object, 0, keep=1)
    checkpoints.save_checkpoint(
          tmp_dir, test_object, 0, keep=1, overwrite=True)
    new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0)
    jtu.check_eq(new_object, test_object)
    checkpoints.save_checkpoint(
          tmp_dir, test_object0, 2, keep=1, overwrite=True)
    new_object = checkpoints.restore_checkpoint(tmp_dir, test_object)
    jtu.check_eq(new_object, test_object0)
    with self.assertRaises(errors.InvalidCheckpointError):
      checkpoints.save_checkpoint(
            tmp_dir, test_object, 1, keep=1)
    checkpoints.save_checkpoint(
          tmp_dir, test_object, 1, keep=1, overwrite=True)
    new_object = checkpoints.restore_checkpoint(tmp_dir, test_object0)
    jtu.check_eq(new_object, test_object)
示例#19
0
def test_fori_collect_thinning():
    def f(x):
        return x + 1.0

    actual2 = fori_collect(0, 9, f, jnp.array([-1]), thinning=2)
    expected2 = jnp.array([[2], [4], [6], [8]])
    check_eq(actual2, expected2)

    actual3 = fori_collect(0, 9, f, jnp.array([-1]), thinning=3)
    expected3 = jnp.array([[2], [5], [8]])
    check_eq(actual3, expected3)

    actual4 = fori_collect(0, 9, f, jnp.array([-1]), thinning=4)
    expected4 = jnp.array([[4], [8]])
    check_eq(actual4, expected4)

    actual5 = fori_collect(12, 37, f, jnp.array([-1]), thinning=5)
    expected5 = jnp.array([[16], [21], [26], [31], [36]])
    check_eq(actual5, expected5)
示例#20
0
 def test_save_restore_checkpoints(self):
     tmp_dir = pathlib.Path(self.create_tempdir().full_path)
     test_object0 = {
         'a': np.array([0, 0, 0], np.int32),
         'b': np.array([0, 0, 0], np.int32)
     }
     test_object1 = {
         'a': np.array([1, 2, 3], np.int32),
         'b': np.array([1, 1, 1], np.int32)
     }
     test_object2 = {
         'a': np.array([4, 5, 6], np.int32),
         'b': np.array([2, 2, 2], np.int32)
     }
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object0)
     # Create leftover temporary checkpoint, which should be ignored.
     gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 0,
                                 prefix='test_',
                                 keep=1)
     self.assertIn('test_0', os.listdir(tmp_dir))
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 1,
                                 prefix='test_',
                                 keep=1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 2,
                                 prefix='test_',
                                 keep=1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 3,
                                 prefix='test_',
                                 keep=2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 4,
                                 prefix='test_',
                                 keep=2)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 step=3,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     # Restore a specific path.
     new_object = checkpoints.restore_checkpoint(
         os.path.join(tmp_dir, 'test_3'), test_object0)
     jtu.check_eq(new_object, test_object2)
     # If a specific path is specified, but it does not exist, the same behavior
     # as when a directory is empty should apply: the target is returned
     # unchanged.
     new_object = checkpoints.restore_checkpoint(
         os.path.join(tmp_dir, 'test_not_there'), test_object0)
     jtu.check_eq(new_object, test_object0)
     with self.assertRaises(ValueError):
         checkpoints.restore_checkpoint(tmp_dir,
                                        test_object0,
                                        step=5,
                                        prefix='test_')
示例#21
0
 def test_jit(self, f, args):
     jtu.check_eq(jit(f)(*args), f(*args))
示例#22
0
 def test_save_restore_checkpoints(self):
     tmp_dir = self.create_tempdir().full_path
     test_object0 = {
         'a': np.array([0, 0, 0], np.int32),
         'b': np.array([0, 0, 0], np.int32)
     }
     test_object1 = {
         'a': np.array([1, 2, 3], np.int32),
         'b': np.array([1, 1, 1], np.int32)
     }
     test_object2 = {
         'a': np.array([4, 5, 6], np.int32),
         'b': np.array([2, 2, 2], np.int32)
     }
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object0)
     # Create leftover temporary checkpoint, which should be ignored.
     gfile.GFile(os.path.join(tmp_dir, 'test_tmp'), 'w')
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 0,
                                 prefix='test_',
                                 keep=1)
     self.assertIn('test_0', os.listdir(tmp_dir))
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 1,
                                 prefix='test_',
                                 keep=1)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 2,
                                 prefix='test_',
                                 keep=1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object2,
                                 3,
                                 prefix='test_',
                                 keep=2)
     checkpoints.save_checkpoint(tmp_dir,
                                 test_object1,
                                 4,
                                 prefix='test_',
                                 keep=2)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object1)
     new_object = checkpoints.restore_checkpoint(tmp_dir,
                                                 test_object0,
                                                 step=3,
                                                 prefix='test_')
     jtu.check_eq(new_object, test_object2)
     with self.assertRaises(ValueError):
         checkpoints.restore_checkpoint(tmp_dir,
                                        test_object0,
                                        step=5,
                                        prefix='test_')