Example #1
0
def _make_representation_models(env: gym.Env) -> nn.ModuleDict:
    config = get_config()
    hs_features_dim: int = config.hs_features_dim
    normalize_hs_features: bool = config.normalize_hs_features

    # agent
    state_model = EmbeddingRepresentation(env.state_space.n, 64)
    action_model = EmbeddingRepresentation(env.action_space.n, 64)
    observation_model = EmbeddingRepresentation(
        env.observation_space.n, 64, padding_idx=-1
    )
    history_model = GRUHistoryRepresentation(
        action_model,
        observation_model,
        hidden_size=128,
    )

    # resize history and state models
    if hs_features_dim:
        history_model = ResizeRepresentation(history_model, hs_features_dim)
        state_model = ResizeRepresentation(state_model, hs_features_dim)

    # normalize history and state models
    if normalize_hs_features:
        history_model = NormalizationRepresentation(history_model)
        state_model = NormalizationRepresentation(state_model)

    return nn.ModuleDict(
        {
            'state_model': state_model,
            'action_model': action_model,
            'observation_model': observation_model,
            'history_model': history_model,
        }
    )
Example #2
0
def setup() -> RunState:
    config = get_config()

    env = make_env(
        config.env,
        max_episode_timesteps=config.max_episode_timesteps,
    )

    algo = make_a2c_algorithm(
        config.algo,
        env,
        truncated_histories=config.truncated_histories,
        truncated_histories_n=config.truncated_histories_n,
    )

    optimizer_actor = torch.optim.Adam(
        algo.models.parameters(),
        lr=config.optim_lr_actor,
        eps=config.optim_eps_actor,
    )
    optimizer_critic = torch.optim.Adam(
        algo.models.parameters(),
        lr=config.optim_lr_critic,
        eps=config.optim_eps_critic,
    )

    wandb_logger = WandbLogger()

    xstats = XStats()
    timer = Timer()

    running_averages = {
        'avg_target_returns': InfiniteRunningAverage(),
        'avg_behavior_returns': InfiniteRunningAverage(),
        'avg100_behavior_returns': WindowRunningAverage(100),
    }

    wandb_log_period = config.max_simulation_timesteps // config.num_wandb_logs
    dispensers = {
        'target_update_dispenser': Dispenser(config.target_update_period),
        'wandb_log_dispenser': Dispenser(wandb_log_period),
    }

    return RunState(
        env,
        algo,
        optimizer_actor,
        optimizer_critic,
        wandb_logger,
        xstats,
        timer,
        running_averages,
        dispensers,
    )
Example #3
0
def _make_representation_models(env: gym.Env) -> nn.ModuleDict:
    config = get_config()

    state_model = GV_Representation(
        env.state_space,
        [f'agent-grid-{config.gv_state_grid_model_type}', 'agent', 'item'],
        embedding_size=1,
        layers=[512] * config.gv_state_representation_layers,
    )
    action_model = EmbeddingRepresentation(env.action_space.n, 1)
    observation_model = GV_Representation(
        env.observation_space,
        [f'grid-{config.gv_state_grid_model_type}', 'item'],
        embedding_size=8,
        layers=[512] * config.gv_observation_representation_layers,
    )
    history_model = GRUHistoryRepresentation(
        action_model,
        observation_model,
        hidden_size=64,
    )

    # resize history and state models
    hs_features_dim: int = config.hs_features_dim
    if hs_features_dim:
        history_model = ResizeRepresentation(history_model, hs_features_dim)
        state_model = ResizeRepresentation(state_model, hs_features_dim)

    # normalize history and state models
    if config.normalize_hs_features:
        history_model = NormalizationRepresentation(history_model)
        state_model = NormalizationRepresentation(state_model)

    return nn.ModuleDict({
        'state_model': state_model,
        'action_model': action_model,
        'observation_model': observation_model,
        'history_model': history_model,
    })
