コード例 #1
0
ファイル: losses.py プロジェクト: zrqohbug/google-research
    def loss_fn(model):
        if train:
            with nn.stateful(state.model_state) as new_model_state:
                with nn.stochastic(run_rng):
                    if not class_conditional:
                        scores = model(perturbed_data, labels, train=train)
                    else:
                        scores = model(perturbed_data,
                                       labels,
                                       y=class_labels,
                                       train=train)
        else:
            with nn.stateful(state.model_state, mutable=False):
                with nn.stochastic(run_rng):
                    if not class_conditional:
                        scores = model(perturbed_data, labels, train=train)
                    else:
                        scores = model(perturbed_data,
                                       labels,
                                       y=class_labels,
                                       train=train)

            new_model_state = state.model_state

        scores = scores.reshape((scores.shape[0], -1))
        target = -1 / (used_sigmas**2) * noise
        target = target.reshape((target.shape[0], -1))
        losses = 1 / 2. * ((scores - target)**2).sum(
            axis=-1) * used_sigmas.squeeze()**anneal_power
        loss = jnp.mean(losses)

        if loss_per_sigma:
            return loss, new_model_state, losses
        else:
            return loss, new_model_state
コード例 #2
0
ファイル: losses.py プロジェクト: zrqohbug/google-research
    def loss_fn(model):
        if train:
            with nn.stateful(state.model_state) as new_model_state:
                with nn.stochastic(run_rng):
                    scores = model(perturbed_data, T, train=train)
        else:
            with nn.stateful(state.model_state, mutable=False):
                with nn.stochastic(run_rng):
                    scores = model(perturbed_data, T, train=train)

            new_model_state = state.model_state

        scores = scores.reshape((scores.shape[0], -1))
        target = noise.reshape((noise.shape[0], -1))
        loss = jnp.mean((scores - target)**2)
        return loss, new_model_state
コード例 #3
0
def create_model(key, batch_size, image_size, model_dtype, space_to_depth):
    """Initialize a ResNet-50 model."""
    if space_to_depth:
        input_shape = (batch_size, image_size // 2, image_size // 2, 3 * 2 * 2)
    else:
        input_shape = (batch_size, image_size, image_size, 3)
    model_type = models.FakeResNet if FLAGS.fake_model else models.ResNet
    batchnorm_span = FLAGS.batchnorm_span
    if batchnorm_span is None:
        batchnorm_span = max(batch_size, 64)
    if FLAGS.distributed_batchnorm and (batch_size < batchnorm_span <=
                                        batch_size * jax.device_count()):
        mllogger.event('model_bn_span', batchnorm_span)
        model_def = model_type.partial(num_classes=1000,
                                       axis_name='batch',
                                       axis_index_groups=local_replica_groups(
                                           batchnorm_span // batch_size),
                                       dtype=model_dtype,
                                       conv0_space_to_depth=space_to_depth)
    else:
        mllogger.event('model_bn_span', batch_size)
        model_def = model_type.partial(num_classes=1000,
                                       dtype=model_dtype,
                                       conv0_space_to_depth=space_to_depth)
    with nn.stateful() as init_state:
        _, params = model_def.init_by_shape(key, [(input_shape, model_dtype)])
    model = nn.Model(model_def, params)
    return model, init_state
コード例 #4
0
    def evaluate_batch(self, flax_module, batch_stats, batch):
        """Evaluates cross_entopy on the given batch."""

        # TODO(ankugarg): Augment with other metrics like log-perplexity.
        with nn.stateful(batch_stats, mutable=False):
            logits = flax_module(batch['inputs'],
                                 batch['targets'],
                                 batch.get('inputs_positions'),
                                 batch.get('targets_positions'),
                                 batch.get('inputs_segmentation'),
                                 batch.get('targets_segmentation'),
                                 train=False)

        weights = batch.get('weights')
        targets = batch['targets']
        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])

        # Add log-perplexity metric.
        evaluated_metrics = {}
        for key in self.metrics_bundle:
            per_example_metrics = self.metrics_bundle[key](logits, targets,
                                                           weights)
            evaluated_metrics[key] = jnp.sum(
                lax.psum(per_example_metrics, axis_name='batch'))

        return evaluated_metrics
