def test_multiple_sequences(self):
        tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
        model = FlaxBertModel.from_pretrained("bert-base-cased")

        sequences = [
            "this is an example sentence", "this is another", "and a third one"
        ]
        encodings = tokenizer(sequences,
                              return_tensors=TensorType.JAX,
                              padding=True,
                              truncation=True)

        @jax.jit
        def model_jitted(input_ids, attention_mask=None, token_type_ids=None):
            return model(input_ids, attention_mask, token_type_ids)

        with self.subTest("JIT Disabled"):
            with jax.disable_jit():
                tokens, pooled = model_jitted(**encodings)
                self.assertEqual(tokens.shape, (3, 7, 768))
                self.assertEqual(pooled.shape, (3, 768))

        with self.subTest("JIT Enabled"):
            jitted_tokens, jitted_pooled = model_jitted(**encodings)

            self.assertEqual(jitted_tokens.shape, (3, 7, 768))
            self.assertEqual(jitted_pooled.shape, (3, 768))
    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):

                # TODO later: have some way to initialize easily a Flax model from config, for now I go through PT
                pt_model_class_name = model_class.__name__[
                    4:]  # Skip the "Flax" at the beginning
                pt_model_class = getattr(transformers, pt_model_class_name)
                pt_model = pt_model_class(config).eval()

                model = convert_pt_model_to_flax(pt_model, config, model_class)

                @jax.jit
                def model_jitted(input_ids,
                                 attention_mask=None,
                                 token_type_ids=None):
                    return model(input_ids, attention_mask, token_type_ids)

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
                        outputs = model_jitted(**inputs_dict)

                with self.subTest("JIT Enabled"):
                    jitted_outputs = model_jitted(**inputs_dict)

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
                    self.assertEqual(jitted_output.shape, output.shape)
示例#3
0
 def inference_speed_memory(self, batch_size, seq_length):
     # input_ids = np.random.randint(0, self.vocab_size, (batch_size, seq_length))
     key = jax.random.PRNGKey(0)
     input_ids = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size)
     @jax.jit
     def ref_step():
         out = self.model(input_ids=input_ids)
         return out[0]
     if jax.local_devices()[0].platform == 'gpu':
         nvml.nvmlInit()
         ref_step().block_until_ready()
         handle = nvml.nvmlDeviceGetHandleByIndex(0)
         meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
         max_bytes_in_use = meminfo.used
         memory = Memory(max_bytes_in_use)
         # shutdown nvml
         nvml.nvmlShutdown()
     else:
         memory = None
     timeit.repeat("ref_step().block_until_ready()", repeat=1, number=2,globals=locals())
     if self.jit:
         runtimes = timeit.repeat("ref_step().block_until_ready()", repeat=self.repeat,number=3,globals=locals())
     else:
         with jax.disable_jit():
             runtimes = timeit.repeat("ref_step().block_until_ready()",repeat=self.repeat,number=3,globals=locals())
     return float(np.min(runtimes)/3.0), memory
    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                prepared_inputs_dict = self._prepare_for_class(
                    inputs_dict, model_class)
                model = model_class(config)

                @jax.jit
                def model_jitted(input_ids, pixel_values, **kwargs):
                    return model(input_ids=input_ids,
                                 pixel_values=pixel_values,
                                 **kwargs).to_tuple()

                with self.subTest("JIT Enabled"):
                    jitted_outputs = model_jitted(**prepared_inputs_dict)

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
                        outputs = model_jitted(**prepared_inputs_dict)

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs[:4],
                                                 outputs[:4]):
                    self.assertEqual(jitted_output.shape, output.shape)
