def loss_fn(model): if train: with nn.stateful(state.model_state) as new_model_state: with nn.stochastic(run_rng): if not class_conditional: scores = model(perturbed_data, labels, train=train) else: scores = model(perturbed_data, labels, y=class_labels, train=train) else: with nn.stateful(state.model_state, mutable=False): with nn.stochastic(run_rng): if not class_conditional: scores = model(perturbed_data, labels, train=train) else: scores = model(perturbed_data, labels, y=class_labels, train=train) new_model_state = state.model_state scores = scores.reshape((scores.shape[0], -1)) target = -1 / (used_sigmas**2) * noise target = target.reshape((target.shape[0], -1)) losses = 1 / 2. * ((scores - target)**2).sum( axis=-1) * used_sigmas.squeeze()**anneal_power loss = jnp.mean(losses) if loss_per_sigma: return loss, new_model_state, losses else: return loss, new_model_state
def loss_fn(model): if train: with nn.stateful(state.model_state) as new_model_state: with nn.stochastic(run_rng): scores = model(perturbed_data, T, train=train) else: with nn.stateful(state.model_state, mutable=False): with nn.stochastic(run_rng): scores = model(perturbed_data, T, train=train) new_model_state = state.model_state scores = scores.reshape((scores.shape[0], -1)) target = noise.reshape((noise.shape[0], -1)) loss = jnp.mean((scores - target)**2) return loss, new_model_state
def create_model(key, batch_size, image_size, model_dtype, space_to_depth): """Initialize a ResNet-50 model.""" if space_to_depth: input_shape = (batch_size, image_size // 2, image_size // 2, 3 * 2 * 2) else: input_shape = (batch_size, image_size, image_size, 3) model_type = models.FakeResNet if FLAGS.fake_model else models.ResNet batchnorm_span = FLAGS.batchnorm_span if batchnorm_span is None: batchnorm_span = max(batch_size, 64) if FLAGS.distributed_batchnorm and (batch_size < batchnorm_span <= batch_size * jax.device_count()): mllogger.event('model_bn_span', batchnorm_span) model_def = model_type.partial(num_classes=1000, axis_name='batch', axis_index_groups=local_replica_groups( batchnorm_span // batch_size), dtype=model_dtype, conv0_space_to_depth=space_to_depth) else: mllogger.event('model_bn_span', batch_size) model_def = model_type.partial(num_classes=1000, dtype=model_dtype, conv0_space_to_depth=space_to_depth) with nn.stateful() as init_state: _, params = model_def.init_by_shape(key, [(input_shape, model_dtype)]) model = nn.Model(model_def, params) return model, init_state
def evaluate_batch(self, flax_module, batch_stats, batch): """Evaluates cross_entopy on the given batch.""" # TODO(ankugarg): Augment with other metrics like log-perplexity. with nn.stateful(batch_stats, mutable=False): logits = flax_module(batch['inputs'], batch['targets'], batch.get('inputs_positions'), batch.get('targets_positions'), batch.get('inputs_segmentation'), batch.get('targets_segmentation'), train=False) weights = batch.get('weights') targets = batch['targets'] if self.dataset_meta_data['apply_one_hot_in_loss']: targets = one_hot(batch['targets'], logits.shape[-1]) # Add log-perplexity metric. evaluated_metrics = {} for key in self.metrics_bundle: per_example_metrics = self.metrics_bundle[key](logits, targets, weights) evaluated_metrics[key] = jnp.sum( lax.psum(per_example_metrics, axis_name='batch')) return evaluated_metrics
def training_cost(self, flax_module, batch_stats, batch, dropout_rng): """Return cross entropy loss with (optional) L2 penalty on the weights.""" with nn.stateful(batch_stats) as new_batch_stats: with nn.stochastic(dropout_rng): # inputs/targets positions and segmentations are required # when we have packed examples. logits = flax_module(batch['inputs'], batch['targets'], batch.get('inputs_positions'), batch.get('targets_positions'), batch.get('inputs_segmentation'), batch.get('targets_segmentation'), train=True) weights = batch.get('weights') targets = batch['targets'] if self.dataset_meta_data['apply_one_hot_in_loss']: targets = one_hot(batch['targets'], logits.shape[-1]) # Optionally apply label smoothing. if self.hps.get('label_smoothing') is not None: targets = model_utils.apply_label_smoothing( targets, self.hps.get('label_smoothing')) total_loss = self.loss_fn(logits, targets, weights) if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( flax_module.params, self.hps.l2_decay_rank_threshold) total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss return total_loss, (new_batch_stats)
def loss_fn(model): with nn.stateful(state.model_state) as new_model_state: rays = batch["rays"] ret = model(key_0, key_1, rays.origins, rays.directions, rays.viewdirs) if len(ret) not in (1, 2): raise ValueError( "ret should contain either 1 set of output (coarse only), or 2 sets" "of output (coarse as ret[0] and fine as ret[1]).") # The main prediction is always at the end of the ret list. rgb, unused_disp, unused_acc = ret[-1] loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean() psnr = utils.compute_psnr(loss) stats = [utils.Stats(loss=loss, psnr=psnr)] if len(ret) > 1: # If there are both coarse and fine predictions, we compuate the loss for # the coarse prediction (ret[0]) as well. rgb_c, unused_disp_c, unused_acc_c = ret[0] loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean() psnr_c = utils.compute_psnr(loss_c) stats.append(utils.Stats(loss=loss_c, psnr=psnr_c)) else: loss_c = 0. psnr_c = 0. return loss + loss_c, (new_model_state, stats)
def impl_loss_fn(model_params): with nn.stochastic(rng), nn.stateful( state.model_state) as new_model_state: logits, stats = module.call(model_params, batch["image"]) losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn] loss = sum(l(logits, batch["label"], stats) for l in losses) return loss, (logits, new_model_state, stats)
def initialize(flax_module_def, initializer, loss_fn, input_shape, output_shape, hps, rng, metrics_logger): """Run the given initializer. We initialize in 3 phases. First we run the default initializer that is specified by the model constructor. Next we apply any rescaling as specified by hps.layer_rescale_factors. Finally we run the black box initializer provided by the initializer arg (the default is noop). Args: flax_module_def: An uninitialized flax module definition. initializer: An initializer defined in init_lib. loss_fn: A loss function. input_shape: The input shape of a single data example. output_shape: The output shape of a single data example. hps: A dictionary specifying the model and initializer hparams. rng: An rng key to seed the initialization. metrics_logger: Used for black box initializers that have learning curves. Returns: A tuple (model, batch_stats), where model is the initialized flax.nn.Model and batch_stats is the collection used for batch norm. """ model_dtype = utils.dtype_from_str(hps.model_dtype) # init_by_shape should either pass in a tuple or a list of tuples. # For example, for vision tasks typically input_shape is (image_shape) # For seq2seq tasks, shape can be a list of two tuples corresponding to # input_sequence_shape for encoder and output_sequence_shape for decoder. # TODO(gilmer,ankugarg): Support initializers for list of tuples. if isinstance(input_shape, list): # Typical case for seq2seq models input_specs = [((hps.batch_size, *x), model_dtype) for x in input_shape] else: # Typical case for classification models input_specs = [((hps.batch_size, *input_shape), model_dtype)] params_rng, init_rng, dropout_rng = jax.random.split(rng, num=3) with nn.stateful() as batch_stats: with nn.stochastic(dropout_rng): # Using flax_module_def.create can OOM for larger models, so we must use # create by shape here. # TODO(gilmer) Link to flax issue when bug reporting process finalizes. _, params = flax_module_def.init_by_shape(params_rng, input_specs, train=False) model = nn.Model(flax_module_def, params) if hps.get('layer_rescale_factors'): model = model_utils.rescale_layers(model, hps.layer_rescale_factors) # We don't pass batch_stats to the initializer, the initializer will just # run batch_norm in train mode and does not need to maintain the batch_stats. # TODO(gilmer): We hardcode here weighted_cross_entropy, but this will need # to change for other models. Maybe have meta_loss_inner as an initializer # hyper_param? # TODO(gilmer): instead of passing in weighted_xent, pass in the model and get # the loss from that. new_model = initializer(loss_fn, model, hps, input_shape, output_shape, init_rng, metrics_logger) return new_model, batch_stats
def eval_step(model, state, batch, prev_metrics, image_format, space_to_depth): images, labels = batch images = maybe_transpose_images(images, image_format) images = normalize_images(images, space_to_depth) with nn.stateful(state, mutable=False): logits = model(images, train=False) metrics = compute_metrics(logits, labels) return jax.tree_multimap(jnp.add, prev_metrics, metrics)
def create_model(key, batch_size, image_size, model_dtype): input_shape = (batch_size, image_size, image_size, 3) module = models.ResNet.partial(num_classes=1000, dtype=model_dtype) with nn.stateful() as init_state: _, initial_params = module.init_by_shape(key, [(input_shape, model_dtype)]) model = nn.Model(module, initial_params) return model, init_state
def _create_model(key): module = flax_module.partial(**model_kwargs) with nn.stateful() as init_state: with nn.stochastic(key): _, initial_params = module.init_by_shape( key, [(input_shape, jnp.float32)]) model = nn.Model(module, initial_params) return model, init_state
def create_model(module, input_shape, rng): """Instanciates the model.""" model_rng, init_rng = jax.random.split(rng) with nn.stochastic(model_rng), nn.stateful() as init_state: x = jnp.ones(input_shape, dtype=jnp.float32) _, init_params = module.init(init_rng, x) model = nn.Model(module, init_params) return model, init_params, init_state
def eval_step(model, state, batch, prev_metrics): images, labels = batch['image'], batch['label'] if FLAGS.transpose_images: images = jnp.transpose(images, [3, 0, 1, 2]) images = normalize_images(images) with nn.stateful(state, mutable=False): logits = model(images, train=False) metrics = compute_metrics(logits, labels) return jax.tree_multimap(jnp.add, prev_metrics, metrics)
def loss_fn(model): """Loss function used for training.""" with nn.stateful(state) as new_state: with nn.stochastic(dropout_rng): logits = model(inputs, train=True) loss, weight_sum = train_utils.compute_weighted_cross_entropy( logits, targets, num_classes=num_classes, weights=None) mean_loss = loss / weight_sum return mean_loss, (new_state, logits)
def test_different_eval_batch_size(self): """Test virtual BN can use a different batch size for evals.""" rng = jax.random.PRNGKey(0) batch_size = 10 feature_size = 7 input_shape = (batch_size, 3, 3, feature_size) x = 2.0 * jnp.ones(input_shape) vbn_model_cls = normalization.VirtualBatchNorm.partial( momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') vbn_model, vbn_state = _init(vbn_model_cls, rng, input_shape) with nn.stateful(vbn_state) as vbn_state: vbn_model(x) with nn.stateful(vbn_state) as vbn_state: vbn_model(jnp.ones((13, 3, 3, feature_size)), use_running_average=True)
def render_image(state, rays, render_fn, rng, normalize_disp, chunk=8192): """Render all the pixels of an image (in test mode). Args: state: model_utils.TrainState. rays: a `Rays` namedtuple, the rays to be rendered. render_fn: function, jit-ed render function. rng: jnp.ndarray, random number generator (used in training mode only). normalize_disp: bool, if true then normalize `disp` to [0, 1]. chunk: int, the size of chunks to render sequentially. Returns: rgb: jnp.ndarray, rendered color image. disp: jnp.ndarray, rendered disparity image. acc: jnp.ndarray, rendered accumulated weights per pixel. """ height, width = rays[0].shape[:2] num_rays = height * width rays = datasets.ray_fn(lambda r: r.reshape((num_rays, -1)), rays) unused_rng, key_0, key_1 = jax.random.split(rng, 3) model = state.optimizer.target model_state = state.model_state host_id = jax.host_id() results = [] with nn.stateful(model_state, mutable=False): for i in range(0, num_rays, chunk): # pylint: disable=cell-var-from-loop print(" " + "X" * int((i / num_rays) * 78), end="\r") chunk_rays = datasets.ray_fn(lambda r: r[i:i + chunk], rays) chunk_size = chunk_rays[0].shape[0] rays_remaining = chunk_size % jax.device_count() rays_per_host = chunk_size // jax.host_count() if rays_remaining != 0: padding = jax.device_count() - rays_remaining chunk_rays = datasets.ray_fn( lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"), chunk_rays) else: padding = 0 # After padding the number of chunk_rays is always divisible by # host_count. start, stop = host_id * rays_per_host, (host_id + 1) * rays_per_host chunk_rays = datasets.ray_fn(lambda r: shard(r[start:stop]), chunk_rays) chunk_results = render_fn(key_0, key_1, model, chunk_rays)[-1] results.append([unshard(x[0], padding) for x in chunk_results]) # pylint: enable=cell-var-from-loop print("") rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)] # Normalize disp for visualization for ndc_rays in llff front-facing scenes. if normalize_disp: disp = (disp - disp.min()) / (disp.max() - disp.min()) return (rgb.reshape((height, width, -1)), disp.reshape( (height, width, -1)), acc.reshape((height, width, -1)))
def eval_step(model, state, batch, num_classes, flatten_input=True): eval_keys = ['inputs', 'targets'] (inputs, targets) = [batch.get(k, None) for k in eval_keys] if flatten_input: inputs = inputs.reshape(inputs.shape[0], -1) if jax.tree_leaves(state): state = jax.lax.pmean(state, 'batch') with nn.stateful(state, mutable=False): logits = model(inputs, train=False) return compute_metrics(logits, targets, num_classes, weights=None)
def test_create_model(self): model, state = train.create_model(random.PRNGKey(0), 8, 224, jnp.float32) x = random.normal(random.PRNGKey(1), (8, 224, 224, 3)) with nn.stateful(state) as new_state: y = model(x) state = jax.tree_map(onp.shape, state.as_dict()) new_state = jax.tree_map(onp.shape, new_state.as_dict()) self.assertEqual(state, new_state) self.assertEqual(y.shape, (8, 1000))
def test_batch_norm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (4, 3, 2)) model_cls = nn.BatchNorm.partial(momentum=0.9) with nn.stateful() as state_0: y, initial_params = model_cls.init(key2, x) model = nn.Model(model_cls, initial_params) mean = y.mean((0, 1)) var = y.var((0, 1)) onp.testing.assert_allclose(mean, onp.array([0., 0.]), atol=1e-4) onp.testing.assert_allclose(var, onp.array([1., 1.]), rtol=1e-4) with nn.stateful(state_0) as state: y = model(x) ema = state['/'] onp.testing.assert_allclose( ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4) onp.testing.assert_allclose( ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4)
def render_image(state, data, render_fn, rng, chunk=8192): """Render all the pixels of an image (in test mode). Args: state: model_utils.TrainState. data: dict, test example. render_fn: function, jit-ed render function. rng: jnp.ndarray, random number generator (used in training mode only). chunk: int, the size of chunks to render sequentially. Returns: rgb: jnp.ndarray, rendered color image. disp: jnp.ndarray, rendered disparity image. acc: jnp.ndarray, rendered accumulated weights per pixel. """ rays = data["rays"] h, w = rays.shape[:2] rays = rays.reshape((h * w, -1)) unused_rng, key_0, key_1 = jax.random.split(rng, 3) model = state.optimizer.target model_state = state.model_state host_id = jax.host_id() rgb = [] disp = [] acc = [] with nn.stateful(model_state, mutable=False): for i in range(0, rays.shape[0], chunk): print(" " + "X" * int((i / rays.shape[0]) * 78), end="\r") chunk_rays = rays[i:i + chunk] remainder = chunk_rays.shape[0] % jax.device_count() if remainder != 0: padding = jax.device_count() - remainder chunk_rays = jnp.pad(chunk_rays, ((0, padding), (0, 0)), mode="edge") else: padding = 0 # After padding the number of chunk_rays is always divisible by # host_count. per_host_rays = chunk_rays.shape[0] // jax.host_count() chunk_rays = chunk_rays[(host_id * per_host_rays):((host_id + 1) * per_host_rays)] chunk_rays = shard(chunk_rays) ret = render_fn(key_0, key_1, model, chunk_rays) rgb.append(unshard(ret[-1][0][0], padding)) disp.append(unshard(ret[-1][1][0], padding)) acc.append(unshard(ret[-1][2][0], padding)) print("") rgb = jnp.concatenate(rgb, axis=0) disp = jnp.concatenate(disp, axis=0) # Normalize disp for visualization for ndc_rays in llff front-facing scenes. if rays.shape[-1] > 6: disp = (disp - disp.min()) / (disp.max() - disp.min()) acc = jnp.concatenate(acc, axis=0) return (rgb.reshape((h, w, -1)), disp.reshape( (h, w, -1)), acc.reshape((h, w, -1)))
def _evaluate_batch(flax_module, batch_stats, batch, metrics_bundle, apply_one_hot_in_loss): """Evaluates metrics on the given batch. Currently we assume each metric_fn in metrics_bundle has the API: metric_fn(logits, targets, weights) and returns an array of shape [batch_size]. We also assume that to compute the aggregate metric, one should sum across all batches, then divide by the total samples seen (calculated by the 'denominator' metric). In this way we currently only support metrics of the 1/N sum f(inputs, targets). Note, the caller is responsible for dividing by metrics['denominator'] when computing the mean of each metric. Args: flax_module: A flax.nn.Module batch_stats: A flax.nn.Collection object tracking batch_stats. batch: A dictionary with keys 'inputs', 'targets', 'weights'. metrics_bundle: A group of metrics to use for evaluation. apply_one_hot_in_loss: Indicates whether or not the targets are one hot encoded. Returns: A dictionary with the same keys as metrics, but mapping to the summed metric across the sharded batch_dim. """ with nn.stateful(batch_stats, mutable=False): logits = flax_module(batch['inputs'], train=False) targets = batch['targets'] if apply_one_hot_in_loss: targets = one_hot(batch['targets'], logits.shape[-1]) # map the dict values (which are functions) to function(targets, logits) weights = batch.get('weights') # Weights might not be defined. eval_batch_size = targets.shape[0] if weights is None: weights = jnp.ones(eval_batch_size) # This psum is required to correctly evaluate with multihost. Only host 0 # will report the metrics, so we must aggregate across all hosts. The psum # will map an array of shape [n_global_devices, batch_size] -> [batch_size] # by summing across the devices dimension. The outer sum then sums across the # batch dim. The result is the we have summed across all samples in the # sharded batch. evaluated_metrics = {} for key in metrics_bundle: per_example_metrics = metrics_bundle[key](logits, targets, weights) evaluated_metrics[key] = jnp.sum( lax.psum(per_example_metrics, axis_name='batch')) return evaluated_metrics
def create_model_optimizer(n_bins): ResNet50 = ResNet.partial(stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock) module = ResNet50.partial(n_bins=n_bins, dtype=jnp.float32) input_shape = (1, training_data.shape[1], 1) with nn.stateful() as init_state: _, initial_params = module.init_by_shape( jax.random.PRNGKey(0), [(input_shape, jnp.float32)] ) model = nn.Model(module, initial_params) optimizer = optim.Adam(learning_rate=learning_rate).create(model) return model, optimizer
def loss_fn(model): """loss function used for training.""" with nn.stateful(state.model_state) as new_model_state: logits = model(batch['image']) loss = cross_entropy_loss(logits, batch['label']) weight_penalty_params = jax.tree_leaves(model.params) weight_decay = 0.0001 weight_l2 = sum( [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1]) weight_penalty = weight_decay * 0.5 * weight_l2 loss = loss + weight_penalty return loss * FLAGS.loss_scaling, (new_model_state, logits)
def test_batch_norm(self): """Test virtual BN recovers BN when the virtual size equals batch size.""" rng = jax.random.PRNGKey(0) batch_size = 10 feature_size = 7 input_shape = (batch_size, 3, 3, feature_size) half_input_shape = (batch_size // 2, 3, 3, feature_size) twos = 2.0 * jnp.ones(half_input_shape) nines = 9.0 * jnp.ones(half_input_shape) x = jnp.concatenate((twos, nines)) bn_model_cls = nn.BatchNorm.partial(momentum=0.9) bn_model, bn_state = _init(bn_model_cls, rng, input_shape) vbn_model_cls = normalization.VirtualBatchNorm.partial( momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') vbn_model, vbn_state = _init(vbn_model_cls, rng, input_shape) with nn.stateful(bn_state) as bn_state: bn_y = bn_model(x) with nn.stateful(bn_state) as bn_state: bn_y = bn_model(x) with nn.stateful(vbn_state) as vbn_state: vbn_y = vbn_model(x) with nn.stateful(vbn_state) as vbn_state: vbn_y = vbn_model(x) # Test that the layer forward passes are the same. np.testing.assert_allclose(bn_y, vbn_y, atol=1e-4) # Test that virtual and regular BN produce the same EMAs. np.testing.assert_allclose( bn_state['/']['mean'], np.squeeze(vbn_state['/']['batch_norm_running_mean'], 0), atol=1e-4) np.testing.assert_allclose( bn_state['/']['var'], np.squeeze(vbn_state['/']['batch_norm_running_var'], 0), atol=1e-4)
def test_module_state(self): class StatefulModule(nn.Module): def apply(self, x, coll=None): state = self.state('state', x.shape, nn.initializers.zeros, collection=coll) state.value += x x = jnp.array([ 1., ]) # no collection should raise an error with self.assertRaises(ValueError): StatefulModule.call({}, x) # pass collection explicitly with nn.Collection().mutate() as state: self.assertEqual(state.as_dict(), {}) StatefulModule.init(random.PRNGKey(0), x, state) self.assertEqual(state.as_dict(), {'/': {'state': x}}) self.assertEqual(state.as_dict(), {'/': {'state': x}}) with state.mutate() as new_state: # assert new_state is a clone of state self.assertEqual(new_state.as_dict(), state.as_dict()) StatefulModule.call({}, x, new_state) self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}}) # use stateful with nn.stateful() as state: self.assertEqual(state.as_dict(), {}) StatefulModule.init(random.PRNGKey(0), x) self.assertEqual(state.as_dict(), {'/': {'state': x}}) with nn.stateful(state) as new_state: # assert new_state is a clone of state self.assertEqual(new_state.as_dict(), state.as_dict()) StatefulModule.call({}, x) self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}}) self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}})
def test_collection_serialization(self): @struct.dataclass class DummyDataClass: x: float @classmethod def initializer(cls, key, shape): del shape, key return cls(x=0.) class StatefulModule(nn.Module): def apply(self): state = self.state('state', (), DummyDataClass.initializer) state.value = state.value.replace(x=state.value.x + 1.) # use stateful with nn.stateful() as state: self.assertEqual(state.as_dict(), {}) StatefulModule.init(random.PRNGKey(0)) self.assertEqual(state.as_dict(), {'/': { 'state': DummyDataClass(x=1.) }}) with nn.stateful(state) as new_state: StatefulModule.call({}) self.assertEqual(new_state.as_dict(), {'/': { 'state': DummyDataClass(x=2.) }}) serialized_state_dict = serialization.to_state_dict(new_state) self.assertEqual(serialized_state_dict, {'/': {'state': {'x': 2.}}}) deserialized_state = serialization.from_state_dict( state, serialized_state_dict) self.assertEqual(state.as_dict(), {'/': { 'state': DummyDataClass(x=1.) }}) self.assertEqual(new_state.as_dict(), deserialized_state.as_dict())
def test_autoencoder_model(self, model_str): """Test forward pass of the autoencoder models.""" model_cls = models.get_model(model_str) model_hps = models.get_model_hparams(model_str) loss = 'sigmoid_binary_cross_entropy' metrics = 'binary_autoencoder_metrics' hps = copy.copy(model_hps) hps.update({'output_shape': OUTPUT_SHAPE[model_str]}) rng = jax.random.PRNGKey(0) model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str])) rng, params_rng = jax.random.split(rng) with nn.stateful() as batch_stats: with nn.stochastic(params_rng): _, flax_module = model.flax_module_def.create(params_rng, xs) # Check that the forward pass works with mutated batch_stats. with nn.stateful(batch_stats) as new_batch_stats: with nn.stochastic(params_rng): outputs = flax_module(xs) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str]))) # If it's a batch norm model check the batch stats changed. if batch_stats.as_dict(): bflat, _ = ravel_pytree(batch_stats) new_bflat, _ = ravel_pytree(new_batch_stats) self.assertFalse(jnp.array_equal(bflat, new_bflat)) # Test batch_norm in inference mode. with nn.stateful(batch_stats, mutable=False): outputs = flax_module(xs, train=False) self.assertEqual( outputs.shape, tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))
def create_model(key, batch_size, image_size, model_dtype): """Model creation.""" input_shape = (batch_size, image_size, image_size, 3) model_type = models.FakeResNet if FLAGS.fake_model else models.ResNet model_def = model_type.partial( num_classes=1000, axis_name='batch' if FLAGS.distributed_batchnorm else None, parameters={ 'dtype': model_dtype, 'conv0_space_to_depth': False }, num_layers=FLAGS.resnet_layers) with nn.stateful() as init_state: _, model = model_def.create_by_shape(key, [(input_shape, model_dtype)]) return model, init_state
def _predict_batch(flax_module, batch_stats, batch, output_activation_fn=None): """Compute predictions for a batch of data. NOTE: We assume that batch_stats has been sync'ed. Args: flax_module: A flax.nn.Module batch_stats: A flax.nn.Collection object tracking batch_stats. batch: A dictionary with keys 'inputs', 'targets', 'weights'. output_activation_fn: An output activation function from jax.nn.functions Returns: An array of shape [batch_size, num_classes] that contains all the logits. """ with nn.stateful(batch_stats, mutable=False): logits = flax_module(batch['inputs'], train=False) if output_activation_fn: return output_activation_fn(logits) return logits
def loss_fn(model): """Loss function used for training.""" # Stateful collection for tracking internal state like activations. with nn.stateful() as batch_stats: with nn.stochastic(dropout_rng): outputs = model(inputs, train=True, cache=None) if isinstance(outputs, dict): logits = outputs.get('logits', None) regression_predictions = outputs.get('regression', None) else: logits = outputs regression_predictions = None mean_loss = 0.0 # Classification loss if classification_targets is not None: classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy( logits, classification_targets, token_weights=classification_weights, example_weights=example_weights) classification_weight_sum = jnp.maximum(classification_weight_sum, epsilon) # Handle case where nothing is masked out in BERT # (Only occurs with very short sequences). mean_classification_loss = classification_loss / classification_weight_sum mean_loss += mean_classification_loss if regression_targets is not None: regression_loss, regression_weight_sum = utils.compute_weighted_mse( regression_predictions, regression_targets, weights=regression_weights) regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon) mean_regression_loss = regression_loss / regression_weight_sum outputs['regression_loss'] = mean_regression_loss # TODO(ddohan): Allow weighting each loss separately. mean_loss += mean_regression_loss return mean_loss, (outputs, batch_stats)