コード例 #5
0
    def training_cost(self, flax_module, batch_stats, batch, dropout_rng):
        """Return cross entropy loss with (optional) L2 penalty on the weights."""

        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(dropout_rng):
                # inputs/targets positions and segmentations are required
                # when we have packed examples.
                logits = flax_module(batch['inputs'],
                                     batch['targets'],
                                     batch.get('inputs_positions'),
                                     batch.get('targets_positions'),
                                     batch.get('inputs_segmentation'),
                                     batch.get('targets_segmentation'),
                                     train=True)

        weights = batch.get('weights')
        targets = batch['targets']

        if self.dataset_meta_data['apply_one_hot_in_loss']:
            targets = one_hot(batch['targets'], logits.shape[-1])
        # Optionally apply label smoothing.
        if self.hps.get('label_smoothing') is not None:
            targets = model_utils.apply_label_smoothing(
                targets, self.hps.get('label_smoothing'))
        total_loss = self.loss_fn(logits, targets, weights)

        if self.hps.get('l2_decay_factor'):
            l2_loss = model_utils.l2_regularization(
                flax_module.params, self.hps.l2_decay_rank_threshold)
            total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss
        return total_loss, (new_batch_stats)
コード例 #6
0
 def loss_fn(model):
     with nn.stateful(state.model_state) as new_model_state:
         rays = batch["rays"]
         ret = model(key_0, key_1, rays.origins, rays.directions,
                     rays.viewdirs)
     if len(ret) not in (1, 2):
         raise ValueError(
             "ret should contain either 1 set of output (coarse only), or 2 sets"
             "of output (coarse as ret[0] and fine as ret[1]).")
     # The main prediction is always at the end of the ret list.
     rgb, unused_disp, unused_acc = ret[-1]
     loss = ((rgb - batch["pixels"][Ellipsis, :3])**2).mean()
     psnr = utils.compute_psnr(loss)
     stats = [utils.Stats(loss=loss, psnr=psnr)]
     if len(ret) > 1:
         # If there are both coarse and fine predictions, we compuate the loss for
         # the coarse prediction (ret[0]) as well.
         rgb_c, unused_disp_c, unused_acc_c = ret[0]
         loss_c = ((rgb_c - batch["pixels"][Ellipsis, :3])**2).mean()
         psnr_c = utils.compute_psnr(loss_c)
         stats.append(utils.Stats(loss=loss_c, psnr=psnr_c))
     else:
         loss_c = 0.
         psnr_c = 0.
     return loss + loss_c, (new_model_state, stats)
コード例 #7
0
 def impl_loss_fn(model_params):
     with nn.stochastic(rng), nn.stateful(
             state.model_state) as new_model_state:
         logits, stats = module.call(model_params, batch["image"])
     losses = loss_fn if isinstance(loss_fn, (list, tuple)) else [loss_fn]
     loss = sum(l(logits, batch["label"], stats) for l in losses)
     return loss, (logits, new_model_state, stats)
