示例#1
0
    def test_on_steps(self):
        callback = mock.Mock()
        hook = periodic_actions.PeriodicCallback(on_steps=[8],
                                                 callback_fn=callback)

        for step in range(1, 10):
            hook(step, remainder=step % 3)

        callback.assert_called_once_with(remainder=2, step=8, t=mock.ANY)
示例#2
0
    def test_function_without_step_and_time(self):

        # This must be used with pass_step_and_time=False.
        def cb():
            return 5

        hook = periodic_actions.PeriodicCallback(every_steps=1,
                                                 callback_fn=cb,
                                                 pass_step_and_time=False)
        hook(0)
        hook(1)
        self.assertEqual(hook.get_last_callback_result(), 5)
示例#3
0
    def test_error_async_is_forwarded(self):
        def cb(step, t):
            del step
            del t
            raise Exception

        hook = periodic_actions.PeriodicCallback(every_steps=1,
                                                 callback_fn=cb,
                                                 execute_async=True)

        hook(0)

        with self.assertRaises(Exception):
            hook(1)
示例#4
0
    def test_every_secs(self, mock_time):
        callback = mock.Mock()
        hook = periodic_actions.PeriodicCallback(every_secs=2,
                                                 callback_fn=callback)

        for step in range(1, 10):
            mock_time.return_value = float(step)
            hook(step, remainder=step % 5)
        # Note: time will be initialized at 1 so hook runs at steps 4 & 7.
        expected_calls = [
            mock.call(remainder=4, step=4, t=4.0),
            mock.call(remainder=2, step=7, t=7.0)
        ]
        self.assertListEqual(expected_calls, callback.call_args_list)
示例#5
0
    def test_every_steps(self):
        callback = mock.Mock()
        hook = periodic_actions.PeriodicCallback(every_steps=2,
                                                 callback_fn=callback)

        for step in range(1, 10):
            hook(step, 3, remainder=step % 3)

        expected_calls = [
            mock.call(remainder=2, step=2, t=3),
            mock.call(remainder=1, step=4, t=3),
            mock.call(remainder=0, step=6, t=3),
            mock.call(remainder=2, step=8, t=3)
        ]
        self.assertListEqual(expected_calls, callback.call_args_list)
示例#6
0
    def test_async_execution(self):
        out = []

        def cb(step, t):
            del t
            out.append(step)

        hook = periodic_actions.PeriodicCallback(every_steps=1,
                                                 callback_fn=cb,
                                                 execute_async=True)
        hook(0)
        hook(1)
        hook(2)
        hook(3)
        # Block till all the hooks have finished.
        hook.get_last_callback_result().result()
        # Check order of execution is preserved.
        self.assertListEqual(out, [0, 1, 2, 3])
