def create_optax_optim(name, learning_rate=None, momentum=0.9, weight_decay=0, **kwargs): """ Optimizer Factory Args: learning_rate (float): specify learning rate or leave up to scheduler / optim if None weight_decay (float): weight decay to apply to all params, not applied if 0 **kwargs: optional / optimizer specific params that override defaults With regards to the kwargs, I've tried to keep the param naming incoming via kwargs from config file more consistent so there is less variation. Names of common args such as eps, beta1, beta2 etc will be remapped where possible (even if optimizer impl uses a diff name) and removed when not needed. A list of some common params to use in config files as named: eps (float): default stability / regularization epsilon value beta1 (float): moving average / momentum coefficient for gradient beta2 (float): moving average / momentum coefficient for gradient magnitude (squared grad) """ name = name.lower() opt_args = dict(learning_rate=learning_rate, **kwargs) _rename(opt_args, ('beta1', 'beta2'), ('b1', 'b2')) if name == 'sgd' or name == 'momentum' or name == 'nesterov': _erase(opt_args, ('eps', )) if name == 'momentum': optimizer = optax.sgd(momentum=momentum, **opt_args) elif name == 'nesterov': optimizer = optax.sgd(momentum=momentum, nesterov=True) else: assert name == 'sgd' optimizer = optax.sgd(momentum=0, **opt_args) elif name == 'adabelief': optimizer = optax.adabelief(**opt_args) elif name == 'adam' or name == 'adamw': if name == 'adamw': optimizer = optax.adamw(weight_decay=weight_decay, **opt_args) else: optimizer = optax.adam(**opt_args) elif name == 'lamb': optimizer = optax.lamb(weight_decay=weight_decay, **opt_args) elif name == 'lars': optimizer = lars(weight_decay=weight_decay, **opt_args) elif name == 'rmsprop': optimizer = optax.rmsprop(momentum=momentum, **opt_args) elif name == 'rmsproptf': optimizer = optax.rmsprop(momentum=momentum, initial_scale=1.0, **opt_args) else: assert False, f"Invalid optimizer name specified ({name})" return optimizer
def __init__( self, learning_rate: float = 0.001, decay: float = 0.9, eps: float = 1e-8, initial_scale: float = 0, centered: bool = False, momentum: float or None = None, nesterov: bool = False, ): super(RMSProp, self).__init__(learning_rate=learning_rate) self._decay = decay self._eps = eps self._initial_scale = initial_scale self._centered = centered self._momentum = momentum self._nesterov = nesterov self._optimizer = optax.rmsprop( learning_rate=learning_rate, decay=decay, eps=eps, initial_scale=initial_scale, centered=centered, momentum=momentum, nesterov=nesterov, ) self._optimizer_update = jit(self._optimizer.update)
def test_example_restore(self): class MLP(eg.Module): @eg.compact def __call__(self, x): x = eg.Linear(10)(x) x = jax.lax.stop_gradient(x) return x # This callback will stop the training when there is no improvement in # the for three consecutive epochs. model = eg.Model( module=MLP(), loss=eg.losses.MeanSquaredError(), optimizer=optax.rmsprop(0.01), ) history = model.fit( inputs=np.ones((5, 20)), labels=np.zeros((5, 10)), epochs=10, batch_size=1, callbacks=[ eg.callbacks.EarlyStopping(monitor="loss", patience=3, restore_best_weights=True) ], verbose=0, ) assert len(history.history["loss"]) == 4 # Only 4 epochs are run.
def main(_): # A thunk that builds a new environment. # Substitute your environment here! build_env = catch.Catch # Construct the agent. We need a sample environment for its spec. env_for_spec = build_env() num_actions = env_for_spec.action_spec().num_values agent = agent_lib.Agent(num_actions, env_for_spec.observation_spec(), haiku_nets.CatchNet) # Construct the optimizer. max_updates = MAX_ENV_FRAMES / FRAMES_PER_ITER opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7) # Construct the learner. learner = learner_lib.Learner( agent, jax.random.PRNGKey(428), opt, BATCH_SIZE, DISCOUNT_FACTOR, FRAMES_PER_ITER, max_abs_reward=1., logger=util.AbslLogger(), # Provide your own logger here. ) # Construct the actors on different threads. # stop_signal in a list so the reference is shared. actor_threads = [] stop_signal = [False] for i in range(NUM_ACTORS): actor = actor_lib.Actor( agent, build_env(), UNROLL_LENGTH, learner, rng_seed=i, logger=util.AbslLogger(), # Provide your own logger here. ) args = (actor, stop_signal) actor_threads.append(threading.Thread(target=run_actor, args=args)) # Start the actors and learner. for t in actor_threads: t.start() learner.run(int(max_updates)) # Stop. stop_signal[0] = True for t in actor_threads: t.join()
def get_optimizer(optimizer_name: OptimizerName, learning_rate: float, momentum: float = 0.0, adam_beta1: float = 0.9, adam_beta2: float = 0.999, adam_epsilon: float = 1e-8, rmsprop_decay: float = 0.9, rmsprop_epsilon: float = 1e-8, adagrad_init_accumulator: float = 0.1, adagrad_epsilon: float = 1e-6) -> Optimizer: """Given parameters, returns the corresponding optimizer. Args: optimizer_name: One of SGD, MOMENTUM, ADAM, RMSPROP. learning_rate: Learning rate for all optimizers. momentum: Momentum parameter for MOMENTUM. adam_beta1: beta1 parameter for ADAM. adam_beta2: beta2 parameter for ADAM. adam_epsilon: epsilon parameter for ADAM. rmsprop_decay: decay parameter for RMSPROP. rmsprop_epsilon: epsilon parameter for RMSPROP. adagrad_init_accumulator: initial accumulator for ADAGRAD. adagrad_epsilon: epsilon parameter for ADAGRAD. Returns: Returns the Optimizer with the specified properties. Raises: ValueError: iff the optimizer names is not one of SGD, MOMENTUM, ADAM, RMSPROP, or Adagrad, raises errors. """ if optimizer_name == OptimizerName.SGD: return Optimizer(*optax.sgd(learning_rate)) elif optimizer_name == OptimizerName.MOMENTUM: return Optimizer(*optax.sgd(learning_rate, momentum)) elif optimizer_name == OptimizerName.ADAM: return Optimizer(*optax.adam( learning_rate, b1=adam_beta1, b2=adam_beta2, eps=adam_epsilon)) elif optimizer_name == OptimizerName.RMSPROP: return Optimizer( *optax.rmsprop(learning_rate, decay=rmsprop_decay, eps=rmsprop_epsilon)) elif optimizer_name == OptimizerName.ADAGRAD: return Optimizer(*optax.adagrad( learning_rate, initial_accumulator_value=adagrad_init_accumulator, eps=adagrad_epsilon)) else: raise ValueError(f'Unsupported optimizer_name {optimizer_name}.')
def make_optimizer(optimizer_type): """Constructs optimizer.""" if optimizer_type == 'rmsprop': learning_rate = 0.00025 epsilon = 0.01 / (32**2) optimizer = optax.rmsprop(learning_rate=learning_rate, decay=0.95, eps=epsilon, centered=True) elif optimizer_type == 'adam': learning_rate = 0.00005 epsilon = 0.01 / 32 optimizer = optax.adam(learning_rate=learning_rate, eps=epsilon) else: raise ValueError('Unknown optimizer "{}"'.format(optimizer_type)) return optimizer
def run(*, trajectories_per_actor, num_actors, unroll_len): """Runs the example.""" # Construct the agent network. We need a sample environment for its spec. env = catch.Catch() num_actions = env.action_spec().num_values net = hk.without_apply_rng( hk.transform(lambda ts: SimpleNet(num_actions)(ts))) # pylint: disable=unnecessary-lambda # Construct the agent and learner. agent = Agent(net.apply) opt = optax.rmsprop(5e-3, decay=0.99, eps=1e-7) learner = Learner(agent, opt.update) # Initialize the optimizer state. sample_ts = env.reset() sample_ts = preprocess_step(sample_ts) ts_with_batch = jax.tree_map(lambda t: np.expand_dims(t, 0), sample_ts) params = jax.jit(net.init)(jax.random.PRNGKey(428), ts_with_batch) opt_state = opt.init(params) # Create accessor and queueing functions. current_params = lambda: params batch_size = 2 q = queue.Queue(maxsize=batch_size) def dequeue(): batch = [] for _ in range(batch_size): batch.append(q.get()) batch = jax.tree_map(lambda *ts: np.stack(ts, axis=1), *batch) return jax.device_put(batch) # Start the actors. for i in range(num_actors): key = jax.random.PRNGKey(i) args = (agent, key, current_params, q.put, unroll_len, trajectories_per_actor) threading.Thread(target=run_actor, args=args).start() # Run the learner. num_steps = num_actors * trajectories_per_actor // batch_size for i in range(num_steps): traj = dequeue() params, opt_state = learner.update(params, opt_state, traj)
def make(observation_spec: specs.Array, action_spec: specs.DiscreteArray, rnn_hidden_size: int = 32, encoding_hidden_size: List[int] = [256, 128, 64], buffer_length: int = 120, discount: float = .5, td_lambda: float = .9, entropy_cost: float = 1., critic_cost: float = 1., seed: int = 0): """Creates a default agent.""" initial_rnn_state = jnp.zeros((1, rnn_hidden_size), dtype=jnp.float32) def network(inputs: List[jnp.ndarray], state) -> ModelOutput: observation = hk.Flatten()(inputs[0]).reshape((1, -1)) previous_reward = inputs[1].reshape((1, 1)) previous_action = inputs[2].reshape((1, -1)) torso = hk.nets.MLP(encoding_hidden_size) gru = hk.GRU(rnn_hidden_size) policy_head = hk.Linear(action_spec.num_values) value_head = hk.Linear(1) input_embedding = jnp.concatenate( [observation, previous_reward, previous_action], -1) input_embedding = torso(input_embedding) embedding, state = gru(input_embedding, state) logits = policy_head(embedding) value = value_head(embedding) return (logits, jnp.squeeze(value, axis=-1), embedding, embedding, embedding), state return ActorCriticRNN(observation_spec=observation_spec, action_spec=action_spec, network=network, initial_rnn_state=initial_rnn_state, optimizer=optax.rmsprop(1e-3), rng=hk.PRNGSequence(seed), buffer_length=buffer_length, discount=discount, td_lambda=td_lambda, entropy_cost=entropy_cost, critic_cost=critic_cost)
def rmsprop(learning_rate: ScalarOrSchedule, decay: float = 0.9, eps: float = 1e-8, initial_scale: float = 0., centered: bool = False, momentum: Optional[float] = None, nesterov: bool = False) -> Optimizer: """A flexible RMSProp optimiser. RMSProp is an SGD variant with learning rate adaptation. The `learning_rate` used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimiser that can be used to switch between several of these variants. References: [Tieleman and Hinton, 2012](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) [Graves, 2013](https://arxiv.org/abs/1308.0850) Args: learning_rate: This is a fixed global scaling factor. decay: The decay used to track the magnitude of previous gradients. eps: A small numerical constant to avoid dividing by zero when rescaling. initial_scale: Initialisation of accumulators tracking the magnitude of previous updates. PyTorch uses `0`, TF1 uses `1`. When reproducing results from a paper, verify the value used by the authors. centered: Whether the second moment or the variance of the past gradients is used to rescale the latest gradients. momentum: The `decay` rate used by the momentum term, when it is set to `None`, then momentum is not used at all. nesterov: Whether nesterov momentum is used. Returns: The corresponding `Optimizer`. """ return create_optimizer_from_optax( optax.rmsprop( learning_rate=learning_rate, decay=decay, eps=eps, initial_scale=initial_scale, centered=centered, momentum=momentum, nesterov=nesterov))
def create_opt(name='adamw', learning_rate=6.25e-5, beta1=0.9, beta2=0.999, eps=1.5e-4, weight_decay=0.0, centered=False): """Create an optimizer for training. Currently, only the Adam and RMSProp optimizers are supported. Args: name: str, name of the optimizer to create. learning_rate: float, learning rate to use in the optimizer. beta1: float, beta1 parameter for the optimizer. beta2: float, beta2 parameter for the optimizer. eps: float, epsilon parameter for the optimizer. centered: bool, centered parameter for RMSProp. Returns: A flax optimizer. """ if name == 'adam': logging.info( 'Creating AdamW optimizer with settings lr=%f, beta1=%f, ' 'beta2=%f, eps=%f, weight decay=%f', learning_rate, beta1, beta2, eps, weight_decay) return optax.adam(learning_rate, b1=beta1, b2=beta2, eps=eps) elif name == 'rmsprop': logging.info( 'Creating RMSProp optimizer with settings lr=%f, beta2=%f, ' 'eps=%f', learning_rate, beta2, eps) return optax.rmsprop(learning_rate, decay=beta2, eps=eps, centered=centered) else: raise ValueError('Unsupported optimizer {}'.format(name))
def make(observation_spec: specs.Array, action_spec: specs.DiscreteArray, hidden_size: List[int] = [256, 128, 64], buffer_length: int = 120, discount: float = .5, td_lambda: float = .9, entropy_cost: float = 1., critic_cost: float = 1., seed: int = 0): """Creates a default agent.""" def network(observation: Observation, previous_reward: PreviousReward, previous_action: PreviousAction) -> ModelOutput: observation = hk.Flatten()(observation) previous_reward = hk.Flatten()(previous_reward) previous_action = hk.Flatten()(previous_action) torso = hk.nets.MLP(hidden_size) policy_head = hk.Linear(action_spec.num_values) value_head = hk.Linear(1) embedding = torso( jnp.concatenate([observation, previous_reward, previous_action], -1)) logits = policy_head(embedding) value = value_head(embedding) return logits, jnp.squeeze(value, axis=-1), embedding, embedding, embedding return ActorCritic(observation_spec=observation_spec, action_spec=action_spec, network=network, optimizer=optax.rmsprop(1e-3), rng=hk.PRNGSequence(seed), buffer_length=buffer_length, discount=discount, td_lambda=td_lambda, entropy_cost=entropy_cost, critic_cost=critic_cost)
def RmsProp( learning_rate: float = 0.001, beta: float = 0.9, epscut: float = 1.0e-7, centered: bool = False, ): r"""RMSProp optimizer. RMSProp is a well-known update algorithm proposed by Geoff Hinton in his Neural Networks course notes `Neural Networks course notes <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_. It corrects the problem with AdaGrad by using an exponentially weighted moving average over past squared gradients instead of a cumulative sum. After initializing the vector :math:`\mathbf{s}` to zero, :math:`s_k` and t he parameters :math:`p_k` are updated as .. math:: s^\prime_k = \beta s_k + (1-\beta) G_k(\mathbf{p})^2 \\ p^\prime_k = p_k - \frac{\eta}{\sqrt{s_k}+\epsilon} G_k(\mathbf{p}) Constructs a new ``RmsProp`` optimizer. Args: learning_rate: The learning rate :math:`\eta` beta: Exponential decay rate. epscut: Small cutoff value. centered: whever to center the moving average. Examples: RmsProp optimizer. >>> from netket.optimizer import RmsProp >>> op = RmsProp(learning_rate=0.02) """ from optax import rmsprop return rmsprop(learning_rate, decay=beta, eps=epscut, centered=centered)
def test_example(self): class MLP(elegy.Module): def call(self, x): x = elegy.nn.Linear(10)(x) x = jax.lax.stop_gradient(x) return x callback = elegy.callbacks.EarlyStopping(monitor="loss", patience=3) # This callback will stop the training when there is no improvement in # the for three consecutive epochs. model = elegy.Model( module=MLP(), loss=elegy.losses.MeanSquaredError(), optimizer=optax.rmsprop(0.01), ) history = model.fit( x=np.ones((5, 20)), y=np.zeros((5, 10)), epochs=10, batch_size=1, callbacks=[callback], verbose=0, ) assert len(history.history["loss"]) == 4 # Only 4 epochs are run.
def main(argv): """Trains Prioritized DQN agent on Atari.""" del argv logging.info('Prioritized DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari( FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.double_dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation chex.assert_shape(sample_network_input, (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames)) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) # Note the t in the replay is not exactly aligned with the agent t. importance_sampling_exponent_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity), end_t=(FLAGS.num_iterations * int(FLAGS.num_train_frames / FLAGS.num_action_repeats)), begin_value=FLAGS.importance_sampling_exponent_begin_value, end_value=FLAGS.importance_sampling_exponent_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.PrioritizedTransitionReplay( FLAGS.replay_capacity, replay_structure, FLAGS.priority_exponent, importance_sampling_exponent_schedule, FLAGS.uniform_sample_probability, FLAGS.normalize_weights, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.PrioritizedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. checkpoint = parts.NullCheckpoint() state = checkpoint.state state.iteration = 0 state.train_agent = train_agent state.eval_agent = eval_agent state.random_state = random_state state.writer = writer if checkpoint.can_be_restored(): checkpoint.restore() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_trackers = parts.make_default_trackers(train_agent) train_stats = parts.generate_statistics(train_trackers, train_seq_truncated) logging.info('Evaluation iteration %d.', state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_trackers = parts.make_default_trackers(eval_agent) eval_stats = parts.generate_statistics(eval_trackers, eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', state.iteration, '%3d'), ('frame', state.iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('train_state_value', train_stats['state_value'], '%.3f'), ('importance_sampling_exponent', train_agent.importance_sampling_exponent, '%.3f'), ('max_seen_priority', train_agent.max_seen_priority, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ] log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def main(argv): """Trains DQN agent on Atari.""" del argv logging.info("DQN on Atari on %s.", jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Key-Door environment.""" env = gym_key_door.GymKeyDoor( env_args={ constants.MAP_ASCII_PATH: FLAGS.map_ascii_path, constants.MAP_YAML_PATH: FLAGS.map_yaml_path, constants.REPRESENTATION: constants.PIXEL, constants.SCALING: FLAGS.env_scaling, constants.EPISODE_TIMEOUT: FLAGS.max_frames_per_episode, constants.GRAYSCALE: False, constants.BATCH_DIMENSION: False, constants.TORCH_AXES: False, }, env_shape=FLAGS.env_shape, ) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info("Environment: %s", FLAGS.environment_name) logging.info("Action spec: %s", env.action_spec()) logging.info("Observation spec: %s", env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.dqn_atari_network(num_actions) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation assert sample_network_input.shape == ( FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames, ) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value, ) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t), ) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t), ) else: encoder = None decoder = None replay_structure = replay_lib.Transition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, ) replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) if FLAGS.shaping_function_type == constants.NO_PENALTY: shaping_function = shaping.NoPenalty() if FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY: shaping_function = shaping.HardCodedPenalty( penalty=FLAGS.shaping_multiplicative_factor) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.Dqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, shaping_function=shaping_function, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. # checkpoint = parts.NullCheckpoint() checkpoint = parts.ImplementedCheckpoint( checkpoint_path=FLAGS.checkpoint_path) if checkpoint.can_be_restored(): checkpoint.restore() train_agent.set_state(state=checkpoint.state.train_agent) eval_agent.set_state(state=checkpoint.state.eval_agent) writer.set_state(state=checkpoint.state.writer) state = checkpoint.state state.iteration = 0 state.train_agent = train_agent.get_state() state.eval_agent = eval_agent.get_state() state.random_state = random_state state.writer = writer.get_state() while state.iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info("Training iteration %d.", state.iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_stats = parts.generate_statistics(train_seq_truncated) logging.info("Evaluation iteration %d.", state.iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = parts.generate_statistics(eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats["episode_return"]) capped_human_normalized_score = np.amin([1.0, human_normalized_score]) log_output = [ ("iteration", state.iteration, "%3d"), ("frame", state.iteration * FLAGS.num_train_frames, "%5d"), ("eval_episode_return", eval_stats["episode_return"], "% 2.2f"), ("train_episode_return", train_stats["episode_return"], "% 2.2f"), ("eval_num_episodes", eval_stats["num_episodes"], "%3d"), ("train_num_episodes", train_stats["num_episodes"], "%3d"), ("eval_frame_rate", eval_stats["step_rate"], "%4.0f"), ("train_frame_rate", train_stats["step_rate"], "%4.0f"), ("train_exploration_epsilon", train_agent.exploration_epsilon, "%.3f"), ("normalized_return", human_normalized_score, "%.3f"), ("capped_normalized_return", capped_human_normalized_score, "%.3f"), ("human_gap", 1.0 - capped_human_normalized_score, "%.3f"), ] log_output_str = ", ".join( ("%s: " + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) state.iteration += 1 checkpoint.save() writer.close()
def main(debug: bool = False, eager: bool = False, logdir: str = "runs"): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) X_train, _1, X_test, _2 = dataget.image.mnist(global_cache=True).get() print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) class MLP(elegy.Module): """Standard LeNet-300-100 MLP network.""" def __init__(self, n1: int = 300, n2: int = 100, **kwargs): super().__init__(**kwargs) self.n1 = n1 self.n2 = n2 def call(self, image: jnp.ndarray): image = image.astype(jnp.float32) / 255.0 x = elegy.nn.Flatten()(image) x = elegy.nn.sequential( elegy.nn.Linear(self.n1), jax.nn.relu, elegy.nn.Linear(self.n2), jax.nn.relu, elegy.nn.Linear(self.n1), jax.nn.relu, elegy.nn.Linear(x.shape[-1]), jax.nn.sigmoid, )(x) return x.reshape(image.shape) * 255 class MeanSquaredError(elegy.losses.MeanSquaredError): # we request `x` instead of `y_true` since we are don't require labels in autoencoders def call(self, x, y_pred): return super().call(x, y_pred) model = elegy.Model( module=MLP(n1=256, n2=64), loss=MeanSquaredError(), optimizer=optax.rmsprop(0.001), run_eagerly=eager, ) model.summary(X_train[:64]) # Notice we are not passing `y` history = model.fit( x=X_train, epochs=20, batch_size=64, validation_data=(X_test,), shuffle=True, callbacks=[elegy.callbacks.TensorBoard(logdir=logdir, update_freq=300)], ) plot_history(history) # get random samples idxs = np.random.randint(0, 10000, size=(5,)) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred[i], cmap="gray") # tbwriter.add_figure("AutoEncoder images", figure, 20) plt.show() print( "\n\n\nMetrics and images can be explored using tensorboard using:", f"\n \t\t\t tensorboard --logdir {logdir}", )
def update( self, gradient: Weights, state: GenericGradientState, parameters: Optional[Weights] ) -> Tuple[Weights, GenericGradientState]: return GenericGradientState.wrap(*rmsprop( **asdict(self)).update(gradient, state.data, parameters))
def init(self, parameters: Weights) -> GenericGradientState: return GenericGradientState(rmsprop(**asdict(self)).init(parameters))
def main(): start_time = time.time() parser = argparse.ArgumentParser() add_args(parser) args = parser.parse_args() print(args) print("Is jax using @jit decorators?", not jax.config.read("jax_disable_jit")) rng_seq = hk.PRNGSequence(args.random_seed) p_log_prob = hk.transform(lambda x, z: Model( args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)(x=x, z=z)) if args.variational == "mean-field": variational = VariationalMeanField elif args.variational == "flow": variational = VariationalFlow q_sample_and_log_prob = hk.transform(lambda x, num_samples: variational( args.latent_size, args.hidden_size)(x, num_samples)) p_params = p_log_prob.init( next(rng_seq), z=np.zeros((1, args.latent_size), dtype=np.float32), x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), ) q_params = q_sample_and_log_prob.init( next(rng_seq), x=np.zeros((1, *MNIST_IMAGE_SHAPE), dtype=np.float32), num_samples=1, ) optimizer = optax.rmsprop(args.learning_rate) params = (p_params, q_params) opt_state = optimizer.init(params) @jax.jit def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray: """Objective function is negative ELBO.""" x = batch["image"] p_params, q_params = params z, log_q_z = q_sample_and_log_prob.apply(q_params, rng_key, x=x, num_samples=1) log_p_x_z = p_log_prob.apply(p_params, rng_key, x=x, z=z) elbo = log_p_x_z - log_q_z # average elbo over number of samples elbo = elbo.mean(axis=0) # sum elbo over batch elbo = elbo.sum(axis=0) return -elbo @jax.jit def train_step(params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch) -> Tuple[hk.Params, optax.OptState]: """Single update step to maximize the ELBO.""" grads = jax.grad(objective_fn)(params, rng_key, batch) updates, new_opt_state = optimizer.update(grads, opt_state) new_params = optax.apply_updates(params, updates) return new_params, new_opt_state @jax.jit def importance_weighted_estimate( params: hk.Params, rng_key: PRNGKey, batch: Batch) -> Tuple[jnp.ndarray, jnp.ndarray]: """Estimate marginal log p(x) using importance sampling.""" x = batch["image"] p_params, q_params = params z, log_q_z = q_sample_and_log_prob.apply( q_params, rng_key, x=x, num_samples=args.num_importance_samples) log_p_x_z = p_log_prob.apply(p_params, rng_key, x, z) elbo = log_p_x_z - log_q_z # importance sampling of approximate marginal likelihood with q(z) # as the proposal, and logsumexp in the sample dimension log_p_x = jax.nn.logsumexp(elbo, axis=0) - jnp.log(jnp.shape(elbo)[0]) # sum over the elements of the minibatch log_p_x = log_p_x.sum(0) # average elbo over number of samples elbo = elbo.mean(axis=0) # sum elbo over batch elbo = elbo.sum(axis=0) return elbo, log_p_x def evaluate( dataset: Generator[Batch, None, None], params: hk.Params, rng_seq: hk.PRNGSequence, ) -> Tuple[float, float]: total_elbo = 0.0 total_log_p_x = 0.0 dataset_size = 0 for batch in dataset: elbo, log_p_x = importance_weighted_estimate( params, next(rng_seq), batch) total_elbo += elbo total_log_p_x += log_p_x dataset_size += len(batch["image"]) return total_elbo / dataset_size, total_log_p_x / dataset_size train_ds = load_dataset(tfds.Split.TRAIN, args.batch_size, args.random_seed, repeat=True) test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed) def print_progress(step: int, examples_per_sec: float): valid_ds = load_dataset(tfds.Split.VALIDATION, args.batch_size, args.random_seed) elbo, log_p_x = evaluate(valid_ds, params, rng_seq) train_elbo = (-objective_fn(params, next(rng_seq), next(train_ds)) / args.batch_size) print(f"Step {step:<10d}\t" f"Train ELBO estimate: {train_elbo:<5.3f}\t" f"Validation ELBO estimate: {elbo:<5.3f}\t" f"Validation log p(x) estimate: {log_p_x:<5.3f}\t" f"Speed: {examples_per_sec:<5.2e} examples/s") t0 = time.time() for step in range(args.training_steps): if step % args.log_interval == 0: t1 = time.time() examples_per_sec = args.log_interval * args.batch_size / (t1 - t0) print_progress(step, examples_per_sec) t0 = t1 params, opt_state = train_step(params, next(rng_seq), opt_state, next(train_ds)) test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed) elbo, log_p_x = evaluate(test_ds, params, rng_seq) print(f"Step {step:<10d}\t" f"Test ELBO estimate: {elbo:<5.3f}\t" f"Test log p(x) estimate: {log_p_x:<5.3f}\t") print(f"Total time: {(time.time() - start_time) / 60:.3f} minutes")
def main(argv): """Trains DQN agent on Atari.""" del argv logging.info('Boostrapped DQN on Atari on %s.', jax.lib.xla_bridge.get_backend().platform) random_state = np.random.RandomState(FLAGS.seed) rng_key = jax.random.PRNGKey( random_state.randint(-sys.maxsize - 1, sys.maxsize + 1)) if FLAGS.results_csv_path: writer = parts.CsvWriter(FLAGS.results_csv_path) else: writer = parts.NullWriter() def environment_builder(): """Creates Atari environment.""" env = gym_atari.GymAtari(FLAGS.environment_name, seed=random_state.randint(1, 2**32)) return gym_atari.RandomNoopsEnvironmentWrapper( env, min_noop_steps=1, max_noop_steps=30, seed=random_state.randint(1, 2**32), ) env = environment_builder() logging.info('Environment: %s', FLAGS.environment_name) logging.info('Action spec: %s', env.action_spec()) logging.info('Observation spec: %s', env.observation_spec()) num_actions = env.action_spec().num_values network_fn = networks.bootstrapped_dqn_multi_head_network( num_actions, num_heads=FLAGS.num_heads, mask_probability=FLAGS.mask_probability) network = hk.transform(network_fn) def preprocessor_builder(): return processors.atari( additional_discount=FLAGS.additional_discount, max_abs_reward=FLAGS.max_abs_reward, resize_shape=(FLAGS.environment_height, FLAGS.environment_width), num_action_repeats=FLAGS.num_action_repeats, num_pooled_frames=2, zero_discount_on_life_loss=True, num_stacked_frames=FLAGS.num_stacked_frames, grayscaling=True, ) # Create sample network input from sample preprocessor output. sample_processed_timestep = preprocessor_builder()(env.reset()) sample_processed_timestep = typing.cast(dm_env.TimeStep, sample_processed_timestep) sample_network_input = sample_processed_timestep.observation assert sample_network_input.shape == (FLAGS.environment_height, FLAGS.environment_width, FLAGS.num_stacked_frames) exploration_epsilon_schedule = parts.LinearSchedule( begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * FLAGS.num_action_repeats), decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * FLAGS.num_iterations * FLAGS.num_train_frames), begin_value=FLAGS.exploration_epsilon_begin_value, end_value=FLAGS.exploration_epsilon_end_value) if FLAGS.compress_state: def encoder(transition): return transition._replace( s_tm1=replay_lib.compress_array(transition.s_tm1), s_t=replay_lib.compress_array(transition.s_t)) def decoder(transition): return transition._replace( s_tm1=replay_lib.uncompress_array(transition.s_tm1), s_t=replay_lib.uncompress_array(transition.s_t)) else: encoder = None decoder = None replay_structure = replay_lib.MaskedTransition( s_tm1=None, a_tm1=None, r_t=None, discount_t=None, s_t=None, mask_t=None, ) replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure, random_state, encoder, decoder) optimizer = optax.rmsprop( learning_rate=FLAGS.learning_rate, decay=0.95, eps=FLAGS.optimizer_epsilon, centered=True, ) if FLAGS.shaping_function_type == constants.NO_PENALTY: shaping_function = shaping.NoPenalty() elif FLAGS.shaping_function_type == constants.HARD_CODED_PENALTY: shaping_function = shaping.HardCodedPenalty( penalty=FLAGS.shaping_multiplicative_factor) elif FLAGS.shaping_function_type == constants.UNCERTAINTY_PENALTY: shaping_function = shaping.UncertaintyPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor) elif FLAGS.shaping_function_type == constants.POLICY_ENTROPY_PENALTY: shaping_function = shaping.PolicyEntropyPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor, num_actions=num_actions) elif FLAGS.shaping_function_type == constants.MUNCHAUSEN_PENALTY: shaping_function = shaping.MunchausenPenalty( multiplicative_factor=FLAGS.shaping_multiplicative_factor, num_actions=num_actions) train_rng_key, eval_rng_key = jax.random.split(rng_key) train_agent = agent.BootstrappedDqn( preprocessor=preprocessor_builder(), sample_network_input=sample_network_input, network=network, optimizer=optimizer, transition_accumulator=replay_lib.TransitionAccumulator(), replay=replay, shaping_function=shaping_function, mask_probability=FLAGS.mask_probability, num_heads=FLAGS.num_heads, batch_size=FLAGS.batch_size, exploration_epsilon=exploration_epsilon_schedule, min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, learn_period=FLAGS.learn_period, target_network_update_period=FLAGS.target_network_update_period, grad_error_bound=FLAGS.grad_error_bound, rng_key=train_rng_key, ) eval_agent = parts.EpsilonGreedyActor( preprocessor=preprocessor_builder(), network=network, exploration_epsilon=FLAGS.eval_exploration_epsilon, rng_key=eval_rng_key, ) # Set up checkpointing. # checkpoint = parts.NullCheckpoint() checkpoint = parts.ImplementedCheckpoint( checkpoint_path=FLAGS.checkpoint_path) if checkpoint.can_be_restored(): checkpoint.restore() iteration = checkpoint.state.iteration random_state = checkpoint.state.random_state train_agent.set_state(state=checkpoint.state.train_agent) eval_agent.set_state(state=checkpoint.state.eval_agent) writer.set_state(state=checkpoint.state.writer) else: iteration = 0 while iteration <= FLAGS.num_iterations: # New environment for each iteration to allow for determinism if preempted. env = environment_builder() logging.info('Training iteration %d.', iteration) train_seq = parts.run_loop(train_agent, env, FLAGS.max_frames_per_episode) num_train_frames = 0 if iteration == 0 else FLAGS.num_train_frames train_seq_truncated = itertools.islice(train_seq, num_train_frames) train_stats = parts.generate_statistics(train_seq_truncated) logging.info('Evaluation iteration %d.', iteration) eval_agent.network_params = train_agent.online_params eval_seq = parts.run_loop(eval_agent, env, FLAGS.max_frames_per_episode) eval_seq_truncated = itertools.islice(eval_seq, FLAGS.num_eval_frames) eval_stats = parts.generate_statistics(eval_seq_truncated) # Logging and checkpointing. human_normalized_score = atari_data.get_human_normalized_score( FLAGS.environment_name, eval_stats['episode_return']) capped_human_normalized_score = np.amin([1., human_normalized_score]) log_output = [ ('iteration', iteration, '%3d'), ('frame', iteration * FLAGS.num_train_frames, '%5d'), ('eval_episode_return', eval_stats['episode_return'], '% 2.2f'), ('train_episode_return', train_stats['episode_return'], '% 2.2f'), ('eval_num_episodes', eval_stats['num_episodes'], '%3d'), ('train_num_episodes', train_stats['num_episodes'], '%3d'), ('eval_frame_rate', eval_stats['step_rate'], '%4.0f'), ('train_frame_rate', train_stats['step_rate'], '%4.0f'), ('train_exploration_epsilon', train_agent.exploration_epsilon, '%.3f'), ('normalized_return', human_normalized_score, '%.3f'), ('capped_normalized_return', capped_human_normalized_score, '%.3f'), ('human_gap', 1. - capped_human_normalized_score, '%.3f'), ('train_loss', train_stats['train_loss'], '% 2.2f'), ('shaped_reward', train_stats['shaped_reward'], '% 2.2f'), ('penalties', train_stats['penalties'], '% 2.2f') ] log_output_str = ', '.join( ('%s: ' + f) % (n, v) for n, v, f in log_output) logging.info(log_output_str) writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) iteration += 1 # update state before checkpointing checkpoint.state.iteration = iteration checkpoint.state.train_agent = train_agent.get_state() checkpoint.state.eval_agent = eval_agent.get_state() checkpoint.state.random_state = random_state checkpoint.state.writer = writer.get_state() checkpoint.save() writer.close()
def main( debug: bool = False, eager: bool = False, logdir: str = "runs", steps_per_epoch: int = 200, epochs: int = 100, batch_size: int = 64, ): if debug: import debugpy print("Waiting for debugger...") debugpy.listen(5678) debugpy.wait_for_client() current_time = datetime.now().strftime("%b%d_%H-%M-%S") logdir = os.path.join(logdir, current_time) dataset = load_dataset("mnist") dataset.set_format("np") X_train = np.stack(dataset["train"]["image"]) X_test = np.stack(dataset["test"]["image"]) print("X_train:", X_train.shape, X_train.dtype) print("X_test:", X_test.shape, X_test.dtype) model = eg.Model( module=MLP(n1=256, n2=64), loss=MeanSquaredError(), optimizer=optax.rmsprop(0.001), eager=eager, ) model.summary(X_train[:64]) # Notice we are not passing `y` history = model.fit( inputs=X_train, epochs=epochs, steps_per_epoch=steps_per_epoch, batch_size=batch_size, validation_data=(X_test, ), shuffle=True, callbacks=[eg.callbacks.TensorBoard(logdir=logdir, update_freq=300)], ) eg.utils.plot_history(history) # get random samples idxs = np.random.randint(0, 10000, size=(5, )) x_sample = X_test[idxs] # get predictions y_pred = model.predict(x=x_sample) # plot and save results with SummaryWriter(os.path.join(logdir, "val")) as tbwriter: figure = plt.figure(figsize=(12, 12)) for i in range(5): plt.subplot(2, 5, i + 1) plt.imshow(x_sample[i], cmap="gray") plt.subplot(2, 5, 5 + i + 1) plt.imshow(y_pred[i], cmap="gray") plt.show()