コード例 #8
0
def initialize(flax_module_def, initializer, loss_fn, input_shape,
               output_shape, hps, rng, metrics_logger):
    """Run the given initializer.

  We initialize in 3 phases. First we run the default initializer that is
  specified by the model constructor. Next we apply any rescaling as specified
  by hps.layer_rescale_factors. Finally we run the black box initializer
  provided by the initializer arg (the default is noop).

  Args:
    flax_module_def: An uninitialized flax module definition.
    initializer: An initializer defined in init_lib.
    loss_fn: A loss function.
    input_shape: The input shape of a single data example.
    output_shape: The output shape of a single data example.
    hps: A dictionary specifying the model and initializer hparams.
    rng: An rng key to seed the initialization.
    metrics_logger: Used for black box initializers that have learning curves.

  Returns:
    A tuple (model, batch_stats), where model is the initialized
    flax.nn.Model and batch_stats is the collection used for batch norm.
  """
    model_dtype = utils.dtype_from_str(hps.model_dtype)
    # init_by_shape should either pass in a tuple or a list of tuples.
    # For example, for vision tasks typically input_shape is (image_shape)
    # For seq2seq tasks, shape can be a list of two tuples corresponding to
    # input_sequence_shape for encoder and output_sequence_shape for decoder.
    # TODO(gilmer,ankugarg): Support initializers for list of tuples.
    if isinstance(input_shape, list):  # Typical case for seq2seq models
        input_specs = [((hps.batch_size, *x), model_dtype)
                       for x in input_shape]
    else:  # Typical case for classification models
        input_specs = [((hps.batch_size, *input_shape), model_dtype)]
    params_rng, init_rng, dropout_rng = jax.random.split(rng, num=3)

    with nn.stateful() as batch_stats:
        with nn.stochastic(dropout_rng):
            # Using flax_module_def.create can OOM for larger models, so we must use
            # create by shape here.
            # TODO(gilmer) Link to flax issue when bug reporting process finalizes.
            _, params = flax_module_def.init_by_shape(params_rng,
                                                      input_specs,
                                                      train=False)
    model = nn.Model(flax_module_def, params)

    if hps.get('layer_rescale_factors'):
        model = model_utils.rescale_layers(model, hps.layer_rescale_factors)
    # We don't pass batch_stats to the initializer, the initializer will just
    # run batch_norm in train mode and does not need to maintain the batch_stats.
    # TODO(gilmer): We hardcode here weighted_cross_entropy, but this will need
    # to change for other models. Maybe have meta_loss_inner as an initializer
    # hyper_param?
    # TODO(gilmer): instead of passing in weighted_xent, pass in the model and get
    # the loss from that.
    new_model = initializer(loss_fn, model, hps, input_shape, output_shape,
                            init_rng, metrics_logger)

    return new_model, batch_stats
コード例 #9
0
def eval_step(model, state, batch, prev_metrics, image_format, space_to_depth):
    images, labels = batch
    images = maybe_transpose_images(images, image_format)
    images = normalize_images(images, space_to_depth)
    with nn.stateful(state, mutable=False):
        logits = model(images, train=False)
    metrics = compute_metrics(logits, labels)
    return jax.tree_multimap(jnp.add, prev_metrics, metrics)
コード例 #10
0
def create_model(key, batch_size, image_size, model_dtype):
    input_shape = (batch_size, image_size, image_size, 3)
    module = models.ResNet.partial(num_classes=1000, dtype=model_dtype)
    with nn.stateful() as init_state:
        _, initial_params = module.init_by_shape(key,
                                                 [(input_shape, model_dtype)])
        model = nn.Model(module, initial_params)
    return model, init_state
コード例 #11
0
 def _create_model(key):
     module = flax_module.partial(**model_kwargs)
     with nn.stateful() as init_state:
         with nn.stochastic(key):
             _, initial_params = module.init_by_shape(
                 key, [(input_shape, jnp.float32)])
             model = nn.Model(module, initial_params)
     return model, init_state
コード例 #12
0
def create_model(module, input_shape, rng):
    """Instanciates the model."""
    model_rng, init_rng = jax.random.split(rng)
    with nn.stochastic(model_rng), nn.stateful() as init_state:
        x = jnp.ones(input_shape, dtype=jnp.float32)
        _, init_params = module.init(init_rng, x)
    model = nn.Model(module, init_params)
    return model, init_params, init_state
コード例 #13
0
def eval_step(model, state, batch, prev_metrics):
    images, labels = batch['image'], batch['label']
    if FLAGS.transpose_images:
        images = jnp.transpose(images, [3, 0, 1, 2])
    images = normalize_images(images)
    with nn.stateful(state, mutable=False):
        logits = model(images, train=False)
    metrics = compute_metrics(logits, labels)
    return jax.tree_multimap(jnp.add, prev_metrics, metrics)
コード例 #14
0
 def loss_fn(model):
     """Loss function used for training."""
     with nn.stateful(state) as new_state:
         with nn.stochastic(dropout_rng):
             logits = model(inputs, train=True)
     loss, weight_sum = train_utils.compute_weighted_cross_entropy(
         logits, targets, num_classes=num_classes, weights=None)
     mean_loss = loss / weight_sum
     return mean_loss, (new_state, logits)
