コード例 #1
0
 def test_ignores_incomplete_checkpoint(self):
     base_dir = tempfile.mkdtemp()
     state = TrainState(step=1)
     ckpt = checkpoint.Checkpoint(base_dir)
     # Initializes.
     state = ckpt.restore_or_initialize(state)
     state = TrainState(step=0)
     # Restores step=1.
     state = ckpt.restore_or_initialize(state)
     self.assertEqual(state.step, 1)
     state = TrainState(step=2)
     # Failed save : step=2 is stored, but TensorFlow checkpoint fails.
     ckpt.tf_checkpoint_manager.save = None
     with self.assertRaisesRegex(TypeError,
                                 r"'NoneType' object is not callable"):
         ckpt.save(state)
     files = os.listdir(base_dir)
     self.assertIn("ckpt-2.flax", files)
     self.assertNotIn("ckpt-2.index", files)
     ckpt = checkpoint.Checkpoint(base_dir)
     state = TrainState(step=0)
     # Restores step=1.
     state = ckpt.restore_or_initialize(state)
     self.assertEqual(state.step, 1)
     # Stores step=2.
     state = TrainState(step=2)
     path = ckpt.save(state)
     self.assertEqual(_checkpoint_number(path), 2)
     files = os.listdir(base_dir)
     self.assertIn("ckpt-2.flax", files)
     self.assertIn("ckpt-2.index", files)
     state = TrainState(step=0)
     # Restores step=2.
     state = ckpt.restore_or_initialize(state)
     self.assertEqual(state.step, 2)
コード例 #2
0
 def test_overwrite(self):
     base_dir = tempfile.mkdtemp()
     tf_step = tf.Variable(1)
     state = TrainState(step=1)
     ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step))
     # Initialize step=1.
     state = ckpt.restore_or_initialize(state)
     self.assertEqual(state.step, 1)
     self.assertEqual(tf_step.numpy(), 1)
     checkpoint_info = checkpoint.CheckpointInfo.from_path(
         ckpt.current_checkpoint)
     # Stores steps 2, 3, 4, 5
     for _ in range(4):
         tf_step.assign_add(1)
         state = state.replace(step=state.step + 1)
         ckpt.save(state)
     latest_checkpoint = str(checkpoint_info._replace(number=5))
     self.assertEqual(ckpt.current_checkpoint, latest_checkpoint)
     self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
     # Restores at step=1
     ckpt = checkpoint.Checkpoint(base_dir, dict(step=tf_step))
     state = ckpt.restore(state, checkpoint=str(checkpoint_info))
     self.assertEqual(state.step, 1)
     self.assertEqual(tf_step.numpy(), 1)
     self.assertNotEqual(ckpt.current_checkpoint, ckpt.latest_checkpoint)
     self.assertEqual(ckpt.current_checkpoint, str(checkpoint_info))
     self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
     # Overwrites step=2, deletes 3, 4, 5.
     tf_step.assign_add(1)
     state = state.replace(step=state.step + 1)
     ckpt.save(state)
     latest_checkpoint = str(checkpoint_info._replace(number=2))
     self.assertEqual(ckpt.current_checkpoint, latest_checkpoint)
     self.assertEqual(ckpt.latest_checkpoint, latest_checkpoint)
コード例 #3
0
 def test_fails_if_save_counter_mismatch(self):
   base_dir = tempfile.mkdtemp()
   ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1)
   state = TrainState(step=1)
   state = ckpt.restore_or_initialize(state)
   ckpt.save(state)
   ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1)
   state = TrainState(step=2)
   with self.assertRaisesRegexp(RuntimeError, r"^Expected.*to match"):
     ckpt.save(state)
コード例 #4
0
 def test_restore_flax_alone(self):
     base_dir = tempfile.mkdtemp()
     ds_iter = iter(_make_dataset())
     ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
     state = TrainState(step=1)
     # Initializes.
     state = ckpt.restore_or_initialize(state)
     state = TrainState(step=0)
     ckpt = checkpoint.Checkpoint(base_dir)
     # Restores step=1.
     state = ckpt.restore_or_initialize(state)
     self.assertEqual(state.step, 1)