示例#5
0
    def test_decode(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                model = model_class(config)
                encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])

                prepared_inputs_dict = {
                    "decoder_input_ids": inputs_dict["decoder_input_ids"],
                    "decoder_attention_mask": inputs_dict["decoder_attention_mask"],
                    "encoder_outputs": encoder_outputs,
                }

                @jax.jit
                def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs):
                    return model.decode(
                        decoder_input_ids=decoder_input_ids,
                        decoder_attention_mask=decoder_attention_mask,
                        encoder_outputs=encoder_outputs,
                    )

                with self.subTest("JIT Enabled"):
                    jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
                        outputs = decode_jitted(**prepared_inputs_dict).to_tuple()

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
                    self.assertEqual(jitted_output.shape, output.shape)
示例#6
0
def main(argv):
  del argv

  if FLAGS.jax_debug_nans:
    config.update("jax_debug_nans", True)

  def run_training_loop():
    optimizer_fun = functools.partial(
        ppo.optimizer_fun, step_size=FLAGS.learning_rate)

    ppo.training_loop(
        env_name=FLAGS.env_name,
        epochs=FLAGS.epochs,
        policy_and_value_net_fun=functools.partial(
            ppo.policy_and_value_net, bottom_layers=common_layers()),
        policy_and_value_optimizer_fun=optimizer_fun,
        batch_size=FLAGS.batch_size,
        num_optimizer_steps=FLAGS.num_optimizer_steps,
        boundary=FLAGS.boundary,
        max_timestep=FLAGS.max_timestep,
        random_seed=FLAGS.random_seed)

  if FLAGS.jax_debug_nans or FLAGS.disable_jit:
    with jax.disable_jit():
      run_training_loop()
  else:
    run_training_loop()
示例#7
0
def test_welford_covariance(jitted, diagonal, regularize):
    with optional(jitted,
                  disable_jit()), optional(jitted,
                                           control_flow_prims_disabled()):
        np.random.seed(0)
        loc = np.random.randn(3)
        a = np.random.randn(3, 3)
        target_cov = np.matmul(a, a.T)
        x = np.random.multivariate_normal(loc, target_cov, size=(2000, ))
        x = device_put(x)

        @jit
        def get_cov(x):
            wc_init, wc_update, wc_final = welford_covariance(
                diagonal=diagonal)
            wc_state = wc_init(3)
            wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val),
                                 wc_state)
            cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize)
            return cov, cov_inv_sqrt

        cov, cov_inv_sqrt = get_cov(x)

        if diagonal:
            diag_cov = jnp.diagonal(target_cov)
            assert_allclose(cov, diag_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.sqrt(jnp.reciprocal(diag_cov)),
                            rtol=0.06)
        else:
            assert_allclose(cov, target_cov, rtol=0.06)
            assert_allclose(cov_inv_sqrt,
                            jnp.linalg.cholesky(jnp.linalg.inv(cov)),
                            rtol=0.06)
示例#8
0
def test_ellipsoid_clustering():
    import pylab as plt
    from jax import disable_jit, jit
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
                             axis=0)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.ones(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, mask)
    radii, rotation = ellipsoid_params(C)
    # plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS)
                )(random.PRNGKey(0), points, log_VS)
        mu, radii, rotation = ellipsoid_parameters
        print(mu, radii, rotation, jnp.bincount(cluster_id, minlength=0, length=4))

    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters)))

    plt.show()
    def test_encode(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                prepared_inputs_dict = self._prepare_for_class(
                    inputs_dict, model_class)
                model = model_class(config)

                @jax.jit
                def encode_jitted(input_ids, attention_mask=None, **kwargs):
                    return model.encode(input_ids=input_ids,
                                        attention_mask=attention_mask)

                with self.subTest("JIT Enabled"):
                    jitted_outputs = encode_jitted(
                        **prepared_inputs_dict).to_tuple()

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
                        outputs = encode_jitted(
                            **prepared_inputs_dict).to_tuple()

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
                    self.assertEqual(jitted_output.shape, output.shape)
示例#10
0
def test_sample_multi_ellipsoid():
    import pylab as plt
    from jax import disable_jit, jit, vmap
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
                             axis=0)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.ones(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, mask)
    radii, rotation = ellipsoid_params(C)
    y = mu[:, None] + rotation @ jnp.diag(radii) @ x
    # plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)

    with disable_jit():
        cluster_id, ellipsoid_parameters = \
            jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS)
                )(random.PRNGKey(0), points, log_VS)

        mu, radii, rotation = ellipsoid_parameters
        # print(mu, radii, rotation)
        u = vmap(lambda key: sample_multi_ellipsoid(key, mu, radii, rotation, unit_cube_constraint=True)[1])(random.split(random.PRNGKey(0),1000))
    plt.scatter(u[:, 0], u[:, 1], marker='+')
    for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)):
        y = mu[:, None] + rotation @ jnp.diag(radii) @ x
        plt.plot(y[0, :], y[1, :])
        mask = cluster_id == i
        # plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters)))
    plt.show()