コード例 #15
0
    def test_different_eval_batch_size(self):
        """Test virtual BN can use a different batch size for evals."""
        rng = jax.random.PRNGKey(0)
        batch_size = 10
        feature_size = 7
        input_shape = (batch_size, 3, 3, feature_size)
        x = 2.0 * jnp.ones(input_shape)

        vbn_model_cls = normalization.VirtualBatchNorm.partial(
            momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC')
        vbn_model, vbn_state = _init(vbn_model_cls, rng, input_shape)

        with nn.stateful(vbn_state) as vbn_state:
            vbn_model(x)

        with nn.stateful(vbn_state) as vbn_state:
            vbn_model(jnp.ones((13, 3, 3, feature_size)),
                      use_running_average=True)
コード例 #16
0
def render_image(state, rays, render_fn, rng, normalize_disp, chunk=8192):
    """Render all the pixels of an image (in test mode).

  Args:
    state: model_utils.TrainState.
    rays: a `Rays` namedtuple, the rays to be rendered.
    render_fn: function, jit-ed render function.
    rng: jnp.ndarray, random number generator (used in training mode only).
    normalize_disp: bool, if true then normalize `disp` to [0, 1].
    chunk: int, the size of chunks to render sequentially.

  Returns:
    rgb: jnp.ndarray, rendered color image.
    disp: jnp.ndarray, rendered disparity image.
    acc: jnp.ndarray, rendered accumulated weights per pixel.
  """
    height, width = rays[0].shape[:2]
    num_rays = height * width
    rays = datasets.ray_fn(lambda r: r.reshape((num_rays, -1)), rays)

    unused_rng, key_0, key_1 = jax.random.split(rng, 3)
    model = state.optimizer.target
    model_state = state.model_state
    host_id = jax.host_id()
    results = []
    with nn.stateful(model_state, mutable=False):
        for i in range(0, num_rays, chunk):
            # pylint: disable=cell-var-from-loop
            print("  " + "X" * int((i / num_rays) * 78), end="\r")
            chunk_rays = datasets.ray_fn(lambda r: r[i:i + chunk], rays)
            chunk_size = chunk_rays[0].shape[0]
            rays_remaining = chunk_size % jax.device_count()
            rays_per_host = chunk_size // jax.host_count()
            if rays_remaining != 0:
                padding = jax.device_count() - rays_remaining
                chunk_rays = datasets.ray_fn(
                    lambda r: jnp.pad(r, ((0, padding), (0, 0)), mode="edge"),
                    chunk_rays)
            else:
                padding = 0
            # After padding the number of chunk_rays is always divisible by
            # host_count.
            start, stop = host_id * rays_per_host, (host_id +
                                                    1) * rays_per_host
            chunk_rays = datasets.ray_fn(lambda r: shard(r[start:stop]),
                                         chunk_rays)
            chunk_results = render_fn(key_0, key_1, model, chunk_rays)[-1]
            results.append([unshard(x[0], padding) for x in chunk_results])
            # pylint: enable=cell-var-from-loop
        print("")
    rgb, disp, acc = [jnp.concatenate(r, axis=0) for r in zip(*results)]
    # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
    if normalize_disp:
        disp = (disp - disp.min()) / (disp.max() - disp.min())
    return (rgb.reshape((height, width, -1)), disp.reshape(
        (height, width, -1)), acc.reshape((height, width, -1)))
コード例 #17
0
def eval_step(model, state, batch, num_classes, flatten_input=True):
    eval_keys = ['inputs', 'targets']
    (inputs, targets) = [batch.get(k, None) for k in eval_keys]
    if flatten_input:
        inputs = inputs.reshape(inputs.shape[0], -1)
    if jax.tree_leaves(state):
        state = jax.lax.pmean(state, 'batch')
    with nn.stateful(state, mutable=False):
        logits = model(inputs, train=False)
    return compute_metrics(logits, targets, num_classes, weights=None)
コード例 #18
0
ファイル: train_test.py プロジェクト: vballoli/flax
 def test_create_model(self):
     model, state = train.create_model(random.PRNGKey(0), 8, 224,
                                       jnp.float32)
     x = random.normal(random.PRNGKey(1), (8, 224, 224, 3))
     with nn.stateful(state) as new_state:
         y = model(x)
     state = jax.tree_map(onp.shape, state.as_dict())
     new_state = jax.tree_map(onp.shape, new_state.as_dict())
     self.assertEqual(state, new_state)
     self.assertEqual(y.shape, (8, 1000))
コード例 #19
0
ファイル: nn_test.py プロジェクト: zhang-yd15/flax
 def test_batch_norm(self):
   rng = random.PRNGKey(0)
   key1, key2 = random.split(rng)
   x = random.normal(key1, (4, 3, 2))
   model_cls = nn.BatchNorm.partial(momentum=0.9)
   with nn.stateful() as state_0:
     y, initial_params = model_cls.init(key2, x)
     model = nn.Model(model_cls, initial_params)
   mean = y.mean((0, 1))
   var = y.var((0, 1))
   onp.testing.assert_allclose(mean, onp.array([0., 0.]), atol=1e-4)
   onp.testing.assert_allclose(var, onp.array([1., 1.]), rtol=1e-4)
   with nn.stateful(state_0) as state:
     y = model(x)
   ema = state['/']
   onp.testing.assert_allclose(
       ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4)
   onp.testing.assert_allclose(
       ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4)
コード例 #20
0
def render_image(state, data, render_fn, rng, chunk=8192):
    """Render all the pixels of an image (in test mode).

  Args:
    state: model_utils.TrainState.
    data: dict, test example.
    render_fn: function, jit-ed render function.
    rng: jnp.ndarray, random number generator (used in training mode only).
    chunk: int, the size of chunks to render sequentially.

  Returns:
    rgb: jnp.ndarray, rendered color image.
    disp: jnp.ndarray, rendered disparity image.
    acc: jnp.ndarray, rendered accumulated weights per pixel.
  """
    rays = data["rays"]
    h, w = rays.shape[:2]
    rays = rays.reshape((h * w, -1))
    unused_rng, key_0, key_1 = jax.random.split(rng, 3)
    model = state.optimizer.target
    model_state = state.model_state
    host_id = jax.host_id()
    rgb = []
    disp = []
    acc = []
    with nn.stateful(model_state, mutable=False):
        for i in range(0, rays.shape[0], chunk):
            print("  " + "X" * int((i / rays.shape[0]) * 78), end="\r")
            chunk_rays = rays[i:i + chunk]
            remainder = chunk_rays.shape[0] % jax.device_count()
            if remainder != 0:
                padding = jax.device_count() - remainder
                chunk_rays = jnp.pad(chunk_rays, ((0, padding), (0, 0)),
                                     mode="edge")
            else:
                padding = 0
            # After padding the number of chunk_rays is always divisible by
            # host_count.
            per_host_rays = chunk_rays.shape[0] // jax.host_count()
            chunk_rays = chunk_rays[(host_id * per_host_rays):((host_id + 1) *
                                                               per_host_rays)]
            chunk_rays = shard(chunk_rays)
            ret = render_fn(key_0, key_1, model, chunk_rays)
            rgb.append(unshard(ret[-1][0][0], padding))
            disp.append(unshard(ret[-1][1][0], padding))
            acc.append(unshard(ret[-1][2][0], padding))
        print("")
    rgb = jnp.concatenate(rgb, axis=0)
    disp = jnp.concatenate(disp, axis=0)
    # Normalize disp for visualization for ndc_rays in llff front-facing scenes.
    if rays.shape[-1] > 6:
        disp = (disp - disp.min()) / (disp.max() - disp.min())
    acc = jnp.concatenate(acc, axis=0)
    return (rgb.reshape((h, w, -1)), disp.reshape(
        (h, w, -1)), acc.reshape((h, w, -1)))
コード例 #21
0
ファイル: base_model.py プロジェクト: cshallue/init2winit
def _evaluate_batch(flax_module, batch_stats, batch, metrics_bundle,
                    apply_one_hot_in_loss):
    """Evaluates metrics on the given batch.

  Currently we assume each metric_fn in metrics_bundle has the API:
    metric_fn(logits, targets, weights)
  and returns an array of shape [batch_size]. We also assume that to compute
  the aggregate metric, one should sum across all batches, then divide by the
  total samples seen (calculated by the 'denominator' metric). In this way we
  currently only support metrics of the 1/N sum f(inputs, targets). Note, the
  caller is responsible for dividing by metrics['denominator'] when computing
  the mean of each metric.

  Args:
    flax_module: A flax.nn.Module
    batch_stats: A flax.nn.Collection object tracking batch_stats.
    batch: A dictionary with keys 'inputs', 'targets', 'weights'.
    metrics_bundle: A group of metrics to use for evaluation.
    apply_one_hot_in_loss: Indicates whether or not the targets are one hot
      encoded.

  Returns:
    A dictionary with the same keys as metrics, but mapping to the summed metric
    across the sharded batch_dim.

  """
    with nn.stateful(batch_stats, mutable=False):
        logits = flax_module(batch['inputs'], train=False)
    targets = batch['targets']

    if apply_one_hot_in_loss:
        targets = one_hot(batch['targets'], logits.shape[-1])

    # map the dict values (which are functions) to function(targets, logits)
    weights = batch.get('weights')  # Weights might not be defined.
    eval_batch_size = targets.shape[0]
    if weights is None:
        weights = jnp.ones(eval_batch_size)

    # This psum is required to correctly evaluate with multihost. Only host 0
    # will report the metrics, so we must aggregate across all hosts. The psum
    # will map an array of shape [n_global_devices, batch_size] -> [batch_size]
    # by summing across the devices dimension. The outer sum then sums across the
    # batch dim. The result is the we have summed across all samples in the
    # sharded batch.

    evaluated_metrics = {}
    for key in metrics_bundle:
        per_example_metrics = metrics_bundle[key](logits, targets, weights)
        evaluated_metrics[key] = jnp.sum(
            lax.psum(per_example_metrics, axis_name='batch'))

    return evaluated_metrics
コード例 #22
0
 def create_model_optimizer(n_bins):
     ResNet50 = ResNet.partial(stage_sizes=[3, 4, 6, 3],
                   block_cls=ResNetBlock)
     module = ResNet50.partial(n_bins=n_bins, dtype=jnp.float32)
     input_shape = (1, training_data.shape[1], 1)
     with nn.stateful() as init_state:
         _, initial_params = module.init_by_shape(
             jax.random.PRNGKey(0), [(input_shape, jnp.float32)]
       )
         model = nn.Model(module, initial_params)
     optimizer = optim.Adam(learning_rate=learning_rate).create(model)
     return model, optimizer
コード例 #23
0
 def loss_fn(model):
     """loss function used for training."""
     with nn.stateful(state.model_state) as new_model_state:
         logits = model(batch['image'])
     loss = cross_entropy_loss(logits, batch['label'])
     weight_penalty_params = jax.tree_leaves(model.params)
     weight_decay = 0.0001
     weight_l2 = sum(
         [jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
     weight_penalty = weight_decay * 0.5 * weight_l2
     loss = loss + weight_penalty
     return loss * FLAGS.loss_scaling, (new_model_state, logits)
コード例 #24
0
    def test_batch_norm(self):
        """Test virtual BN recovers BN when the virtual size equals batch size."""
        rng = jax.random.PRNGKey(0)
        batch_size = 10
        feature_size = 7
        input_shape = (batch_size, 3, 3, feature_size)
        half_input_shape = (batch_size // 2, 3, 3, feature_size)
        twos = 2.0 * jnp.ones(half_input_shape)
        nines = 9.0 * jnp.ones(half_input_shape)
        x = jnp.concatenate((twos, nines))

        bn_model_cls = nn.BatchNorm.partial(momentum=0.9)
        bn_model, bn_state = _init(bn_model_cls, rng, input_shape)

        vbn_model_cls = normalization.VirtualBatchNorm.partial(
            momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC')
        vbn_model, vbn_state = _init(vbn_model_cls, rng, input_shape)

        with nn.stateful(bn_state) as bn_state:
            bn_y = bn_model(x)
        with nn.stateful(bn_state) as bn_state:
            bn_y = bn_model(x)

        with nn.stateful(vbn_state) as vbn_state:
            vbn_y = vbn_model(x)
        with nn.stateful(vbn_state) as vbn_state:
            vbn_y = vbn_model(x)

        # Test that the layer forward passes are the same.
        np.testing.assert_allclose(bn_y, vbn_y, atol=1e-4)

        # Test that virtual and regular BN produce the same EMAs.
        np.testing.assert_allclose(
            bn_state['/']['mean'],
            np.squeeze(vbn_state['/']['batch_norm_running_mean'], 0),
            atol=1e-4)
        np.testing.assert_allclose(
            bn_state['/']['var'],
            np.squeeze(vbn_state['/']['batch_norm_running_var'], 0),
            atol=1e-4)
コード例 #25
0
ファイル: nn_test.py プロジェクト: wdevazelhes/flax
    def test_module_state(self):
        class StatefulModule(nn.Module):
            def apply(self, x, coll=None):
                state = self.state('state',
                                   x.shape,
                                   nn.initializers.zeros,
                                   collection=coll)
                state.value += x

        x = jnp.array([
            1.,
        ])
        # no collection should raise an error
        with self.assertRaises(ValueError):
            StatefulModule.call({}, x)

        # pass collection explicitly
        with nn.Collection().mutate() as state:
            self.assertEqual(state.as_dict(), {})
            StatefulModule.init(random.PRNGKey(0), x, state)
            self.assertEqual(state.as_dict(), {'/': {'state': x}})
        self.assertEqual(state.as_dict(), {'/': {'state': x}})
        with state.mutate() as new_state:
            # assert new_state is a clone of state
            self.assertEqual(new_state.as_dict(), state.as_dict())
            StatefulModule.call({}, x, new_state)
        self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}})

        # use stateful
        with nn.stateful() as state:
            self.assertEqual(state.as_dict(), {})
            StatefulModule.init(random.PRNGKey(0), x)
        self.assertEqual(state.as_dict(), {'/': {'state': x}})
        with nn.stateful(state) as new_state:
            # assert new_state is a clone of state
            self.assertEqual(new_state.as_dict(), state.as_dict())
            StatefulModule.call({}, x)
            self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}})
        self.assertEqual(new_state.as_dict(), {'/': {'state': x + x}})
