コード例 #1
0
def create_model(module, input_shape, rng):
  """Instanciates the model."""
  model_rng, init_rng = jax.random.split(rng)
  with nn.stochastic(model_rng), nn.stateful() as init_state:
    x = jnp.ones(input_shape, dtype=jnp.float32)
    _, init_params = module.init(init_rng, x)
  model = nn.Model(module, init_params)
  return model, init_params, init_state
コード例 #2
0
 def load(self, i=''):
     directory = os.path.join(self.model_dir, self.name + 'params' + i)
     if not Exists(directory):
         logging.info('still training')
         return 1
     self.opt_params = serialization.load_params(directory)
     if self.flaxd:
         self.model = nn.Model(self.model_def, self.opt_params)
     return 0
コード例 #3
0
 def _create_flax_module():
     device_batch_size = hparams.batch_size // jax.device_count()
     shape = (device_batch_size, ) + tuple(input_shape[1:])
     model_rng, init_rng = jax.random.split(rng)
     with nn.stateful() as init_model_state:
         with nn.stochastic(model_rng):
             _, initial_params = flax_module_def.init_by_shape(
                 init_rng, [(shape, model_input_dtype)])
     flax_module = nn.Model(flax_module_def, initial_params)
     num_trainable_params = model_utils.log_param_shapes(flax_module)
     return flax_module, init_model_state, num_trainable_params
コード例 #4
0
    def setup_transformers(self, hidden_reps_dim):
        """Sets up linear transformers for the auxiliary loss.

    Args:
      hidden_reps_dim: int; Dimensionality of the representational space (size
        of the representations used for computing the domain mapping loss.
    """
        transformer_class = self.get_transformer_module(hidden_reps_dim)
        self.state_transformers = {}
        env_keys = list(map(int, self.dataset.splits.train.keys()))
        # Get list of all possible environment pairs (this includes
        # different permutations).
        env_pairs = list(itertools.permutations(env_keys, 2))

        rng = nn.make_rng()
        for env_pair in env_pairs:
            rng, params_rng = jax.random.split(rng)
            _, init_params = transformer_class.init_by_shape(
                params_rng, [((1, hidden_reps_dim), jnp.float32)])
            self.state_transformers[env_pair] = nn.Model(
                transformer_class, init_params)
