Example #1
0
    def run_lightning(config):

        if config['use_RCP_buffer']:
            train_buffer = RingBuffer(
                obs_dim=env_params['STORED_STATE_SIZE'],
                act_dim=env_params['STORED_ACTION_SIZE'],
                size=config['max_buffer_size'],
                use_td_lambda_buf=config['desire_advantage'])
            test_buffer = RingBuffer(
                obs_dim=env_params['STORED_STATE_SIZE'],
                act_dim=env_params['STORED_ACTION_SIZE'],
                size=config['batch_size'] * 10,
                use_td_lambda_buf=config['desire_advantage'])
        else:
            config['max_buffer_size'] *= env_params['avg_episode_length']
            train_buffer = SortedBuffer(
                obs_dim=env_params['STORED_STATE_SIZE'],
                act_dim=env_params['STORED_ACTION_SIZE'],
                size=config['max_buffer_size'],
                use_td_lambda_buf=config['desire_advantage'])
            test_buffer = SortedBuffer(
                obs_dim=env_params['STORED_STATE_SIZE'],
                act_dim=env_params['STORED_ACTION_SIZE'],
                size=config['batch_size'] * 10,
                use_td_lambda_buf=config['desire_advantage'])

        model = LightningTemplate(game_dir, config, train_buffer, test_buffer)

        if args.reload or args.eval_agent:
            # load in trained model:
            # get name from either eval or reload:
            if args.reload:
                load_name = args.reload
            else:
                load_name = args.eval_agent
            state_dict = torch.load(load_name)['state_dict']
            if args.implementation == 'RCP-A':
                # need to handle the advantage model also existing here.
                # [len('model.'):] strips the type of model from the front of the name.
                model_state_dict = {
                    k[len('model.'):]: v
                    for k, v in state_dict.items()
                    if 'model' in k[:len('model.')]
                }
                adv_state_dict = {
                    k[len('advantage_model.'):]: v
                    for k, v in state_dict.items() if 'advantage_model.' in k
                }
                model.advantage_model.load_state_dict(adv_state_dict)
            else:
                # strips the name 'model' from the front of the strings.
                model_state_dict = {
                    k[len('model.'):]: v
                    for k, v in state_dict.items()
                }
            model.model.load_state_dict(model_state_dict)
            print("Loaded in Model!")

        if args.eval_agent:
            print(
                'Ensure the desires for your agent (approx line 76 of lightning_trainer.py) \
                correspond to those your agent learned.')
            # calls part of lightning.
            model.eval_agent()

        else:
            trainer = Trainer(deterministic=True,
                              logger=logger,
                              default_root_dir=game_dir,
                              max_epochs=epochs,
                              profiler=False,
                              checkpoint_callback=every_checkpoint_callback,
                              callbacks=callback_list,
                              gradient_clip_val=config['grad_clip_val'],
                              progress_bar_refresh_rate=0)
            trainer.fit(model)