コード例 #26
0
ファイル: serialization_test.py プロジェクト: shafiahmed/flax
    def test_collection_serialization(self):
        @struct.dataclass
        class DummyDataClass:
            x: float

            @classmethod
            def initializer(cls, key, shape):
                del shape, key
                return cls(x=0.)

        class StatefulModule(nn.Module):
            def apply(self):
                state = self.state('state', (), DummyDataClass.initializer)
                state.value = state.value.replace(x=state.value.x + 1.)

        # use stateful
        with nn.stateful() as state:
            self.assertEqual(state.as_dict(), {})
            StatefulModule.init(random.PRNGKey(0))
        self.assertEqual(state.as_dict(),
                         {'/': {
                             'state': DummyDataClass(x=1.)
                         }})
        with nn.stateful(state) as new_state:
            StatefulModule.call({})
        self.assertEqual(new_state.as_dict(),
                         {'/': {
                             'state': DummyDataClass(x=2.)
                         }})
        serialized_state_dict = serialization.to_state_dict(new_state)
        self.assertEqual(serialized_state_dict, {'/': {'state': {'x': 2.}}})
        deserialized_state = serialization.from_state_dict(
            state, serialized_state_dict)
        self.assertEqual(state.as_dict(),
                         {'/': {
                             'state': DummyDataClass(x=1.)
                         }})
        self.assertEqual(new_state.as_dict(), deserialized_state.as_dict())