コード例 #5
0
    def test_restore_dict(self):
        base_dir = tempfile.mkdtemp()
        ds_iter = iter(_make_dataset())
        ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
        with self.assertRaisesRegex(FileNotFoundError,
                                    r"No checkpoint found at"):
            ckpt.restore_dict()
        with self.assertRaisesRegex(FileNotFoundError,
                                    r"Checkpoint invalid does not exist"):
            ckpt.restore_dict(checkpoint="invalid")

        state = TrainState(step=1)
        ckpt.save(state)

        state_dict = ckpt.restore_dict()
        self.assertEqual(state_dict, dict(step=1))
        first_checkpoint = ckpt.latest_checkpoint

        new_state = TrainState(step=2)
        ckpt.save(new_state)

        self.assertEqual(ckpt.restore_dict(checkpoint=first_checkpoint),
                         dict(step=1))
        self.assertEqual(ckpt.restore_dict(), dict(step=2))
        self.assertEqual(ckpt.restore_dict(checkpoint=ckpt.latest_checkpoint),
                         dict(step=2))
コード例 #6
0
 def test_restores_tf_state(self):
     base_dir = tempfile.mkdtemp()
     ds_iter = iter(_make_dataset())
     ckpt = checkpoint.Checkpoint(base_dir, dict(ds_iter=ds_iter))
     features0 = next(ds_iter)  # Advance iterator by one.
     del features0
     state = TrainState(step=1)
     # Initialize at features1.
     state = ckpt.restore_or_initialize(state)
     features1 = next(ds_iter)
     features2 = next(ds_iter)
     self.assertNotAllEqual(features1["x"], features2["x"])
     self.assertNotAllEqual(features1["y"], features2["y"])
     # Restore at features1.
     state = ckpt.restore_or_initialize(state)
     features1_restored = next(ds_iter)
     self.assertAllEqual(features1["x"], features1_restored["x"])
     self.assertAllEqual(features1["y"], features1_restored["y"])
     # Save at features2.
     path = ckpt.save(state)
     self.assertEqual(_checkpoint_number(path), 2)
     features2 = next(ds_iter)
     features3 = next(ds_iter)
     self.assertNotAllEqual(features2["x"], features3["x"])
     self.assertNotAllEqual(features2["y"], features3["y"])
     # Restore at features2.
     state = ckpt.restore_or_initialize(state)
     features2_restored = next(ds_iter)
     self.assertAllEqual(features2["x"], features2_restored["x"])
     self.assertAllEqual(features2["y"], features2_restored["y"])
     # Restore at features2 as dictionary.
     state = ckpt.restore_dict()
     features2_restored = next(ds_iter)
     self.assertAllEqual(features2["x"], features2_restored["x"])
     self.assertAllEqual(features2["y"], features2_restored["y"])
コード例 #7
0
 def test_restores_flax_state(self):
     base_dir = tempfile.mkdtemp()
     state = TrainState(step=1)
     ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=2)
     # Initializes.
     state = ckpt.restore_or_initialize(state)
     state = TrainState(step=0)
     # Restores step=1.
     state = ckpt.restore_or_initialize(state)
     self.assertEqual(state.step, 1)
     state = TrainState(step=2)
     # Stores step=2.
     path = ckpt.save(state)
     self.assertEqual(_checkpoint_number(path), 2)
     state = TrainState(step=0)
     # Restores step=2.
     state = ckpt.restore(state)
     self.assertEqual(state.step, 2)
     state = TrainState(step=3)
     # Stores step=3
     path2 = ckpt.save(state)
     self.assertEqual(_checkpoint_number(path2), 3)
     state = TrainState(step=0)
     # Restores step=2.
     state = ckpt.restore(state, path)
     self.assertEqual(state.step, 2)
