def _make_representation_models(env: gym.Env) -> nn.ModuleDict: config = get_config() hs_features_dim: int = config.hs_features_dim normalize_hs_features: bool = config.normalize_hs_features # agent state_model = EmbeddingRepresentation(env.state_space.n, 64) action_model = EmbeddingRepresentation(env.action_space.n, 64) observation_model = EmbeddingRepresentation( env.observation_space.n, 64, padding_idx=-1 ) history_model = GRUHistoryRepresentation( action_model, observation_model, hidden_size=128, ) # resize history and state models if hs_features_dim: history_model = ResizeRepresentation(history_model, hs_features_dim) state_model = ResizeRepresentation(state_model, hs_features_dim) # normalize history and state models if normalize_hs_features: history_model = NormalizationRepresentation(history_model) state_model = NormalizationRepresentation(state_model) return nn.ModuleDict( { 'state_model': state_model, 'action_model': action_model, 'observation_model': observation_model, 'history_model': history_model, } )
def setup() -> RunState: config = get_config() env = make_env( config.env, max_episode_timesteps=config.max_episode_timesteps, ) algo = make_a2c_algorithm( config.algo, env, truncated_histories=config.truncated_histories, truncated_histories_n=config.truncated_histories_n, ) optimizer_actor = torch.optim.Adam( algo.models.parameters(), lr=config.optim_lr_actor, eps=config.optim_eps_actor, ) optimizer_critic = torch.optim.Adam( algo.models.parameters(), lr=config.optim_lr_critic, eps=config.optim_eps_critic, ) wandb_logger = WandbLogger() xstats = XStats() timer = Timer() running_averages = { 'avg_target_returns': InfiniteRunningAverage(), 'avg_behavior_returns': InfiniteRunningAverage(), 'avg100_behavior_returns': WindowRunningAverage(100), } wandb_log_period = config.max_simulation_timesteps // config.num_wandb_logs dispensers = { 'target_update_dispenser': Dispenser(config.target_update_period), 'wandb_log_dispenser': Dispenser(wandb_log_period), } return RunState( env, algo, optimizer_actor, optimizer_critic, wandb_logger, xstats, timer, running_averages, dispensers, )
def _make_representation_models(env: gym.Env) -> nn.ModuleDict: config = get_config() state_model = GV_Representation( env.state_space, [f'agent-grid-{config.gv_state_grid_model_type}', 'agent', 'item'], embedding_size=1, layers=[512] * config.gv_state_representation_layers, ) action_model = EmbeddingRepresentation(env.action_space.n, 1) observation_model = GV_Representation( env.observation_space, [f'grid-{config.gv_state_grid_model_type}', 'item'], embedding_size=8, layers=[512] * config.gv_observation_representation_layers, ) history_model = GRUHistoryRepresentation( action_model, observation_model, hidden_size=64, ) # resize history and state models hs_features_dim: int = config.hs_features_dim if hs_features_dim: history_model = ResizeRepresentation(history_model, hs_features_dim) state_model = ResizeRepresentation(state_model, hs_features_dim) # normalize history and state models if config.normalize_hs_features: history_model = NormalizationRepresentation(history_model) state_model = NormalizationRepresentation(state_model) return nn.ModuleDict({ 'state_model': state_model, 'action_model': action_model, 'observation_model': observation_model, 'history_model': history_model, })
def main(): args = parse_args() wandb_kwargs = { 'project': args.wandb_project, 'entity': args.wandb_entity, 'group': args.wandb_group, 'tags': args.wandb_tags, 'mode': args.wandb_mode, 'config': args, } try: checkpoint = load_data(args.checkpoint) except (TypeError, FileNotFoundError): checkpoint = None else: wandb_kwargs.update({ 'resume': 'must', 'id': checkpoint['metadata']['wandb_id'], }) with wandb.init(**wandb_kwargs): config = get_config() config._update(dict(wandb.config)) logger.info('setup of runstate...') runstate = setup() logger.info('setup DONE') if checkpoint is not None: if checkpoint['metadata']['config'] != config._as_dict(): raise RuntimeError( 'checkpoint config inconsistent with program config') logger.debug('updating runstate from checkpoint') runstate.load_state_dict(checkpoint['data']) logger.info('run...') done = run(runstate) logger.info('run DONE') wandb_run_id = wandb.run.id if config.checkpoint is not None: if not done: logger.info('checkpointing...') checkpoint = { 'metadata': { 'config': config._as_dict(), 'wandb_id': wandb_run_id, }, 'data': runstate.state_dict(), } save_data(config.checkpoint, checkpoint) logger.info('checkpointing DONE') else: try: os.remove(config.checkpoint) except FileNotFoundError: pass
def run(runstate: RunState) -> bool: config = get_config() logger.info('run %s %s', config.env_label, config.algo_label) ( env, algo, optimizer_actor, optimizer_critic, wandb_logger, xstats, timer, running_averages, dispensers, ) = runstate avg_target_returns = running_averages['avg_target_returns'] avg_behavior_returns = running_averages['avg_behavior_returns'] avg100_behavior_returns = running_averages['avg100_behavior_returns'] target_update_dispenser = dispensers['target_update_dispenser'] wandb_log_dispenser = dispensers['wandb_log_dispenser'] device = get_device(config.device) algo.to(device) # reproducibility if config.seed is not None: random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) reset_gv_rng(config.seed) env.seed(config.seed) env.state_space.seed(config.seed) env.action_space.seed(config.seed) env.observation_space.seed(config.seed) if config.deterministic: torch.use_deterministic_algorithms(True) # initialize return type q_estimator = q_estimator_factory( config.q_estimator, n=config.q_estimator_n, lambda_=config.q_estimator_lambda, ) behavior_policy = algo.behavior_policy() evaluation_policy = algo.evaluation_policy() evaluation_policy.epsilon = 0.1 negentropy_schedule = make_schedule( config.negentropy_schedule, value_from=config.negentropy_value_from, value_to=config.negentropy_value_to, nsteps=config.negentropy_nsteps, halflife=config.negentropy_halflife, ) weight_negentropy = negentropy_schedule(xstats.simulation_timesteps) # setup interrupt flag via signal interrupt = False def set_interrupt_flag(): nonlocal interrupt interrupt = True logger.debug('signal received, setting interrupt=True') signal.signal(signal.SIGUSR1, lambda signal, frame: set_interrupt_flag()) # main learning loop wandb.watch(algo.models) while xstats.simulation_timesteps < config.max_simulation_timesteps: if interrupt: break # evaluate policy algo.models.eval() if config.evaluation and xstats.epoch % config.evaluation_period == 0: if config.render: sample_episodes( env, evaluation_policy, num_episodes=1, render=True, ) episodes = sample_episodes( env, evaluation_policy, num_episodes=config.evaluation_num_episodes, ) mean_length = sum(map(len, episodes)) / len(episodes) returns = evaluate_returns(episodes, discount=config.evaluation_discount) avg_target_returns.extend(returns.tolist()) logger.info(f'EVALUATE epoch {xstats.epoch}' f' simulation_timestep {xstats.simulation_timesteps}' f' return {returns.mean():.3f}') wandb_logger.log({ **xstats.asdict(), 'hours': timer.hours, 'diagnostics/target_mean_episode_length': mean_length, 'performance/target_mean_return': returns.mean(), 'performance/avg_target_mean_return': avg_target_returns.value(), }) episodes = sample_episodes( env, behavior_policy, num_episodes=config.simulation_num_episodes, ) mean_length = sum(map(len, episodes)) / len(episodes) returns = evaluate_returns(episodes, discount=config.evaluation_discount) avg_behavior_returns.extend(returns.tolist()) avg100_behavior_returns.extend(returns.tolist()) wandb_log = wandb_log_dispenser.dispense(xstats.simulation_timesteps) if wandb_log: logger.info( 'behavior log - simulation_step %d return %.3f', xstats.simulation_timesteps, returns.mean(), ) wandb_logger.log({ **xstats.asdict(), 'hours': timer.hours, 'diagnostics/behavior_mean_episode_length': mean_length, 'performance/behavior_mean_return': returns.mean(), 'performance/avg_behavior_mean_return': avg_behavior_returns.value(), 'performance/avg100_behavior_mean_return': avg100_behavior_returns.value(), }) # storing torch data directly episodes = [episode.torch().to(device) for episode in episodes] xstats.simulation_episodes += len(episodes) xstats.simulation_timesteps += sum( len(episode) for episode in episodes) weight_negentropy = negentropy_schedule(xstats.simulation_timesteps) # target model update if target_update_dispenser.dispense(xstats.simulation_timesteps): # Update the target network algo.target_models.load_state_dict(algo.models.state_dict()) algo.models.train() # critic optimizer_critic.zero_grad() losses = [ algo.critic_loss( episode, discount=config.training_discount, q_estimator=q_estimator, ) for episode in episodes ] critic_loss = average(losses) critic_loss.backward() critic_gradient_norm = nn.utils.clip_grad_norm_( algo.models.parameters(), max_norm=config.optim_max_norm) optimizer_critic.step() # actor optimizer_actor.zero_grad() losses = [ algo.actor_losses( episode, discount=config.training_discount, q_estimator=q_estimator, ) for episode in episodes ] actor_losses, negentropy_losses = zip(*losses) actor_loss = average(actor_losses) negentropy_loss = average(negentropy_losses) loss = actor_loss + weight_negentropy * negentropy_loss loss.backward() actor_gradient_norm = nn.utils.clip_grad_norm_( algo.models.parameters(), max_norm=config.optim_max_norm) optimizer_actor.step() if wandb_log: logger.info( 'training log - simulation_step %d losses %.3f %.3f %.3f', xstats.simulation_timesteps, actor_loss, critic_loss, negentropy_loss, ) wandb_logger.log({ **xstats.asdict(), 'hours': timer.hours, 'training/losses/actor': actor_loss, 'training/losses/critic': critic_loss, 'training/losses/negentropy': negentropy_loss, 'training/weights/negentropy': weight_negentropy, 'training/gradient_norms/actor': actor_gradient_norm, 'training/gradient_norms/critic': critic_gradient_norm, }) if config.save_modelseq and config.modelseq_filename is not None: data = { 'metadata': { 'config': config._as_dict() }, 'data': { 'timestep': xstats.simulation_timesteps, 'model.state_dict': algo.models.state_dict(), }, } filename = config.modelseq_filename.format( xstats.simulation_timesteps) save_data(filename, data) xstats.epoch += 1 xstats.optimizer_steps += 1 xstats.training_episodes += len(episodes) xstats.training_timesteps += sum(len(episode) for episode in episodes) done = not interrupt if done and config.save_model and config.model_filename is not None: data = { 'metadata': { 'config': config._as_dict() }, 'data': { 'models.state_dict': algo.models.state_dict() }, } save_data(config.model_filename, data) return done
def run(runstate: RunState) -> bool: config = get_config() logger.info('run %s %s', config.env_label, config.algo_label) ( env, algo, optimizer, wandb_logger, xstats, timer, running_averages, dispensers, ) = runstate avg_target_returns = running_averages['avg_target_returns'] avg_behavior_returns = running_averages['avg_behavior_returns'] avg100_behavior_returns = running_averages['avg100_behavior_returns'] target_update_dispenser = dispensers['target_update_dispenser'] wandb_log_dispenser = dispensers['wandb_log_dispenser'] device = get_device(config.device) algo.to(device) # reproducibility if config.seed is not None: random.seed(config.seed) np.random.seed(config.seed) torch.manual_seed(config.seed) reset_gv_rng(config.seed) env.seed(config.seed) env.state_space.seed(config.seed) env.action_space.seed(config.seed) env.observation_space.seed(config.seed) if config.deterministic: torch.use_deterministic_algorithms(True) epsilon_schedule = make_schedule( config.epsilon_schedule, value_from=config.epsilon_value_from, value_to=config.epsilon_value_to, nsteps=config.epsilon_nsteps, ) behavior_policy = algo.behavior_policy(env.action_space) target_policy = algo.target_policy() logger.info( f'setting prepopulating policy:' f' {config.episode_buffer_prepopulate_policy}' ) prepopulate_policy: Policy if config.episode_buffer_prepopulate_policy == 'random': prepopulate_policy = RandomPolicy(env.action_space) elif config.episode_buffer_prepopulate_policy == 'behavior': prepopulate_policy = behavior_policy elif config.episode_buffer_prepopulate_policy == 'target': prepopulate_policy = target_policy else: assert False if xstats.simulation_timesteps == 0: prepopulate_timesteps = config.episode_buffer_prepopulate_timesteps else: prepopulate_policy.epsilon = epsilon_schedule( xstats.simulation_timesteps - config.episode_buffer_prepopulate_timesteps ) prepopulate_timesteps = xstats.simulation_timesteps # instantiate and prepopulate buffer logger.info( f'prepopulating episode buffer' f' ({prepopulate_timesteps:_} timesteps)...' ) episode_buffer = EpisodeBuffer(config.episode_buffer_max_timesteps) while episode_buffer.num_interactions() < prepopulate_timesteps: (episode,) = sample_episodes(env, prepopulate_policy, num_episodes=1) episode_buffer.append_episode(episode.torch()) logger.info('prepopulating DONE') if xstats.simulation_timesteps == 0: xstats.simulation_episodes = episode_buffer.num_episodes() xstats.simulation_timesteps = episode_buffer.num_interactions() # setup interrupt flag via signal interrupt = False def set_interrupt_flag(): nonlocal interrupt logger.debug('signal received, setting interrupt=True') interrupt = True signal.signal(signal.SIGUSR1, lambda signal, frame: set_interrupt_flag()) # main learning loop wandb.watch(algo.models) while xstats.simulation_timesteps < config.max_simulation_timesteps: if interrupt: break algo.models.eval() # evaluate target policy if config.evaluation and xstats.epoch % config.evaluation_period == 0: if config.render: sample_episodes( env, target_policy, num_episodes=1, render=True, ) episodes = sample_episodes( env, target_policy, num_episodes=config.evaluation_num_episodes, ) mean_length = sum(map(len, episodes)) / len(episodes) returns = evaluate_returns( episodes, discount=config.evaluation_discount ) avg_target_returns.extend(returns.tolist()) logger.info( 'EVALUATE epoch %d simulation_step %d return %.3f', xstats.epoch, xstats.simulation_timesteps, returns.mean(), ) wandb_logger.log( { **xstats.asdict(), 'hours': timer.hours, 'diagnostics/target_mean_episode_length': mean_length, 'performance/target_mean_return': returns.mean(), 'performance/avg_target_mean_return': avg_target_returns.value(), } ) # populate episode buffer behavior_policy.epsilon = epsilon_schedule( xstats.simulation_timesteps - config.episode_buffer_prepopulate_timesteps ) episodes = sample_episodes( env, behavior_policy, num_episodes=config.simulation_num_episodes, ) mean_length = sum(map(len, episodes)) / len(episodes) returns = evaluate_returns( episodes, discount=config.evaluation_discount ) avg_behavior_returns.extend(returns.tolist()) avg100_behavior_returns.extend(returns.tolist()) wandb_log = wandb_log_dispenser.dispense(xstats.simulation_timesteps) if wandb_log: logger.info( 'behavior log - simulation_step %d return %.3f', xstats.simulation_timesteps, returns.mean(), ) wandb_logger.log( { **xstats.asdict(), 'hours': timer.hours, 'diagnostics/epsilon': behavior_policy.epsilon, 'diagnostics/behavior_mean_episode_length': mean_length, 'performance/behavior_mean_return': returns.mean(), 'performance/avg_behavior_mean_return': avg_behavior_returns.value(), 'performance/avg100_behavior_mean_return': avg100_behavior_returns.value(), } ) # storing torch data directly episodes = [episode.torch().to(device) for episode in episodes] episode_buffer.append_episodes(episodes) xstats.simulation_episodes += len(episodes) xstats.simulation_timesteps += sum(len(episode) for episode in episodes) # target model update if target_update_dispenser.dispense(xstats.simulation_timesteps): # Update the target network algo.target_models.load_state_dict(algo.models.state_dict()) # train based on episode buffer algo.models.train() while ( xstats.training_timesteps < ( xstats.simulation_timesteps - config.episode_buffer_prepopulate_timesteps ) * config.training_timesteps_per_simulation_timestep ): optimizer.zero_grad() if algo.episodic_training: episodes = episode_buffer.sample_episodes( num_samples=config.training_num_episodes, replacement=True, ) episodes = [episode.to(device) for episode in episodes] loss = algo.episodic_loss( episodes, discount=config.training_discount ) else: batch = episode_buffer.sample_batch( batch_size=config.training_batch_size ) batch = batch.to(device) loss = algo.batched_loss( batch, discount=config.training_discount ) loss.backward() gradient_norm = nn.utils.clip_grad_norm_( algo.models.parameters(), max_norm=config.optim_max_norm ) optimizer.step() if wandb_log: logger.debug( 'training log - simulation_step %d loss %.3f', xstats.simulation_timesteps, loss, ) wandb_logger.log( { **xstats.asdict(), 'hours': timer.hours, 'training/loss': loss, 'training/gradient_norm': gradient_norm, } ) if config.save_modelseq and config.modelseq_filename is not None: data = { 'metadata': {'config': config._as_dict()}, 'data': { 'timestep': xstats.simulation_timesteps, 'model.state_dict': algo.models.state_dict(), }, } filename = config.modelseq_filename.format( xstats.simulation_timesteps ) save_data(filename, data) xstats.optimizer_steps += 1 if algo.episodic_training: xstats.training_episodes += len(episodes) xstats.training_timesteps += sum( len(episode) for episode in episodes ) else: xstats.training_timesteps += len(batch) xstats.epoch += 1 done = not interrupt if done and config.save_model and config.model_filename is not None: data = { 'metadata': {'config': config._as_dict()}, 'data': {'models.state_dict': algo.models.state_dict()}, } save_data(config.model_filename, data) return done
def execute_before_tests(): config = get_config() config._update({ 'hs_features_dim': 0, 'normalize_hs_features': False, })