コード例 #27
0
ファイル: test_models.py プロジェクト: cshallue/init2winit
    def test_autoencoder_model(self, model_str):
        """Test forward pass of the autoencoder models."""

        model_cls = models.get_model(model_str)
        model_hps = models.get_model_hparams(model_str)
        loss = 'sigmoid_binary_cross_entropy'
        metrics = 'binary_autoencoder_metrics'
        hps = copy.copy(model_hps)
        hps.update({'output_shape': OUTPUT_SHAPE[model_str]})
        rng = jax.random.PRNGKey(0)
        model = model_cls(hps, {}, loss, metrics)
        xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str]))
        rng, params_rng = jax.random.split(rng)
        with nn.stateful() as batch_stats:
            with nn.stochastic(params_rng):
                _, flax_module = model.flax_module_def.create(params_rng, xs)

        # Check that the forward pass works with mutated batch_stats.
        with nn.stateful(batch_stats) as new_batch_stats:
            with nn.stochastic(params_rng):
                outputs = flax_module(xs)
                self.assertEqual(
                    outputs.shape,
                    tuple([INPUT_SHAPE[model_str][0]] +
                          list(OUTPUT_SHAPE[model_str])))

        # If it's a batch norm model check the batch stats changed.
        if batch_stats.as_dict():
            bflat, _ = ravel_pytree(batch_stats)
            new_bflat, _ = ravel_pytree(new_batch_stats)
            self.assertFalse(jnp.array_equal(bflat, new_bflat))

        # Test batch_norm in inference mode.
        with nn.stateful(batch_stats, mutable=False):
            outputs = flax_module(xs, train=False)
        self.assertEqual(
            outputs.shape,
            tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])))
