Exemple #1
0
def run_task(vv, log_dir=None, exp_name=None):
    if log_dir or logger.get_dir() is None:
        logger.configure(dir=log_dir, exp_name=exp_name, format_strs=['csv'])
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)
    updated_vv = copy.copy(DEFAULT_CONFIG)
    updated_vv.update(**vv)
    main(vv_to_args(updated_vv))
Exemple #2
0
def run_task(vv, log_dir=None, exp_name=None):
    if log_dir or logger.get_dir() is None:
        logger.configure(dir=log_dir, exp_name=exp_name, format_strs=['csv'])
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)

    default_cfg = yaml.load(open('drq/config.yml', 'r'))
    cfg = update_config(default_cfg, vv)
    cfg = update_env_kwargs(cfg)
    workspace = Workspace(vv_to_args(cfg))
    workspace.run()
Exemple #3
0
def run_task(arg_vv, log_dir, exp_name):
    if arg_vv['algorithm'] == 'planet':
        from planet.config import DEFAULT_PARAMS
    elif arg_vv['algorithm'] == 'dreamer':
        from dreamer.config import DEFAULT_PARAMS
    else:
        raise NotImplementedError

    vv = DEFAULT_PARAMS
    vv.update(**arg_vv)
    vv = update_env_kwargs(vv)
    vv['max_episode_length'] = vv['env_kwargs']['horizon']

    # Configure logger
    logger.configure(dir=log_dir, exp_name=exp_name)
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)

    # Configure torch
    if torch.cuda.is_available():
        device = torch.device('cuda:1') if torch.cuda.device_count(
        ) > 1 else torch.device('cuda:0')
        torch.cuda.manual_seed(vv['seed'])
    else:
        device = torch.device('cpu')

    # Dump parameters
    with open(osp.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(vv, f, indent=2, sort_keys=True)
    env = Env(vv['env_name'],
              vv['symbolic_env'],
              vv['seed'],
              vv['max_episode_length'],
              vv['action_repeat'],
              vv['bit_depth'],
              vv['image_dim'],
              env_kwargs=vv['env_kwargs'])

    if vv['algorithm'] == 'planet':
        from planet.planet_agent import PlaNetAgent
        agent = PlaNetAgent(env, vv, device)
        agent.train(train_epoch=vv['train_epoch'])
        env.close()
    elif vv['algorithm'] == 'dreamer':
        from dreamer.dreamer_agent import DreamerAgent
        agent = DreamerAgent(env, vv, device)
        agent.train(train_episode=vv['train_episode'])
        env.close()
Exemple #4
0
    def __init__(self, cfg):
        self.work_dir = logger.get_dir()
        print(f'workspace: {self.work_dir}')

        self.cfg = cfg

        self.logger = Logger(self.work_dir,
                             save_tb=cfg.log_save_tb,
                             log_frequency=cfg.log_frequency_step,
                             agent='drq',
                             action_repeat=1,
                             chester_log=logger)

        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)
        self.env = make_env(cfg)

        obs_shape = self.env.observation_space.shape
        new_obs_shape = np.zeros_like(obs_shape)
        new_obs_shape[0] = obs_shape[-1]
        new_obs_shape[1] = obs_shape[0]
        new_obs_shape[2] = obs_shape[1]
        cfg.agent['obs_shape'] = cfg.encoder['obs_shape'] = new_obs_shape
        cfg.agent['action_shape'] = self.env.action_space.shape
        cfg.actor['action_shape'] = self.env.action_space.shape
        cfg.critic['action_shape'] = self.env.action_space.shape
        cfg.actor['encoder_cfg'] = cfg.encoder
        cfg.critic['encoder_cfg'] = cfg.encoder
        cfg.agent['action_range'] = [
            float(self.env.action_space.low.min()),
            float(self.env.action_space.high.max())
        ]
        cfg.agent['encoder_cfg'] = cfg.encoder
        cfg.agent['critic_cfg'] = cfg.critic
        cfg.agent['actor_cfg'] = cfg.actor

        self.agent = DRQAgent(**cfg.agent)

        self.replay_buffer = ReplayBuffer(new_obs_shape,
                                          self.env.action_space.shape,
                                          cfg.replay_buffer_capacity,
                                          self.cfg.image_pad, self.device)

        # self.video_recorder = VideoRecorder(
        #     self.work_dir if cfg.save_video else None)
        self.step = 0
        self.video_dir = os.path.join(self.work_dir, 'video')
        self.model_dir = os.path.join(self.work_dir, 'model')
        if not os.path.exists(self.video_dir):
            os.makedirs(self.video_dir, exist_ok=True)
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir, exist_ok=True)
Exemple #5
0
def vv_to_args(vv):
    class VArgs(object):
        def __init__(self, vv):
            for key, val in vv.items():
                setattr(self, key, val)

    args = VArgs(vv)

    # Dump parameters
    with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(vv, f, indent=2, sort_keys=True)

    return args