コード例 #5
0
def train(config, workdir):
  """Runs a training loop.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

  # Create directories for experimental logs
  tf.io.gfile.makedirs(workdir)
  sample_dir = os.path.join(workdir, "samples")
  tf.io.gfile.makedirs(sample_dir)
  rng = jax.random.PRNGKey(config.seed)
  tb_dir = os.path.join(workdir, "tensorboard")
  tf.io.gfile.makedirs(tb_dir)
  if jax.host_id() == 0:
    writer = tensorboard.SummaryWriter(tb_dir)

  # Initialize model.
  rng, model_rng = jax.random.split(rng)
  model_name = config.model.name
  ncsn_def = mutils.get_model(model_name).partial(config=config)
  rng, run_rng = jax.random.split(rng)
  # Whether the generative model is conditioned on class labels
  class_conditional = "conditional" in config.training.loss.lower()
  with nn.stateful() as init_model_state:
    with nn.stochastic(run_rng):
      input_shape = (jax.local_device_count(), config.data.image_size,
                     config.data.image_size, 3)
      input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]
      if class_conditional:
        input_list.append(input_list[-1])
      _, initial_params = ncsn_def.init_by_shape(
          model_rng, input_list, train=True)
      ncsn = nn.Model(ncsn_def, initial_params)

  optimizer = losses.get_optimizer(config).create(ncsn)

  state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
                       model_state=init_model_state,
                       ema_rate=config.model.ema_rate,
                       params_ema=initial_params,
                       rng=rng)  # pytype: disable=wrong-keyword-args

  del ncsn, init_model_state  # Do not keep a copy of the initial model.

  # Create checkpoints directory and the initial checkpoint
  checkpoint_dir = os.path.join(workdir, "checkpoints")
  ckpt = utils.Checkpoint(
      checkpoint_dir,
      max_to_keep=None)
  ckpt.restore_or_initialize(state)

  # Save intermediate checkpoints to resume training automatically
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta")
  ckpt_meta = utils.Checkpoint(
      checkpoint_meta_dir,
      max_to_keep=1)
  state = ckpt_meta.restore_or_initialize(state)
  initial_step = int(state.step)
  rng = state.rng

  # Build input pipeline.
  rng, ds_rng = jax.random.split(rng)
  train_ds, eval_ds, _ = datasets.get_dataset(ds_rng, config)
  train_iter = iter(train_ds)  # pytype: disable=wrong-arg-types
  eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
  scaler = datasets.get_data_scaler(config)  # data normalizer
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Distribute training.
  optimize_fn = losses.optimization_manager(config)
  if config.training.loss.lower() == "ddpm":
    # Use score matching loss with DDPM-type perturbation.
    ddpm_params = mutils.get_ddpm_params()
    train_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
                                   train=True, optimize_fn=optimize_fn)
    eval_step = functools.partial(losses.ddpm_loss, ddpm_params=ddpm_params,
                                  train=False)
  else:
    # Use score matching loss with NCSN-type perturbation.
    sigmas = mutils.get_sigmas(config)
    # Whether to use a continuous distribution of noise levels
    continuous = "continuous" in config.training.loss.lower()
    train_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        class_conditional=class_conditional,
        continuous=continuous,
        train=True,
        optimize_fn=optimize_fn,
        anneal_power=config.training.anneal_power)
    eval_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        class_conditional=class_conditional,
        continuous=continuous,
        train=False,
        anneal_power=config.training.anneal_power)

  p_train_step = jax.pmap(train_step, axis_name="batch")
  p_eval_step = jax.pmap(eval_step, axis_name="batch")
  state = flax_utils.replicate(state)

  num_train_steps = config.training.n_iters

  logging.info("Starting training loop at step %d.", initial_step)
  rng = jax.random.fold_in(rng, jax.host_id())
  for step in range(initial_step, num_train_steps + 1):
    # `step` is a Python integer. `state.step` is JAX integer on the GPU/TPU
    # devices.

    # Convert data to JAX arrays. Use ._numpy() to avoid copy.
    batch = jax.tree_map(lambda x: scaler(x._numpy()), next(train_iter))  # pylint: disable=protected-access

    rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
    next_rng = jnp.asarray(next_rng)
    loss, state = p_train_step(next_rng, state, batch)
    loss = flax.jax_utils.unreplicate(loss)

    # Quick indication that training is happening.
    logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step)

    if jax.host_id() == 0 and step % 50 == 0:
      logging.info("step: %d, training_loss: %.5e", step, loss)
      writer.scalar("training_loss", loss, step)

    # Save a temporary checkpoint to resume training after pre-emption.
    if step % config.training.snapshot_freq_for_preemption == 0 and jax.host_id(
    ) == 0:
      saved_state = flax_utils.unreplicate(state)
      saved_state = saved_state.replace(rng=rng)
      ckpt_meta.save(saved_state)

    # Report the loss on an evaluation dataset.
    if step % 100 == 0:
      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
      next_rng = jnp.asarray(next_rng)
      eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), next(eval_iter))  # pylint: disable=protected-access
      eval_loss, _ = p_eval_step(next_rng, state, eval_batch)
      eval_loss = flax.jax_utils.unreplicate(eval_loss)
      if jax.host_id() == 0:
        logging.info("step: %d, eval_loss: %.5e", step, eval_loss)
        writer.scalar("eval_loss", eval_loss, step)

    # Save a checkpoint periodically and generate samples.
    if (step +
        1) % config.training.snapshot_freq == 0 or step == num_train_steps:
      # Save the checkpoint.
      if jax.host_id() == 0:
        saved_state = flax_utils.unreplicate(state)
        saved_state = saved_state.replace(rng=rng)
        ckpt.save(saved_state)

      # Generate and save samples
      if config.training.snapshot_sampling:
        rng, sample_rng = jax.random.split(rng)
        init_shape = tuple(train_ds.element_spec["image"].shape)
        samples = sampling.get_samples(sample_rng,
                                       config,
                                       flax_utils.unreplicate(state),
                                       init_shape,
                                       scaler,
                                       inverse_scaler,
                                       class_conditional=class_conditional)
        this_sample_dir = os.path.join(
            sample_dir, "iter_{}_host_{}".format(step, jax.host_id()))
        tf.io.gfile.makedirs(this_sample_dir)

        if config.sampling.final_only:  # Do not save intermediate samples
          sample = samples[-1]
          image_grid = sample.reshape((-1, *sample.shape[2:]))
          nrow = int(np.sqrt(image_grid.shape[0]))
          sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
          with tf.io.gfile.GFile(
              os.path.join(this_sample_dir, "sample.np"), "wb") as fout:
            np.save(fout, sample)

          with tf.io.gfile.GFile(
              os.path.join(this_sample_dir, "sample.png"), "wb") as fout:
            utils.save_image(image_grid, fout, nrow=nrow, padding=2)
        else:  # Save all intermediate samples produced during sampling.
          for i, sample in enumerate(samples):
            image_grid = sample.reshape((-1, *sample.shape[2:]))
            nrow = int(np.sqrt(image_grid.shape[0]))
            sample = np.clip(sample * 255, 0, 255).astype(np.uint8)
            with tf.io.gfile.GFile(
                os.path.join(this_sample_dir, "sample_{}.np".format(i)),
                "wb") as fout:
              np.save(fout, sample)

            with tf.io.gfile.GFile(
                os.path.join(this_sample_dir, "sample_{}.png".format(i)),
                "wb") as fout:
              utils.save_image(image_grid, fout, nrow=nrow, padding=2)
コード例 #6
0
def evaluate(config,
             workdir,
             eval_folder = "eval"):
  """Evaluate trained models.

  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints.
    eval_folder: The subfolder for storing evaluation results. Default to
      "eval".
  """
  # Create eval_dir
  eval_dir = os.path.join(workdir, eval_folder)
  tf.io.gfile.makedirs(eval_dir)

  rng = jax.random.PRNGKey(config.seed + 1)

  # Build input pipeline.
  rng, ds_rng = jax.random.split(rng)
  _, eval_ds, _ = datasets.get_dataset(ds_rng, config, evaluation=True)
  scaler = datasets.get_data_scaler(config)
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Initialize model.
  rng, model_rng = jax.random.split(rng)
  model_name = config.model.name
  ncsn_def = mutils.get_model(model_name).partial(config=config)
  rng, run_rng = jax.random.split(rng)
  class_conditional = "conditional" in config.training.loss.lower()
  with nn.stateful() as init_model_state:
    with nn.stochastic(run_rng):
      input_shape = tuple(eval_ds.element_spec["image"].shape[1:])
      input_list = [(input_shape, jnp.float32), (input_shape[:1], jnp.int32)]
      if class_conditional:
        input_list.append(input_list[-1])
      _, initial_params = ncsn_def.init_by_shape(
          model_rng, input_list, train=True)
      ncsn = nn.Model(ncsn_def, initial_params)

  optimizer = losses.get_optimizer(config).create(ncsn)
  state = mutils.State(step=0, optimizer=optimizer, lr=config.optim.lr,
                       model_state=init_model_state,
                       ema_rate=config.model.ema_rate,
                       params_ema=initial_params,
                       rng=rng)  # pytype: disable=wrong-keyword-args

  del ncsn, init_model_state  # Do not keep a copy of the initial model.

  checkpoint_dir = os.path.join(workdir, "checkpoints")
  if config.training.loss.lower() == "ddpm":
    # Use the score matching loss with DDPM-type perturbation.
    ddpm_params = mutils.get_ddpm_params()
    eval_step = functools.partial(
        losses.ddpm_loss, ddpm_params=ddpm_params, train=False)
  else:
    # Use the score matching loss with NCSN-type perturbation.
    sigmas = mutils.get_sigmas(config)
    continuous = "continuous" in config.training.loss.lower()
    eval_step = functools.partial(
        losses.ncsn_loss,
        sigmas=sigmas,
        continuous=continuous,
        class_conditional=class_conditional,
        train=False,
        anneal_power=config.training.anneal_power)

  p_eval_step = jax.pmap(eval_step, axis_name="batch")

  rng = jax.random.fold_in(rng, jax.host_id())

  # A data class for checkpointing.
  @flax.struct.dataclass
  class EvalMeta:
    ckpt_id: int
    round_id: int
    rng: Any

  # Add one additional round to get the exact number of samples as required.
  num_rounds = config.eval.num_samples // config.eval.batch_size + 1

  eval_meta = EvalMeta(ckpt_id=config.eval.begin_ckpt, round_id=-1, rng=rng)
  eval_meta = checkpoints.restore_checkpoint(
      eval_dir, eval_meta, step=None, prefix=f"meta_{jax.host_id()}_")

  if eval_meta.round_id < num_rounds - 1:
    begin_ckpt = eval_meta.ckpt_id
    begin_round = eval_meta.round_id + 1
  else:
    begin_ckpt = eval_meta.ckpt_id + 1
    begin_round = 0

  rng = eval_meta.rng
  # Use inceptionV3 for images with higher resolution
  inceptionv3 = config.data.image_size >= 256
  inception_model = evaluation.get_inception_model(inceptionv3=inceptionv3)

  logging.info("begin checkpoint: %d", begin_ckpt)
  for ckpt in range(begin_ckpt, config.eval.end_ckpt + 1):
    ckpt_filename = os.path.join(checkpoint_dir, "ckpt-{}.flax".format(ckpt))

    # Wait if the target checkpoint hasn't been produced yet.
    waiting_message_printed = False
    while not tf.io.gfile.exists(ckpt_filename):
      if not waiting_message_printed and jax.host_id() == 0:
        logging.warn("Waiting for the arrival of ckpt-%d.flax", ckpt)
        waiting_message_printed = True
      time.sleep(10)

    # In case the file was just written and not ready to read from yet.
    try:
      state = utils.load_state_dict(ckpt_filename, state)
    except:
      time.sleep(60)
      try:
        state = utils.load_state_dict(ckpt_filename, state)
      except:
        time.sleep(120)
        state = utils.load_state_dict(ckpt_filename, state)

    pstate = flax.jax_utils.replicate(state)
    eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types

    # Compute the loss function on the full evaluation dataset.
    all_losses = []
    for i, batch in enumerate(eval_iter):
      rng, *next_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
      next_rng = jnp.asarray(next_rng)
      eval_batch = jax.tree_map(lambda x: scaler(x._numpy()), batch)  # pylint: disable=protected-access
      eval_loss, _ = p_eval_step(next_rng, pstate, eval_batch)
      eval_loss = flax.jax_utils.unreplicate(eval_loss)
      all_losses.append(eval_loss)
      if (i + 1) % 1000 == 0 and jax.host_id() == 0:
        logging.info("Finished %dth step loss evaluation", i + 1)

    all_losses = jnp.asarray(all_losses)

    state = jax.device_put(state)
    # Sampling and computing statistics for Inception scores, FIDs, and KIDs.
    # Designed to be pre-emption safe. Automatically resumes when interrupted.
    for r in range(begin_round, num_rounds):
      if jax.host_id() == 0:
        logging.info("sampling -- ckpt: %d, round: %d", ckpt, r)
      rng, sample_rng = jax.random.split(rng)
      init_shape = tuple(eval_ds.element_spec["image"].shape)

      this_sample_dir = os.path.join(
          eval_dir, f"ckpt_{ckpt}_host_{jax.host_id()}")
      tf.io.gfile.makedirs(this_sample_dir)
      samples = sampling.get_samples(sample_rng, config, state, init_shape,
                                     scaler, inverse_scaler,
                                     class_conditional=class_conditional)
      samples = samples[-1]
      samples = np.clip(samples * 255., 0, 255).astype(np.uint8)
      samples = samples.reshape(
          (-1, config.data.image_size, config.data.image_size, 3))
      with tf.io.gfile.GFile(
          os.path.join(this_sample_dir, f"samples_{r}.npz"), "wb") as fout:
        io_buffer = io.BytesIO()
        np.savez_compressed(io_buffer, samples=samples)
        fout.write(io_buffer.getvalue())

      gc.collect()
      latents = evaluation.run_inception_distributed(samples, inception_model,
                                                     inceptionv3=inceptionv3)
      gc.collect()
      with tf.io.gfile.GFile(
          os.path.join(this_sample_dir, f"statistics_{r}.npz"), "wb") as fout:
        io_buffer = io.BytesIO()
        np.savez_compressed(
            io_buffer, pool_3=latents["pool_3"], logits=latents["logits"])
        fout.write(io_buffer.getvalue())

      eval_meta = eval_meta.replace(ckpt_id=ckpt, round_id=r, rng=rng)
      # Save an intermediate checkpoint directly if not the last round.
      # Otherwise save eval_meta after computing the Inception scores and FIDs
      if r < num_rounds - 1:
        checkpoints.save_checkpoint(
            eval_dir,
            eval_meta,
            step=ckpt * num_rounds + r,
            keep=1,
            prefix=f"meta_{jax.host_id()}_")

    # Compute inception scores, FIDs and KIDs.
    if jax.host_id() == 0:
      # Load all statistics that have been previously computed and saved.
      all_logits = []
      all_pools = []
      for host in range(jax.host_count()):
        this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}_host_{host}")

        stats = tf.io.gfile.glob(
            os.path.join(this_sample_dir, "statistics_*.npz"))
        wait_message = False
        while len(stats) < num_rounds:
          if not wait_message:
            logging.warn("Waiting for statistics on host %d", host)
            wait_message = True
          stats = tf.io.gfile.glob(
              os.path.join(this_sample_dir, "statistics_*.npz"))
          time.sleep(1)

        for stat_file in stats:
          with tf.io.gfile.GFile(stat_file, "rb") as fin:
            stat = np.load(fin)
            if not inceptionv3:
              all_logits.append(stat["logits"])
            all_pools.append(stat["pool_3"])

      if not inceptionv3:
        all_logits = np.concatenate(
            all_logits, axis=0)[:config.eval.num_samples]
      all_pools = np.concatenate(all_pools, axis=0)[:config.eval.num_samples]

      # Load pre-computed dataset statistics.
      data_stats = evaluation.load_dataset_stats(config)
      data_pools = data_stats["pool_3"]

      if hasattr(config.eval, "num_partitions"):
        # Divide samples into several partitions and compute FID/KID/IS on them.
        assert not inceptionv3
        fids = []
        kids = []
        inception_scores = []
        partition_size = config.eval.num_samples // config.eval.num_partitions
        tf_data_pools = tf.convert_to_tensor(data_pools)
        for i in range(config.eval.num_partitions):
          this_pools = all_pools[i * partition_size:(i + 1) * partition_size]
          this_logits = all_logits[i * partition_size:(i + 1) * partition_size]
          inception_scores.append(
              tfgan.eval.classifier_score_from_logits(this_logits))
          fids.append(
              tfgan.eval.frechet_classifier_distance_from_activations(
                  data_pools, this_pools))
          this_pools = tf.convert_to_tensor(this_pools)
          kids.append(
              tfgan.eval.kernel_classifier_distance_from_activations(
                  tf_data_pools, this_pools).numpy())

        fids = np.asarray(fids)
        inception_scores = np.asarray(inception_scores)
        kids = np.asarray(kids)
        with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_all_{ckpt}.npz"),
                               "wb") as f:
          io_buffer = io.BytesIO()
          np.savez_compressed(
              io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(),
              ISs=inception_scores, fids=fids, kids=kids)
          f.write(io_buffer.getvalue())

      else:
        # Compute FID/KID/IS on all samples together.
        if not inceptionv3:
          inception_score = tfgan.eval.classifier_score_from_logits(all_logits)
        else:
          inception_score = -1

        fid = tfgan.eval.frechet_classifier_distance_from_activations(
            data_pools, all_pools)
        # Hack to get tfgan KID work for eager execution.
        tf_data_pools = tf.convert_to_tensor(data_pools)
        tf_all_pools = tf.convert_to_tensor(all_pools)
        kid = tfgan.eval.kernel_classifier_distance_from_activations(
            tf_data_pools, tf_all_pools).numpy()
        del tf_data_pools, tf_all_pools

        logging.info(
            "ckpt-%d --- loss: %.6e, inception_score: %.6e, FID: %.6e, KID: %.6e",
            ckpt, all_losses.mean(), inception_score, fid, kid)

        with tf.io.gfile.GFile(os.path.join(eval_dir, f"report_{ckpt}.npz"),
                               "wb") as f:
          io_buffer = io.BytesIO()
          np.savez_compressed(
              io_buffer, all_losses=all_losses, mean_loss=all_losses.mean(),
              IS=inception_score, fid=fid, kid=kid)
          f.write(io_buffer.getvalue())
    else:
      # For host_id() != 0.
      # Use file existence to emulate synchronization across hosts.
      if hasattr(config.eval, "num_partitions"):
        assert not inceptionv3
        while not tf.io.gfile.exists(
            os.path.join(eval_dir, f"report_all_{ckpt}.npz")):
          time.sleep(1.)

      else:
        while not tf.io.gfile.exists(
            os.path.join(eval_dir, f"report_{ckpt}.npz")):
          time.sleep(1.)

    # Save eval_meta after computing IS/KID/FID to mark the end of evaluation
    # for this checkpoint.
    checkpoints.save_checkpoint(
        eval_dir,
        eval_meta,
        step=ckpt * num_rounds + r,
        keep=1,
        prefix=f"meta_{jax.host_id()}_")

    begin_round = 0

  # Remove all meta files after finishing evaluation.
  meta_files = tf.io.gfile.glob(
      os.path.join(eval_dir, f"meta_{jax.host_id()}_*"))
  for file in meta_files:
    tf.io.gfile.remove(file)
コード例 #7
0
    def __init__(self,
                 n=2**7 - 1,
                 rng=None,
                 channels=8,
                 loss=loss_gmres,
                 iter_gmres=lambda i: 10,
                 training_iter=500,
                 name='net',
                 model_dir=None,
                 lr=3e-4,
                 k=0.0,
                 n_test=10,
                 beta1=0.9,
                 beta2=0.999,
                 lr_og=3e-3,
                 flaxd=False):
        self.n = n
        self.n_test = n_test
        self.mesh = meshes.Mesh(n)
        self.in_shape = (-1, n, n, 1)
        self.inner_channels = channels

        def itera(i):
            return onp.random.choice([5, 10, 10, 10, 10, 15, 15, 15, 20, 25])

        self.iter_gmres = itera
        self.training_iter = training_iter
        self.name = name
        self.k = k
        self.model_dir = model_dir
        if flaxd:
            self.test_loss = loss_gmresR_flax
        else:
            self.test_loss = loss_gmresR
        self.beta1 = beta1
        self.beta2 = beta2
        if rng is None:
            rng = random.PRNGKey(1)
        if not flaxd:
            self.net_init, self.net_apply = stax.serial(
                UNetBlock(1, (3, 3),
                          stax.serial(
                              UnbiasedConv(self.inner_channels, (3, 3),
                                           padding='SAME'),
                              UnbiasedConv(self.inner_channels, (3, 3),
                                           padding='SAME'),
                              UNetBlock(self.inner_channels, (3, 3),
                                        stax.serial(
                                            UnbiasedConv(self.inner_channels,
                                                         (3, 3),
                                                         padding='SAME'),
                                            UnbiasedConv(self.inner_channels,
                                                         (3, 3),
                                                         padding='SAME'),
                                            UnbiasedConv(self.inner_channels,
                                                         (3, 3),
                                                         padding='SAME'),
                                        ),
                                        strides=(2, 2),
                                        padding='VALID'),
                              UnbiasedConv(self.inner_channels, (3, 3),
                                           padding='SAME'),
                              UnbiasedConv(self.inner_channels, (3, 3),
                                           padding='SAME'),
                          ),
                          strides=(2, 2),
                          padding='VALID'), )
            out_shape, net_params = self.net_init(rng, self.in_shape)
        else:
            #import pdb;pdb.set_trace()
            model_def = flax_cnn.new_CNN.partial(
                inner_channels=self.inner_channels)
            out_shape, net_params = model_def.init_by_shape(
                rng, [(self.in_shape, np.float32)])
            self.model_def = model_def
            self.model = nn.Model(model_def, net_params)
            self.net_apply = lambda param, x: nn.Model(model_def, param)(
                x)  #.reshape(self.in_shape))
        self.out_shape = out_shape
        self.net_params = net_params
        self.loss = loss
        self.lr_og = lr_og
        self.lr = lr
        if not flaxd:
            self.opt_init, self.opt_update, self.get_params = optimizers.adam(
                step_size=lambda i: np.where(i < 100, lr_og, lr),
                b1=beta1,
                b2=beta2)
            self.opt_state = self.opt_init(self.net_params)
            self.step = self.step_notflax

        if flaxd:
            self.step = self.step_flax
            self.optimizer = flax.optim.Adam(learning_rate=lr,
                                             beta1=beta1,
                                             beta2=beta2).create(self.model)
            #self.optimizer = flax.optim.Momentum(
            #    learning_rate= lr, beta=beta1,
            #    weight_decay=0, nesterov=False).create(self.model)
        self.alpha = lambda i: 0.0
        self.flaxd = flaxd
        if flaxd:
            self.preconditioner = self.preconditioner_flaxed
        else:
            self.preconditioner = self.preconditioner_unflaxed