Example #4
0
def main():
    args = parse_args()
    wandb_kwargs = {
        'project': args.wandb_project,
        'entity': args.wandb_entity,
        'group': args.wandb_group,
        'tags': args.wandb_tags,
        'mode': args.wandb_mode,
        'config': args,
    }

    try:
        checkpoint = load_data(args.checkpoint)
    except (TypeError, FileNotFoundError):
        checkpoint = None
    else:
        wandb_kwargs.update({
            'resume': 'must',
            'id': checkpoint['metadata']['wandb_id'],
        })

    with wandb.init(**wandb_kwargs):
        config = get_config()
        config._update(dict(wandb.config))

        logger.info('setup of runstate...')
        runstate = setup()
        logger.info('setup DONE')

        if checkpoint is not None:
            if checkpoint['metadata']['config'] != config._as_dict():
                raise RuntimeError(
                    'checkpoint config inconsistent with program config')

            logger.debug('updating runstate from checkpoint')
            runstate.load_state_dict(checkpoint['data'])

        logger.info('run...')
        done = run(runstate)
        logger.info('run DONE')

        wandb_run_id = wandb.run.id

    if config.checkpoint is not None:
        if not done:
            logger.info('checkpointing...')
            checkpoint = {
                'metadata': {
                    'config': config._as_dict(),
                    'wandb_id': wandb_run_id,
                },
                'data': runstate.state_dict(),
            }
            save_data(config.checkpoint, checkpoint)
            logger.info('checkpointing DONE')

        else:
            try:
                os.remove(config.checkpoint)
            except FileNotFoundError:
                pass