コード例 #8
0
 def test_initialize_mkdir(self):
   base_dir = os.path.join(tempfile.mkdtemp(), "test")
   state = TrainState(step=1)
   ckpt = checkpoint.Checkpoint(base_dir)
   self.assertIsNone(ckpt.latest_checkpoint)
   self.assertFalse(os.path.isdir(base_dir))
   state = ckpt.restore_or_initialize(state)
   self.assertIsNotNone(ckpt.latest_checkpoint)
   self.assertTrue(os.path.isdir(base_dir))
コード例 #9
0
 def test_fails_when_restoring_superset(self):
     base_dir = tempfile.mkdtemp()
     ckpt = checkpoint.Checkpoint(base_dir)
     state = TrainState(step=0)
     # Initialixes with TrainState.
     state = ckpt.restore_or_initialize(state)
     state = TrainStateExtended(step=1, name="test")
     # Restores with TrainStateExtended.
     with self.assertRaisesRegex(ValueError, r"^Missing field"):
         state = ckpt.restore_or_initialize(state)
コード例 #10
0
 def test_load_state_dict(self):
   base_dir = tempfile.mkdtemp()
   state = TrainState(step=1)
   ckpt = checkpoint.Checkpoint(base_dir)
   # Initializes.
   state = ckpt.restore_or_initialize(state)
   # Load via load_state_dict().
   flax_dict = checkpoint.load_state_dict(base_dir)
   self.assertEqual(flax_dict, dict(step=1))
   with self.assertRaisesRegexp(FileNotFoundError, r"^No checkpoint found"):
     checkpoint.load_state_dict(tempfile.mkdtemp())
コード例 #11
0
 def test_max_to_keep(self):
     base_dir = tempfile.mkdtemp()
     state = TrainState(step=1)
     ckpt = checkpoint.Checkpoint(base_dir, max_to_keep=1)
     state = ckpt.restore_or_initialize(state)
     files1 = os.listdir(base_dir)
     state = TrainState(step=2)
     path = ckpt.save(state)
     self.assertEqual(_checkpoint_number(path), 2)
     files2 = os.listdir(base_dir)
     self.assertEqual(len(files1), len(files2))
     self.assertNotEqual(files1, files2)
コード例 #12
0
def main(_):
    jax.config.update('jax_enable_x64', True)

    config: config_dict.ConfigDict = _CONFIG.value
    logging.info(config)

    key = jax.random.PRNGKey(config.seed)
    key, psi_key, phi_key = jax.random.split(key, 3)
    Psi = jax.random.normal(psi_key, (config.S, config.T), dtype=jnp.float64)
    Phi = jax.random.normal(phi_key, (config.S, config.d), dtype=jnp.float64)
    # Wrap feature matrix in np array to allow for indexing.
    Phi = np.array(Phi)

    chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value)

    initial_step = 0
    initial_step, Phi = chkpt_manager.restore_or_initialize(
        (initial_step, Phi))

    optimal_subspace = compute_optimal_subspace(Psi, config.d)

    workdir = epath.Path(_WORKDIR.value)
    workdir.mkdir(exist_ok=True)

    Phis = train(workdir=workdir,
                 initial_step=initial_step,
                 chkpt_manager=chkpt_manager,
                 Phi=Phi,
                 Psi=Psi,
                 optimal_subspace=optimal_subspace,
                 num_epochs=config.num_epochs,
                 learning_rate=config.lr,
                 key=key,
                 method=config.method,
                 lissa_kappa=config.kappa,
                 optimizer=config.optimizer,
                 covariance_batch_size=config.covariance_batch_size,
                 main_batch_size=config.main_batch_size,
                 weight_batch_size=config.weight_batch_size,
                 estimate_feature_norm=config.estimate_feature_norm)

    with (workdir / 'phis.pkl').open('wb') as fout:
        pickle.dump(Phis, fout, protocol=4)