示例#11
0
 def test_disable_jit_odeint_with_vmap(self):
     # https://github.com/google/jax/issues/2598
     with jax.disable_jit():
         t = jnp.array([0.0, 1.0])
         x0_eval = jnp.zeros((5, 2))
         f = lambda x0: odeint(lambda x, _t: x, x0, t)
         jax.vmap(f)(x0_eval)  # doesn't crash
示例#12
0
def test_cluster_split():
    import pylab as plt
    from jax import disable_jit
    points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)),
                              1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))],
                             axis=0)
    theta = jnp.linspace(0., jnp.pi * 2, 100)
    x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0)
    mask = jnp.zeros(points.shape[0], jnp.bool_)
    mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0], jnp.bool_))
    radii, rotation = ellipsoid_params(C)
    y = mu[:, None] + rotation @ jnp.diag(radii) @ x
    plt.plot(y[0, :], y[1, :])
    log_VS = log_ellipsoid_volume(radii) - jnp.log(5)
    with disable_jit():
        cluster_id, log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split = \
            cluster_split(random.PRNGKey(0), points, mask, log_VS, log_ellipsoid_volume(radii), kmeans_init=True)
        print(jnp.logaddexp(log_ellipsoid_volume(radii1), log_ellipsoid_volume(radii2)), log_ellipsoid_volume(radii))
        print(log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split)
        print(cluster_id)

    y = mu1[:, None] + rotation1 @ jnp.diag(radii1) @ x
    plt.plot(y[0, :], y[1, :])

    y = mu2[:, None] + rotation2 @ jnp.diag(radii2) @ x
    plt.plot(y[0, :], y[1, :])

    mask = cluster_id == 0
    plt.scatter(points[mask, 0], points[mask, 1])
    mask = cluster_id == 1
    plt.scatter(points[mask, 0], points[mask, 1])

    plt.show()
示例#13
0
def test_param_tracking():
    from jax import jit, numpy as jnp, disable_jit, make_jaxpr
    shape = {
        'a': (4, ),
        'b': (4, ),
        'c': (4, ),
        'd': (4, ),
        'e': (4, ),
        'f': (4, ),
    }
    sample = dict_multimap(jnp.ones, shape)
    n = jnp.array(10)
    log_L = jnp.array(0.)

    @jit
    def test_jax(sample, n, log_L):
        tracked = TrackedExpectation(
            {
                k: lambda sample, n, log_L: jnp.ones(shape[k])
                for k in shape.keys()
            }, shape)
        tracked.update(sample, n, log_L)

        return (tracked.evidence_mean(), tracked.evidence_variance(),
                tracked.information_gain_mean())
        # return (evidence.state, H.state, m.state, M.state)

    print()
    print(len(str(make_jaxpr(test_jax)(sample, n, log_L))))
    with disable_jit():
        print(test_jax(sample, n, log_L))