Example #5
0
def run(runstate: RunState) -> bool:
    config = get_config()
    logger.info('run %s %s', config.env_label, config.algo_label)

    (
        env,
        algo,
        optimizer_actor,
        optimizer_critic,
        wandb_logger,
        xstats,
        timer,
        running_averages,
        dispensers,
    ) = runstate

    avg_target_returns = running_averages['avg_target_returns']
    avg_behavior_returns = running_averages['avg_behavior_returns']
    avg100_behavior_returns = running_averages['avg100_behavior_returns']
    target_update_dispenser = dispensers['target_update_dispenser']
    wandb_log_dispenser = dispensers['wandb_log_dispenser']

    device = get_device(config.device)
    algo.to(device)

    # reproducibility
    if config.seed is not None:
        random.seed(config.seed)
        np.random.seed(config.seed)
        torch.manual_seed(config.seed)
        reset_gv_rng(config.seed)
        env.seed(config.seed)
        env.state_space.seed(config.seed)
        env.action_space.seed(config.seed)
        env.observation_space.seed(config.seed)

    if config.deterministic:
        torch.use_deterministic_algorithms(True)

    # initialize return type
    q_estimator = q_estimator_factory(
        config.q_estimator,
        n=config.q_estimator_n,
        lambda_=config.q_estimator_lambda,
    )

    behavior_policy = algo.behavior_policy()
    evaluation_policy = algo.evaluation_policy()
    evaluation_policy.epsilon = 0.1

    negentropy_schedule = make_schedule(
        config.negentropy_schedule,
        value_from=config.negentropy_value_from,
        value_to=config.negentropy_value_to,
        nsteps=config.negentropy_nsteps,
        halflife=config.negentropy_halflife,
    )
    weight_negentropy = negentropy_schedule(xstats.simulation_timesteps)

    # setup interrupt flag via signal
    interrupt = False

    def set_interrupt_flag():
        nonlocal interrupt
        interrupt = True
        logger.debug('signal received, setting interrupt=True')

    signal.signal(signal.SIGUSR1, lambda signal, frame: set_interrupt_flag())

    # main learning loop
    wandb.watch(algo.models)
    while xstats.simulation_timesteps < config.max_simulation_timesteps:
        if interrupt:
            break

        # evaluate policy
        algo.models.eval()

        if config.evaluation and xstats.epoch % config.evaluation_period == 0:
            if config.render:
                sample_episodes(
                    env,
                    evaluation_policy,
                    num_episodes=1,
                    render=True,
                )

            episodes = sample_episodes(
                env,
                evaluation_policy,
                num_episodes=config.evaluation_num_episodes,
            )
            mean_length = sum(map(len, episodes)) / len(episodes)
            returns = evaluate_returns(episodes,
                                       discount=config.evaluation_discount)
            avg_target_returns.extend(returns.tolist())
            logger.info(f'EVALUATE epoch {xstats.epoch}'
                        f' simulation_timestep {xstats.simulation_timesteps}'
                        f' return {returns.mean():.3f}')
            wandb_logger.log({
                **xstats.asdict(),
                'hours':
                timer.hours,
                'diagnostics/target_mean_episode_length':
                mean_length,
                'performance/target_mean_return':
                returns.mean(),
                'performance/avg_target_mean_return':
                avg_target_returns.value(),
            })

        episodes = sample_episodes(
            env,
            behavior_policy,
            num_episodes=config.simulation_num_episodes,
        )

        mean_length = sum(map(len, episodes)) / len(episodes)
        returns = evaluate_returns(episodes,
                                   discount=config.evaluation_discount)
        avg_behavior_returns.extend(returns.tolist())
        avg100_behavior_returns.extend(returns.tolist())

        wandb_log = wandb_log_dispenser.dispense(xstats.simulation_timesteps)

        if wandb_log:
            logger.info(
                'behavior log - simulation_step %d return %.3f',
                xstats.simulation_timesteps,
                returns.mean(),
            )
            wandb_logger.log({
                **xstats.asdict(),
                'hours':
                timer.hours,
                'diagnostics/behavior_mean_episode_length':
                mean_length,
                'performance/behavior_mean_return':
                returns.mean(),
                'performance/avg_behavior_mean_return':
                avg_behavior_returns.value(),
                'performance/avg100_behavior_mean_return':
                avg100_behavior_returns.value(),
            })

        # storing torch data directly
        episodes = [episode.torch().to(device) for episode in episodes]
        xstats.simulation_episodes += len(episodes)
        xstats.simulation_timesteps += sum(
            len(episode) for episode in episodes)
        weight_negentropy = negentropy_schedule(xstats.simulation_timesteps)

        # target model update
        if target_update_dispenser.dispense(xstats.simulation_timesteps):
            # Update the target network
            algo.target_models.load_state_dict(algo.models.state_dict())

        algo.models.train()

        # critic
        optimizer_critic.zero_grad()
        losses = [
            algo.critic_loss(
                episode,
                discount=config.training_discount,
                q_estimator=q_estimator,
            ) for episode in episodes
        ]
        critic_loss = average(losses)
        critic_loss.backward()
        critic_gradient_norm = nn.utils.clip_grad_norm_(
            algo.models.parameters(), max_norm=config.optim_max_norm)
        optimizer_critic.step()

        # actor
        optimizer_actor.zero_grad()
        losses = [
            algo.actor_losses(
                episode,
                discount=config.training_discount,
                q_estimator=q_estimator,
            ) for episode in episodes
        ]

        actor_losses, negentropy_losses = zip(*losses)
        actor_loss = average(actor_losses)
        negentropy_loss = average(negentropy_losses)

        loss = actor_loss + weight_negentropy * negentropy_loss
        loss.backward()
        actor_gradient_norm = nn.utils.clip_grad_norm_(
            algo.models.parameters(), max_norm=config.optim_max_norm)
        optimizer_actor.step()

        if wandb_log:
            logger.info(
                'training log - simulation_step %d losses %.3f %.3f %.3f',
                xstats.simulation_timesteps,
                actor_loss,
                critic_loss,
                negentropy_loss,
            )
            wandb_logger.log({
                **xstats.asdict(),
                'hours':
                timer.hours,
                'training/losses/actor':
                actor_loss,
                'training/losses/critic':
                critic_loss,
                'training/losses/negentropy':
                negentropy_loss,
                'training/weights/negentropy':
                weight_negentropy,
                'training/gradient_norms/actor':
                actor_gradient_norm,
                'training/gradient_norms/critic':
                critic_gradient_norm,
            })

            if config.save_modelseq and config.modelseq_filename is not None:
                data = {
                    'metadata': {
                        'config': config._as_dict()
                    },
                    'data': {
                        'timestep': xstats.simulation_timesteps,
                        'model.state_dict': algo.models.state_dict(),
                    },
                }
                filename = config.modelseq_filename.format(
                    xstats.simulation_timesteps)
                save_data(filename, data)

        xstats.epoch += 1
        xstats.optimizer_steps += 1
        xstats.training_episodes += len(episodes)
        xstats.training_timesteps += sum(len(episode) for episode in episodes)

    done = not interrupt

    if done and config.save_model and config.model_filename is not None:
        data = {
            'metadata': {
                'config': config._as_dict()
            },
            'data': {
                'models.state_dict': algo.models.state_dict()
            },
        }
        save_data(config.model_filename, data)

    return done