コード例 #13
0
ファイル: experiment.py プロジェクト: afcarl/google-research
    def evaluate(self, workdir, dir_name='eval', ckpt_name=None):
        """Perform one evaluation."""
        checkpoint_dir = os.path.join(workdir, 'checkpoints-0')
        ckpt = checkpoint.Checkpoint(checkpoint_dir)
        state_dict = ckpt.restore_dict(os.path.join(checkpoint_dir, ckpt_name))
        ema_params = flax.core.FrozenDict(state_dict['ema_params'])
        step = int(state_dict['step'])

        # Distribute training.
        ema_params = flax_utils.replicate(ema_params)

        eval_logdir = os.path.join(workdir, dir_name)
        tf.io.gfile.makedirs(eval_logdir)
        writer = metric_writers.create_default_writer(
            eval_logdir, just_logging=jax.process_index() > 0)

        outputs = self._eval_epoch(params=ema_params)
        outputs = flax_utils.unreplicate(outputs)
        scalars, images = outputs['scalars'], outputs['images']
        writer.write_scalars(step, scalars)
        writer.write_images(step, images)
コード例 #14
0
def evaluate(base_dir, config, *, train_state):
    """Eval function."""
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'eval'))

    writer = create_default_writer()

    key = jax.random.PRNGKey(config.eval.seed)
    model_init_key, ds_key = jax.random.split(key)

    linear_module = LinearModule(config.eval.num_tasks)
    params = linear_module.init(model_init_key,
                                jnp.zeros((config.encoder.embedding_dim, )))
    lr = optax.cosine_decay_schedule(config.eval.learning_rate,
                                     config.num_eval_steps)
    optim = optax.adam(lr)

    ds = dataset.get_dataset(config, ds_key, num_tasks=config.eval.num_tasks)
    ds_iter = iter(ds)

    state = TrainState.create(apply_fn=linear_module.apply,
                              params=params,
                              tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_eval_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = EvalMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_eval_steps)):
            with jax.profiler.StepTraceAnnotation('eval', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = evaluate_step(train_state, state, metrics,
                                               states, targets)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = EvalMetrics.empty()

            for hook in hooks:
                hook(step)

        # Finally, evaluate on the true(ish) test aux task matrix.
        states, targets = dataset.EvalDataset(config, ds_key).get_batch()

        @jax.jit
        def loss_fn():
            outputs = train_state.apply_fn(train_state.params, states)
            phis = outputs.phi
            predictions = jax.vmap(state.apply_fn,
                                   in_axes=(None, 0))(state.params, phis)
            return jnp.mean(optax.l2_loss(predictions, targets))

        test_loss = loss_fn()
        writer.write_scalars(config.num_eval_steps + 1,
                             {'test_loss': test_loss})
コード例 #15
0
 def test_fails_if_not_registered(self):
     base_dir = tempfile.mkdtemp()
     not_state = NotTrainState()
     ckpt = checkpoint.Checkpoint(base_dir)
     with self.assertRaisesRegex(TypeError, r"serialize"):
         ckpt.restore_or_initialize(not_state)
コード例 #16
0
 def test_checkpoint_name(self):
     base_dir = tempfile.mkdtemp()
     state = TrainState(step=1)
     ckpt = checkpoint.Checkpoint(base_dir, checkpoint_name="test")
     path = ckpt.save(state)
     self.assertIn("test", path)
コード例 #17
0
def train_and_evaluate(config, workdir):
    """Execute model training and evaluation loop.

  Args:
    config: Hyperparameter configuration for training and evaluation.
    workdir: Directory where the tensorboard summaries are written to.

  Returns:
    The train state (which includes the `.params`).
  """
    # Seed for reproducibility.
    rng = jax.random.PRNGKey(config.rng_seed)

    # Set up logging.
    summary_writer = metric_writers.create_default_writer(workdir)
    summary_writer.write_hparams(dict(config))

    # Get datasets.
    rng, dataset_rng = jax.random.split(rng)
    dataset = input_pipeline.get_dataset(config, dataset_rng)
    graph, labels, masks = jax.tree_map(jnp.asarray, dataset)
    labels = jax.nn.one_hot(labels, config.num_classes)
    train_mask = masks['train']
    train_indices = jnp.where(train_mask)[0]
    train_labels = labels[train_indices]
    num_training_nodes = len(train_indices)

    # Get subgraphs.
    if config.differentially_private_training:
        graph = jax.tree_map(np.asarray, graph)
        subgraphs = get_subgraphs(graph, pad_to=config.pad_subgraphs_to)
        graph = jax.tree_map(jnp.asarray, graph)

        # We only need the subgraphs for training nodes.
        train_subgraphs = subgraphs[train_indices]
        del subgraphs
    else:
        train_subgraphs = None

    # Initialize privacy accountant.
    training_privacy_accountant = privacy_accountants.get_training_privacy_accountant(
        config, num_training_nodes, compute_max_terms_per_node(config))

    # Construct and initialize model.
    rng, init_rng = jax.random.split(rng)
    estimation_indices = get_estimation_indices(train_indices, config)
    state = create_train_state(init_rng, config, graph, train_labels,
                               train_subgraphs, estimation_indices)

    # Set up checkpointing of the model.
    checkpoint_dir = os.path.join(workdir, 'checkpoints')
    ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=2)
    state = ckpt.restore_or_initialize(state)
    initial_step = int(state.step) + 1

    # Log overview of parameters.
    parameter_overview.log_parameter_overview(state.params)

    # Log metrics after initialization.
    logits = compute_logits(state, graph)
    metrics_after_init = compute_metrics(logits, labels, masks)
    metrics_after_init['epsilon'] = 0
    log_metrics(0, metrics_after_init, summary_writer, postfix='_after_init')

    # Train model.
    rng, train_rng = jax.random.split(rng)
    max_training_epsilon = get_max_training_epsilon(config)

    # Hooks called periodically during training.
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_training_steps, writer=summary_writer)
    profiler = periodic_actions.Profile(num_profile_steps=5, logdir=workdir)
    hooks = [report_progress, profiler]

    for step in range(initial_step, config.num_training_steps):

        # Perform one step of training.
        with jax.profiler.StepTraceAnnotation('train', step_num=step):
            # Sample batch.
            step_rng = jax.random.fold_in(train_rng, step)
            indices = jax.random.choice(step_rng, num_training_nodes,
                                        (config.batch_size, ))

            # Compute gradients.
            if config.differentially_private_training:
                grads = compute_updates_for_dp(state, graph, train_labels,
                                               train_subgraphs, indices,
                                               config.adjacency_normalization)
            else:
                grads = compute_updates(state, graph, train_labels, indices)

            # Update parameters.
            state = update_model(state, grads)

        # Quick indication that training is happening.
        logging.log_first_n(logging.INFO, 'Finished training step %d.', 10,
                            step)
        for hook in hooks:
            hook(step)

        # Evaluate, if required.
        is_last_step = (step == config.num_training_steps - 1)
        if step % config.evaluate_every_steps == 0 or is_last_step:
            with report_progress.timed('eval'):
                # Check if privacy budget exhausted.
                training_epsilon = training_privacy_accountant(step + 1)
                if max_training_epsilon is not None and training_epsilon >= max_training_epsilon:
                    break

                # Compute metrics.
                logits = compute_logits(state, graph)
                metrics_during_training = compute_metrics(
                    logits, labels, masks)
                metrics_during_training['epsilon'] = training_epsilon
                log_metrics(step, metrics_during_training, summary_writer)

        # Checkpoint, if required.
        if step % config.checkpoint_every_steps == 0 or is_last_step:
            with report_progress.timed('checkpoint'):
                ckpt.save(state)

    return state