コード例 #28
0
def create_model(key, batch_size, image_size, model_dtype):
    """Model creation."""
    input_shape = (batch_size, image_size, image_size, 3)
    model_type = models.FakeResNet if FLAGS.fake_model else models.ResNet
    model_def = model_type.partial(
        num_classes=1000,
        axis_name='batch' if FLAGS.distributed_batchnorm else None,
        parameters={
            'dtype': model_dtype,
            'conv0_space_to_depth': False
        },
        num_layers=FLAGS.resnet_layers)
    with nn.stateful() as init_state:
        _, model = model_def.create_by_shape(key, [(input_shape, model_dtype)])
    return model, init_state
コード例 #29
0
ファイル: base_model.py プロジェクト: cshallue/init2winit
def _predict_batch(flax_module, batch_stats, batch, output_activation_fn=None):
    """Compute predictions for a batch of data.

  NOTE: We assume that batch_stats has been sync'ed.

  Args:
    flax_module: A flax.nn.Module
    batch_stats: A flax.nn.Collection object tracking batch_stats.
    batch: A dictionary with keys 'inputs', 'targets', 'weights'.
    output_activation_fn: An output activation function from jax.nn.functions

  Returns:
    An array of shape [batch_size, num_classes] that contains all the logits.
  """
    with nn.stateful(batch_stats, mutable=False):
        logits = flax_module(batch['inputs'], train=False)
    if output_activation_fn:
        return output_activation_fn(logits)
    return logits