示例#14
0
 def run(shared_input, clients):
     with jax.disable_jit():
         for client_id, client_batches, client_input in clients:
             step_results = []
             try:
                 state = client_init(shared_input, client_input)
             except Exception as e:
                 raise ForEachClientError(
                     e,
                     stage='client_init',
                     client_id=client_id,
                     client_init=client_init,
                     shared_input=shared_input,
                     client_input=client_input) from e
             for batch in client_batches:
                 try:
                     state, step_result = client_step(state, batch)
                 except Exception as e:
                     raise ForEachClientError(e,
                                              stage='client_step',
                                              client_id=client_id,
                                              client_step=client_step,
                                              state=state,
                                              batch=batch) from e
                 step_results.append(step_result)
             try:
                 output = client_final(shared_input, state)
             except Exception as e:
                 raise ForEachClientError(e,
                                          stage='client_final',
                                          client_id=client_id,
                                          client_final=client_final,
                                          shared_input=shared_input,
                                          state=state) from e
             yield client_id, output, step_results
示例#15
0
def main(argv):
    del argv

    if FLAGS.jax_debug_nans:
        config.update("jax_debug_nans", True)

    # Make an env here.
    env = make_env()
    assert env

    def run_training_loop():
        """Runs the training loop."""
        policy_net_fun = None
        value_net_fun = None
        policy_and_value_net_fun = None
        policy_optimizer_fun = None
        value_optimizer_fun = None
        policy_and_value_optimizer_fun = None

        if FLAGS.combined_policy_and_value_function:
            policy_and_value_net_fun = functools.partial(
                ppo.policy_and_value_net, bottom_layers=common_layers())
            policy_and_value_optimizer_fun = get_optimizer_fun(
                FLAGS.learning_rate)
        else:
            policy_net_fun = functools.partial(ppo.policy_net,
                                               bottom_layers=common_layers())
            value_net_fun = functools.partial(ppo.value_net,
                                              bottom_layers=common_layers())
            policy_optimizer_fun = get_optimizer_fun(
                FLAGS.policy_only_learning_rate)
            value_optimizer_fun = get_optimizer_fun(
                FLAGS.value_only_learning_rate)

        ppo.training_loop(
            env=env,
            epochs=FLAGS.epochs,
            policy_net_fun=policy_net_fun,
            value_net_fun=value_net_fun,
            policy_and_value_net_fun=policy_and_value_net_fun,
            policy_optimizer_fun=policy_optimizer_fun,
            value_optimizer_fun=value_optimizer_fun,
            policy_and_value_optimizer_fun=policy_and_value_optimizer_fun,
            batch_size=FLAGS.batch_size,
            num_optimizer_steps=FLAGS.num_optimizer_steps,
            policy_only_num_optimizer_steps=FLAGS.
            policy_only_num_optimizer_steps,
            value_only_num_optimizer_steps=FLAGS.
            value_only_num_optimizer_steps,
            target_kl=FLAGS.target_kl,
            boundary=FLAGS.boundary,
            max_timestep=FLAGS.max_timestep,
            random_seed=FLAGS.random_seed)

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        with jax.disable_jit():
            run_training_loop()
    else:
        run_training_loop()
示例#16
0
    def testLogSumExpNans(self):
        # Regression test for https://github.com/google/jax/issues/7634
        with jax.debug_nans(True):
            with jax.disable_jit():
                result = lsp_special.logsumexp(1.0)
                self.assertEqual(result, 1.0)

                result = lsp_special.logsumexp(1.0, b=1.0)
                self.assertEqual(result, 1.0)
示例#17
0
def check_preprocessors(space,
                        *preprocessors,
                        num_samples=20,
                        random_seed=None):
    r"""

    Check whether two preprocessors are the same.

    Parameters
    ----------
    space : gym.Space

        The domain of the prepocessors.

    \*preprocessors

        Preprocessor functions, which are functions with input signature: :code:`func(rng: PRNGKey,
        x: Element[space]) -> Any`.

    num_samples : positive int

        The number of samples in which to run checks.

    Returns
    -------
    match : bool

        Whether the preprocessors match.

    """
    if len(preprocessors) < 2:
        raise ValueError(
            "need at least two preprocessors in order to run test")

    def test_leaves(a, b):
        assert type(a) is type(b)
        return onp.testing.assert_allclose(onp.asanyarray(a),
                                           onp.asanyarray(b))

    rngs = hk.PRNGSequence(
        onp.random.RandomState(random_seed).randint(jnp.iinfo('int32').max))
    p0, *ps = preprocessors

    with jax.disable_jit():
        for _ in range(num_samples):
            x = space.sample()
            y0 = p0(next(rngs), x)
            for p in ps:
                y = p(next(rngs), x)
                if jax.tree_structure(y) != jax.tree_structure(y0):
                    return False
                try:
                    jax.tree_multimap(test_leaves, y, y0)
                except AssertionError:
                    return False
    return True