コード例 #18
0
def train(*,
          workdir,
          compute_phi,
          compute_psi,
          params,
          optimal_subspace,
          num_epochs,
          learning_rate,
          key,
          method,
          lissa_kappa,
          optimizer,
          covariance_batch_size,
          main_batch_size,
          weight_batch_size,
          d,
          num_tasks,
          compute_feature_norm_on_oracle_states,
          sample_states,
          eval_states,
          use_tabular_gradient=True):
    """Training function.

  For lissa, the total number of samples is
  2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size.

  Args:
    workdir: Work directory, where we'll save logs.
    compute_phi: A function that takes params and states and returns
      a matrix of phis.
    compute_psi: A function that takes an array of states and an array
      of tasks and returns Psi[states, tasks].
    params: Parameters used as the first argument for compute_phi.
    optimal_subspace: Top-d left singular vectors of Psi.
    num_epochs: How many gradient steps to perform. (Not really epochs)
    learning_rate: The step size parameter for sgd.
    key: The jax prng key.
    method: 'naive', 'lissa', or 'oracle'.
    lissa_kappa: The parameter of the lissa method, if used.
    optimizer: Which optimizer to use. Only 'sgd' is supported.
    covariance_batch_size: the 'J' parameter. For the naive method, this is how
      many states we sample to construct the inverse. For the lissa method,
      ditto -- these are also "iterations".
    main_batch_size: How many states to update at once.
    weight_batch_size: How many states to construct the weight vector.
    d: The dimension of the representation.
    num_tasks: The total number of tasks.
    compute_feature_norm_on_oracle_states: If True, computes the feature norm
      using the oracle states (all the states in synthetic experiments).
      Otherwise, computes the norm using the sampled batch.
      Only applies to LISSA.
    sample_states: A function that takes an rng key and a number of states
      to sample, and returns a tuple containing
      (a vector of sampled states, an updated rng key).
    eval_states: An array of states to use to compute metrics on.
      This will be used to compute Phi = compute_phi(params, eval_states).
    use_tabular_gradient: If true, the train step will calculate the
      gradient using the tabular calculation. Otherwise, it will use a
      jax.vjp to backpropagate the gradient.
  """
    # Create an explicit weight vector (needed for explicit method only).
    if method == 'explicit':
        key, weight_key = jax.random.split(key)
        explicit_weight_matrix = jax.random.normal(weight_key, (d, num_tasks),
                                                   dtype=jnp.float32)
        params['explicit_weight_matrix'] = explicit_weight_matrix

    if optimizer == 'sgd':
        optimizer = optax.sgd(learning_rate)
    elif optimizer == 'adam':
        optimizer = optax.adam(learning_rate)
    else:
        raise ValueError(f'Unknown optimizer {optimizer}.')
    optimizer_state = optimizer.init(params)

    chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value)
    initial_step, params, optimizer_state = chkpt_manager.restore_or_initialize(
        (0, params, optimizer_state))

    writer = metric_writers.create_default_writer(logdir=str(workdir), )

    # Checkpointing and logging too much can use a lot of disk space.
    # Therefore, we don't want to checkpoint more than 10 times an experiment,
    # or keep more than 1k Phis per experiment.
    checkpoint_period = max(num_epochs // 10, 100_000)
    log_period = max(1_000, num_epochs // 1_000)

    def _checkpoint_callback(step, t, params, optimizer_state):
        del t  # Unused.
        chkpt_manager.save((step, params, optimizer_state))

    hooks = [
        periodic_actions.PeriodicCallback(every_steps=checkpoint_period,
                                          callback_fn=_checkpoint_callback)
    ]

    fixed_train_kwargs = {
        'compute_phi':
        compute_phi,
        'compute_psi':
        compute_psi,
        'optimizer':
        optimizer,
        'method':
        method,
        # In the tabular case, the eval_states are all the states.
        'oracle_states':
        eval_states,
        'lissa_kappa':
        lissa_kappa,
        'main_batch_size':
        main_batch_size,
        'covariance_batch_size':
        covariance_batch_size,
        'weight_batch_size':
        weight_batch_size,
        'd':
        d,
        'num_tasks':
        num_tasks,
        'compute_feature_norm_on_oracle_states':
        (compute_feature_norm_on_oracle_states),
        'sample_states':
        sample_states,
        'use_tabular_gradient':
        use_tabular_gradient,
    }
    variable_kwargs = {
        'params': params,
        'optimizer_state': optimizer_state,
        'key': key,
    }

    @jax.jit
    def _eval_step(phi_params):
        eval_phi = compute_phi(phi_params, eval_states)
        eval_psi = compute_psi(eval_states)  # pytype: disable=wrong-arg-count

        metrics = compute_metrics(eval_phi, optimal_subspace)
        metrics |= {'frob_norm': utils.outer_objective_mc(eval_phi, eval_psi)}
        return metrics

    # Perform num_epochs gradient steps.
    with metric_writers.ensure_flushes(writer):
        for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1),
                               initial=initial_step,
                               total=num_epochs):

            variable_kwargs = _train_step(**fixed_train_kwargs,
                                          **variable_kwargs)

            if step % log_period == 0:
                metrics = _eval_step(variable_kwargs['params']['phi_params'])
                writer.write_scalars(step, metrics)

            for hook in hooks:
                hook(step,
                     params=variable_kwargs['params'],
                     optimizer_state=variable_kwargs['optimizer_state'])

    writer.flush()