コード例 #30
0
    def loss_fn(model):
        """Loss function used for training."""
        # Stateful collection for tracking internal state like activations.
        with nn.stateful() as batch_stats:
            with nn.stochastic(dropout_rng):
                outputs = model(inputs, train=True, cache=None)

            if isinstance(outputs, dict):
                logits = outputs.get('logits', None)
                regression_predictions = outputs.get('regression', None)
            else:
                logits = outputs
                regression_predictions = None

        mean_loss = 0.0

        # Classification loss
        if classification_targets is not None:
            classification_loss, classification_weight_sum = utils.compute_weighted_cross_entropy(
                logits,
                classification_targets,
                token_weights=classification_weights,
                example_weights=example_weights)
            classification_weight_sum = jnp.maximum(classification_weight_sum,
                                                    epsilon)
            # Handle case where nothing is masked out in BERT
            # (Only occurs with very short sequences).
            mean_classification_loss = classification_loss / classification_weight_sum
            mean_loss += mean_classification_loss

        if regression_targets is not None:
            regression_loss, regression_weight_sum = utils.compute_weighted_mse(
                regression_predictions,
                regression_targets,
                weights=regression_weights)
            regression_weight_sum = jnp.maximum(regression_weight_sum, epsilon)
            mean_regression_loss = regression_loss / regression_weight_sum
            outputs['regression_loss'] = mean_regression_loss

            # TODO(ddohan): Allow weighting each loss separately.
            mean_loss += mean_regression_loss

        return mean_loss, (outputs, batch_stats)