示例#18
0
    def test_autodownload_pretrained_r50(self):
        fname, _ = urllib.request.urlretrieve(
            "https://upload.wikimedia.org/wikipedia/commons/e/e4/A_French_Bulldog.jpg"
        )
        im = np.array(PIL.Image.open(fname).resize([224, 224
                                                    ])) / np.float32(255)

        r50 = elegy.nets.resnet.ResNet50(weights="imagenet")
        with jax.disable_jit():
            assert elegy.Model(r50).predict(im[np.newaxis]).argmax() == 245
示例#19
0
def main(argv):
    del argv

    if FLAGS.jax_debug_nans:
        config.update("jax_debug_nans", True)

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        with jax.disable_jit():
            run_training_loop()
    else:
        run_training_loop()
示例#20
0
def main(argv):
  del argv
  logging.info("Starting PPO Main.")

  if FLAGS.jax_debug_nans:
    config.update("jax_debug_nans", True)

  if FLAGS.use_tpu:
    config.update("jax_platform_name", "tpu")
  else:
    config.update("jax_platform_name", "gpu")


  gin_configs = FLAGS.config or []
  gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs)

  # TODO(pkozakowski): Find a better way to determine this.
  if "OnlineTuneEnv" in FLAGS.env_problem_name:
    # TODO(pkozakowski): Separate env output dirs by train/eval and epoch.
    env_kwargs = {"output_dir": os.path.join(FLAGS.output_dir, "envs")}
  else:
    env_kwargs = {}

  # Make an env here.
  env = make_env(batch_size=FLAGS.batch_size, **env_kwargs)
  assert env

  eval_env = make_env(batch_size=FLAGS.eval_batch_size, **env_kwargs)
  assert eval_env

  def run_training_loop():
    """Runs the training loop."""
    logging.info("Starting the training loop.")

    policy_and_value_net_fn = functools.partial(
        ppo.policy_and_value_net,
        bottom_layers_fn=common_layers,
        two_towers=FLAGS.two_towers)
    policy_and_value_optimizer_fn = get_optimizer_fn(FLAGS.learning_rate)

    ppo.training_loop(
        output_dir=FLAGS.output_dir,
        env=env,
        eval_env=eval_env,
        env_name=str(FLAGS.env_problem_name),
        policy_and_value_net_fn=policy_and_value_net_fn,
        policy_and_value_optimizer_fn=policy_and_value_optimizer_fn,
    )

  if FLAGS.jax_debug_nans or FLAGS.disable_jit:
    with jax.disable_jit():
      run_training_loop()
  else:
    run_training_loop()
