def test_multiple_sequences(self): tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") model = FlaxBertModel.from_pretrained("bert-base-cased") sequences = [ "this is an example sentence", "this is another", "and a third one" ] encodings = tokenizer(sequences, return_tensors=TensorType.JAX, padding=True, truncation=True) @jax.jit def model_jitted(input_ids, attention_mask=None, token_type_ids=None): return model(input_ids, attention_mask, token_type_ids) with self.subTest("JIT Disabled"): with jax.disable_jit(): tokens, pooled = model_jitted(**encodings) self.assertEqual(tokens.shape, (3, 7, 768)) self.assertEqual(pooled.shape, (3, 768)) with self.subTest("JIT Enabled"): jitted_tokens, jitted_pooled = model_jitted(**encodings) self.assertEqual(jitted_tokens.shape, (3, 7, 768)) self.assertEqual(jitted_pooled.shape, (3, 768))
def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: with self.subTest(model_class.__name__): # TODO later: have some way to initialize easily a Flax model from config, for now I go through PT pt_model_class_name = model_class.__name__[ 4:] # Skip the "Flax" at the beginning pt_model_class = getattr(transformers, pt_model_class_name) pt_model = pt_model_class(config).eval() model = convert_pt_model_to_flax(pt_model, config, model_class) @jax.jit def model_jitted(input_ids, attention_mask=None, token_type_ids=None): return model(input_ids, attention_mask, token_type_ids) with self.subTest("JIT Disabled"): with jax.disable_jit(): outputs = model_jitted(**inputs_dict) with self.subTest("JIT Enabled"): jitted_outputs = model_jitted(**inputs_dict) self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape)
def inference_speed_memory(self, batch_size, seq_length): # input_ids = np.random.randint(0, self.vocab_size, (batch_size, seq_length)) key = jax.random.PRNGKey(0) input_ids = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size) @jax.jit def ref_step(): out = self.model(input_ids=input_ids) return out[0] if jax.local_devices()[0].platform == 'gpu': nvml.nvmlInit() ref_step().block_until_ready() handle = nvml.nvmlDeviceGetHandleByIndex(0) meminfo = nvml.nvmlDeviceGetMemoryInfo(handle) max_bytes_in_use = meminfo.used memory = Memory(max_bytes_in_use) # shutdown nvml nvml.nvmlShutdown() else: memory = None timeit.repeat("ref_step().block_until_ready()", repeat=1, number=2,globals=locals()) if self.jit: runtimes = timeit.repeat("ref_step().block_until_ready()", repeat=self.repeat,number=3,globals=locals()) else: with jax.disable_jit(): runtimes = timeit.repeat("ref_step().block_until_ready()",repeat=self.repeat,number=3,globals=locals()) return float(np.min(runtimes)/3.0), memory
def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: with self.subTest(model_class.__name__): prepared_inputs_dict = self._prepare_for_class( inputs_dict, model_class) model = model_class(config) @jax.jit def model_jitted(input_ids, pixel_values, **kwargs): return model(input_ids=input_ids, pixel_values=pixel_values, **kwargs).to_tuple() with self.subTest("JIT Enabled"): jitted_outputs = model_jitted(**prepared_inputs_dict) with self.subTest("JIT Disabled"): with jax.disable_jit(): outputs = model_jitted(**prepared_inputs_dict) self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs[:4], outputs[:4]): self.assertEqual(jitted_output.shape, output.shape)
def test_decode(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): model = model_class(config) encoder_outputs = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"]) prepared_inputs_dict = { "decoder_input_ids": inputs_dict["decoder_input_ids"], "decoder_attention_mask": inputs_dict["decoder_attention_mask"], "encoder_outputs": encoder_outputs, } @jax.jit def decode_jitted(decoder_input_ids, decoder_attention_mask, encoder_outputs): return model.decode( decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask, encoder_outputs=encoder_outputs, ) with self.subTest("JIT Enabled"): jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple() with self.subTest("JIT Disabled"): with jax.disable_jit(): outputs = decode_jitted(**prepared_inputs_dict).to_tuple() self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape)
def main(argv): del argv if FLAGS.jax_debug_nans: config.update("jax_debug_nans", True) def run_training_loop(): optimizer_fun = functools.partial( ppo.optimizer_fun, step_size=FLAGS.learning_rate) ppo.training_loop( env_name=FLAGS.env_name, epochs=FLAGS.epochs, policy_and_value_net_fun=functools.partial( ppo.policy_and_value_net, bottom_layers=common_layers()), policy_and_value_optimizer_fun=optimizer_fun, batch_size=FLAGS.batch_size, num_optimizer_steps=FLAGS.num_optimizer_steps, boundary=FLAGS.boundary, max_timestep=FLAGS.max_timestep, random_seed=FLAGS.random_seed) if FLAGS.jax_debug_nans or FLAGS.disable_jit: with jax.disable_jit(): run_training_loop() else: run_training_loop()
def test_welford_covariance(jitted, diagonal, regularize): with optional(jitted, disable_jit()), optional(jitted, control_flow_prims_disabled()): np.random.seed(0) loc = np.random.randn(3) a = np.random.randn(3, 3) target_cov = np.matmul(a, a.T) x = np.random.multivariate_normal(loc, target_cov, size=(2000, )) x = device_put(x) @jit def get_cov(x): wc_init, wc_update, wc_final = welford_covariance( diagonal=diagonal) wc_state = wc_init(3) wc_state = fori_loop(0, 2000, lambda i, val: wc_update(x[i], val), wc_state) cov, cov_inv_sqrt = wc_final(wc_state, regularize=regularize) return cov, cov_inv_sqrt cov, cov_inv_sqrt = get_cov(x) if diagonal: diag_cov = jnp.diagonal(target_cov) assert_allclose(cov, diag_cov, rtol=0.06) assert_allclose(cov_inv_sqrt, jnp.sqrt(jnp.reciprocal(diag_cov)), rtol=0.06) else: assert_allclose(cov, target_cov, rtol=0.06) assert_allclose(cov_inv_sqrt, jnp.linalg.cholesky(jnp.linalg.inv(cov)), rtol=0.06)
def test_ellipsoid_clustering(): import pylab as plt from jax import disable_jit, jit points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.ones(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, mask) radii, rotation = ellipsoid_params(C) # plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS) )(random.PRNGKey(0), points, log_VS) mu, radii, rotation = ellipsoid_parameters print(mu, radii, rotation, jnp.bincount(cluster_id, minlength=0, length=4)) for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == i plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters))) plt.show()
def test_encode(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: with self.subTest(model_class.__name__): prepared_inputs_dict = self._prepare_for_class( inputs_dict, model_class) model = model_class(config) @jax.jit def encode_jitted(input_ids, attention_mask=None, **kwargs): return model.encode(input_ids=input_ids, attention_mask=attention_mask) with self.subTest("JIT Enabled"): jitted_outputs = encode_jitted( **prepared_inputs_dict).to_tuple() with self.subTest("JIT Disabled"): with jax.disable_jit(): outputs = encode_jitted( **prepared_inputs_dict).to_tuple() self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape)
def test_sample_multi_ellipsoid(): import pylab as plt from jax import disable_jit, jit, vmap points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.ones(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, mask) radii, rotation = ellipsoid_params(C) y = mu[:, None] + rotation @ jnp.diag(radii) @ x # plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, ellipsoid_parameters = \ jit(lambda key, points, log_VS: ellipsoid_clustering(random.PRNGKey(0), points, 4, log_VS) )(random.PRNGKey(0), points, log_VS) mu, radii, rotation = ellipsoid_parameters # print(mu, radii, rotation) u = vmap(lambda key: sample_multi_ellipsoid(key, mu, radii, rotation, unit_cube_constraint=True)[1])(random.split(random.PRNGKey(0),1000)) plt.scatter(u[:, 0], u[:, 1], marker='+') for i, (mu, radii, rotation) in enumerate(zip(mu, radii, rotation)): y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == i # plt.scatter(points[mask, 0], points[mask, 1], c=plt.cm.jet(i / len(ellipsoid_parameters))) plt.show()
def test_disable_jit_odeint_with_vmap(self): # https://github.com/google/jax/issues/2598 with jax.disable_jit(): t = jnp.array([0.0, 1.0]) x0_eval = jnp.zeros((5, 2)) f = lambda x0: odeint(lambda x, _t: x, x0, t) jax.vmap(f)(x0_eval) # doesn't crash
def test_cluster_split(): import pylab as plt from jax import disable_jit points = jnp.concatenate([random.uniform(random.PRNGKey(0), shape=(30, 2)), 1.25 + random.uniform(random.PRNGKey(0), shape=(10, 2))], axis=0) theta = jnp.linspace(0., jnp.pi * 2, 100) x = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=0) mask = jnp.zeros(points.shape[0], jnp.bool_) mu, C = bounding_ellipsoid(points, jnp.ones(points.shape[0], jnp.bool_)) radii, rotation = ellipsoid_params(C) y = mu[:, None] + rotation @ jnp.diag(radii) @ x plt.plot(y[0, :], y[1, :]) log_VS = log_ellipsoid_volume(radii) - jnp.log(5) with disable_jit(): cluster_id, log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split = \ cluster_split(random.PRNGKey(0), points, mask, log_VS, log_ellipsoid_volume(radii), kmeans_init=True) print(jnp.logaddexp(log_ellipsoid_volume(radii1), log_ellipsoid_volume(radii2)), log_ellipsoid_volume(radii)) print(log_VS1, mu1, radii1, rotation1, log_VS2, mu2, radii2, rotation2, do_split) print(cluster_id) y = mu1[:, None] + rotation1 @ jnp.diag(radii1) @ x plt.plot(y[0, :], y[1, :]) y = mu2[:, None] + rotation2 @ jnp.diag(radii2) @ x plt.plot(y[0, :], y[1, :]) mask = cluster_id == 0 plt.scatter(points[mask, 0], points[mask, 1]) mask = cluster_id == 1 plt.scatter(points[mask, 0], points[mask, 1]) plt.show()
def test_param_tracking(): from jax import jit, numpy as jnp, disable_jit, make_jaxpr shape = { 'a': (4, ), 'b': (4, ), 'c': (4, ), 'd': (4, ), 'e': (4, ), 'f': (4, ), } sample = dict_multimap(jnp.ones, shape) n = jnp.array(10) log_L = jnp.array(0.) @jit def test_jax(sample, n, log_L): tracked = TrackedExpectation( { k: lambda sample, n, log_L: jnp.ones(shape[k]) for k in shape.keys() }, shape) tracked.update(sample, n, log_L) return (tracked.evidence_mean(), tracked.evidence_variance(), tracked.information_gain_mean()) # return (evidence.state, H.state, m.state, M.state) print() print(len(str(make_jaxpr(test_jax)(sample, n, log_L)))) with disable_jit(): print(test_jax(sample, n, log_L))
def run(shared_input, clients): with jax.disable_jit(): for client_id, client_batches, client_input in clients: step_results = [] try: state = client_init(shared_input, client_input) except Exception as e: raise ForEachClientError( e, stage='client_init', client_id=client_id, client_init=client_init, shared_input=shared_input, client_input=client_input) from e for batch in client_batches: try: state, step_result = client_step(state, batch) except Exception as e: raise ForEachClientError(e, stage='client_step', client_id=client_id, client_step=client_step, state=state, batch=batch) from e step_results.append(step_result) try: output = client_final(shared_input, state) except Exception as e: raise ForEachClientError(e, stage='client_final', client_id=client_id, client_final=client_final, shared_input=shared_input, state=state) from e yield client_id, output, step_results
def main(argv): del argv if FLAGS.jax_debug_nans: config.update("jax_debug_nans", True) # Make an env here. env = make_env() assert env def run_training_loop(): """Runs the training loop.""" policy_net_fun = None value_net_fun = None policy_and_value_net_fun = None policy_optimizer_fun = None value_optimizer_fun = None policy_and_value_optimizer_fun = None if FLAGS.combined_policy_and_value_function: policy_and_value_net_fun = functools.partial( ppo.policy_and_value_net, bottom_layers=common_layers()) policy_and_value_optimizer_fun = get_optimizer_fun( FLAGS.learning_rate) else: policy_net_fun = functools.partial(ppo.policy_net, bottom_layers=common_layers()) value_net_fun = functools.partial(ppo.value_net, bottom_layers=common_layers()) policy_optimizer_fun = get_optimizer_fun( FLAGS.policy_only_learning_rate) value_optimizer_fun = get_optimizer_fun( FLAGS.value_only_learning_rate) ppo.training_loop( env=env, epochs=FLAGS.epochs, policy_net_fun=policy_net_fun, value_net_fun=value_net_fun, policy_and_value_net_fun=policy_and_value_net_fun, policy_optimizer_fun=policy_optimizer_fun, value_optimizer_fun=value_optimizer_fun, policy_and_value_optimizer_fun=policy_and_value_optimizer_fun, batch_size=FLAGS.batch_size, num_optimizer_steps=FLAGS.num_optimizer_steps, policy_only_num_optimizer_steps=FLAGS. policy_only_num_optimizer_steps, value_only_num_optimizer_steps=FLAGS. value_only_num_optimizer_steps, target_kl=FLAGS.target_kl, boundary=FLAGS.boundary, max_timestep=FLAGS.max_timestep, random_seed=FLAGS.random_seed) if FLAGS.jax_debug_nans or FLAGS.disable_jit: with jax.disable_jit(): run_training_loop() else: run_training_loop()
def testLogSumExpNans(self): # Regression test for https://github.com/google/jax/issues/7634 with jax.debug_nans(True): with jax.disable_jit(): result = lsp_special.logsumexp(1.0) self.assertEqual(result, 1.0) result = lsp_special.logsumexp(1.0, b=1.0) self.assertEqual(result, 1.0)
def check_preprocessors(space, *preprocessors, num_samples=20, random_seed=None): r""" Check whether two preprocessors are the same. Parameters ---------- space : gym.Space The domain of the prepocessors. \*preprocessors Preprocessor functions, which are functions with input signature: :code:`func(rng: PRNGKey, x: Element[space]) -> Any`. num_samples : positive int The number of samples in which to run checks. Returns ------- match : bool Whether the preprocessors match. """ if len(preprocessors) < 2: raise ValueError( "need at least two preprocessors in order to run test") def test_leaves(a, b): assert type(a) is type(b) return onp.testing.assert_allclose(onp.asanyarray(a), onp.asanyarray(b)) rngs = hk.PRNGSequence( onp.random.RandomState(random_seed).randint(jnp.iinfo('int32').max)) p0, *ps = preprocessors with jax.disable_jit(): for _ in range(num_samples): x = space.sample() y0 = p0(next(rngs), x) for p in ps: y = p(next(rngs), x) if jax.tree_structure(y) != jax.tree_structure(y0): return False try: jax.tree_multimap(test_leaves, y, y0) except AssertionError: return False return True
def test_autodownload_pretrained_r50(self): fname, _ = urllib.request.urlretrieve( "https://upload.wikimedia.org/wikipedia/commons/e/e4/A_French_Bulldog.jpg" ) im = np.array(PIL.Image.open(fname).resize([224, 224 ])) / np.float32(255) r50 = elegy.nets.resnet.ResNet50(weights="imagenet") with jax.disable_jit(): assert elegy.Model(r50).predict(im[np.newaxis]).argmax() == 245
def main(argv): del argv if FLAGS.jax_debug_nans: config.update("jax_debug_nans", True) if FLAGS.jax_debug_nans or FLAGS.disable_jit: with jax.disable_jit(): run_training_loop() else: run_training_loop()
def main(argv): del argv logging.info("Starting PPO Main.") if FLAGS.jax_debug_nans: config.update("jax_debug_nans", True) if FLAGS.use_tpu: config.update("jax_platform_name", "tpu") else: config.update("jax_platform_name", "gpu") gin_configs = FLAGS.config or [] gin.parse_config_files_and_bindings(FLAGS.config_file, gin_configs) # TODO(pkozakowski): Find a better way to determine this. if "OnlineTuneEnv" in FLAGS.env_problem_name: # TODO(pkozakowski): Separate env output dirs by train/eval and epoch. env_kwargs = {"output_dir": os.path.join(FLAGS.output_dir, "envs")} else: env_kwargs = {} # Make an env here. env = make_env(batch_size=FLAGS.batch_size, **env_kwargs) assert env eval_env = make_env(batch_size=FLAGS.eval_batch_size, **env_kwargs) assert eval_env def run_training_loop(): """Runs the training loop.""" logging.info("Starting the training loop.") policy_and_value_net_fn = functools.partial( ppo.policy_and_value_net, bottom_layers_fn=common_layers, two_towers=FLAGS.two_towers) policy_and_value_optimizer_fn = get_optimizer_fn(FLAGS.learning_rate) ppo.training_loop( output_dir=FLAGS.output_dir, env=env, eval_env=eval_env, env_name=str(FLAGS.env_problem_name), policy_and_value_net_fn=policy_and_value_net_fn, policy_and_value_optimizer_fn=policy_and_value_optimizer_fn, ) if FLAGS.jax_debug_nans or FLAGS.disable_jit: with jax.disable_jit(): run_training_loop() else: run_training_loop()
def train_speed_memory(self, batch_size, seq_length): key = jax.random.PRNGKey(0) input_ids = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size) targets = jax.random.randint(key, (batch_size, seq_length), 0, self.vocab_size) labels = jax.random.randint(key, (batch_size, seq_length), 0, 2) # input_ids = np.random.randint(0, self.vocab_size, (batch_size, seq_length)) # targets = np.random.randint(0, self.vocab_size, (batch_size, seq_length)) # labels = np.random.randint(0,2, (batch_size, seq_length)) @jax.jit def train_step(): def loss_fn(params): token_mask = jnp.where(labels > 0, 1.0, 0.0).astype(self.dtype) logits = self.model(input_ids=input_ids, train=True, params=params, dropout_rng=jax.random.PRNGKey(0))[0] loss, normalizing_factor = cross_entropy(logits,targets, token_mask) jax.profiler.save_device_memory_profile(f"memory/{workload[0]}_{workload[1]}_memory.prof", "gpu") return loss / normalizing_factor if self.fp16 and jax.local_devices()[0].platform == 'gpu': grad_fn = self.dynamic_scale.value_and_grad(loss_fn) dyn_scale, is_fin, loss, grad = grad_fn(self.model.params) else: grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(self.model.params) return tree_flatten(grad)[0] if jax.local_devices()[0].platform == 'gpu': nvml.nvmlInit() train_step() handle = nvml.nvmlDeviceGetHandleByIndex(0) meminfo = nvml.nvmlDeviceGetMemoryInfo(handle) max_bytes_in_use = meminfo.used memory = Memory(max_bytes_in_use) # shutdown nvml nvml.nvmlShutdown() else: memory = None # timeit.repeat(train_step,repeat=1,number=2) timeit.repeat("for i in train_step():i.block_until_ready()", repeat=1, number=2,globals=locals()) if self.jit: # runtimes = timeit.repeat(train_step,repeat=self.repeat,number=3) runtimes = timeit.repeat("for i in train_step():i.block_until_ready()", repeat=self.repeat, number=3,globals=locals()) else: with jax.disable_jit(): # runtimes = timeit.repeat(train_step, repeat=self.repeat, number=3) runtimes = timeit.repeat("for i in train_step():i.block_until_ready()", repeat=self.repeat, number=3,globals=locals()) return float(np.min(runtimes)/3.0), memory
def train_rl(output_dir, n_epochs=10000, light_rl=True, light_rl_trainer=light_trainers.PolicyGradient): """Train the RL agent. Args: output_dir: Output directory. n_epochs: Number epochs to run the training for. light_rl: deprecated, always True, left out for old gin configs. light_rl_trainer: which light RL trainer to use (experimental). """ del light_rl tf_np.set_allow_float64(FLAGS.tf_allow_float64) task = rl_task.RLTask() env_name = task.env_name if FLAGS.jax_debug_nans: config.update('jax_debug_nans', True) if FLAGS.use_tpu: config.update('jax_platform_name', 'tpu') else: config.update('jax_platform_name', '') trainer = light_rl_trainer(task=task, output_dir=output_dir) def light_training_loop(): """Run the trainer for n_epochs and call close on it.""" try: logging.info('Starting RL training for %d epochs.', n_epochs) trainer.run(n_epochs, n_epochs_is_total_epochs=True) logging.info('Completed RL training for %d epochs.', n_epochs) trainer.close() logging.info('Trainer is now closed.') except Exception as e: raise e finally: logging.info( 'Encountered an exception, still calling trainer.close()') trainer.close() logging.info('Trainer is now closed.') if FLAGS.jax_debug_nans or FLAGS.disable_jit: fastmath.disable_jit() with jax.disable_jit(): light_training_loop() else: light_training_loop()
def wrapped_fun(*args): # TODO(necula): remove the jit disabling once we handle all control-flow. # Disabling the jit helps to avoid some unsupported jax primitives. # E.g. scan will be statically unrolled. def doit(): f = lu.wrap_init(fun) args_flat, in_tree = tree_util.tree_flatten((args, {})) flat_fun, out_tree = flatten_fun(f, in_tree) out_flat = _interpret_fun(flat_fun, args_flat) return tree_util.tree_unflatten(out_tree(), out_flat) if _jit_state.disable_jit: with jax.disable_jit(): return doit() else: return doit()
def test_get_image_features(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) model = FlaxCLIPModel(config) @jax.jit def model_jitted(pixel_values): return model.get_image_features(pixel_values=pixel_values) with self.subTest("JIT Enabled"): jitted_output = model_jitted(inputs_dict["pixel_values"]) with self.subTest("JIT Disabled"): with jax.disable_jit(): output = model_jitted(inputs_dict["pixel_values"]) self.assertEqual(jitted_output.shape, output.shape) self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
def test_lax_dot_has_integer_inputs_in_quantized_dot( self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot, prec): weight_params = QuantOps.WeightParams(prec=prec, axis=(0, ), half_shift=False) act_params = QuantOps.ActHParams(input_distribution=act_distribution, bounds=jnp.array([[3.0, 1.5]]), prec=prec, half_shift=False) act = self.lhs if act_distribution == 'positive': act = jnp.abs(act) # We need this context manager to stop Jax from trying to compile the arms # of the `lax.cond` call in `dot_general_aqt`. By default, Jax will always # try to compile the functions passed to `lax.cond`, even if outside of a # JITed context. JIT compilation is incompatible with using a mock for the # call to 'dot_general' because during compilation Jax will expect # 'dot_general' to return a tracer and will throw an error if it returns a # mock instead. By explicily using jax.disable_jit, Jax will not try to # compile the arms to lax.cond and so using a mock will work fine. with jax.disable_jit(): quantization.quantized_dot( w=self.rhs, act=act, weight_params=weight_params, act_hparams=act_params, get_bounds_params=None, quant_type=QuantType.aqt, prefer_int8_to_int32_dot=prefer_int8_to_int32_dot) act_inputs, weight_inputs = mock_dot_general.call_args[0] self.assert_is_integer_in_range(act_inputs, prec=prec, distribution=act_distribution) self.assert_is_integer_in_range(weight_inputs, prec=prec, distribution='symmetric') if prefer_int8_to_int32_dot and not (act_distribution == 'positive' and prec == 8): expected_input_dtype = jnp.int8 else: expected_input_dtype = jnp.float32 self.assertEqual(act_inputs.dtype, expected_input_dtype) self.assertEqual(weight_inputs.dtype, expected_input_dtype)
def test_jit_compilation(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) for model_class in self.all_model_classes: with self.subTest(model_class.__name__): prepared_inputs_dict = self._prepare_for_class( inputs_dict, model_class) model = model_class(config) @jax.jit def model_jitted(input_ids, attention_mask=None, token_type_ids=None): return model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ).to_tuple() with self.subTest("JIT Enabled"): jitted_outputs = model_jitted(**prepared_inputs_dict) with self.subTest("JIT Disabled"): with jax.disable_jit(): outputs = model_jitted(**prepared_inputs_dict) self.assertEqual(len(outputs), len(jitted_outputs)) for jitted_output, output in zip(jitted_outputs, outputs): self.assertEqual(jitted_output.shape, output.shape) @jax.jit def model_jitted_return_dict(input_ids, attention_mask=None, token_type_ids=None): return model( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, ) # jitted function cannot return OrderedDict with self.assertRaises(TypeError): model_jitted_return_dict(**prepared_inputs_dict)
def test_get_text_features(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( ) model = FlaxCLIPModel(config) @jax.jit def model_jitted(input_ids, attention_mask, **kwargs): return model.get_text_features(input_ids=input_ids, attention_mask=attention_mask) with self.subTest("JIT Enabled"): jitted_output = model_jitted(**inputs_dict) with self.subTest("JIT Disabled"): with jax.disable_jit(): output = model_jitted(**inputs_dict) self.assertEqual(jitted_output.shape, output.shape) self.assertTrue(np.allclose(jitted_output, output, atol=1e-3))
def debug_mvee(): import pylab as plt n = random.normal(random.PRNGKey(0), (10000,2)) n = n /jnp.linalg.norm(n, axis=1, keepdims=True) angle = jnp.arctan2(n[:,1], n[:,0]) plt.hist(angle, bins=100) plt.show() N = 120 D = 2 points = random.uniform(random.PRNGKey(0), (N, D)) from jax import disable_jit with disable_jit(): center, radii, rotation = minimum_volume_enclosing_ellipsoid(points, 0.01) plt.hist(jnp.linalg.norm((rotation.T @ (points.T - center[:, None])) / radii[:, None], axis=0)) plt.show() print(center, radii, rotation) plt.scatter(points[:, 0], points[:, 1]) theta = jnp.linspace(0., jnp.pi*2, 100) ellipsis = center[:, None] + rotation @ jnp.stack([radii[0]*jnp.cos(theta), radii[1]*jnp.sin(theta)], axis=0) plt.plot(ellipsis[0,:], ellipsis[1,:]) for i in range(1000): y = sample_ellipsoid(random.PRNGKey(i), center, radii, rotation) plt.scatter(y[0], y[1]) C = jnp.linalg.pinv(jnp.cov(points, rowvar=False, bias=True)) p = (N - D - 1)/N def q(p): return p + p**2/(4.*(D-1)) C = C / q(p) c = jnp.mean(points, axis=0) W, Q, Vh = jnp.linalg.svd(C) radii = jnp.reciprocal(jnp.sqrt(Q)) rotation = Vh.conj().T ellipsis = c[:, None] + rotation @ jnp.stack([radii[0] * jnp.cos(theta), radii[1] * jnp.sin(theta)], axis=0) plt.plot(ellipsis[0, :], ellipsis[1, :]) plt.show()
def test_multiple_sentences(jit): tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") model = FlaxRobertaModel.from_pretrained("roberta-base") sentences = ["this is an example sentence", "this is another", "and a third one"] encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True) @jax.jit def model_jitted(input_ids, attention_mask): return model(input_ids, attention_mask) if jit == "disable_jit": with jax.disable_jit(): tokens, pooled = model_jitted(**encodings) else: tokens, pooled = model_jitted(**encodings) assert tokens.shape == (3, 7, 768) assert pooled.shape == (3, 768)
def test_accept_order(): def constraint(u): return u ** 2 > 0.5 def accept_prob(u): if u > 0.9: return 1 if u > 0.8: return 0.5 if u > 0.7: return 0.25 return 0. def f1(key): while True: key, u_key = random.split(key, 2) u = random.uniform(u_key) if constraint(u): key, a_key = random.split(key, 2) if random.uniform(a_key) < accept_prob(u): return u def f2(key): while True: key, u_key, a_key = random.split(key, 3) u = random.uniform(u_key) if random.uniform(a_key) < accept_prob(u): if constraint(u): return u from jax import vmap, disable_jit import pylab as plt keys = random.split(random.PRNGKey(0), 1000) with disable_jit(): u1 = jnp.array([f1(key) for key in keys]) u2 = jnp.array([f2(key) for key in keys]) print(u1) plt.hist(u1, bins='auto', alpha=0.5) plt.hist(u2, bins='auto', alpha=0.5) plt.show()