def test_on_steps(self): callback = mock.Mock() hook = periodic_actions.PeriodicCallback(on_steps=[8], callback_fn=callback) for step in range(1, 10): hook(step, remainder=step % 3) callback.assert_called_once_with(remainder=2, step=8, t=mock.ANY)
def test_function_without_step_and_time(self): # This must be used with pass_step_and_time=False. def cb(): return 5 hook = periodic_actions.PeriodicCallback(every_steps=1, callback_fn=cb, pass_step_and_time=False) hook(0) hook(1) self.assertEqual(hook.get_last_callback_result(), 5)
def test_error_async_is_forwarded(self): def cb(step, t): del step del t raise Exception hook = periodic_actions.PeriodicCallback(every_steps=1, callback_fn=cb, execute_async=True) hook(0) with self.assertRaises(Exception): hook(1)
def test_every_secs(self, mock_time): callback = mock.Mock() hook = periodic_actions.PeriodicCallback(every_secs=2, callback_fn=callback) for step in range(1, 10): mock_time.return_value = float(step) hook(step, remainder=step % 5) # Note: time will be initialized at 1 so hook runs at steps 4 & 7. expected_calls = [ mock.call(remainder=4, step=4, t=4.0), mock.call(remainder=2, step=7, t=7.0) ] self.assertListEqual(expected_calls, callback.call_args_list)
def test_every_steps(self): callback = mock.Mock() hook = periodic_actions.PeriodicCallback(every_steps=2, callback_fn=callback) for step in range(1, 10): hook(step, 3, remainder=step % 3) expected_calls = [ mock.call(remainder=2, step=2, t=3), mock.call(remainder=1, step=4, t=3), mock.call(remainder=0, step=6, t=3), mock.call(remainder=2, step=8, t=3) ] self.assertListEqual(expected_calls, callback.call_args_list)
def test_async_execution(self): out = [] def cb(step, t): del t out.append(step) hook = periodic_actions.PeriodicCallback(every_steps=1, callback_fn=cb, execute_async=True) hook(0) hook(1) hook(2) hook(3) # Block till all the hooks have finished. hook.get_last_callback_result().result() # Check order of execution is preserved. self.assertListEqual(out, [0, 1, 2, 3])
def train(*, workdir, compute_phi, compute_psi, params, optimal_subspace, num_epochs, learning_rate, key, method, lissa_kappa, optimizer, covariance_batch_size, main_batch_size, weight_batch_size, d, num_tasks, compute_feature_norm_on_oracle_states, sample_states, eval_states, use_tabular_gradient=True): """Training function. For lissa, the total number of samples is 2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size. Args: workdir: Work directory, where we'll save logs. compute_phi: A function that takes params and states and returns a matrix of phis. compute_psi: A function that takes an array of states and an array of tasks and returns Psi[states, tasks]. params: Parameters used as the first argument for compute_phi. optimal_subspace: Top-d left singular vectors of Psi. num_epochs: How many gradient steps to perform. (Not really epochs) learning_rate: The step size parameter for sgd. key: The jax prng key. method: 'naive', 'lissa', or 'oracle'. lissa_kappa: The parameter of the lissa method, if used. optimizer: Which optimizer to use. Only 'sgd' is supported. covariance_batch_size: the 'J' parameter. For the naive method, this is how many states we sample to construct the inverse. For the lissa method, ditto -- these are also "iterations". main_batch_size: How many states to update at once. weight_batch_size: How many states to construct the weight vector. d: The dimension of the representation. num_tasks: The total number of tasks. compute_feature_norm_on_oracle_states: If True, computes the feature norm using the oracle states (all the states in synthetic experiments). Otherwise, computes the norm using the sampled batch. Only applies to LISSA. sample_states: A function that takes an rng key and a number of states to sample, and returns a tuple containing (a vector of sampled states, an updated rng key). eval_states: An array of states to use to compute metrics on. This will be used to compute Phi = compute_phi(params, eval_states). use_tabular_gradient: If true, the train step will calculate the gradient using the tabular calculation. Otherwise, it will use a jax.vjp to backpropagate the gradient. """ # Create an explicit weight vector (needed for explicit method only). if method == 'explicit': key, weight_key = jax.random.split(key) explicit_weight_matrix = jax.random.normal(weight_key, (d, num_tasks), dtype=jnp.float32) params['explicit_weight_matrix'] = explicit_weight_matrix if optimizer == 'sgd': optimizer = optax.sgd(learning_rate) elif optimizer == 'adam': optimizer = optax.adam(learning_rate) else: raise ValueError(f'Unknown optimizer {optimizer}.') optimizer_state = optimizer.init(params) chkpt_manager = checkpoint.Checkpoint(base_directory=_WORKDIR.value) initial_step, params, optimizer_state = chkpt_manager.restore_or_initialize( (0, params, optimizer_state)) writer = metric_writers.create_default_writer(logdir=str(workdir), ) # Checkpointing and logging too much can use a lot of disk space. # Therefore, we don't want to checkpoint more than 10 times an experiment, # or keep more than 1k Phis per experiment. checkpoint_period = max(num_epochs // 10, 100_000) log_period = max(1_000, num_epochs // 1_000) def _checkpoint_callback(step, t, params, optimizer_state): del t # Unused. chkpt_manager.save((step, params, optimizer_state)) hooks = [ periodic_actions.PeriodicCallback(every_steps=checkpoint_period, callback_fn=_checkpoint_callback) ] fixed_train_kwargs = { 'compute_phi': compute_phi, 'compute_psi': compute_psi, 'optimizer': optimizer, 'method': method, # In the tabular case, the eval_states are all the states. 'oracle_states': eval_states, 'lissa_kappa': lissa_kappa, 'main_batch_size': main_batch_size, 'covariance_batch_size': covariance_batch_size, 'weight_batch_size': weight_batch_size, 'd': d, 'num_tasks': num_tasks, 'compute_feature_norm_on_oracle_states': (compute_feature_norm_on_oracle_states), 'sample_states': sample_states, 'use_tabular_gradient': use_tabular_gradient, } variable_kwargs = { 'params': params, 'optimizer_state': optimizer_state, 'key': key, } @jax.jit def _eval_step(phi_params): eval_phi = compute_phi(phi_params, eval_states) eval_psi = compute_psi(eval_states) # pytype: disable=wrong-arg-count metrics = compute_metrics(eval_phi, optimal_subspace) metrics |= {'frob_norm': utils.outer_objective_mc(eval_phi, eval_psi)} return metrics # Perform num_epochs gradient steps. with metric_writers.ensure_flushes(writer): for step in etqdm.tqdm(range(initial_step + 1, num_epochs + 1), initial=initial_step, total=num_epochs): variable_kwargs = _train_step(**fixed_train_kwargs, **variable_kwargs) if step % log_period == 0: metrics = _eval_step(variable_kwargs['params']['phi_params']) writer.write_scalars(step, metrics) for hook in hooks: hook(step, params=variable_kwargs['params'], optimizer_state=variable_kwargs['optimizer_state']) writer.flush()
def train(*, workdir, initial_step, chkpt_manager, Phi, Psi, optimal_subspace, num_epochs, learning_rate, key, method, lissa_kappa, optimizer, covariance_batch_size, main_batch_size, weight_batch_size, estimate_feature_norm=True): """Training function. For lissa, the total number of samples is 2 x covariance_batch_size + main_batch_size + 2 x weight_batch_size. Args: workdir: Work directory, where we'll save logs. initial_step: Initial step chkpt_manager: Checkpoint manager. Phi: The initial feature matrix. Psi: The target matrix whose PCA is to be determined. optimal_subspace: Top-d left singular vectors of Psi. num_epochs: How many gradient steps to perform. (Not really epochs) learning_rate: The step size parameter for sgd. key: The jax prng key. method: 'naive', 'lissa', or 'oracle'. lissa_kappa: The parameter of the lissa method, if used. optimizer: Which optimizer to use. Only 'sgd' is supported. covariance_batch_size: the 'J' parameter. For the naive method, this is how many states we sample to construct the inverse. For the lissa method, ditto -- these are also "iterations". main_batch_size: How many states to update at once. weight_batch_size: How many states to construct the weight vector. estimate_feature_norm: Whether to use a running average of the max feature norm rather than the real maximum. Returns: tuple: representation and gradient arrays """ # Don't overwrite Phi. Phi = np.copy(Phi) Phis = [np.copy(Phi)] num_states, d = Phi.shape _, num_tasks = Psi.shape # Keep a running average of the max norm of a feature vector. None means: # don't do it. if estimate_feature_norm: estimated_feature_norm = utils.compute_max_feature_norm(Phi) else: estimated_feature_norm = None # Create an explicit weight vector (needed for explicit method). key, weight_key = jax.random.split(key) explicit_weight_matrix = np.array( jax.random.normal( # charlinel(why benefit of np?) weight_key, (d, num_tasks), dtype=jnp.float64)) assert optimizer == 'sgd', 'Non-sgd not yet supported.' writer = metric_writers.create_default_writer(logdir=str(workdir), ) hooks = [ periodic_actions.PeriodicCallback( every_steps=5_000, callback_fn=lambda step, t: chkpt_manager.save((step, Phi))) ]