示例#21
0
    def train_speed_memory(self, batch_size, seq_length):
        key = jax.random.PRNGKey(0)
        input_ids = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size)
        targets = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size)
        labels = jax.random.randint(key, (batch_size, seq_length), 0, 2)
        # input_ids = np.random.randint(0, self.vocab_size, (batch_size, seq_length))
        # targets = np.random.randint(0, self.vocab_size, (batch_size, seq_length))
        # labels = np.random.randint(0,2, (batch_size, seq_length))
        @jax.jit
        def train_step():

            def loss_fn(params):
                token_mask = jnp.where(labels > 0, 1.0, 0.0).astype(self.dtype)
                logits = self.model(input_ids=input_ids, train=True, params=params, dropout_rng=jax.random.PRNGKey(0))[0]
                loss, normalizing_factor = cross_entropy(logits,targets, token_mask)
                jax.profiler.save_device_memory_profile(f"memory/{workload[0]}_{workload[1]}_memory.prof", "gpu")
                return loss / normalizing_factor
            if self.fp16 and jax.local_devices()[0].platform == 'gpu':
                grad_fn = self.dynamic_scale.value_and_grad(loss_fn)
                dyn_scale, is_fin, loss, grad = grad_fn(self.model.params)
            else:
                grad_fn = jax.value_and_grad(loss_fn)
                loss, grad = grad_fn(self.model.params)
            return tree_flatten(grad)[0]


        if jax.local_devices()[0].platform == 'gpu':
            nvml.nvmlInit()
            train_step()
            handle = nvml.nvmlDeviceGetHandleByIndex(0)
            meminfo = nvml.nvmlDeviceGetMemoryInfo(handle)
            max_bytes_in_use = meminfo.used
            memory = Memory(max_bytes_in_use)
            # shutdown nvml
            nvml.nvmlShutdown()
        else:
            memory = None
        # timeit.repeat(train_step,repeat=1,number=2)
        timeit.repeat("for i in train_step():i.block_until_ready()", repeat=1, number=2,globals=locals())
        if self.jit:
            # runtimes = timeit.repeat(train_step,repeat=self.repeat,number=3)
            runtimes = timeit.repeat("for i in train_step():i.block_until_ready()", repeat=self.repeat, number=3,globals=locals())
        else:
            with jax.disable_jit():
                # runtimes = timeit.repeat(train_step, repeat=self.repeat, number=3)
                runtimes = timeit.repeat("for i in train_step():i.block_until_ready()", repeat=self.repeat, number=3,globals=locals())


        return float(np.min(runtimes)/3.0), memory
示例#22
0
def train_rl(output_dir,
             n_epochs=10000,
             light_rl=True,
             light_rl_trainer=light_trainers.PolicyGradient):
    """Train the RL agent.

  Args:
    output_dir: Output directory.
    n_epochs: Number epochs to run the training for.
    light_rl: deprecated, always True, left out for old gin configs.
    light_rl_trainer: which light RL trainer to use (experimental).
  """
    del light_rl
    tf_np.set_allow_float64(FLAGS.tf_allow_float64)
    task = rl_task.RLTask()
    env_name = task.env_name

    if FLAGS.jax_debug_nans:
        config.update('jax_debug_nans', True)

    if FLAGS.use_tpu:
        config.update('jax_platform_name', 'tpu')
    else:
        config.update('jax_platform_name', '')

    trainer = light_rl_trainer(task=task, output_dir=output_dir)

    def light_training_loop():
        """Run the trainer for n_epochs and call close on it."""
        try:
            logging.info('Starting RL training for %d epochs.', n_epochs)
            trainer.run(n_epochs, n_epochs_is_total_epochs=True)
            logging.info('Completed RL training for %d epochs.', n_epochs)
            trainer.close()
            logging.info('Trainer is now closed.')
        except Exception as e:
            raise e
        finally:
            logging.info(
                'Encountered an exception, still calling trainer.close()')
            trainer.close()
            logging.info('Trainer is now closed.')

    if FLAGS.jax_debug_nans or FLAGS.disable_jit:
        fastmath.disable_jit()
        with jax.disable_jit():
            light_training_loop()
    else:
        light_training_loop()