Exemple #6
0
def main(args):
    if args.seed == -1:
        args.__dict__["seed"] = np.random.randint(1, 1000000)
    utils.set_seed_everywhere(args.seed)

    args.__dict__ = update_env_kwargs(args.__dict__)  # Update env_kwargs

    symbolic = args.env_kwargs['observation_mode'] != 'cam_rgb'
    args.encoder_type = 'identity' if symbolic else 'pixel'

    env = Env(args.env_name,
              symbolic,
              args.seed,
              200,
              1,
              8,
              args.pre_transform_image_size,
              env_kwargs=args.env_kwargs,
              normalize_observation=False,
              scale_reward=args.scale_reward,
              clip_obs=args.clip_obs)
    env.seed(args.seed)

    # make directory
    ts = time.gmtime()
    ts = time.strftime("%m-%d", ts)

    args.work_dir = logger.get_dir()

    video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
    model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
    buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    action_shape = env.action_space.shape

    if args.encoder_type == 'pixel':
        obs_shape = (3, args.image_size, args.image_size)
        pre_aug_obs_shape = (3, args.pre_transform_image_size,
                             args.pre_transform_image_size)
    else:
        obs_shape = env.observation_space.shape
        pre_aug_obs_shape = obs_shape

    replay_buffer = utils.ReplayBuffer(
        obs_shape=pre_aug_obs_shape,
        action_shape=action_shape,
        capacity=args.replay_buffer_capacity,
        batch_size=args.batch_size,
        device=device,
        image_size=args.image_size,
    )

    agent = make_agent(obs_shape=obs_shape,
                       action_shape=action_shape,
                       args=args,
                       device=device)

    L = Logger(args.work_dir, use_tb=args.save_tb, chester_logger=logger)

    episode, episode_reward, done, ep_info = 0, 0, True, []
    start_time = time.time()
    for step in range(args.num_train_steps):
        # evaluate agent periodically

        if step % args.eval_freq == 0:
            L.log('eval/episode', episode, step)
            evaluate(env, agent, video_dir, args.num_eval_episodes, L, step,
                     args)
            if args.save_model and (step % (args.eval_freq * 5) == 0):
                agent.save(model_dir, step)
            if args.save_buffer:
                replay_buffer.save(buffer_dir)
        if done:
            if step > 0:
                if step % args.log_interval == 0:
                    L.log('train/duration', time.time() - start_time, step)
                    for key, val in get_info_stats([ep_info]).items():
                        L.log('train/info_' + key, val, step)
                    L.dump(step)
                start_time = time.time()
            if step % args.log_interval == 0:
                L.log('train/episode_reward', episode_reward, step)

            obs = env.reset()
            done = False
            ep_info = []
            episode_reward = 0
            episode_step = 0
            episode += 1
            if step % args.log_interval == 0:
                L.log('train/episode', episode, step)

        # sample action for data collection
        if step < args.init_steps:
            action = env.action_space.sample()
        else:
            with utils.eval_mode(agent):
                action = agent.sample_action(obs)

        # run training update
        if step >= args.init_steps:
            num_updates = 1
            for _ in range(num_updates):
                agent.update(replay_buffer, L, step)
        next_obs, reward, done, info = env.step(action)

        # allow infinit bootstrap
        ep_info.append(info)
        done_bool = 0 if episode_step + 1 == env.horizon else float(done)
        episode_reward += reward
        replay_buffer.add(obs, action, reward, next_obs, done_bool)

        obs = next_obs
        episode_step += 1