コード例 #19
0
def train(base_dir, config):
    """Train function."""
    print(config)
    chkpt_manager = checkpoint.Checkpoint(str(base_dir / 'train'))

    writer = create_default_writer()

    # Initialize dataset
    key = jax.random.PRNGKey(config.seed)
    key, subkey = jax.random.split(key)
    ds = dataset.get_dataset(config, subkey, num_tasks=config.num_tasks)
    ds_iter = iter(ds)

    key, subkey = jax.random.split(key)
    encoder = MLPEncoder(**config.encoder)

    train_config = config.train.to_dict()
    train_method = train_config.pop('method')

    module_config = train_config.pop('module')
    module_class = module_config.pop('name')

    module = globals().get(module_class)(encoder, **module_config)
    train_step = globals().get(f'train_step_{train_method}')
    train_step = functools.partial(train_step, **train_config)

    params = module.init(subkey, next(ds_iter)[0])
    lr = optax.cosine_decay_schedule(config.learning_rate,
                                     config.num_train_steps)
    optim = optax.chain(optax.adam(lr),
                        # optax.adaptive_grad_clip(0.15)
                        )

    state = TrainState.create(apply_fn=module.apply, params=params, tx=optim)
    state = chkpt_manager.restore_or_initialize(state)

    # Hooks
    report_progress = periodic_actions.ReportProgress(
        num_train_steps=config.num_train_steps, writer=writer)
    hooks = [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=str(base_dir))
    ]

    def handle_preemption(signal_number, _):
        logging.info('Received signal %d, saving checkpoint.', signal_number)
        with report_progress.timed('checkpointing'):
            chkpt_manager.save(state)
        logging.info('Finished saving checkpoint.')

    signal.signal(signal.SIGTERM, handle_preemption)

    metrics = TrainMetrics.empty()
    with metric_writers.ensure_flushes(writer):
        for step in tqdm.tqdm(range(state.step, config.num_train_steps)):
            with jax.profiler.StepTraceAnnotation('train', step_num=step):
                states, targets = next(ds_iter)
                state, metrics = train_step(state, metrics, states, targets)

            logging.log_first_n(logging.INFO, 'Finished training step %d', 5,
                                step)

            if step % config.log_metrics_every == 0:
                writer.write_scalars(step, metrics.compute())
                metrics = TrainMetrics.empty()

            # if step % config.log_eval_metrics_every == 0 and isinstance(
            #     ds, dataset.MDPDataset):
            #   eval_metrics = evaluate_mdp(state, ds.aux_task_matrix, config)
            #   writer.write_scalars(step, eval_metrics.compute())

            for hook in hooks:
                hook(step)

    chkpt_manager.save(state)
    return state