示例#23
0
文件: jax_to_tf.py 项目: qiuminxu/jax
    def wrapped_fun(*args):
        # TODO(necula): remove the jit disabling once we handle all control-flow.
        # Disabling the jit helps to avoid some unsupported jax primitives.
        # E.g. scan will be statically unrolled.
        def doit():
            f = lu.wrap_init(fun)
            args_flat, in_tree = tree_util.tree_flatten((args, {}))
            flat_fun, out_tree = flatten_fun(f, in_tree)
            out_flat = _interpret_fun(flat_fun, args_flat)
            return tree_util.tree_unflatten(out_tree(), out_flat)

        if _jit_state.disable_jit:
            with jax.disable_jit():
                return doit()
        else:
            return doit()
    def test_get_image_features(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        model = FlaxCLIPModel(config)

        @jax.jit
        def model_jitted(pixel_values):
            return model.get_image_features(pixel_values=pixel_values)

        with self.subTest("JIT Enabled"):
            jitted_output = model_jitted(inputs_dict["pixel_values"])

        with self.subTest("JIT Disabled"):
            with jax.disable_jit():
                output = model_jitted(inputs_dict["pixel_values"])

        self.assertEqual(jitted_output.shape, output.shape)
        self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
 def test_lax_dot_has_integer_inputs_in_quantized_dot(
         self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot,
         prec):
     weight_params = QuantOps.WeightParams(prec=prec,
                                           axis=(0, ),
                                           half_shift=False)
     act_params = QuantOps.ActHParams(input_distribution=act_distribution,
                                      bounds=jnp.array([[3.0, 1.5]]),
                                      prec=prec,
                                      half_shift=False)
     act = self.lhs
     if act_distribution == 'positive':
         act = jnp.abs(act)
     # We need this context manager to stop Jax from trying to compile the arms
     # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always
     # try to compile the functions passed to `lax.cond`, even if outside of a
     # JITed context. JIT compilation is incompatible with using a mock for the
     # call to 'dot_general' because during compilation Jax will expect
     # 'dot_general' to return a tracer and will throw an error if it returns a
     # mock instead. By explicily using jax.disable_jit, Jax will not try to
     # compile the arms to lax.cond and so using a mock will work fine.
     with jax.disable_jit():
         quantization.quantized_dot(
             w=self.rhs,
             act=act,
             weight_params=weight_params,
             act_hparams=act_params,
             get_bounds_params=None,
             quant_type=QuantType.aqt,
             prefer_int8_to_int32_dot=prefer_int8_to_int32_dot)
     act_inputs, weight_inputs = mock_dot_general.call_args[0]
     self.assert_is_integer_in_range(act_inputs,
                                     prec=prec,
                                     distribution=act_distribution)
     self.assert_is_integer_in_range(weight_inputs,
                                     prec=prec,
                                     distribution='symmetric')
     if prefer_int8_to_int32_dot and not (act_distribution == 'positive'
                                          and prec == 8):
         expected_input_dtype = jnp.int8
     else:
         expected_input_dtype = jnp.float32
     self.assertEqual(act_inputs.dtype, expected_input_dtype)
     self.assertEqual(weight_inputs.dtype, expected_input_dtype)
示例#26
0
    def test_jit_compilation(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )

        for model_class in self.all_model_classes:
            with self.subTest(model_class.__name__):
                prepared_inputs_dict = self._prepare_for_class(
                    inputs_dict, model_class)
                model = model_class(config)

                @jax.jit
                def model_jitted(input_ids,
                                 attention_mask=None,
                                 token_type_ids=None):
                    return model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                    ).to_tuple()

                with self.subTest("JIT Enabled"):
                    jitted_outputs = model_jitted(**prepared_inputs_dict)

                with self.subTest("JIT Disabled"):
                    with jax.disable_jit():
                        outputs = model_jitted(**prepared_inputs_dict)

                self.assertEqual(len(outputs), len(jitted_outputs))
                for jitted_output, output in zip(jitted_outputs, outputs):
                    self.assertEqual(jitted_output.shape, output.shape)

                @jax.jit
                def model_jitted_return_dict(input_ids,
                                             attention_mask=None,
                                             token_type_ids=None):
                    return model(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                    )

                # jitted function cannot return OrderedDict
                with self.assertRaises(TypeError):
                    model_jitted_return_dict(**prepared_inputs_dict)
    def test_get_text_features(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
        )
        model = FlaxCLIPModel(config)

        @jax.jit
        def model_jitted(input_ids, attention_mask, **kwargs):
            return model.get_text_features(input_ids=input_ids,
                                           attention_mask=attention_mask)

        with self.subTest("JIT Enabled"):
            jitted_output = model_jitted(**inputs_dict)

        with self.subTest("JIT Disabled"):
            with jax.disable_jit():
                output = model_jitted(**inputs_dict)

        self.assertEqual(jitted_output.shape, output.shape)
        self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
