def compile(self): """Compile the update function and validation function. The update function is used to update parameters of the network. It first calls the loss function, compute the gradient, and finally calls the optimizer. """ _train_loss_fn = hk.transform_with_state(self.train_loss_fn) _vag = jax.value_and_grad(_train_loss_fn.apply, has_aux=True) rng = jax.random.PRNGKey(42) params, aux = _train_loss_fn.init(rng, next(self.train_iter)) optim = self.optimizer.init(params) self.state: TrainerState = TrainerState(params, aux, optim, rng) def _update_ops(state: TrainerState, inputs): rng, rng_next = jax.random.split(state.rng, 2) (loss, aux), grads = _vag(state.params, state.aux, rng, inputs) grads, optim = self.optimizer.update(grads, state.optim, state.params) params = optax.apply_updates(state.params, grads) return TrainerState(params, aux, optim, rng_next), loss self.jit_update_ops = jax.jit(_update_ops) if self.val_loss_fn is not None: _val_loss_obj = hk.transform_with_state(self.val_loss_fn) def _val_loss_fn(inputs, state: TrainerState): return _val_loss_obj.apply(state.params, state.aux, state.rng, inputs)[0] self.jit_val_loss_fn = jax.jit(_val_loss_fn)
def test_vmap(self, module_fn: ModuleFn, shape, dtype): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) # Expand our input since we will map over it. x = jnp.broadcast_to(x, (2, ) + x.shape) f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda f_mapped = hk.transform_with_state( lambda x: hk.vmap(lambda x: module_fn()(x))(x)) # pylint: disable=unnecessary-lambda params, state = f_mapped.init(rng, x) # JAX vmap with explicitly unmapped params/state/rng. This should be # equivalent to `f_mapped.apply(..)` (since by default hk.vmap does not map # params/state/rng). v_apply = jax.vmap(f.apply, in_axes=(None, None, None, 0), out_axes=(0, None)) module_type = descriptors.module_type(module_fn) atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) jax.tree_multimap(assert_allclose, f_mapped.apply(params, state, rng, x), v_apply(params, state, rng, x))
def test_logits_config(self, resnet_class, resnet_v2): def model_func_logits_config_default(img): model = resnet_class(1000, resnet_v2=resnet_v2) return model(img, is_training=True) def model_func_logits_config_modified(img): model = resnet_class(1000, resnet_v2=resnet_v2, logits_config=dict(w_init=jnp.ones)) return model(img, is_training=True) image = jnp.ones([2, 64, 64, 3]) rng = jax.random.PRNGKey(0) model = hk.transform_with_state(model_func_logits_config_default) params, _ = model.init(rng, image) logits_keys = [k for k in params.keys() if "/logits" in k] self.assertLen(logits_keys, 1) # Check logits params are zeros w_logits = params[logits_keys[0]]["w"] np.testing.assert_allclose(jnp.zeros_like(w_logits), w_logits) model = hk.transform_with_state(model_func_logits_config_modified) params, _ = model.init(rng, image) # Check logits params are ones w_logits = params[logits_keys[0]]["w"] np.testing.assert_allclose(jnp.ones_like(w_logits), w_logits)
def setUp(self): super(GatedLinearNetworkTest, self).setUp() self._name = "test_network" self._rng = hk.PRNGSequence(jax.random.PRNGKey(42)) self._output_sizes = (4, 5, 6) self._context_dim = 2 self._bias_len = 3 def gln_factory(): return gaussian.GatedLinearNetwork( output_sizes=self._output_sizes, context_dim=self._context_dim, bias_len=self._bias_len, name=self._name, ) def inference_fn(inputs, side_info): return gln_factory().inference(inputs, side_info, 0.5) def batch_inference_fn(inputs, side_info): return jax.vmap(inference_fn, in_axes=(0, 0))(inputs, side_info) def update_fn(inputs, side_info, label, learning_rate): params, predictions, unused_loss = gln_factory().update( inputs, side_info, label, learning_rate, 0.5) return predictions, params def batch_update_fn(inputs, side_info, label, learning_rate): predictions, params = jax.vmap(update_fn, in_axes=(0, 0, 0, None))(inputs, side_info, label, learning_rate) avg_params = tree.map_structure(lambda x: jnp.mean(x, axis=0), params) return predictions, avg_params # Haiku transform functions. self._init_fn, inference_fn_ = hk.without_apply_rng( hk.transform_with_state(inference_fn)) self._batch_init_fn, batch_inference_fn_ = hk.without_apply_rng( hk.transform_with_state(batch_inference_fn)) _, update_fn_ = hk.without_apply_rng( hk.transform_with_state(update_fn)) _, batch_update_fn_ = hk.without_apply_rng( hk.transform_with_state(batch_update_fn)) self._inference_fn = jax.jit(inference_fn_) self._batch_inference_fn = jax.jit(batch_inference_fn_) self._update_fn = jax.jit(update_fn_) self._batch_update_fn = jax.jit(batch_update_fn_)
def __init__( self, mode: str, init_rng: jnp.ndarray, config: config_dict.ConfigDict, ): """Initializes experiment.""" super(Experiment, self).__init__(mode=mode, init_rng=init_rng) tf.config.experimental.set_visible_devices([], device_type='GPU') tf.config.experimental.set_visible_devices([], device_type='TPU') if mode not in ('train', 'eval', 'train_eval_multithreaded'): raise ValueError(f'Invalid mode {mode}.') self.mode = mode self.config = config self.init_rng = init_rng self.forward = hk.transform_with_state(self._forward_fn) self._predictions = None # Needed for checkpoint restore. self._params = None self._ema_params = None self._network_state = None self._ema_network_state = None self._opt_state = None # Track what has started. self._training = False self._evaluating = False
def get_model_fn_t( num_initial_atom_features: int, num_atom_features: int, num_bond_features: int, num_convs: int, num_hidden_layers: int, num_hidden_features: int, max_num_neighbors: int, batch_size: int, ): def model_fn(batch: Batch, is_training: bool) -> jnp.ndarray: model = CGCNN( num_initial_atom_features=num_initial_atom_features, num_atom_features=num_atom_features, num_bond_features=num_bond_features, num_convs=num_convs, num_hidden_layers=num_hidden_layers, num_hidden_features=num_hidden_features, max_num_neighbors=max_num_neighbors, batch_size=batch_size, name="cgcnn", ) neighbor_indices = batch["neighbor_indices"] atom_features = batch["atom_features"] bond_features = batch["bond_features"] num_atoms = batch["num_atoms"] segment_ids = batch["segment_ids"] return model(neighbor_indices, atom_features, bond_features, num_atoms, segment_ids, is_training) return hk.transform_with_state(model_fn)
def get_model(model_name, data_info, **kwargs): _MODEL_FNS = { "lenet": make_lenet5_fn, "resnet20": make_resnet20_fn, "resnet20_frn": make_resnet20_frn_fn, "resnet20_frn_swish": functools.partial(make_resnet20_frn_fn, activation=jax.nn.swish), "cnn_lstm": make_cnn_lstm, "smooth_cnn_lstm": make_smooth_cnn_lstm, "mlp_regression": make_mlp_regression, "mlp_regression_small": make_mlp_regression_small, "mlp_classification": make_mlp_classification, "logistic_regression": make_logistic_regression, } net_fn = _MODEL_FNS[model_name](data_info, **kwargs) net = hk.transform_with_state(net_fn) return net.apply, net.init
def test_numpy_and_jax_results_close(self, module_fn: ModuleFn, shape, dtype): f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda rng = jax.random.PRNGKey(42) x = jnp.ones(shape, dtype) params, state = f.init(rng, x) out, new_state = f.apply(params, state, rng, x) np_rng = np.asarray(rng) np_x = np.asarray(x) with self.subTest('init'): params2, state2 = f.init(np_rng, np_x) jax.tree_multimap(np.testing.assert_allclose, params, params2) jax.tree_multimap(np.testing.assert_allclose, state, state2) with self.subTest('apply'): np_params = jax.tree_map(np.asarray, params) np_state = jax.tree_map(np.asarray, state) out2, new_state2 = f.apply(np_params, np_state, np_rng, np_x) jax.tree_multimap(np.testing.assert_allclose, out, out2) jax.tree_multimap(np.testing.assert_allclose, new_state, new_state2)
def test_forward_shape(self): """Test output shape of SparseGCNPredicator""" forward = hk.transform_with_state(self.__forward) params, state = forward.init(next(self.key), *self.input_data) preds, _ = forward.apply(params, state, next(self.key), *self.input_data) assert preds.shape == (batch_size, n_out)
def test_forward_shape(self): """Test output shape of PadGCN""" forward = hk.transform_with_state(self.__forward) params, state = forward.init(next(self.key), *self.input_data) preds, _ = forward.apply(params, state, next(self.key), *self.input_data) assert preds.shape == (batch_size, max_node_size, hidden_feats[-1])
def __init__(self, mode, init_rng, config): """Initializes experiment.""" super(Experiment, self).__init__(mode=mode, init_rng=init_rng) self.mode = mode self.init_rng = init_rng self.config = config # Checkpointed experiment state. self._params = None self._state = None self._opt_state = None # Input pipelines. self._train_input = None self._eval_input = None self.forward = hk.transform_with_state(self._forward_fn) # NOTE: We "donate" the `params, state, opt_state` arguments which allows # JAX (on some backends) to reuse the device memory associated with these # inputs to store the outputs of our function (which also start with # `params, state, opt_state`). self._update_func = jax.pmap(self._update_func, axis_name='i', donate_argnums=(0, 1, 2)) self._eval_batch = jax.jit(self._eval_batch)
def __init__(self, func, observation_space, action_space=None, random_seed=None): if not isinstance(observation_space, Space): raise TypeError( f"observation_space must be derived from gym.Space, got: {type(observation_space)}") self.observation_space = observation_space if action_space is not None: if not isinstance(action_space, Space): raise TypeError( f"action_space must be derived from gym.Space, got: {type(action_space)}") self.action_space = action_space self.random_seed = random_seed # also initializes self.rng via RandomStateMixin self._jitted_funcs = {} # Haiku-transform the provided func example_data = self._check_signature(func) static_argnums = tuple(i + 3 for i in example_data.inputs.static_argnums) transformed = hk.transform_with_state(func) self._function = jit(transformed.apply, static_argnums=static_argnums) # init function params and state self._params, self._function_state = transformed.init(self.rng, *example_data.inputs.args) # check if output has the expected shape etc. output, _ = \ self._function(self.params, self.function_state, self.rng, *example_data.inputs.args) self._check_output(output, example_data.output) def soft_update_func(old, new, tau): return jax.tree_multimap(lambda a, b: (1 - tau) * a + tau * b, old, new) self._soft_update_func = jit(soft_update_func)
def test_graph_conditioned_transformer_learns(self): graphs = jraph.GraphsTuple( nodes=np.ones((4, 3), dtype=np.float32), edges=np.ones((3, 1), dtype=np.float32), senders=np.array([0, 2, 3], dtype=np.int32), receivers=np.array([1, 3, 2], dtype=np.int32), n_node=np.array([2, 2], dtype=np.int32), n_edge=np.array([1, 2], dtype=np.int32), globals=None, ) seqs = np.array([[1, 2, 2, 0], [1, 3, 3, 3]], dtype=np.int32) vocab_size = seqs.max() + 1 embed_dim = 8 max_graph_size = graphs.n_node.max() logging.info('Training seqs: %r', seqs) x = seqs[:, :-1] y = seqs[:, 1:] def model_fn(vocab_size, embed_dim): return models.Graph2TextTransformer( vocab_size=vocab_size, emb_dim=embed_dim, num_layers=2, num_heads=4, cutoffs=[], gnn_embed_dim=embed_dim, gnn_num_layers=2) def forward(graphs, inputs, labels, max_graph_size): input_mask = (labels != 0).astype(jnp.float32) return model_fn(vocab_size, embed_dim).loss( graphs, max_graph_size, False, inputs, labels, mask=input_mask) init_fn, apply_fn = hk.transform_with_state(forward) rng = hk.PRNGSequence(8) params, state = init_fn(next(rng), graphs, x, y, max_graph_size) def apply(*args, **kwargs): out, state = apply_fn(*args, **kwargs) return out[0], (out[1], state) apply = jax.jit(apply, static_argnums=6) optimizer = optax.chain( optax.scale_by_adam(), optax.scale(-1e-3)) opt_state = optimizer.init(params) for i in range(500): (loss, model_state), grad = jax.value_and_grad(apply, has_aux=True)( params, state, next(rng), graphs, x, y, max_graph_size) metrics, state = model_state updates, opt_state = optimizer.update(grad, opt_state, params) params = optax.apply_updates(params, updates) if (i + 1) % 100 == 0: logging.info( 'Step %d, %r', i + 1, {k: float(v) for k, v in metrics.items()}) logging.info('Loss: %.8f', loss) self.assertLess(loss, 1.0)
def test_bow_transformer_runs(self): bow = np.array([[0, 0, 1, 0, 2, 0, 0, 1], [0, 1, 0, 0, 1, 0, 1, 0], [1, 0, 0, 0, 1, 0, 0, 1]], dtype=np.int32) seqs = np.array([[1, 2, 3, 0, 0], [2, 4, 5, 6, 0], [3, 3, 5, 1, 2]], dtype=np.int32) x = seqs[:, :-1] y = seqs[:, 1:] vocab_size = seqs.max() + 1 def forward(bow, inputs, labels): model = models.Bow2TextTransformer( vocab_size=vocab_size, emb_dim=16, num_layers=2, num_heads=4, cutoffs=[]) return model.loss(bow, inputs, labels) init_fn, apply_fn = hk.transform_with_state(forward) key = hk.PRNGSequence(8) params, state = init_fn(next(key), bow, x, y) out, _ = apply_fn(params, state, next(key), bow, x, y) loss, metrics = out logging.info('loss: %g', loss) logging.info('metrics: %r', metrics)
def test_vmap( self, module_fn: ModuleFn, shape: Shape, dtype: DType, ): batch_size, shape = shape[0], shape[1:] rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): sample = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: sample = jax.random.uniform(rng, shape, dtype) batch = jnp.broadcast_to(sample, (batch_size, ) + sample.shape) def g(x): return module_fn()(x) f = hk.transform_with_state(g) # Ensure application under vmap is the same. params, state = f.init(rng, sample) v_apply = jax.vmap(f.apply, in_axes=(None, None, None, 0)) jax.tree_multimap( lambda a, b: np.testing.assert_allclose(a, b, atol=DEFAULT_ATOL), f.apply(params, state, rng, batch), v_apply(params, state, rng, batch))
def test_jit( self, module_fn: ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x): return module_fn()(x) f = hk.transform_with_state(g) atol = CUSTOM_ATOL.get(module_type(module_fn), DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) # Ensure initialization under jit is the same. jax.tree_multimap(assert_allclose, f.init(rng, x), jax.jit(f.init)(rng, x)) # Ensure application under jit is the same. params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x), jax.jit(f.apply)(params, state, rng, x))
def test_hk_remat( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, remat=False): mod = module_fn() if remat: mod = hk.remat(mod) out = mod(x) if isinstance(out, dict): out = out['loss'] return jnp.mean(out) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) grad_jax_remat = jax.grad(jax.remat(f.apply), has_aux=True) grad_hk_remat = jax.grad(functools.partial(f.apply, remat=True), has_aux=True) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, grad_jax_remat(params, state, rng, x), grad_hk_remat(params, state, rng, x))
def test_optimize_rng_use_under_jit( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x): return module_fn()(x) f = hk.transform_with_state(hk.experimental.optimize_rng_use(g)) module_type = descriptors.module_type(module_fn) atol = CUSTOM_ATOL.get(module_type, DEFAULT_ATOL) assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol) params, state = jax.jit(f.init)(rng, x) jax.tree_multimap(assert_allclose, (params, state), f.init(rng, x)) if module_type in (hk.nets.VectorQuantizer, hk.nets.VectorQuantizerEMA): # For stochastic modules just test apply runs. jax.device_get(jax.jit(f.apply)(params, state, rng, x)) else: jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x))
def test_hk_jit( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, init: bool, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, jit=False): mod = module_fn() if jit: mod = hk.jit(mod) return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-4) # NOTE: We shard init/apply tests since some modules are expensive to jit # (e.g. ResNet50 takes ~60s to compile and we compile it twice per test). if init: jax.tree_multimap(assert_allclose, jax.jit(f.init)(rng, x), f.init(rng, x, jit=True)) else: params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, jax.jit(f.apply)(params, state, rng, x), f.apply(params, state, rng, x, jit=True))
def __init__( self, output_dim: int, activation_fn: str = "relu", stochastic_parameters: bool = False, linear_model: bool = False, dropout: bool = False, dropout_rate: float = 0.0, ): """Wrapper of resnet50_fsvi Args: output_dim: the output dimension activation_fn: the type of activation function, e.g. "relu", "tanh" stochastic_parameters: if True, we keep a variational distribution of parameters. linear_model: if True, only put variational distribution on the last layer. dropout: if True, apply dropout. dropout_rate: dropout rate if we apply dropout. """ self.output_dim = output_dim self.linear_model = linear_model self.dropout = dropout self.dropout_rate = dropout_rate self.activation_fn = ACTIVATION_DICT[activation_fn] self.stochastic_parameters = stochastic_parameters self.forward = hk.transform_with_state(self.make_forward_fn())
def test_profiler_name_scopes( self, module_fn: descriptors.ModuleFn, shape: Shape, dtype: DType, ): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def g(x, name_scopes=False): hk.experimental.profiler_name_scopes(enabled=name_scopes) mod = module_fn() return mod(x) f = hk.transform_with_state(g) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5) params, state = f.init(rng, x) jax.tree_multimap(assert_allclose, f.apply(params, state, rng, x), f.apply(params, state, rng, x, name_scopes=True)) # TODO(lenamartens): flip to True when default changes hk.experimental.profiler_name_scopes(enabled=False)
def test_forward_shape_with_batch_norm(self): """Test output shape of PadGCNLayer with BatchNorm""" forward = hk.transform_with_state(self.__forward_with_batch_norm) params, state = forward.init(next(self.key), *self.input_data) preds, _ = forward.apply(params, state, next(self.key), *self.input_data) assert preds.shape == (batch_size, max_node_size, out_feats)
def test_abstract_to_dot(self, module_fn: ModuleFn, shape, dtype): f = hk.transform_with_state(lambda x: module_fn()(x)) # pylint: disable=unnecessary-lambda rng = jax.random.PRNGKey(42) x = np.ones(shape, dtype) params, state = jax.eval_shape(f.init, rng, x) self.assertIsNotNone( hk.experimental.abstract_to_dot(f.apply)(params, state, rng, x))
def test_transformer_with_extra_runs(self): extra = np.array([[1, 1, 0, 0], [2, 2, 2, 2], [3, 3, 3, 0]], dtype=np.int32) seqs = np.array([[1, 2, 3, 0, 0], [2, 4, 5, 6, 0], [3, 3, 5, 1, 2]], dtype=np.int32) x = seqs[:, :-1] y = seqs[:, 1:] vocab_size = seqs.max() + 1 extra_vocab_size = extra.max() + 1 def forward(inputs, labels, extra): input_mask = (labels != 0).astype(jnp.float32) extra_mask = (extra != 0).astype(jnp.float32) extra = hk.Embed(vocab_size=extra_vocab_size, embed_dim=16)(extra) model = models.TransformerXL( vocab_size=vocab_size, emb_dim=16, num_layers=2, num_heads=4, cutoffs=[], ) return model.loss(inputs, labels, mask=input_mask, extra=extra, extra_mask=extra_mask) init_fn, apply_fn = hk.transform_with_state(forward) key = hk.PRNGSequence(8) params, state = init_fn(next(key), x, y, extra) out, _ = apply_fn(params, state, next(key), x, y, extra) loss, metrics = out logging.info('loss: %g', loss) logging.info('metrics: %r', metrics)
def test_transformer_param_count(self): seqs = np.array([[1, 2, 3, 0, 0], [3, 3, 5, 1, 2]], dtype=np.int32) x = seqs[:, :-1] y = seqs[:, 1:] vocab_size = 267_735 def forward(inputs, labels): input_mask = (labels != 0).astype(jnp.float32) model = models.TransformerXL( vocab_size=vocab_size, emb_dim=210, num_layers=2, num_heads=10, dropout_prob=0.0, dropout_attn_prob=0.0, self_att_init_scale=0.02, dense_init_scale=0.02, dense_dim=2100, cutoffs=(20000, 40000, 200000), # WikiText-103 relative_pos_clamp_len=None, ) return model.loss(inputs, labels, mask=input_mask, cache_steps=2) init_fn, apply_fn = hk.transform_with_state(forward) key = hk.PRNGSequence(8) params, state = init_fn(next(key), x, y) out, _ = apply_fn(params, state, next(key), x, y) loss, metrics = out logging.info('loss: %g', loss) logging.info('metrics: %r', metrics) param_count = tree_size(params) self.assertEqual(param_count, 58_704_438)
def __init__(self, loss_fn, optimizer, devices=None, has_graph=False): self._net_init_fn, self._apply_fn = hk.transform_with_state( functools.partial(loss_fn, is_training=True)) _, self._eval_apply_fn = hk.transform_with_state( functools.partial(loss_fn, is_training=False)) if optimizer is None: optimizer = optax.identity() self._optimizer = optimizer self._num_devices = jax.local_device_count() if devices is None: devices = [] for host_id in range(jax.process_count()): for device_id in jax.local_devices(host_id): devices.append(device_id) else: self._num_devices = min(self._num_devices, len(devices)) def _pmap(f, static_broadcasted_argnums=()): return jax.pmap( f, axis_name='i', devices=devices, static_broadcasted_argnums=static_broadcasted_argnums) def handle_graph_size(fn): def _fn(*args): batch = args[-1].copy() max_graph_size = batch['max_graph_size'] del batch['max_graph_size'] args = args[:-1] + (batch, max_graph_size) return fn(*args) return _fn # Try to jit. if has_graph: # If the model contains full graphs, we need to set the max_graph_size # as a statically broadcasted argument. self._init_fn = handle_graph_size(_pmap(self._init, 4)) self._update_fn = handle_graph_size(_pmap(self._update, 2)) self._eval_fn = handle_graph_size(_pmap(self._eval, 2)) else: self._init_fn = _pmap(self._init) self._update_fn = _pmap(self._update) self._eval_fn = _pmap(self._eval)
def __init__(self, *args, **kwargs): self._model = hk.transform_with_state( lambda *a, **k: # pylint: disable=g-long-lambda,unnecessary-lambda s3d.S3D(normalize_fn=normalization.get_normalize_fn(), *args, **kwargs)(*a, **k)) self._rng = jax.random.PRNGKey(42) self._params, self._state = None, None
def test_hk_scan(self, module_fn: descriptors.ModuleFn, shape, dtype, init): rng = jax.random.PRNGKey(42) if jnp.issubdtype(dtype, jnp.integer): x = jax.random.randint(rng, shape, 0, np.prod(shape), dtype) else: x = jax.random.uniform(rng, shape, dtype) def f(x): mod = module_fn() return mod(x) def u_f(xs): mod = module_fn() def s(carry, x): y = mod(x) return carry, y _, ys = hk.scan(s, (), xs) return ys u_f = hk.transform_with_state(u_f) f = hk.transform_with_state(f) assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-4) xs = jnp.broadcast_to(x, (8, ) + x.shape) params, state = f.init(rng, x) if init: u_params, u_state = u_f.init(rng, xs) jax.tree_multimap(assert_allclose, u_params, params) jax.tree_multimap(assert_allclose, u_state, state) return def fun(state, x): y, state = f.apply(params, state, rng, x) return state, y s_state, s_ys = jax.lax.scan(fun, state, xs) u_ys, u_state = u_f.apply(params, state, rng, xs) jax.tree_multimap(assert_allclose, u_ys, s_ys) jax.tree_multimap(assert_allclose, u_state, s_state)
def __init__(self, random_seed, num_classes, batch_size, max_steps, enable_double_transpose, checkpoint_to_evaluate, allow_train_from_scratch, freeze_backbone, network_config, optimizer_config, lr_schedule_config, evaluation_config, checkpointing_config): """Constructs the experiment. Args: random_seed: the random seed to use when initializing network weights. num_classes: the number of classes; used for the online evaluation. batch_size: the total batch size; should be a multiple of the number of available accelerators. max_steps: the number of training steps; used for the lr/target network ema schedules. enable_double_transpose: see dataset.py; only has effect on TPU. checkpoint_to_evaluate: the path to the checkpoint to evaluate. allow_train_from_scratch: whether to allow training without specifying a checkpoint to evaluate (training from scratch). freeze_backbone: whether the backbone resnet should remain frozen (linear evaluation) or be trainable (fine-tuning). network_config: the configuration for the network. optimizer_config: the configuration for the optimizer. lr_schedule_config: the configuration for the learning rate schedule. evaluation_config: the evaluation configuration. checkpointing_config: the configuration for checkpointing. """ self._random_seed = random_seed self._enable_double_transpose = enable_double_transpose self._num_classes = num_classes self._lr_schedule_config = lr_schedule_config self._batch_size = batch_size self._max_steps = max_steps self._checkpoint_to_evaluate = checkpoint_to_evaluate self._allow_train_from_scratch = allow_train_from_scratch self._freeze_backbone = freeze_backbone self._optimizer_config = optimizer_config self._evaluation_config = evaluation_config # Checkpointed experiment state. self._experiment_state = None # Input pipelines. self._train_input = None self._eval_input = None backbone_fn = functools.partial(self._backbone_fn, **network_config) self.forward_backbone = hk.without_apply_rng( hk.transform_with_state(backbone_fn)) self.forward_classif = hk.without_apply_rng( hk.transform(self._classif_fn)) self.update_pmap = jax.pmap(self._update_func, axis_name='i') self.eval_batch_jit = jax.jit(self._eval_batch) self._is_backbone_training = not self._freeze_backbone self._checkpointer = checkpointing.Checkpointer(**checkpointing_config)
def test_info_and_html(self, module_fn: ModuleFn, shape, dtype): x = jnp.ones(shape, dtype) f = hk.transform_with_state(lambda: module_fn()(x)) rng = jax.random.PRNGKey(42) params, state = f.init(rng) info = jaxpr_info.make_model_info(f.apply)(params, state, rng) if descriptors.module_type(module_fn).__name__ != 'Sequential': self.assertNotEmpty(info.expressions) self.assertIsNotNone(jaxpr_info.as_html_page(info))