def DDPG_train(vv, env, agent):
    '''
    Given the training args, training/testing environments, log files, and the DDPG agents, this function 
    uses the python threading module to interact the agent with the training environment in multiple threads, 
    collect interacting samples, perform training of the agent, and record necessary logistics in log files.
    '''

    a_lr = vv['actor_lr']
    c_lr = vv['critic_lr']
    ### training begins
    for train_iter in range(vv['train_epoch']):
        dx = np.random.choice(vv['dx'])
        if dx <= 0.02:
            weno_freq = vv['weno_freq']
        else:
            weno_freq = 0

        pre_state = env.reset(dx=dx)
        env_batch_size = env.num_x
        horizon = env.horizon

        # decay learning rate
        if train_iter > 0 and vv['lr_decay_interval'] > 0 and train_iter % vv[
                'lr_decay_interval'] == 0:
            a_lr = max(vv['final_actor_lr'], a_lr / 2)
            c_lr = max(vv['final_critic_lr'], c_lr / 2)
            c_optimizers = [agent.critic_optimizer, agent.critic_optimizer2]
            for optim in c_optimizers:
                for param_group in optim.param_groups:
                    param_group['lr'] = c_lr

            for param_group in agent.actor_optimizer.param_groups:
                param_group['lr'] = a_lr

        ret = 0
        for t in range(1, horizon):
            p = np.random.rand()
            if p < weno_freq:
                action = agent.action(pre_state, mode='weno')
            else:
                action = agent.action(pre_state)

            ### next_state, reward, done, all batches
            next_state, reward, done, _ = env.step(action)
            ret += np.mean(reward)

            # TODO: change this to support store a batch
            for i in range(env_batch_size):
                agent.store(pre_state[i], action[i], reward[i], next_state[i],
                            done[i])

            agent.train()
            pre_state = next_state

        error, relative_error = env.error('euler')
        solution_idx = env.solution_idx
        logger.record_tabular('Train/{}-error'.format(solution_idx), error)
        logger.record_tabular('Train/{}-ret'.format(solution_idx), ret)
        logger.record_tabular('Train/{}-relative-error'.format(solution_idx),
                              relative_error)

        ### decrease exploration ratio of the threading agents
        if train_iter > 0 and train_iter % vv['noise_dec_every'] == 0:
            agent.action_noise.decrease(vv['noise_dec'])

        # TODO: implement possible learning rate decay
        # if train_iter > 0 and train_iter % vv['decay_learning_rate'] == 0:
        #     # update_linear_schedule(self.critic_optimizer, self.global_step, self.vv[train_epoch, self.args.c_lr, self.args.final_c_lr)
        #     # update_linear_schedule(self.critic_optimizer2, self.global_step, self.args.train_epoch, self.args.c_lr, self.args.final_c_lr)
        #     # update_linear_schedule(self.actor_optimizer, self.global_step, self.args.train_epoch, self.args.a_lr, self.args.final_a_lr)
        #     pass

        ### test central_agent in test_envs, both euler error and rk4 error.
        if train_iter % vv['test_interval'] == 0:
            env.train_flag = False
            agent.train_mode(False)

            for dx in vv['dx']:
                print("test begin")
                errors, relative_errors = [], []
                for solution_idx in range(
                        len(env.solution_path_list) // 2,
                        len(env.solution_path_list)):
                    pre_state = env.reset(solution_idx=solution_idx,
                                          num_t=200,
                                          dx=dx)
                    horizon = env.num_t
                    for t in range(1, horizon):
                        action = agent.action(
                            pre_state, deterministic=True
                        )  # action: (state_dim -2, 1) batch
                        next_state, reward, done, _ = env.step(action,
                                                               Tscheme='rk4')
                        pre_state = next_state
                    error, relative_error = env.error('rk4')
                    errors.append(error)
                    relative_errors.append(relative_error)

                names = ['error', 'relative_error']
                all_errors = [errors, relative_errors]
                for i in range(len(names)):
                    name = names[i]
                    errors = all_errors[i]
                    logger.record_tabular(f'Test/{dx}_{name}_mean',
                                          np.mean(errors))
                    logger.record_tabular(f'Test/{dx}_{name}_max',
                                          np.max(errors))
                    logger.record_tabular(f'Test/{dx}_{name}_min',
                                          np.min(errors))
                    logger.record_tabular(f'Test/{dx}_{name}_median',
                                          np.median(errors))
                    logger.record_tabular(f'Test/{dx}_{name}_std',
                                          np.std(errors))
                print("test end")

            logger.dump_tabular()
            env.train_flag = True
            agent.train_mode(True)

        if train_iter % vv['save_interval'] == 0 and train_iter > 0:
            agent.save(osp.join(logger.get_dir(), str(train_iter)))

    return agent