示例#28
0
文件: debug.py 项目: fehiepsi/jaxns
def debug_mvee():
    import pylab as plt

    n = random.normal(random.PRNGKey(0), (10000,2))
    n = n /jnp.linalg.norm(n, axis=1, keepdims=True)
    angle = jnp.arctan2(n[:,1], n[:,0])
    plt.hist(angle, bins=100)
    plt.show()
    N = 120
    D = 2
    points = random.uniform(random.PRNGKey(0), (N, D))

    from jax import disable_jit
    with disable_jit():
        center, radii, rotation = minimum_volume_enclosing_ellipsoid(points, 0.01)

    plt.hist(jnp.linalg.norm((rotation.T @ (points.T - center[:, None])) / radii[:, None], axis=0))
    plt.show()
    print(center, radii, rotation)
    plt.scatter(points[:, 0], points[:, 1])
    theta = jnp.linspace(0., jnp.pi*2, 100)
    ellipsis = center[:, None] + rotation @ jnp.stack([radii[0]*jnp.cos(theta), radii[1]*jnp.sin(theta)], axis=0)
    plt.plot(ellipsis[0,:], ellipsis[1,:])

    for i in range(1000):
        y = sample_ellipsoid(random.PRNGKey(i), center, radii, rotation)
        plt.scatter(y[0], y[1])



    C = jnp.linalg.pinv(jnp.cov(points, rowvar=False, bias=True))
    p = (N - D - 1)/N
    def q(p):
        return p + p**2/(4.*(D-1))
    C = C / q(p)
    c = jnp.mean(points, axis=0)
    W, Q, Vh = jnp.linalg.svd(C)
    radii = jnp.reciprocal(jnp.sqrt(Q))
    rotation = Vh.conj().T
    ellipsis = c[:, None] + rotation @ jnp.stack([radii[0] * jnp.cos(theta), radii[1] * jnp.sin(theta)], axis=0)
    plt.plot(ellipsis[0, :], ellipsis[1, :])

    plt.show()
def test_multiple_sentences(jit):
    tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
    model = FlaxRobertaModel.from_pretrained("roberta-base")

    sentences = ["this is an example sentence", "this is another", "and a third one"]
    encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True)

    @jax.jit
    def model_jitted(input_ids, attention_mask):
        return model(input_ids, attention_mask)

    if jit == "disable_jit":
        with jax.disable_jit():
            tokens, pooled = model_jitted(**encodings)
    else:
        tokens, pooled = model_jitted(**encodings)

    assert tokens.shape == (3, 7, 768)
    assert pooled.shape == (3, 768)
示例#30
0
def test_accept_order():
    def constraint(u):
        return u ** 2 > 0.5

    def accept_prob(u):
        if u > 0.9:
            return 1
        if u > 0.8:
            return 0.5
        if u > 0.7:
            return 0.25
        return 0.

    def f1(key):
        while True:
            key, u_key = random.split(key, 2)
            u = random.uniform(u_key)
            if constraint(u):
                key, a_key = random.split(key, 2)
                if random.uniform(a_key) < accept_prob(u):
                    return u

    def f2(key):
        while True:
            key, u_key, a_key = random.split(key, 3)
            u = random.uniform(u_key)
            if random.uniform(a_key) < accept_prob(u):
                if constraint(u):
                    return u

    from jax import vmap, disable_jit
    import pylab as plt
    keys = random.split(random.PRNGKey(0), 1000)
    with disable_jit():
        u1 = jnp.array([f1(key) for key in keys])
        u2 = jnp.array([f2(key) for key in keys])

    print(u1)

    plt.hist(u1, bins='auto', alpha=0.5)
    plt.hist(u2, bins='auto', alpha=0.5)
    plt.show()