Example #6
0
def run(runstate: RunState) -> bool:
    config = get_config()
    logger.info('run %s %s', config.env_label, config.algo_label)

    (
        env,
        algo,
        optimizer,
        wandb_logger,
        xstats,
        timer,
        running_averages,
        dispensers,
    ) = runstate

    avg_target_returns = running_averages['avg_target_returns']
    avg_behavior_returns = running_averages['avg_behavior_returns']
    avg100_behavior_returns = running_averages['avg100_behavior_returns']
    target_update_dispenser = dispensers['target_update_dispenser']
    wandb_log_dispenser = dispensers['wandb_log_dispenser']

    device = get_device(config.device)
    algo.to(device)

    # reproducibility
    if config.seed is not None:
        random.seed(config.seed)
        np.random.seed(config.seed)
        torch.manual_seed(config.seed)
        reset_gv_rng(config.seed)
        env.seed(config.seed)
        env.state_space.seed(config.seed)
        env.action_space.seed(config.seed)
        env.observation_space.seed(config.seed)

    if config.deterministic:
        torch.use_deterministic_algorithms(True)

    epsilon_schedule = make_schedule(
        config.epsilon_schedule,
        value_from=config.epsilon_value_from,
        value_to=config.epsilon_value_to,
        nsteps=config.epsilon_nsteps,
    )

    behavior_policy = algo.behavior_policy(env.action_space)
    target_policy = algo.target_policy()

    logger.info(
        f'setting prepopulating policy:'
        f' {config.episode_buffer_prepopulate_policy}'
    )
    prepopulate_policy: Policy
    if config.episode_buffer_prepopulate_policy == 'random':
        prepopulate_policy = RandomPolicy(env.action_space)
    elif config.episode_buffer_prepopulate_policy == 'behavior':
        prepopulate_policy = behavior_policy
    elif config.episode_buffer_prepopulate_policy == 'target':
        prepopulate_policy = target_policy
    else:
        assert False

    if xstats.simulation_timesteps == 0:
        prepopulate_timesteps = config.episode_buffer_prepopulate_timesteps
    else:
        prepopulate_policy.epsilon = epsilon_schedule(
            xstats.simulation_timesteps
            - config.episode_buffer_prepopulate_timesteps
        )
        prepopulate_timesteps = xstats.simulation_timesteps

    # instantiate and prepopulate buffer
    logger.info(
        f'prepopulating episode buffer'
        f' ({prepopulate_timesteps:_} timesteps)...'
    )
    episode_buffer = EpisodeBuffer(config.episode_buffer_max_timesteps)
    while episode_buffer.num_interactions() < prepopulate_timesteps:
        (episode,) = sample_episodes(env, prepopulate_policy, num_episodes=1)
        episode_buffer.append_episode(episode.torch())
    logger.info('prepopulating DONE')

    if xstats.simulation_timesteps == 0:
        xstats.simulation_episodes = episode_buffer.num_episodes()
        xstats.simulation_timesteps = episode_buffer.num_interactions()

    # setup interrupt flag via signal
    interrupt = False

    def set_interrupt_flag():
        nonlocal interrupt
        logger.debug('signal received, setting interrupt=True')
        interrupt = True

    signal.signal(signal.SIGUSR1, lambda signal, frame: set_interrupt_flag())

    # main learning loop
    wandb.watch(algo.models)
    while xstats.simulation_timesteps < config.max_simulation_timesteps:
        if interrupt:
            break

        algo.models.eval()

        # evaluate target policy
        if config.evaluation and xstats.epoch % config.evaluation_period == 0:
            if config.render:
                sample_episodes(
                    env,
                    target_policy,
                    num_episodes=1,
                    render=True,
                )

            episodes = sample_episodes(
                env,
                target_policy,
                num_episodes=config.evaluation_num_episodes,
            )
            mean_length = sum(map(len, episodes)) / len(episodes)
            returns = evaluate_returns(
                episodes, discount=config.evaluation_discount
            )
            avg_target_returns.extend(returns.tolist())
            logger.info(
                'EVALUATE epoch %d simulation_step %d return %.3f',
                xstats.epoch,
                xstats.simulation_timesteps,
                returns.mean(),
            )
            wandb_logger.log(
                {
                    **xstats.asdict(),
                    'hours': timer.hours,
                    'diagnostics/target_mean_episode_length': mean_length,
                    'performance/target_mean_return': returns.mean(),
                    'performance/avg_target_mean_return': avg_target_returns.value(),
                }
            )

        # populate episode buffer
        behavior_policy.epsilon = epsilon_schedule(
            xstats.simulation_timesteps
            - config.episode_buffer_prepopulate_timesteps
        )
        episodes = sample_episodes(
            env,
            behavior_policy,
            num_episodes=config.simulation_num_episodes,
        )

        mean_length = sum(map(len, episodes)) / len(episodes)
        returns = evaluate_returns(
            episodes, discount=config.evaluation_discount
        )
        avg_behavior_returns.extend(returns.tolist())
        avg100_behavior_returns.extend(returns.tolist())

        wandb_log = wandb_log_dispenser.dispense(xstats.simulation_timesteps)

        if wandb_log:
            logger.info(
                'behavior log - simulation_step %d return %.3f',
                xstats.simulation_timesteps,
                returns.mean(),
            )
            wandb_logger.log(
                {
                    **xstats.asdict(),
                    'hours': timer.hours,
                    'diagnostics/epsilon': behavior_policy.epsilon,
                    'diagnostics/behavior_mean_episode_length': mean_length,
                    'performance/behavior_mean_return': returns.mean(),
                    'performance/avg_behavior_mean_return': avg_behavior_returns.value(),
                    'performance/avg100_behavior_mean_return': avg100_behavior_returns.value(),
                }
            )

        # storing torch data directly
        episodes = [episode.torch().to(device) for episode in episodes]
        episode_buffer.append_episodes(episodes)
        xstats.simulation_episodes += len(episodes)
        xstats.simulation_timesteps += sum(len(episode) for episode in episodes)

        # target model update
        if target_update_dispenser.dispense(xstats.simulation_timesteps):
            # Update the target network
            algo.target_models.load_state_dict(algo.models.state_dict())

        # train based on episode buffer
        algo.models.train()
        while (
            xstats.training_timesteps
            < (
                xstats.simulation_timesteps
                - config.episode_buffer_prepopulate_timesteps
            )
            * config.training_timesteps_per_simulation_timestep
        ):
            optimizer.zero_grad()

            if algo.episodic_training:
                episodes = episode_buffer.sample_episodes(
                    num_samples=config.training_num_episodes,
                    replacement=True,
                )
                episodes = [episode.to(device) for episode in episodes]
                loss = algo.episodic_loss(
                    episodes, discount=config.training_discount
                )

            else:
                batch = episode_buffer.sample_batch(
                    batch_size=config.training_batch_size
                )
                batch = batch.to(device)
                loss = algo.batched_loss(
                    batch, discount=config.training_discount
                )

            loss.backward()
            gradient_norm = nn.utils.clip_grad_norm_(
                algo.models.parameters(), max_norm=config.optim_max_norm
            )
            optimizer.step()

            if wandb_log:
                logger.debug(
                    'training log - simulation_step %d loss %.3f',
                    xstats.simulation_timesteps,
                    loss,
                )
                wandb_logger.log(
                    {
                        **xstats.asdict(),
                        'hours': timer.hours,
                        'training/loss': loss,
                        'training/gradient_norm': gradient_norm,
                    }
                )

            if config.save_modelseq and config.modelseq_filename is not None:
                data = {
                    'metadata': {'config': config._as_dict()},
                    'data': {
                        'timestep': xstats.simulation_timesteps,
                        'model.state_dict': algo.models.state_dict(),
                    },
                }
                filename = config.modelseq_filename.format(
                    xstats.simulation_timesteps
                )
                save_data(filename, data)

            xstats.optimizer_steps += 1
            if algo.episodic_training:
                xstats.training_episodes += len(episodes)
                xstats.training_timesteps += sum(
                    len(episode) for episode in episodes
                )
            else:
                xstats.training_timesteps += len(batch)

        xstats.epoch += 1

    done = not interrupt

    if done and config.save_model and config.model_filename is not None:
        data = {
            'metadata': {'config': config._as_dict()},
            'data': {'models.state_dict': algo.models.state_dict()},
        }
        save_data(config.model_filename, data)

    return done
Example #7
0
def execute_before_tests():
    config = get_config()
    config._update({
        'hs_features_dim': 0,
        'normalize_hs_features': False,
    })