Exemple #8
0
def run_task(vv, log_dir, exp_name):
    mp.set_start_method('spawn')
    env_name = vv['env_name']
    vv['algorithm'] = 'CEM'
    vv['env_kwargs'] = env_arg_dict[env_name]  # Default env parameters
    vv['plan_horizon'] = cem_plan_horizon[env_name]  # Planning horizon

    vv['population_size'] = vv['timestep_per_decision'] // vv['max_iters']
    if vv['use_mpc']:
        vv['population_size'] = vv['population_size'] // vv['plan_horizon']
    vv['num_elites'] = vv['population_size'] // 10
    vv = update_env_kwargs(vv)

    # Configure logger
    logger.configure(dir=log_dir, exp_name=exp_name)
    logdir = logger.get_dir()
    assert logdir is not None
    os.makedirs(logdir, exist_ok=True)

    # Configure torch
    if torch.cuda.is_available():
        torch.cuda.manual_seed(vv['seed'])

    # Dump parameters
    with open(osp.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(vv, f, indent=2, sort_keys=True)

    env_symbolic = vv['env_kwargs']['observation_mode'] != 'cam_rgb'

    env_class = Env
    env_kwargs = {
        'env': vv['env_name'],
        'symbolic': env_symbolic,
        'seed': vv['seed'],
        'max_episode_length': 200,
        'action_repeat':
        1,  # Action repeat for env wrapper is 1 as it is already inside the env
        'bit_depth': 8,
        'image_dim': None,
        'env_kwargs': vv['env_kwargs']
    }
    env = env_class(**env_kwargs)

    env_kwargs_render = copy.deepcopy(env_kwargs)
    env_kwargs_render['env_kwargs']['render'] = True
    env_render = env_class(**env_kwargs_render)

    policy = CEMPolicy(env,
                       env_class,
                       env_kwargs,
                       vv['use_mpc'],
                       plan_horizon=vv['plan_horizon'],
                       max_iters=vv['max_iters'],
                       population_size=vv['population_size'],
                       num_elites=vv['num_elites'])
    # Run policy
    initial_states, action_trajs, configs, all_infos = [], [], [], []
    for i in range(vv['test_episodes']):
        logger.log('episode ' + str(i))
        obs = env.reset()
        policy.reset()
        initial_state = env.get_state()
        action_traj = []
        infos = []
        for j in range(env.horizon):
            logger.log('episode {}, step {}'.format(i, j))
            action = policy.get_action(obs)
            action_traj.append(copy.copy(action))
            obs, reward, _, info = env.step(action)
            infos.append(info)
        all_infos.append(infos)
        initial_states.append(initial_state.copy())
        action_trajs.append(action_traj.copy())
        configs.append(env.get_current_config().copy())

        # Log for each episode
        transformed_info = transform_info([infos])
        for info_name in transformed_info:
            logger.record_tabular('info_' + 'final_' + info_name,
                                  transformed_info[info_name][0, -1])
            logger.record_tabular('info_' + 'avarage_' + info_name,
                                  np.mean(transformed_info[info_name][0, :]))
            logger.record_tabular(
                'info_' + 'sum_' + info_name,
                np.sum(transformed_info[info_name][0, :], axis=-1))
        logger.dump_tabular()

    # Dump trajectories
    traj_dict = {
        'initial_states': initial_states,
        'action_trajs': action_trajs,
        'configs': configs
    }
    with open(osp.join(log_dir, 'cem_traj.pkl'), 'wb') as f:
        pickle.dump(traj_dict, f)

    # Dump video
    cem_make_gif(env_render, initial_states, action_trajs, configs,
                 logger.get_dir(), vv['env_name'] + '.gif')
Exemple #9
0
def run_task(vv, log_dir, exp_name):
    import torch
    import numpy as np
    import copy
    import os, sys
    import time
    import math
    import random
    import json

    from get_args import get_args
    from DDPG.train_util import DDPG_train, DDPG_test
    from DDPG.DDPG_new import DDPG
    from DDPG.util import GaussNoise
    from chester import logger
    from BurgersEnv.Burgers import Burgers
    import utils.ptu as ptu

    if torch.cuda.is_available():
        ptu.set_gpu_mode(True)

    ### dump vv
    logger.configure(dir=log_dir, exp_name=exp_name)
    with open(os.path.join(logger.get_dir(), 'variant.json'), 'w') as f:
        json.dump(vv, f, indent=2, sort_keys=True)

    ### load vv
    ddpg_load_epoch = None
    if vv['load_path'] is not None:
        solution_data_path = vv['solution_data_path']
        dx = vv['dx']
        test_interval = vv['test_interval']
        load_path = os.path.join('data/local', vv['load_path'])
        ddpg_load_epoch = str(vv['load_epoch'])
        with open(os.path.join(load_path, 'variant.json'), 'r') as f:
            vv = json.load(f)
        vv['noise_beg'] = 0.1
        vv['solution_data_path'] = solution_data_path
        vv['test_interval'] = test_interval
        if vv.get('dx') is None:
            vv['dx'] = dx

    ### Important: fix numpy and torch seed!
    seed = vv['seed']
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

    ### Initialize RL agents
    ddpg = DDPG(
        vv, GaussNoise(initial_sig=vv['noise_beg'], final_sig=vv['noise_end']))
    agent = ddpg
    if ddpg_load_epoch is not None:
        print("load ddpg models from {}".format(
            os.path.join(load_path, ddpg_load_epoch)))
        agent.load(os.path.join(load_path, ddpg_load_epoch))

    ### Initialize training and testing encironments
    env = Burgers(vv, agent=agent)

    ### train models
    print('begining training!')
    DDPG_train(vv, env, agent)
Exemple #10
0
    def evaluate_model(self, eval_on_held_out):
        """
        :param eval_on_train: If test on training variations of the environments or the test variations of the environment
        :return:
        """
        if eval_on_held_out:
            prefix = 'train_'
        else:
            prefix = 'eval_'

        self.set_model_eval()
        if eval_on_held_out:
            self.env.eval_flag = True
        # Initialise parallelised test environments
        with torch.no_grad():
            all_total_rewards = []
            all_infos = []
            all_frames, all_frames_reconstr = [], []
            for _ in range(self.vv['test_episodes']):
                frames, frames_reconstr, infos = [], [], []
                observation, total_reward, observation_reconstr = self.env.reset(
                ), 0, []
                belief, posterior_state, action = torch.zeros(1, self.vv['belief_size'], device=self.device), \
                                                  torch.zeros(1, self.vv['state_size'], device=self.device), \
                                                  torch.zeros(1, self.env.action_size, device=self.device)
                pbar = tqdm(range(self.env.horizon))
                for t in pbar:
                    belief, posterior_state, action, next_observation, reward, done, info = \
                        self.update_belief_and_act(self.env, belief, posterior_state, action, observation.to(device=self.device),
                                                   explore=True)
                    total_reward += reward
                    infos.append(info)
                    if not self.vv[
                            'symbolic_env']:  # Collect real vs. predicted frames for video
                        frames.append(observation)
                        frames_reconstr.append(
                            self.observation_model(belief,
                                                   posterior_state).cpu())
                    observation = next_observation
                    if done:
                        pbar.close()
                        break
                all_frames.append(frames)
                all_frames_reconstr.append(frames_reconstr)
                all_total_rewards.append(total_reward)
                all_infos.append(infos)

            all_frames = all_frames[:
                                    8]  # Only take the first 8 episodes to visualize
            all_frames_reconstr = all_frames_reconstr[:8]
            video_frames = []
            for i in range(len(all_frames[0])):
                frame = torch.cat([x[i] for x in all_frames])
                frame_reconstr = torch.cat([x[i] for x in all_frames_reconstr])
                video_frames.append(
                    make_grid(torch.cat([frame, frame_reconstr], dim=3) + 0.5,
                              nrow=4).numpy())

        # Update and plot reward metrics (and write video if applicable) and save metrics
        self.test_episodes += self.vv['test_episodes']

        logger.record_tabular(prefix + 'test_episodes', self.test_episodes)
        logger.record_tabular(prefix + 'test_rewards',
                              np.mean(all_total_rewards))
        transformed_info = transform_info(all_infos)
        for info_name in transformed_info:
            logger.record_tabular(prefix + 'info_' + 'final_' + info_name,
                                  np.mean(transformed_info[info_name][:, -1]))
            logger.record_tabular(prefix + 'info_' + 'avarage_' + info_name,
                                  np.mean(transformed_info[info_name][:, :]))
            logger.record_tabular(
                prefix + 'info_' + 'sum_' + info_name,
                np.mean(np.sum(transformed_info[info_name][:, :], axis=-1)))
        if not self.vv['symbolic_env']:
            episode_str = str(self.train_episodes).zfill(
                len(str(self.train_episodes)))
            write_video(video_frames, prefix + 'test_episode_%s' % episode_str,
                        logger.get_dir())  # Lossy compression
            save_image(
                torch.as_tensor(video_frames[-1]),
                os.path.join(logger.get_dir(),
                             prefix + 'test_episode_%s.png' % episode_str))
        self.set_model_train()
        if eval_on_held_out:
            self.env.eval_flag = False
Exemple #11
0
    def train(self, train_epoch, experience_replay_path=None):
        logger.info('Warming up ...')
        self._init_replay_buffer(experience_replay_path)
        logger.info('Start training ...')
        for epoch in tqdm(range(train_epoch)):
            # Model fitting
            losses = []
            for _ in tqdm(range(self.vv['collect_interval'])):
                # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
                if self.value_model is not None:
                    observations, actions, rewards, values, nonterminals = \
                        self.D.sample(self.vv['batch_size'], self.vv['chunk_size'])  # Transitions start at time t = 0
                else:
                    observations, actions, rewards, nonterminals = \
                        self.D.sample(self.vv['batch_size'], self.vv['chunk_size'])  # Transitions start at time t = 0

                # Create initial belief and state for time t = 0
                init_belief, init_state = torch.zeros(self.vv['batch_size'], self.vv['belief_size'], device=self.device), \
                                          torch.zeros(self.vv['batch_size'], self.vv['state_size'], device=self.device)
                # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
                beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(
                    init_state, actions[:-1], init_belief,
                    bottle(self.encoder, (observations[1:], )),
                    nonterminals[:-1])
                # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
                observation_loss = F.mse_loss(
                    bottle(self.observation_model,
                           (beliefs, posterior_states)),
                    observations[1:],
                    reduction='none').sum(
                        dim=2 if self.vv['symbolic_env'] else (2, 3, 4)).mean(
                            dim=(0, 1))
                reward_loss = F.mse_loss(bottle(self.reward_model,
                                                (beliefs, posterior_states)),
                                         rewards[:-1],
                                         reduction='none').mean(dim=(0, 1))
                # TODO check why the last one of the reward is not used!!
                if self.value_model is not None:
                    prev_beliefs = torch.cat(
                        [init_belief.unsqueeze(dim=0), beliefs[:-1, :, :]])
                    prev_states = torch.cat([
                        init_state.unsqueeze(dim=0),
                        posterior_states[:-1, :, :]
                    ])
                    target = (rewards[:-1] +
                              bottle(self.value_model,
                                     (beliefs, posterior_states)).detach()
                              ) * nonterminals[:-1].squeeze(dim=2)
                    value_loss = F.mse_loss(bottle(
                        self.value_model, (prev_beliefs, prev_states)),
                                            target,
                                            reduction='none').mean(dim=(0, 1))
                    value_target_average = target.mean().item()

                # Note that normalisation by overshooting distance and weighting by overshooting distance cancel out
                kl_loss = torch.max(
                    kl_divergence(Normal(posterior_means, posterior_std_devs),
                                  Normal(prior_means,
                                         prior_std_devs)).sum(dim=2),
                    self.free_nats).mean(dim=(0, 1))

                if self.vv['global_kl_beta'] != 0:
                    kl_loss += self.vv['global_kl_beta'] * kl_divergence(
                        Normal(posterior_means, posterior_std_devs),
                        self.global_prior).sum(dim=2).mean(dim=(0, 1))
                # Calculate latent overshooting objective for t > 0
                if self.vv['overshooting_kl_beta'] != 0:
                    raise NotImplementedError  # Need to deal with value function
                    overshooting_vars = [
                    ]  # Collect variables for overshooting to process in batch
                    for t in range(1, self.vv['chunk_size'] - 1):
                        d = min(t + self.vv['overshooting_distance'],
                                self.vv['chunk_size'] -
                                1)  # Overshooting distance
                        t_, d_ = t - 1, d - 1  # Use t_ and d_ to deal with different time indexing for latent states
                        # Calculate sequence padding so overshooting terms can be calculated in one batch
                        seq_pad = (0, 0, 0, 0, 0,
                                   t - d + self.vv['overshooting_distance'])
                        # Store (0) actions, (1) nonterminals, (2) rewards, (3) beliefs, (4) prior states, (5) posterior means, (6) posterior standard deviations and (7) sequence masks
                        overshooting_vars.append(
                            (F.pad(actions[t:d],
                                   seq_pad), F.pad(nonterminals[t:d], seq_pad),
                             F.pad(rewards[t:d],
                                   seq_pad[2:]), beliefs[t_], prior_states[t_],
                             F.pad(posterior_means[t_ + 1:d_ + 1].detach(),
                                   seq_pad),
                             F.pad(posterior_std_devs[t_ + 1:d_ + 1].detach(),
                                   seq_pad,
                                   value=1),
                             F.pad(
                                 torch.ones(d - t,
                                            self.vv['batch_size'],
                                            self.vv['state_size'],
                                            device=self.device), seq_pad))
                        )  # Posterior standard deviations must be padded with > 0 to prevent infinite KL divergences
                    overshooting_vars = tuple(zip(*overshooting_vars))
                    # Update belief/state using prior from previous belief/state and previous action (over entire sequence at once)
                    beliefs, prior_states, prior_means, prior_std_devs = self.transition_model(
                        torch.cat(overshooting_vars[4], dim=0),
                        torch.cat(overshooting_vars[0], dim=1),
                        torch.cat(overshooting_vars[3], dim=0), None,
                        torch.cat(overshooting_vars[1], dim=1))
                    seq_mask = torch.cat(overshooting_vars[7], dim=1)
                    # Calculate overshooting KL loss with sequence mask
                    # Update KL loss (compensating for extra average over each overshooting/open loop sequence)
                    kl_loss += (
                        1 / self.vv['overshooting_distance']
                    ) * self.vv['overshooting_kl_beta'] * torch.max(
                        (kl_divergence(
                            Normal(torch.cat(overshooting_vars[5], dim=1),
                                   torch.cat(overshooting_vars[6], dim=1)),
                            Normal(prior_means, prior_std_devs)) *
                         seq_mask).sum(dim=2), self.free_nats).mean(
                             dim=(0, 1)) * (self.vv['chunk_size'] - 1)
                    # Calculate overshooting reward prediction loss with sequence mask
                    if self.vv['overshooting_reward_scale'] != 0:
                        reward_loss += (
                            1 / self.vv['overshooting_distance']
                        ) * self.vv['overshooting_reward_scale'] * F.mse_loss(
                            bottle(self.reward_model,
                                   (beliefs, prior_states)) *
                            seq_mask[:, :, 0],
                            torch.cat(overshooting_vars[2], dim=1),
                            reduction='none'
                        ).mean(dim=(0, 1)) * (
                            self.vv['chunk_size'] - 1
                        )  # Update reward loss (compensating for extra average over each overshooting/open loop sequence)

                # Apply linearly ramping learning rate schedule
                if self.vv['learning_rate_schedule'] != 0:
                    for group in self.optimiser.param_groups:
                        group['lr'] = min(
                            group['lr'] + self.vv['learning_rate'] /
                            self.vv['learning_rate_schedule'],
                            self.vv['learning_rate'])
                # Update model parameters
                self.optimiser.zero_grad()
                if self.value_model is not None:
                    (observation_loss + reward_loss + value_loss +
                     kl_loss).backward()
                else:
                    (observation_loss + reward_loss + kl_loss).backward()
                nn.utils.clip_grad_norm_(self.param_list,
                                         self.vv['grad_clip_norm'],
                                         norm_type=2)
                self.optimiser.step()
                # Store (0) observation loss (1) reward loss (2) KL loss
                losses.append([
                    observation_loss.item(),
                    reward_loss.item(),
                    kl_loss.item()
                ])
                if self.value_model is not None:
                    losses[-1].append(value_loss.item())

            # Data collection
            with torch.no_grad():
                all_total_rewards = []  # Average across all episodes
                for i in range(self.vv['episodes_per_loop']):
                    observation, total_reward = self.env.reset(), 0
                    observations, actions, rewards, dones = [], [], [], []
                    belief, posterior_state, action = torch.zeros(1, self.vv['belief_size'], device=self.device), \
                                                      torch.zeros(1, self.vv['state_size'], device=self.device), \
                                                      torch.zeros(1, self.env.action_size, device=self.device)
                    pbar = tqdm(range(self.env.horizon))
                    for t in pbar:
                        belief, posterior_state, action, next_observation, reward, done, info = \
                            self.update_belief_and_act(self.env, belief, posterior_state, action, observation.to(device=self.device), explore=True)
                        observations.append(observation), actions.append(
                            action.cpu()), rewards.append(
                                reward), dones.append(done)
                        total_reward += reward
                        observation = next_observation
                        if done:
                            pbar.close()
                            break
                    self.D.append_episode(observations, actions, rewards,
                                          dones)
                    self.train_episodes += 1
                    self.train_steps += t
                    all_total_rewards.append(total_reward)

            # Log
            losses = np.array(losses)
            logger.record_tabular('observation_loss', np.mean(losses[:, 0]))
            logger.record_tabular('reward_loss', np.mean(losses[:, 1]))
            logger.record_tabular('kl_loss', np.mean(losses[:, 2]))
            if self.value_model is not None:
                logger.record_tabular('value_loss', np.mean(losses[:, 3]))
                logger.record_tabular('value_target_average',
                                      value_target_average)
            logger.record_tabular('train_rewards', np.mean(all_total_rewards))
            logger.record_tabular('num_episodes', self.train_episodes)
            logger.record_tabular('num_steps', self.train_steps)

            # Test model
            if epoch % self.vv['test_interval'] == 0:
                self.evaluate_model(eval_on_held_out=True)
                self.evaluate_model(eval_on_held_out=False)

            # Checkpoint models
            if epoch % self.vv['checkpoint_interval'] == 0:
                torch.save(
                    {
                        'transition_model':
                        self.transition_model.state_dict(),
                        'observation_model':
                        self.observation_model.state_dict(),
                        'reward_model':
                        self.reward_model.state_dict(),
                        'value_model':
                        self.value_model.state_dict()
                        if self.value_model is not None else None,
                        'encoder':
                        self.encoder.state_dict(),
                        'optimiser':
                        self.optimiser.state_dict()
                    }, os.path.join(logger.get_dir(), 'models_%d.pth' % epoch))
                if self.vv['checkpoint_experience']:
                    torch.save(
                        self.D, os.path.join(logger.get_dir(),
                                             'experience.pth')
                    )  # Warning: will fail with MemoryError with large memory sizes
            logger.dump_tabular()