示例#7
0
def train(*,
          workdir,
          compute_phi,
          compute_psi,
          params,
          optimal_subspace,
          num_epochs,
          learning_rate,
          key,
          method,
          lissa_kappa,
          optimizer,
          covariance_batch_size,
          main_batch_size,
          weight_batch_size,
          d,
          num_tasks,
          compute_feature_norm_on_oracle_states,
          sample_states,
          eval_states,
          use_tabular_gradient=True):
    """Training function.

  For lissa, the total number of samples is
  2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size.

  Args:
    workdir: Work directory, where we'll save logs.
    compute_phi: A function that takes params and states and returns
      a matrix of phis.
    compute_psi: A function that takes an array of states and an array
      of tasks and returns Psi[states, tasks].
    params: Parameters used as the first argument for compute_phi.
    optimal_subspace: Top-d left singular vectors of Psi.
    num_epochs: How many gradient steps to perform. (Not really epochs)
    learning_rate: The step size parameter for sgd.
    key: The jax prng key.
    method: 'naive', 'lissa', or 'oracle'.
    lissa_kappa: The parameter of the lissa method, if used.
    optimizer: Which optimizer to use. Only 'sgd' is supported.
    covariance_batch_size: the 'J' parameter. For the naive method, this is how
      many states we sample to construct the inverse. For the lissa method,
      ditto -- these are also "iterations".
    main_batch_size: How many states to update at once.
    weight_batch_size: How many states to construct the weight vector.
    d: The dimension of the representation.
    num_tasks: The total number of tasks.
    compute_feature_norm_on_oracle_states: If True, computes the feature norm
      using the oracle states (all the states in synthetic experiments).
      Otherwise, computes the norm using the sampled batch.
      Only applies to LISSA.
    sample_states: A function that takes an rng key and a number of states
      to sample, and returns a tuple containing
      (a vector of sampled states, an updated rng key).
    eval_states: An array of states to use to compute metrics on.
      This will be used to compute Phi = compute_phi(params, eval_states).
    use_tabular_gradient: If true, the train step will calculate the
      gradient using the tabular calculation. Otherwise, it will use a
      jax.vjp to backpropagate the gradient.
  """
    # Create an explicit weight vector (needed for explicit method only).
    if method == 'explicit':
        key, weight_key = jax.random.split(key)
        explicit_weight_matrix = jax.random.normal(weight_key, (d, num_tasks),
                                                   dtype=jnp.float32)
        params['explicit_weight_matrix'] = explicit_weight_matrix

    if optimizer == 'sgd':
        optimizer = optax.sgd(learning_rate)
    elif optimizer == 'adam':
        optimizer = optax.adam(learning_rate)
    else:
        raise ValueError(f'Unknown optimizer {optimizer}.')
    optimizer_state = optimizer.init(params)

    chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value)
    initial_step, params, optimizer_state = chkpt_manager.restore_or_initialize(
        (0, params, optimizer_state))

    writer = metric_writers.create_default_writer(logdir=str(workdir), )

    # Checkpointing and logging too much can use a lot of disk space.
    # Therefore, we don't want to checkpoint more than 10 times an experiment,
    # or keep more than 1k Phis per experiment.
    checkpoint_period = max(num_epochs // 10, 100_000)
    log_period = max(1_000, num_epochs // 1_000)

    def _checkpoint_callback(step, t, params, optimizer_state):
        del t  # Unused.
        chkpt_manager.save((step, params, optimizer_state))

    hooks = [
        periodic_actions.PeriodicCallback(every_steps=checkpoint_period,
                                          callback_fn=_checkpoint_callback)
    ]

    fixed_train_kwargs = {
        'compute_phi':
        compute_phi,
        'compute_psi':
        compute_psi,
        'optimizer':
        optimizer,
        'method':
        method,
        # In the tabular case, the eval_states are all the states.
        'oracle_states':
        eval_states,
        'lissa_kappa':
        lissa_kappa,
        'main_batch_size':
        main_batch_size,
        'covariance_batch_size':
        covariance_batch_size,
        'weight_batch_size':
        weight_batch_size,
        'd':
        d,
        'num_tasks':
        num_tasks,
        'compute_feature_norm_on_oracle_states':
        (compute_feature_norm_on_oracle_states),
        'sample_states':
        sample_states,
        'use_tabular_gradient':
        use_tabular_gradient,
    }
    variable_kwargs = {
        'params': params,
        'optimizer_state': optimizer_state,
        'key': key,
    }

    @jax.jit
    def _eval_step(phi_params):
        eval_phi = compute_phi(phi_params, eval_states)
        eval_psi = compute_psi(eval_states)  # pytype: disable=wrong-arg-count

        metrics = compute_metrics(eval_phi, optimal_subspace)
        metrics |= {'frob_norm': utils.outer_objective_mc(eval_phi, eval_psi)}
        return metrics

    # Perform num_epochs gradient steps.
    with metric_writers.ensure_flushes(writer):
        for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1),
                               initial=initial_step,
                               total=num_epochs):

            variable_kwargs = _train_step(**fixed_train_kwargs,
                                          **variable_kwargs)

            if step % log_period == 0:
                metrics = _eval_step(variable_kwargs['params']['phi_params'])
                writer.write_scalars(step, metrics)

            for hook in hooks:
                hook(step,
                     params=variable_kwargs['params'],
                     optimizer_state=variable_kwargs['optimizer_state'])

    writer.flush()
def train(*,
          workdir,
          initial_step,
          chkpt_manager,
          Phi,
          Psi,
          optimal_subspace,
          num_epochs,
          learning_rate,
          key,
          method,
          lissa_kappa,
          optimizer,
          covariance_batch_size,
          main_batch_size,
          weight_batch_size,
          estimate_feature_norm=True):
    """Training function.

  For lissa, the total number of samples is
  2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size.

  Args:
    workdir: Work directory, where we'll save logs.
    initial_step: Initial step
    chkpt_manager: Checkpoint manager.
    Phi: The initial feature matrix.
    Psi: The target matrix whose PCA is to be determined.
    optimal_subspace: Top-d left singular vectors of Psi.
    num_epochs: How many gradient steps to perform. (Not really epochs)
    learning_rate: The step size parameter for sgd.
    key: The jax prng key.
    method: 'naive', 'lissa', or 'oracle'.
    lissa_kappa: The parameter of the lissa method, if used.
    optimizer: Which optimizer to use. Only 'sgd' is supported.
    covariance_batch_size: the 'J' parameter. For the naive method, this is how
      many states we sample to construct the inverse. For the lissa method,
      ditto -- these are also "iterations".
    main_batch_size: How many states to update at once.
    weight_batch_size: How many states to construct the weight vector.
    estimate_feature_norm: Whether to use a running average of the max feature
      norm rather than the real maximum.

  Returns:
    tuple: representation and gradient arrays
  """
    # Don't overwrite Phi.
    Phi = np.copy(Phi)
    Phis = [np.copy(Phi)]

    num_states, d = Phi.shape
    _, num_tasks = Psi.shape

    # Keep a running average of the max norm of a feature vector. None means:
    # don't do it.
    if estimate_feature_norm:
        estimated_feature_norm = utils.compute_max_feature_norm(Phi)
    else:
        estimated_feature_norm = None

    # Create an explicit weight vector (needed for explicit method).
    key, weight_key = jax.random.split(key)
    explicit_weight_matrix = np.array(
        jax.random.normal(  # charlinel(why benefit of np?)
            weight_key, (d, num_tasks),
            dtype=jnp.float64))

    assert optimizer == 'sgd', 'Non-sgd not yet supported.'

    writer = metric_writers.create_default_writer(logdir=str(workdir), )

    hooks = [
        periodic_actions.PeriodicCallback(
            every_steps=5_000,
            callback_fn=lambda step, t: chkpt_manager.save((step, Phi)))
    ]