Beispiel #1
0
    def collect_random_statistics(self, num_timesteps):
        #  initialize observation normalization with data from random agent

        self.obs_rms = RunningMeanStd(shape=(1, 7, 7, 3))

        curr_obs = self.obs
        collected_obss = [None] * (self.num_frames_per_proc * num_timesteps)
        for i in range(self.num_frames_per_proc * num_timesteps):
            # Do one agent-environment interaction

            action = torch.randint(
                0, self.env.action_space.n,
                (self.num_procs, ))  # sample uniform actions
            obs, reward, done, _ = self.env.step(action.cpu().numpy())

            # Update experiences values
            collected_obss[i] = curr_obs
            curr_obs = obs

        self.obs = curr_obs
        exps = DictList()
        exps.obs = [
            collected_obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc * num_timesteps)
        ]

        images = [obs["image"] for obs in exps.obs]
        images = numpy.array(images)
        images = torch.tensor(images, dtype=torch.float)

        self.obs_rms.update(images)
Beispiel #2
0
    def __init__(self, action_dim, state_dim, buffer_size=1000000, action_samples=10,
                 mode='linear', beta=1, tau=5e-3, q_normalization=0.01, gamma=0.99,
                 normalize_obs=False, normalize_rewards=False, batch_size=64, actor='AIQN',
                 *args, **kwargs):
        """
        Agent class to generate a stochastic policy.

        Args:
            action_dim (int): action dimension
            state_dim (int): state dimension
            buffer_size (int): how much memory is allocated to the ReplayMemoryClass
            action_samples (int): originally labelled K in the paper, represents how many
                actions should be sampled from the memory buffer
            mode (string): poorly named variable to represent variable being used in the
                distribution being used
            beta (float): value used in boltzmann distribution
            tau (float): update rate parameter
            batch_size (int): batch size
            q_normalization (float): q value normalization rate
            gamma (float): value used in critic training
            normalize_obs (boolean): boolean to indicate that you want to normalize
                observations
            normalize_rewards (boolean): boolean to indicate that you want to normalize
                return values (usually done for numerical stability)
            actor (string): string indicating the type of actor to use
        """
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.buffer_size = buffer_size
        self.gamma = gamma
        self.action_samples = action_samples
        self.mode = mode
        self.beta = beta
        self.tau = tau
        self.batch_size = batch_size

        # normalization
        self.normalize_observations = normalize_obs
        self.q_normalization = q_normalization
        self.normalize_rewards = normalize_rewards

        # type of actor being used
        if actor == 'IQN':
            self.actor = StochasticActor(self.state_dim, self.action_dim)
            self.target_actor = StochasticActor(self.state_dim, self.action_dim)
            self.actor_perturbed = StochasticActor(self.state_dim, self.action_dim)
        elif actor == 'AIQN':
            self.actor = AutoRegressiveStochasticActor(self.state_dim, self.action_dim)
            self.target_actor = AutoRegressiveStochasticActor(self.state_dim, self.action_dim)
            self.actor_perturbed = AutoRegressiveStochasticActor(self.state_dim, self.action_dim)

        if self.normalize_observations:
            self.obs_rms = RunningMeanStd(shape=self.state_dim)
        else:
            self.obs_rms = None

        if self.normalize_rewards:
            self.ret_rms = RunningMeanStd(shape=1)
            self.ret = 0
        else:
            self.ret_rms = None

        # initialize trainable variables
        self.actor(
            tf.zeros([self.batch_size, self.state_dim]),
            tf.zeros([self.batch_size, self.action_dim])
        )
        self.target_actor(
            tf.zeros([self.batch_size, self.state_dim]),
            tf.zeros([self.batch_size, self.action_dim])
        )

        self.critics = Critic(self.state_dim, self.action_dim)
        self.target_critics = Critic(self.state_dim, self.action_dim)

        # initialize trainable variables for critics
        self.critics(
            tf.zeros([self.batch_size, self.state_dim]),
            tf.zeros([self.batch_size, self.action_dim])
        )
        self.target_critics(
            tf.zeros([self.batch_size, self.state_dim]),
            tf.zeros([self.batch_size, self.action_dim])
        )

        self.value = Value(self.state_dim)
        self.target_value = Value(self.state_dim)

        # initialize value training variables
        self.value(tf.zeros([self.batch_size, self.state_dim]))
        self.value(tf.zeros([self.batch_size, self.state_dim]))

        # initialize the target networks.
        update(self.target_actor, self.actor, 1.0)
        update(self.target_critics, self.critics, 1.0)
        update(self.target_value, self.value, 1.0)

        self.replay = ReplayBuffer(self.state_dim, self.action_dim, self.buffer_size)
        self.action_sampler = ActionSampler(self.actor.action_dim)
Beispiel #3
0
class GACAgent:
    """
    GAC agent.
    Action is always from -1 to 1 in each dimension.
    """
    def __init__(self, action_dim, state_dim, buffer_size=1000000, action_samples=10,
                 mode='linear', beta=1, tau=5e-3, q_normalization=0.01, gamma=0.99,
                 normalize_obs=False, normalize_rewards=False, batch_size=64, actor='AIQN',
                 *args, **kwargs):
        """
        Agent class to generate a stochastic policy.

        Args:
            action_dim (int): action dimension
            state_dim (int): state dimension
            buffer_size (int): how much memory is allocated to the ReplayMemoryClass
            action_samples (int): originally labelled K in the paper, represents how many
                actions should be sampled from the memory buffer
            mode (string): poorly named variable to represent variable being used in the
                distribution being used
            beta (float): value used in boltzmann distribution
            tau (float): update rate parameter
            batch_size (int): batch size
            q_normalization (float): q value normalization rate
            gamma (float): value used in critic training
            normalize_obs (boolean): boolean to indicate that you want to normalize
                observations
            normalize_rewards (boolean): boolean to indicate that you want to normalize
                return values (usually done for numerical stability)
            actor (string): string indicating the type of actor to use
        """
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.buffer_size = buffer_size
        self.gamma = gamma
        self.action_samples = action_samples
        self.mode = mode
        self.beta = beta
        self.tau = tau
        self.batch_size = batch_size

        # normalization
        self.normalize_observations = normalize_obs
        self.q_normalization = q_normalization
        self.normalize_rewards = normalize_rewards

        # type of actor being used
        if actor == 'IQN':
            self.actor = StochasticActor(self.state_dim, self.action_dim)
            self.target_actor = StochasticActor(self.state_dim, self.action_dim)
            self.actor_perturbed = StochasticActor(self.state_dim, self.action_dim)
        elif actor == 'AIQN':
            self.actor = AutoRegressiveStochasticActor(self.state_dim, self.action_dim)
            self.target_actor = AutoRegressiveStochasticActor(self.state_dim, self.action_dim)
            self.actor_perturbed = AutoRegressiveStochasticActor(self.state_dim, self.action_dim)

        if self.normalize_observations:
            self.obs_rms = RunningMeanStd(shape=self.state_dim)
        else:
            self.obs_rms = None

        if self.normalize_rewards:
            self.ret_rms = RunningMeanStd(shape=1)
            self.ret = 0
        else:
            self.ret_rms = None

        # initialize trainable variables
        self.actor(
            tf.zeros([self.batch_size, self.state_dim]),
            tf.zeros([self.batch_size, self.action_dim])
        )
        self.target_actor(
            tf.zeros([self.batch_size, self.state_dim]),
            tf.zeros([self.batch_size, self.action_dim])
        )

        self.critics = Critic(self.state_dim, self.action_dim)
        self.target_critics = Critic(self.state_dim, self.action_dim)

        # initialize trainable variables for critics
        self.critics(
            tf.zeros([self.batch_size, self.state_dim]),
            tf.zeros([self.batch_size, self.action_dim])
        )
        self.target_critics(
            tf.zeros([self.batch_size, self.state_dim]),
            tf.zeros([self.batch_size, self.action_dim])
        )

        self.value = Value(self.state_dim)
        self.target_value = Value(self.state_dim)

        # initialize value training variables
        self.value(tf.zeros([self.batch_size, self.state_dim]))
        self.value(tf.zeros([self.batch_size, self.state_dim]))

        # initialize the target networks.
        update(self.target_actor, self.actor, 1.0)
        update(self.target_critics, self.critics, 1.0)
        update(self.target_value, self.value, 1.0)

        self.replay = ReplayBuffer(self.state_dim, self.action_dim, self.buffer_size)
        self.action_sampler = ActionSampler(self.actor.action_dim)

    def train_one_step(self):
        """
        Execute one update for each of the networks. Note that if no positive advantage elements
        are returned the algorithm doesn't update the actor parameters.

        Args:
            None

        Returns:
            None
        """
        # transitions is sampled from replay buffer
        transitions = self.replay.sample_batch(self.batch_size)
        state_batch = normalize(transitions.s, self.obs_rms)
        action_batch = transitions.a
        reward_batch = normalize(transitions.r, self.ret_rms)
        next_state_batch = normalize(transitions.sp, self.obs_rms)
        terminal_mask = transitions.it
        # transitions is sampled from replay buffer
        self.critics.train(
            state_batch,
            action_batch,
            reward_batch,
            next_state_batch,
            terminal_mask,
            self.target_value,
            self.gamma,
            self.q_normalization
        )
        self.value.train(
            state_batch,
            self.target_actor,
            self.target_critics,
            self.action_samples
        )
        # note that transitions.s represents the sampled states from the memory buffer
        states, actions, advantages = self._sample_positive_advantage_actions(state_batch)
        if advantages.shape[0]:
            self.actor.train(
                states,
                actions,
                advantages,
                self.mode,
                self.beta
            )
        update(self.target_actor, self.actor, self.tau)
        update(self.target_critics, self.critics, self.tau)
        update(self.target_value, self.value, self.tau)

    def _sample_positive_advantage_actions(self, states):
        """
        Sample from the target network and a uniform distribution.
        Then only keep the actions with positive advantage.
        Returning one action per state, if more needed, make states contain the
        same state multiple times.

        Args:
            states (tf.Variable): states of dimension (batch_size, state_dim)

        Returns:
            good_states (list): Set of positive advantage states (batch_size, sate_dim)
            good_actions (list): Set of positive advantage actions
            advantages (list[float]): set of positive advantage values (Q - V)
        """
        # tile states to be of dimension (batch_size * K, state_dim)
        tiled_states = tf.tile(states, [self.action_samples, 1])
        # Sample actions with noise for regularization
        target_actions = self.action_sampler.get_actions(self.target_actor, tiled_states)
        target_actions += tf.random.normal(target_actions.shape) * 0.01
        target_actions = tf.clip_by_value(target_actions, -1, 1)
        target_q = self.target_critics(tiled_states, target_actions)
        # Sample multiple actions both from the target policy and from a uniform distribution
        # over the action space. These will be used to determine the target distribution
        random_actions = tf.random.uniform(target_actions.shape, minval=-1.0, maxval=1.0)
        random_q = self.target_critics(tiled_states, random_actions)
        # create target actions vector, consistent of purely random actions and noisy actions
        # for the sake of exploration
        target_actions = tf.concat([target_actions, random_actions], 0)
        # compute Q and V values with dimensions (2 * batch_size * K, 1)
        q = tf.concat([target_q, random_q], 0)
        # determine the estimated value of a given state
        v = self.target_value(tiled_states)
        v = tf.concat([v, v], 0)
        # expand tiled states to allow for indexing later on
        tiled_states = tf.concat([tiled_states, tiled_states], 0)
        # remove unused dimensions
        q_squeezed = tf.squeeze(q)
        v_squeezed = tf.squeeze(v)
        # select s, a with positive advantage
        squeezed_indicies = tf.where(q_squeezed > v_squeezed)
        # collect all advantegeous states and actions
        good_states = tf.gather_nd(tiled_states, squeezed_indicies)
        good_actions = tf.gather_nd(target_actions, squeezed_indicies)
        # retrieve advantage values
        advantages = tf.gather_nd(q-v, squeezed_indicies)
        return good_states, good_actions, advantages

    def get_action(self, states):
        """
        Get a set of actions for a batch of states

        Args:
            states (tf.Variable): dimensions (batch_size, state_dim)

        Returns:
            sampled actions for the given state with dimension (batch_size, action_dim)
        """
        return self.action_sampler.get_actions(self.actor, states)

    def select_perturbed_action(self, state, action_noise=None, param_noise=None):
        """
        Select actions from the perturbed actor using action noise and parameter noise

        Args:
            state (tf.Variable): tf variable containing the state vector
            action_niose (function): action noise function which will construct noise from some
                distribution
            param_noise (boolean): boolean indicating that parameter noise is necessary

        Returns:
            action vector of dimension (batch_size, action_dim). Note that if both action noise and
                param noise are None, this function is the same as get_action.
        """
        state = normalize(tf.Variable(state, dtype=tf.float32), self.obs_rms)
        if param_noise is not None:
            action = self.action_sampler.get_actions(self.actor_perturbed, state)
        else:
            action = self.action_sampler.get_actions(self.actor, state)
        if action_noise is not None:
            action += tf.Variable(action_noise(), dtype=tf.float32)
        action = tf.clip_by_value(action, -1, 1)
        return action

    def perturb_actor_parameters(self, param_noise):
        """
        Apply parameter noise to actor model, for exploration

        Args:
            param_noise (AdaptiveParamNoiseSpec): Object containing adaptive parameter noise
                specifications
        """
        update(self.actor_perturbed, self.actor, 1)
        params = self.actor_perturbed.trainable_variables
        for variable in params:
            variable.assign(variable + tf.random.normal(param.shape) * param_noise.current_stddev)

    def store_transition(self, state, action, reward, next_state, is_done):
        """
        Store the transition in the replay buffer with normalizing, should it be specified.

        Args:
            state (tf.Variable): (batch_size, state_size) state vector
            action (tf.Variable): (batch_size, action_size) action vector
            reward (float): reward value determined by the environment (batch_size, 1)
            next_state (tf.Variable): (batch_size, state_size) next state vector
            is_done (boolean): value to indicate that the state is terminal
        """
        self.replay.store(state, action, reward, next_state, is_done)
        if self.normalize_observations:
            self.obs_rms.update(state)
        if self.normalize_rewards:
            self.ret = self.ret * self.gamma + reward
            self.ret_rms.update(np.array([self.ret]))
            if is_done:
                self.ret = 0
Beispiel #4
0
else:
    raise NotImplementedError

if args.agent == 'ppo':
    algorithm = PPO(device, state_dim, action_dim, agent_args)
elif args.agent == 'sac':
    algorithm = SAC(device, state_dim, action_dim, agent_args)
else:
    raise NotImplementedError
agent = Agent(algorithm, writer, device, state_dim, action_dim, agent_args,
              demonstrations_location_args)
if device == 'cuda':
    agent = agent.cuda()
    discriminator = discriminator.cuda()

state_rms = RunningMeanStd(state_dim)

score_lst = []
discriminator_score_lst = []
score = 0.0
discriminator_score = 0
if agent_args.on_policy == True:
    state_lst = []
    state_ = (env.reset())
    state = np.clip((state_ - state_rms.mean) / (state_rms.var**0.5 + 1e-8),
                    -5, 5)
    for n_epi in range(args.epochs):
        for t in range(agent_args.traj_length):
            if args.render:
                env.render()
            state_lst.append(state_)
Beispiel #5
0
    def __init__(self, cfg, envs, acmodel, agent_data, **kwargs):
        num_frames_per_proc = getattr(cfg, "frames_per_proc", 128)
        discount = getattr(cfg, "discount", 0.99)
        gae_lambda = getattr(cfg, "gae_lambda", 0.95)
        entropy_coef = getattr(cfg, "entropy_coef", 0.01)
        value_loss_coef = getattr(cfg, "value_loss_coef", 0.5)
        max_grad_norm = getattr(cfg, "max_grad_norm", 0.5)
        recurrence = getattr(cfg, "recurrence", 4)

        clip_eps = getattr(cfg, "clip_eps", 0.)
        epochs = getattr(cfg, "epochs", 4)
        batch_size = getattr(cfg, "batch_size", 256)

        optimizer = getattr(cfg, "optimizer", "Adam")
        optimizer_args = getattr(cfg, "optimizer_args", {})

        exp_used_pred = getattr(cfg, "exp_used_pred", 0.25)
        preprocess_obss = kwargs.get("preprocess_obss", None)
        reshape_reward = kwargs.get("reshape_reward", None)

        self.running_norm_obs = getattr(cfg, "running_norm_obs", False)

        self.nminibatches = getattr(cfg, "nminibatches", 4)

        super().__init__(envs, acmodel, num_frames_per_proc, discount,
                         gae_lambda, entropy_coef, value_loss_coef,
                         max_grad_norm, recurrence, preprocess_obss,
                         reshape_reward, exp_used_pred)

        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.int_coeff = cfg.int_coeff
        self.ext_coeff = cfg.ext_coeff

        assert self.batch_size % self.recurrence == 0

        optimizer_args = vars(optimizer_args)

        self.optimizer_policy = getattr(torch.optim, optimizer)(
            self.acmodel.policy_model.parameters(), **optimizer_args)

        # Prepare intrinsic generators
        self.acmodel.random_target.eval()
        self.predictor_rms = RunningMeanStd()
        self.predictor_rff = RewardForwardFilter(gamma=self.discount)

        self.optimizer_predictor = getattr(torch.optim, optimizer)(
            self.acmodel.predictor_network.parameters(), **optimizer_args)

        if "optimizer_policy" in agent_data:
            self.optimizer_policy.load_state_dict(
                agent_data["optimizer_policy"])
            self.optimizer_predictor.load_state_dict(
                agent_data["optimizer_predictor"])
            self.predictor_rms = agent_data[
                "predictor_rms"]  # type: RunningMeanStd

        self.batch_num = 0

        if self.running_norm_obs:
            self.collect_random_statistics(50)
Beispiel #6
0
class PPORND(TwoValueHeadsBase):
    """The class for the Proximal Policy Optimization algorithm
    ([Schulman et al., 2015](https://arxiv.org/abs/1707.06347))."""
    def __init__(self, cfg, envs, acmodel, agent_data, **kwargs):
        num_frames_per_proc = getattr(cfg, "frames_per_proc", 128)
        discount = getattr(cfg, "discount", 0.99)
        gae_lambda = getattr(cfg, "gae_lambda", 0.95)
        entropy_coef = getattr(cfg, "entropy_coef", 0.01)
        value_loss_coef = getattr(cfg, "value_loss_coef", 0.5)
        max_grad_norm = getattr(cfg, "max_grad_norm", 0.5)
        recurrence = getattr(cfg, "recurrence", 4)

        clip_eps = getattr(cfg, "clip_eps", 0.)
        epochs = getattr(cfg, "epochs", 4)
        batch_size = getattr(cfg, "batch_size", 256)

        optimizer = getattr(cfg, "optimizer", "Adam")
        optimizer_args = getattr(cfg, "optimizer_args", {})

        exp_used_pred = getattr(cfg, "exp_used_pred", 0.25)
        preprocess_obss = kwargs.get("preprocess_obss", None)
        reshape_reward = kwargs.get("reshape_reward", None)

        self.running_norm_obs = getattr(cfg, "running_norm_obs", False)

        self.nminibatches = getattr(cfg, "nminibatches", 4)

        super().__init__(envs, acmodel, num_frames_per_proc, discount,
                         gae_lambda, entropy_coef, value_loss_coef,
                         max_grad_norm, recurrence, preprocess_obss,
                         reshape_reward, exp_used_pred)

        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.int_coeff = cfg.int_coeff
        self.ext_coeff = cfg.ext_coeff

        assert self.batch_size % self.recurrence == 0

        optimizer_args = vars(optimizer_args)

        self.optimizer_policy = getattr(torch.optim, optimizer)(
            self.acmodel.policy_model.parameters(), **optimizer_args)

        # Prepare intrinsic generators
        self.acmodel.random_target.eval()
        self.predictor_rms = RunningMeanStd()
        self.predictor_rff = RewardForwardFilter(gamma=self.discount)

        self.optimizer_predictor = getattr(torch.optim, optimizer)(
            self.acmodel.predictor_network.parameters(), **optimizer_args)

        if "optimizer_policy" in agent_data:
            self.optimizer_policy.load_state_dict(
                agent_data["optimizer_policy"])
            self.optimizer_predictor.load_state_dict(
                agent_data["optimizer_predictor"])
            self.predictor_rms = agent_data[
                "predictor_rms"]  # type: RunningMeanStd

        self.batch_num = 0

        if self.running_norm_obs:
            self.collect_random_statistics(50)

    def update_parameters(self):
        # Collect experiences

        exps, logs = self.collect_experiences()

        for epoch_no in range(self.epochs):
            # Initialize log values

            log_entropies = []
            log_values_ext = []
            log_values_int = []
            log_policy_losses = []
            log_value_ext_losses = []
            log_value_int_losses = []
            log_grad_norms = []
            log_ret_int = []
            log_rew_int = []
            for inds in self._get_batches_starting_indexes():
                # Initialize batch values

                batch_entropy = 0
                batch_ext_value = 0
                batch_int_value = 0
                batch_policy_loss = 0
                batch_value_ext_loss = 0
                batch_value_int_loss = 0
                batch_loss = 0
                batch_ret_int = 0
                batch_rew_int = 0

                # Initialize memory

                if self.acmodel.recurrent:
                    memory = exps.memory[inds]

                for i in range(self.recurrence):
                    # Create a sub-batch of experience

                    sb = exps[inds + i]
                    # Compute loss

                    if self.acmodel.recurrent:
                        dist, vvalue, memory = self.acmodel.policy_model(
                            sb.obs, memory * sb.mask)
                    else:
                        dist, vvalue = self.acmodel.policy_model(sb.obs)

                    entropy = dist.entropy().mean()

                    ratio = torch.exp(dist.log_prob(sb.action) - sb.log_prob)
                    adv = (self.int_coeff * sb.advantage_int +
                           self.ext_coeff * sb.advantage_ext)
                    surr1 = ratio * adv
                    surr2 = torch.clamp(ratio, 1.0 - self.clip_eps,
                                        1.0 + self.clip_eps) * adv
                    policy_loss = -torch.min(surr1, surr2).mean()

                    # Value losses
                    value_ext, value_int = vvalue

                    value_ext_clipped = sb.value_ext + torch.clamp(
                        value_ext - sb.value_ext, -self.clip_eps,
                        self.clip_eps)
                    surr1 = (value_ext - sb.returnn_ext).pow(2)
                    surr2 = (value_ext_clipped - sb.returnn_ext).pow(2)
                    value_ext_loss = torch.max(surr1, surr2).mean()

                    value_int_clipped = sb.value_int + torch.clamp(
                        value_int - sb.value_int, -self.clip_eps,
                        self.clip_eps)
                    surr1 = (value_int - sb.returnn_int).pow(2)
                    surr2 = (value_int_clipped - sb.returnn_int).pow(2)
                    value_int_loss = torch.max(surr1, surr2).mean()

                    loss = policy_loss - self.entropy_coef * entropy + \
                           (0.5 * self.value_loss_coef) * value_int_loss + \
                           (0.5 * self.value_loss_coef) * value_ext_loss

                    # Update batch values

                    batch_entropy += entropy.item()
                    batch_ext_value += value_ext.mean().item()
                    batch_int_value += value_int.mean().item()
                    batch_policy_loss += policy_loss.item()
                    batch_value_ext_loss += value_ext_loss.item()
                    batch_value_int_loss += value_int_loss.item()
                    batch_loss += loss
                    batch_ret_int += sb.returnn_int.mean().item()
                    batch_rew_int += sb.reward_int.mean().item()

                    # Update memories for next epoch

                    if self.acmodel.recurrent and i < self.recurrence - 1:
                        exps.memory[inds + i + 1] = memory.detach()

                    # Update Predictor loss

                    # Optimize intrinsic reward generator using only a percentage of experiences
                    norm_obs = sb.obs.image
                    obs = torch.transpose(torch.transpose(norm_obs, 1, 3), 2,
                                          3)

                    with torch.no_grad():
                        target = self.acmodel.random_target(obs)

                    pred = self.acmodel.predictor_network(obs)
                    diff_pred = (pred - target).pow_(2)

                    # Optimize intrinsic reward generator using only a percentage of experiences
                    loss_pred = diff_pred.mean(1)
                    mask = torch.rand(loss_pred.shape[0])
                    mask = (mask < self.exp_used_pred).type(
                        torch.FloatTensor).to(loss_pred.device)
                    loss_pred = (loss_pred * mask).sum() / torch.max(
                        mask.sum(),
                        torch.Tensor([1]).to(loss_pred.device))

                    self.optimizer_predictor.zero_grad()
                    loss_pred.backward()
                    grad_norm = sum(
                        p.grad.data.norm(2).item()**2
                        for p in self.acmodel.predictor_network.parameters()
                        if p.grad is not None)**0.5

                    torch.nn.utils.clip_grad_norm_(
                        self.acmodel.predictor_network.parameters(),
                        self.max_grad_norm)
                    self.optimizer_predictor.step()

                # Update batch values

                batch_entropy /= self.recurrence
                batch_ext_value /= self.recurrence
                batch_int_value /= self.recurrence
                batch_policy_loss /= self.recurrence
                batch_value_ext_loss /= self.recurrence
                batch_value_int_loss /= self.recurrence
                batch_loss /= self.recurrence
                batch_rew_int /= self.recurrence
                batch_ret_int /= self.recurrence

                # Update actor-critic

                self.optimizer_policy.zero_grad()
                batch_loss.backward()
                grad_norm = sum(
                    p.grad.data.norm(2).item()**2
                    for p in self.acmodel.policy_model.parameters()
                    if p.grad is not None)**0.5
                torch.nn.utils.clip_grad_norm_(
                    self.acmodel.policy_model.parameters(), self.max_grad_norm)
                self.optimizer_policy.step()

                # Update log values

                log_entropies.append(batch_entropy)
                log_values_ext.append(batch_ext_value)
                log_values_int.append(batch_int_value)
                log_policy_losses.append(batch_policy_loss)
                log_value_ext_losses.append(batch_value_ext_loss)
                log_value_int_losses.append(batch_value_int_loss)
                log_grad_norms.append(grad_norm)
                log_ret_int.append(batch_ret_int)
                log_rew_int.append((batch_rew_int))

        # Log some values

        logs["entropy"] = numpy.mean(log_entropies)
        logs["value_ext"] = numpy.mean(log_values_ext)
        logs["value_int"] = numpy.mean(log_values_int)
        logs["value"] = logs["value_ext"] + logs["value_int"]
        logs["policy_loss"] = numpy.mean(log_policy_losses)
        logs["value_ext_loss"] = numpy.mean(log_value_ext_losses)
        logs["value_int_loss"] = numpy.mean(log_value_int_losses)
        logs["value_loss"] = logs["value_int_loss"] + logs["value_ext_loss"]
        logs["grad_norm"] = numpy.mean(log_grad_norms)
        logs["return_int"] = numpy.mean(log_ret_int)
        logs["reward_int"] = numpy.mean(log_rew_int)

        return logs

    def _get_batches_starting_indexes(self):
        """Gives, for each batch, the indexes of the observations given to
        the model and the experiences used to compute the loss at first.

        First, the indexes are the integers from 0 to `self.num_frames` with a step of
        `self.recurrence`, shifted by `self.recurrence//2` one time in two for having
        more diverse batches. Then, the indexes are splited into the different batches.

        Returns
        -------
        batches_starting_indexes : list of list of int
            the indexes of the experiences to be used at first for each batch
        """

        indexes = numpy.arange(0, self.num_frames, self.recurrence)
        indexes = numpy.random.permutation(indexes)

        # Shift starting indexes by self.recurrence//2 half the time
        self.batch_num += 1

        num_indexes = self.batch_size // self.recurrence
        batches_starting_indexes = [
            indexes[i:i + num_indexes]
            for i in range(0, len(indexes), num_indexes)
        ]

        return batches_starting_indexes

    def get_save_data(self):
        return dict({
            "optimizer_policy":
            self.optimizer_policy.state_dict(),
            "optimizer_predictor":
            self.optimizer_predictor.state_dict(),
            "predictor_rms":
            self.predictor_rms,
        })

    def collect_random_statistics(self, num_timesteps):
        #  initialize observation normalization with data from random agent

        self.obs_rms = RunningMeanStd(shape=(1, 7, 7, 3))

        curr_obs = self.obs
        collected_obss = [None] * (self.num_frames_per_proc * num_timesteps)
        for i in range(self.num_frames_per_proc * num_timesteps):
            # Do one agent-environment interaction

            action = torch.randint(
                0, self.env.action_space.n,
                (self.num_procs, ))  # sample uniform actions
            obs, reward, done, _ = self.env.step(action.cpu().numpy())

            # Update experiences values
            collected_obss[i] = curr_obs
            curr_obs = obs

        self.obs = curr_obs
        exps = DictList()
        exps.obs = [
            collected_obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc * num_timesteps)
        ]

        images = [obs["image"] for obs in exps.obs]
        images = numpy.array(images)
        images = torch.tensor(images, dtype=torch.float)

        self.obs_rms.update(images)

    def calculate_intrinsic_reward(self, exps: DictList,
                                   dst_intrinsic_r: torch.Tensor):
        """
        replicate (should normalize by a running mean):
            X_r -- random target
            X_r_hat -- predictor

            self.feat_var = tf.reduce_mean(tf.nn.moments(X_r, axes=[0])[1])
            self.max_feat = tf.reduce_max(tf.abs(X_r))
            self.int_rew = tf.reduce_mean(tf.square(tf.stop_gradient(X_r) - X_r_hat), axis=-1, keep_dims=True)
            self.int_rew = tf.reshape(self.int_rew, (self.sy_nenvs, self.sy_nsteps - 1))

            targets = tf.stop_gradient(X_r)
            # self.aux_loss = tf.reduce_mean(tf.square(noisy_targets-X_r_hat))
            self.aux_loss = tf.reduce_mean(tf.square(targets - X_r_hat), -1)
            mask = tf.random_uniform(shape=tf.shape(self.aux_loss), minval=0., maxval=1., dtype=tf.float32)
            mask = tf.cast(mask < self.proportion_of_exp_used_for_predictor_update, tf.float32)
            self.aux_loss = tf.reduce_sum(mask * self.aux_loss) / tf.maximum(tf.reduce_sum(mask), 1.)

        """
        if self.running_norm_obs:
            obs = exps.obs.image * 15.0  # horrible harcoded normalized factor
            # normalize the observations for predictor and target networks
            norm_obs = torch.clamp(
                torch.div(
                    (obs - self.obs_rms.mean.to(exps.obs.image.device)),
                    torch.sqrt(self.obs_rms.var).to(exps.obs.image.device)),
                -5.0, 5.0)

            self.obs_rms.update(obs.cpu())  # update running mean
        else:
            # Without norm
            norm_obs = exps.obs.image

        obs = torch.transpose(torch.transpose(norm_obs, 1, 3), 2, 3)

        with torch.no_grad():
            target = self.acmodel.random_target(obs)
            pred = self.acmodel.predictor_network(obs)

        diff_pred = (pred - target).pow_(2)

        # -- Calculate intrinsic & Normalize intrinsic rewards
        int_rew = diff_pred.detach().mean(1)
        # TODO BIG BIG BUG previously  - :( int_rew  is (self.num_procs, self.num_frames_per_proc,)
        # TODO this should fix it but should double check
        int_rew = int_rew.view(
            (self.num_procs, self.num_frames_per_proc)).transpose(0, 1)

        dst_intrinsic_r.copy_(int_rew)

        # Normalize intrinsic reward
        self.predictor_rff.reset()
        int_rff = torch.zeros((self.num_frames_per_proc, self.num_procs),
                              device=self.device)

        for i in reversed(range(self.num_frames_per_proc)):
            int_rff[i] = self.predictor_rff.update(dst_intrinsic_r[i])

        self.predictor_rms.update(int_rff.view(-1))
        # dst_intrinsic_r.sub_(self.predictor_rms.mean.to(dst_intrinsic_r.device))
        dst_intrinsic_r.div_(
            torch.sqrt(self.predictor_rms.var).to(dst_intrinsic_r.device))
    def __init__(self, cfg, envs, acmodel, agent_data, **kwargs):
        num_frames_per_proc = getattr(cfg, "num_frames_per_proc", 128)
        discount = getattr(cfg, "discount", 0.99)
        gae_lambda = getattr(cfg, "gae_lambda", 0.95)
        entropy_coef = getattr(cfg, "entropy_coef", 0.01)
        value_loss_coef = getattr(cfg, "value_loss_coef", 0.5)
        max_grad_norm = getattr(cfg, "max_grad_norm", 0.5)
        recurrence = getattr(cfg, "recurrence", 4)
        clip_eps = getattr(cfg, "clip_eps", 0.)
        epochs = getattr(cfg, "epochs", 4)
        batch_size = getattr(cfg, "batch_size", 256)

        optimizer = getattr(cfg, "optimizer", "Adam")
        optimizer_args = getattr(cfg, "optimizer_args", {})

        exp_used_pred = getattr(cfg, "exp_used_pred", 0.25)
        preprocess_obss = kwargs.get("preprocess_obss", None)
        reshape_reward = kwargs.get("reshape_reward", None)
        eval_envs = kwargs.get("eval_envs", [])

        self.recurrence_worlds = getattr(cfg, "recurrence_worlds", 16)
        self.running_norm_obs = getattr(cfg, "running_norm_obs", False)
        self.nminibatches = getattr(cfg, "nminibatches", 4)
        self.out_dir = getattr(cfg, "out_dir", None)
        self.pre_fill_memories = pre_fill_memories = getattr(
            cfg, "pre_fill_memories", 1)

        self.save_experience_batch = getattr(cfg, "save_experience_batch", 5)

        super().__init__(envs, acmodel, num_frames_per_proc, discount,
                         gae_lambda, entropy_coef, value_loss_coef,
                         max_grad_norm, recurrence, preprocess_obss,
                         reshape_reward, exp_used_pred)

        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.int_coeff = cfg.int_coeff
        self.ext_coeff = cfg.ext_coeff

        assert self.batch_size % self.recurrence == 0

        # -- Prepare intrinsic generators
        # self.acmodel.random_target.eval()
        self.predictor_rms = RunningMeanStd()
        self.predictor_rff = RewardForwardFilter(gamma=self.discount)

        # -- Prepare optimizers
        optimizer_args = vars(optimizer_args)

        self.optimizer_policy = getattr(torch.optim, optimizer)(
            self.acmodel.policy_model.parameters(), **optimizer_args)

        self.optimizer_agworld = getattr(torch.optim, optimizer)(
            self.acmodel.curiosity_model.parameters(), **optimizer_args)

        if "optimizer_policy" in agent_data:
            self.optimizer_policy.load_state_dict(
                agent_data["optimizer_policy"])
            self.optimizer_agworld.load_state_dict(
                agent_data["optimizer_agworld"])
            self.predictor_rms = agent_data[
                "predictor_rms"]  # type: RunningMeanStd

        self.batch_num = 0
        self.updates_cnt = 0

        # get width and height of the observation space for position normalization
        self.env_width = envs[0][0].unwrapped.width
        self.env_height = envs[0][0].unwrapped.height

        if self.running_norm_obs:
            self.collect_random_statistics(50)

        # -- Previous batch of experiences last frame
        self.prev_frame_exps = None

        # -- Init evaluator envs
        self.eval_envs = None
        self.eval_memory = None
        self.eval_mask = None
        self.eval_icm_memory = None
        self.eval_dir = None

        if len(eval_envs) > 0:
            self.eval_envs = self.init_evaluator(eval_envs)
            self.eval_dir = os.path.join(self.out_dir, "eval")
            if not os.path.isdir(self.eval_dir):
                os.mkdir(self.eval_dir)

        # remember some log values from intrinsic rewards computation
        self.aux_logs = {}
class PPOPE(TwoValueHeadsBaseGeneral):
    """The class for the Proximal Policy Optimization algorithm
    ([Schulman et al., 2015](https://arxiv.org/abs/1707.06347))."""
    def __init__(self, cfg, envs, acmodel, agent_data, **kwargs):
        num_frames_per_proc = getattr(cfg, "num_frames_per_proc", 128)
        discount = getattr(cfg, "discount", 0.99)
        gae_lambda = getattr(cfg, "gae_lambda", 0.95)
        entropy_coef = getattr(cfg, "entropy_coef", 0.01)
        value_loss_coef = getattr(cfg, "value_loss_coef", 0.5)
        max_grad_norm = getattr(cfg, "max_grad_norm", 0.5)
        recurrence = getattr(cfg, "recurrence", 4)
        clip_eps = getattr(cfg, "clip_eps", 0.)
        epochs = getattr(cfg, "epochs", 4)
        batch_size = getattr(cfg, "batch_size", 256)

        optimizer = getattr(cfg, "optimizer", "Adam")
        optimizer_args = getattr(cfg, "optimizer_args", {})

        exp_used_pred = getattr(cfg, "exp_used_pred", 0.25)
        preprocess_obss = kwargs.get("preprocess_obss", None)
        reshape_reward = kwargs.get("reshape_reward", None)
        eval_envs = kwargs.get("eval_envs", [])

        self.recurrence_worlds = getattr(cfg, "recurrence_worlds", 16)
        self.running_norm_obs = getattr(cfg, "running_norm_obs", False)
        self.nminibatches = getattr(cfg, "nminibatches", 4)
        self.out_dir = getattr(cfg, "out_dir", None)
        self.pre_fill_memories = pre_fill_memories = getattr(
            cfg, "pre_fill_memories", 1)

        self.save_experience_batch = getattr(cfg, "save_experience_batch", 5)

        super().__init__(envs, acmodel, num_frames_per_proc, discount,
                         gae_lambda, entropy_coef, value_loss_coef,
                         max_grad_norm, recurrence, preprocess_obss,
                         reshape_reward, exp_used_pred)

        self.clip_eps = clip_eps
        self.epochs = epochs
        self.batch_size = batch_size
        self.int_coeff = cfg.int_coeff
        self.ext_coeff = cfg.ext_coeff

        assert self.batch_size % self.recurrence == 0

        # -- Prepare intrinsic generators
        # self.acmodel.random_target.eval()
        self.predictor_rms = RunningMeanStd()
        self.predictor_rff = RewardForwardFilter(gamma=self.discount)

        # -- Prepare optimizers
        optimizer_args = vars(optimizer_args)

        self.optimizer_policy = getattr(torch.optim, optimizer)(
            self.acmodel.policy_model.parameters(), **optimizer_args)

        self.optimizer_agworld = getattr(torch.optim, optimizer)(
            self.acmodel.curiosity_model.parameters(), **optimizer_args)

        if "optimizer_policy" in agent_data:
            self.optimizer_policy.load_state_dict(
                agent_data["optimizer_policy"])
            self.optimizer_agworld.load_state_dict(
                agent_data["optimizer_agworld"])
            self.predictor_rms = agent_data[
                "predictor_rms"]  # type: RunningMeanStd

        self.batch_num = 0
        self.updates_cnt = 0

        # get width and height of the observation space for position normalization
        self.env_width = envs[0][0].unwrapped.width
        self.env_height = envs[0][0].unwrapped.height

        if self.running_norm_obs:
            self.collect_random_statistics(50)

        # -- Previous batch of experiences last frame
        self.prev_frame_exps = None

        # -- Init evaluator envs
        self.eval_envs = None
        self.eval_memory = None
        self.eval_mask = None
        self.eval_icm_memory = None
        self.eval_dir = None

        if len(eval_envs) > 0:
            self.eval_envs = self.init_evaluator(eval_envs)
            self.eval_dir = os.path.join(self.out_dir, "eval")
            if not os.path.isdir(self.eval_dir):
                os.mkdir(self.eval_dir)

        # remember some log values from intrinsic rewards computation
        self.aux_logs = {}

    def init_evaluator(self, envs):
        from torch_rl.utils import ParallelEnv
        device = self.device
        acmodel = self.acmodel

        eval_envs = ParallelEnv(envs)
        obs = eval_envs.reset()

        if self.acmodel.recurrent:
            self.eval_memory = torch.zeros(len(obs),
                                           acmodel.memory_size,
                                           device=device)

        self.eval_agworld_memory = torch.zeros(
            len(obs), acmodel.curiosity_model.memory_size, device=device)
        self.eval_mask = torch.ones(len(obs), device=device)
        return eval_envs

    def augment_exp(self, exps):

        # from exp (P * T , ** ) -> (T, P, **)
        num_procs = self.num_procs
        num_frames_per_proc = self.num_frames_per_proc
        device = self.device
        env = self.env
        agworld_network = self.acmodel.curiosity_model

        shape = torch.Size([num_procs, num_frames_per_proc])
        frame_exp = Namespace()

        # ------------------------------------------------------------------------------------------
        # Redo in format T x P

        for k, v in exps.items():
            if k == "obs":
                continue
            setattr(frame_exp, k,
                    v.view(shape + v.size()[1:]).transpose(0, 1).contiguous())

        def inverse_img(t, ii):
            return torch.transpose(torch.transpose(t, ii, ii + 2), ii + 1,
                                   ii + 2).contiguous()

        frame_exp.obs_image = inverse_img(frame_exp.obs_image, 2)

        #frame_exp.states = inverse_img(frame_exp.states, 2)

        def gen_memory(ss):
            return torch.zeros(num_frames_per_proc,
                               num_procs,
                               ss,
                               device=device)

        frame_exp.agworld_mems = gen_memory(agworld_network.memory_size)
        frame_exp.agworld_embs = gen_memory(agworld_network.embedding_size)

        frame_exp.actions_onehot = gen_memory(env.action_space.n)
        frame_exp.actions_onehot.scatter_(2,
                                          frame_exp.action.unsqueeze(2).long(),
                                          1.)

        # ------------------------------------------------------------------------------------------
        # Save last frame exp

        last_frame_exp = Namespace()
        for k, v in frame_exp.__dict__.items():
            if k == "obs":
                continue
            setattr(last_frame_exp, k, v[-1].clone())

        prev_frame_exps = self.prev_frame_exps
        if self.prev_frame_exps is None:
            prev_frame_exps = deepcopy(last_frame_exp)
            for k, v in prev_frame_exps.__dict__.items():
                v.zero_()

        self.prev_frame_exps = last_frame_exp

        # ------------------------------------------------------------------------------------------
        # Fill memories with past

        frame_exp.agworld_mems[0] = prev_frame_exps.agworld_mems
        frame_exp.agworld_embs[0] = prev_frame_exps.agworld_embs

        return frame_exp, prev_frame_exps

    @staticmethod
    def flip_back_experience(exp):
        # for all tensors below, T x P -> P x T -> P * T
        for k, v in exp.__dict__.items():
            setattr(exp, k,
                    v.transpose(0, 1).reshape(-1, *v.shape[2:]).contiguous())
        return exp

    def update_parameters(self):
        # Collect experiences

        exps, logs = self.collect_experiences()

        log_entropies = []
        log_values_ext = []
        log_values_int = []
        log_policy_losses = []
        log_value_ext_losses = []
        log_value_int_losses = []
        log_grad_norms = []
        log_ret_int = []
        log_rew_int = []
        batch_ret_int = 0
        batch_rew_int = 0

        for epoch_no in range(self.epochs):
            # Initialize log values

            # Loop for Policy

            for inds in self._get_batches_starting_indexes():
                # Initialize batch values

                batch_entropy = 0
                batch_ext_value = 0
                batch_int_value = 0
                batch_policy_loss = 0
                batch_value_ext_loss = 0
                batch_value_int_loss = 0
                batch_loss = 0

                # Initialize memory

                if self.acmodel.recurrent:
                    memory = exps.memory[inds]

                for i in range(self.recurrence):
                    # Create a sub-batch of experience

                    sb = exps[inds + i]
                    # Compute loss

                    if self.acmodel.recurrent:
                        dist, vvalue, memory = self.acmodel.policy_model(
                            sb.obs, memory * sb.mask)
                    else:
                        dist, vvalue = self.acmodel.policy_model(sb.obs)

                    entropy = dist.entropy().mean()

                    ratio = torch.exp(dist.log_prob(sb.action) - sb.log_prob)
                    adv = (self.int_coeff * sb.advantage_int +
                           self.ext_coeff * sb.advantage_ext)
                    surr1 = ratio * adv
                    surr2 = torch.clamp(ratio, 1.0 - self.clip_eps,
                                        1.0 + self.clip_eps) * adv
                    policy_loss = -torch.min(surr1, surr2).mean()

                    # Value losses
                    value_ext, value_int = vvalue

                    value_ext_clipped = sb.value_ext + torch.clamp(
                        value_ext - sb.value_ext, -self.clip_eps,
                        self.clip_eps)
                    surr1 = (value_ext - sb.returnn_ext).pow(2)
                    surr2 = (value_ext_clipped - sb.returnn_ext).pow(2)
                    value_ext_loss = torch.max(surr1, surr2).mean()

                    value_int_clipped = sb.value_int + torch.clamp(
                        value_int - sb.value_int, -self.clip_eps,
                        self.clip_eps)
                    surr1 = (value_int - sb.returnn_int).pow(2)
                    surr2 = (value_int_clipped - sb.returnn_int).pow(2)
                    value_int_loss = torch.max(surr1, surr2).mean()

                    loss = policy_loss - self.entropy_coef * entropy + \
                           (0.5 * self.value_loss_coef) * value_int_loss + \
                           (0.5 * self.value_loss_coef) * value_ext_loss

                    # Update batch values

                    batch_entropy += entropy.item()
                    batch_ext_value += value_ext.mean().item()
                    batch_int_value += value_int.mean().item()
                    batch_policy_loss += policy_loss.item()
                    batch_value_ext_loss += value_ext_loss.item()
                    batch_value_int_loss += value_int_loss.item()
                    batch_loss += loss
                    batch_ret_int += sb.returnn_int.mean().item()
                    batch_rew_int += sb.reward_int.mean().item()

                    # Update memories for next epoch

                    if self.acmodel.recurrent and i < self.recurrence - 1:
                        exps.memory[inds + i + 1] = memory.detach()

                # Update batch values

                batch_entropy /= self.recurrence
                batch_ext_value /= self.recurrence
                batch_int_value /= self.recurrence
                batch_policy_loss /= self.recurrence
                batch_value_ext_loss /= self.recurrence
                batch_value_int_loss /= self.recurrence
                batch_loss /= self.recurrence
                batch_rew_int /= self.recurrence
                batch_ret_int /= self.recurrence

                # Update actor-critic
                self.optimizer_policy.zero_grad()
                batch_loss.backward()
                grad_norm = sum(
                    p.grad.data.norm(2).item()**2
                    for p in self.acmodel.policy_model.parameters()
                    if p.grad is not None)**0.5
                torch.nn.utils.clip_grad_norm_(
                    self.acmodel.policy_model.parameters(), self.max_grad_norm)
                self.optimizer_policy.step()

                # Update log values

                log_entropies.append(batch_entropy)
                log_values_ext.append(batch_ext_value)
                log_values_int.append(batch_int_value)
                log_policy_losses.append(batch_policy_loss)
                log_value_ext_losses.append(batch_value_ext_loss)
                log_value_int_losses.append(batch_value_int_loss)
                log_grad_norms.append(grad_norm)
                log_ret_int.append(batch_ret_int)
                log_rew_int.append(batch_rew_int)

        # Log some values

        logs["entropy"] = np.mean(log_entropies)
        logs["value_ext"] = np.mean(log_values_ext)
        logs["value_int"] = np.mean(log_values_int)
        logs["value"] = logs["value_ext"] + logs["value_int"]
        logs["policy_loss"] = np.mean(log_policy_losses)
        logs["value_ext_loss"] = np.mean(log_value_ext_losses)
        logs["value_int_loss"] = np.mean(log_value_int_losses)
        logs["value_loss"] = logs["value_int_loss"] + logs["value_ext_loss"]
        logs["grad_norm"] = np.mean(log_grad_norms)
        logs["return_int"] = np.mean(log_ret_int)
        logs["reward_int"] = np.mean(log_rew_int)

        # add extra logs from intrinsic rewards
        for k in self.aux_logs:
            logs[k] = self.aux_logs[k]

        self.updates_cnt += 1
        return logs

    def _get_batches_starting_indexes(self, recurrence=None, padding=0):
        """Gives, for each batch, the indexes of the observations given to
        the model and the experiences used to compute the loss at first.

        First, the indexes are the integers from 0 to `self.num_frames` with a step of
        `self.recurrence`, shifted by `self.recurrence//2` one time in two for having
        more diverse batches. Then, the indexes are splited into the different batches.

        Returns
        -------
        batches_starting_indexes : list of list of int
            the indexes of the experiences to be used at first for each batch

        """
        num_frames_per_proc = self.num_frames_per_proc
        num_procs = self.num_procs

        if recurrence is None:
            recurrence = self.recurrence

        # Consider Num frames list ordered P * T
        if padding == 0:
            indexes = np.arange(0, self.num_frames, recurrence)
        else:
            # Consider Num frames list ordered P * T
            # Do not index step[:padding] and step[-padding:]
            frame_index = np.arange(
                padding, num_frames_per_proc - padding + 1 - recurrence,
                recurrence)
            indexes = np.resize(frame_index.reshape((1, -1)),
                                (num_procs, len(frame_index)))
            indexes = indexes + np.arange(0, num_procs).reshape(
                -1, 1) * num_frames_per_proc
            indexes = indexes.reshape(-1)

        indexes = np.random.permutation(indexes)

        # Shift starting indexes by recurrence//2 half the time
        # TODO Check this ; Bad fix
        if recurrence is None:
            self.batch_num += 1

        num_indexes = self.batch_size // recurrence
        batches_starting_indexes = [
            indexes[i:i + num_indexes]
            for i in range(0, len(indexes), num_indexes)
        ]

        return batches_starting_indexes

    def get_save_data(self):
        return dict({
            "optimizer_policy": self.optimizer_policy.state_dict(),
            "optimizer_agworld": self.optimizer_agworld.state_dict(),
            "predictor_rms": self.predictor_rms,
        })

    def collect_random_statistics(self, num_timesteps):
        #  initialize observation normalization with data from random agent

        self.obs_rms = RunningMeanStd(shape=(1, 7, 7, 3))

        curr_obs = self.obs
        collected_obss = [None] * (self.num_frames_per_proc * num_timesteps)
        for i in range(self.num_frames_per_proc * num_timesteps):
            # Do one agent-environment interaction

            action = torch.randint(
                0, self.env.action_space.n,
                (self.num_procs, ))  # sample uniform actions
            obs, reward, done, _ = self.env.step(action.cpu().numpy())

            # Update experiences values
            collected_obss[i] = curr_obs
            curr_obs = obs

        self.obs = curr_obs
        exps = DictList()
        exps.obs = [
            collected_obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc * num_timesteps)
        ]

        images = [obs["image"] for obs in exps.obs]
        images = np.array(images)
        images = torch.tensor(images, dtype=torch.float)

        self.obs_rms.update(images)

    def calculate_intrinsic_reward(self, exps: DictList,
                                   dst_intrinsic_r: torch.Tensor):

        # ------------------------------------------------------------------------------------------
        # Run worlds models & generate memories

        agworld_network = self.acmodel.curiosity_model

        num_procs = self.num_procs
        num_frames_per_proc = self.num_frames_per_proc
        device = self.device

        # ------------------------------------------------------------------------------------------
        # Get observations and full states
        f, prev_frame_exps = self.augment_exp(exps)

        # Save state

        out_dir = self.eval_dir
        updates_cnt = self.updates_cnt
        save_experience_batch = self.save_experience_batch
        save = save_experience_batch > 0 and (updates_cnt +
                                              1) % save_experience_batch == 0

        # ------------------------------------------------------------------------------------------
        if self.pre_fill_memories:
            prev_actions = prev_frame_exps.actions_onehot
            for i in range(num_frames_per_proc - 1):
                obs = f.obs_image[i]
                masks = f.mask[i]

                # Do one agent-environment interaction
                with torch.no_grad():
                    _, f.agworld_mems[i + 1], f.agworld_embs[i] = \
                        agworld_network(obs, f.agworld_mems[i] * masks, prev_actions)
                    prev_actions = f.actions_onehot[i]

        # ------------------------------------------------------------------------------------------
        # -- Compute Intrinsic rewards\

        # Initialize each observation channel
        pnext_obs1 = torch.zeros(num_frames_per_proc,
                                 num_procs,
                                 agworld_network.obs_channel1[0],
                                 agworld_network.obs_channel1[1],
                                 agworld_network.obs_channel1[2],
                                 device=device)

        pnext_obs2 = torch.zeros(num_frames_per_proc,
                                 num_procs,
                                 agworld_network.obs_channel2[0],
                                 agworld_network.obs_channel2[1],
                                 agworld_network.obs_channel2[2],
                                 device=device)

        pnext_obs3 = torch.zeros(num_frames_per_proc,
                                 num_procs,
                                 agworld_network.obs_channel3[0],
                                 agworld_network.obs_channel3[1],
                                 agworld_network.obs_channel3[2],
                                 device=device)

        prev_actions = prev_frame_exps.actions_onehot
        prev_memory = prev_frame_exps.agworld_mems
        for i in range(num_frames_per_proc - 1):
            obs = f.obs_image[i]
            next_obs = (f.obs_image[i + 1] * 15).type(
                torch.int64)  # TODO bad fix
            masks = f.mask[i]
            actions = f.actions_onehot[i]

            #Do one agent-environment interaction
            with torch.no_grad():
                _, next_state, embs = agworld_network(obs, prev_memory * masks,
                                                      prev_actions)

                pnext_obs1[i], pnext_obs2[i], pnext_obs3[
                    i] = agworld_network.forward_state(next_state, actions)

                prev_actions = actions
                prev_memory = next_state

            dst_intrinsic_r[i] = (1 / 3) * (F.cross_entropy(
                pnext_obs1[i], next_obs[:, 0],
                reduction='none') + F.cross_entropy(
                    pnext_obs2[i], next_obs[:, 1],
                    reduction='none') + F.cross_entropy(
                        pnext_obs3[i], next_obs[:, 2],
                        reduction='none')).detach().mean((1, 2))

        # TODO fix the last intrinsic reward value as it is pred - 0
        dst_intrinsic_r[-1] = dst_intrinsic_r.mean(0)

        # --Normalize intrinsic reward
        self.predictor_rff.reset()  # do you have to rest it every time ???
        int_rff = torch.zeros((self.num_frames_per_proc, self.num_procs),
                              device=self.device)

        for i in reversed(range(self.num_frames_per_proc)):
            int_rff[i] = self.predictor_rff.update(dst_intrinsic_r[i])

        self.predictor_rms.update(int_rff.view(-1))  # running mean statisics
        dst_intrinsic_r.div_(
            torch.sqrt(self.predictor_rms.var).to(dst_intrinsic_r.device))
        #if save:
        #    f.dst_intrinsic = dst_intrinsic_r.clone()
        #    torch.save(f, f"{out_dir}/f_{updates_cnt}")
        #    delattr(f, "dst_intrinsic")

        # ------------------------------------------------------------------------------------------
        # -- Optimize Prediction error module
        optimizer_agworld = self.optimizer_agworld
        recurrence_worlds = self.recurrence_worlds

        max_grad_norm = self.max_grad_norm

        # ------------------------------------------------------------------------------------------
        # _________ for all tensors below, T x P -> P x T -> P * T _______________________
        f = self.flip_back_experience(f)
        # ------------------------------------------------------------------------------------------

        log_obs_loss = []
        log_obs_loss_same = []
        log_obs_loss_diffs = []

        for inds in self._get_batches_starting_indexes(
                recurrence=recurrence_worlds, padding=1):

            agworld_mem = f.agworld_mems[inds].detach()
            new_agworld_emb = [None] * recurrence_worlds
            new_agworld_mem = [None] * recurrence_worlds

            obs_batch_loss = torch.zeros(1, device=self.device)[0]

            obs_batch_loss_same = torch.zeros(1, device=self.device)[0]
            obs_batch_loss_diffs = torch.zeros(1, device=self.device)[0]

            log_grad_agworld_norm = []
            log_grad_eval_norm = []

            # -- Agent world
            for i in range(recurrence_worlds):

                obs = f.obs_image[inds + i].detach()
                mask = f.mask[inds + i]
                prev_actions_one = f.actions_onehot[inds + i - 1].detach()

                # Forward pass Agent Net for memory
                _, agworld_mem, new_agworld_emb[i] = agworld_network(
                    obs, agworld_mem * mask, prev_actions_one)
                new_agworld_mem[i] = agworld_mem

            # Go back and predict action(t) given state(t) & embedding (t+1)
            # and predict state(t + 1) given state(t) and action(t)
            for i in range(recurrence_worlds - 1):

                obs = f.obs_image[inds + i].detach()
                next_obs = f.obs_image[inds + i + 1].detach()

                # take masks and convert them to 1D tensor for indexing
                # use next masks because done gives you the new game obs
                next_mask = f.mask[inds + i + 1].long().detach()
                next_mask = next_mask.squeeze(1).type(torch.ByteTensor)

                crt_actions = f.action[inds + i].long().detach()
                crt_actions_one = f.actions_onehot[inds + i].detach()

                pred_obs1, pred_obs2, pred_obs3 = \
                    agworld_network.forward_state(new_agworld_mem[i], crt_actions_one)

                next_obsl = (f.obs_image[inds + i + 1] * 15).type(torch.int64)
                obs_batch_loss += (1 / 3) * \
                                   (F.cross_entropy(pred_obs1, next_obsl[:, 0]) +
                                    F.cross_entropy(pred_obs2, next_obsl[:, 1]) +
                                    F.cross_entropy(pred_obs3, next_obsl[:, 2]))

                # if all episodes ends at once, can't compute same/diff losses
                #if next_mask.sum() == 0:
                #    continue

                #same = (obs[next_mask] == next_obs[next_mask]).all(1).all(1).all(1)

                #s_pred_state = pred_state[next_mask]
                #s_crt_state = (new_agworld_mem[i + 1].detach())[next_mask]

                # if all are same/diff take care to empty tensors
                #if same.sum() == same.shape[0]:
                #    act_batch_loss_same += loss_m_act(s_pred_act[same], s_crt_act[same])
                #    state_batch_loss_same += loss_m_state(s_pred_state[same], s_crt_state[same])

                #elif same.sum() == 0:
                #    act_batch_loss_diff += loss_m_act(s_pred_act[~same], s_crt_act[~same])
                #    state_batch_loss_diffs += loss_m_state(s_pred_state[~same], s_crt_state[~same])

                #else:
                #    act_batch_loss_same += loss_m_act(s_pred_act[same], s_crt_act[same])
                #    act_batch_loss_diff += loss_m_act(s_pred_act[~same], s_crt_act[~same])

                #    state_batch_loss_same += loss_m_state(s_pred_state[same], s_crt_state[same])
                #    state_batch_loss_diffs += loss_m_state(s_pred_state[~same], s_crt_state[~same])

            # -- Optimize models
            obs_batch_loss /= (recurrence_worlds - 1)

            #obs_batch_loss_same /= (recurrence_worlds - 1)
            #obs_batch_loss_diffs /= (recurrence_worlds - 1)

            optimizer_agworld.zero_grad()
            obs_batch_loss.backward()

            grad_agworld_norm = sum(
                p.grad.data.norm(2).item()**2
                for p in agworld_network.parameters()
                if p.grad is not None)**0.5

            torch.nn.utils.clip_grad_norm_(agworld_network.parameters(),
                                           max_grad_norm)

            #log some shit

            log_obs_loss.append(obs_batch_loss.item())

            log_grad_agworld_norm.append(grad_agworld_norm)

            #log_state_loss_same.append(state_batch_loss_same.item())
            #log_state_loss_diffs.append(state_batch_loss_diffs.item())

            optimizer_agworld.step()

        # ------------------------------------------------------------------------------------------
        # Log some values
        self.aux_logs['next_obs_loss'] = np.mean(log_obs_loss)
        self.aux_logs['grad_norm_icm'] = np.mean(log_grad_agworld_norm)

        #self.aux_logs['next_state_loss_same'] = np.mean(log_state_loss_same)
        #self.aux_logs['next_state_loss_diffs'] = np.mean(log_state_loss_diffs)

        return dst_intrinsic_r

    def add_extra_experience(self, exps: DictList):
        # Process
        full_positions = [
            self.obss[i][j]["position"] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc)
        ]
        # Process
        full_states = [
            self.obss[i][j]["state"] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc)
        ]

        exps.states = preprocess_images(full_states, device=self.device)
        max_pos_value = max(self.env_height, self.env_width)
        exps.position = preprocess_images(full_positions,
                                          device=self.device,
                                          max_image_value=max_pos_value,
                                          normalize=False)
        exps.obs_image = exps.obs.image

    def evaluate(self):

        # set networks in eval mode

        self.acmodel.eval()

        out_dir = self.eval_dir
        env = self.eval_envs
        preprocess_obss = self.preprocess_obss
        device = self.device
        recurrent = self.acmodel.recurrent
        acmodel = self.acmodel
        agworld_network = acmodel.curiosity_model
        updates_cnt = self.updates_cnt

        obs = env.reset()
        if recurrent:
            memory = self.eval_memory

        mask = self.eval_mask.fill_(1).unsqueeze(1)
        eval_ag_memory = self.eval_agworld_memory.zero_()

        prev_actions_a = torch.zeros((len(obs), env.action_space.n),
                                     device=device)
        crt_actions = torch.zeros((len(obs), env.action_space.n),
                                  device=device)

        prev_agworld_mem = None
        obs_batch = None

        transitions = []
        steps = 400

        for i in range(steps):

            prev_obs = obs_batch
            preprocessed_obs = preprocess_obss(obs, device=device)
            obs_batch = torch.transpose(
                torch.transpose(preprocessed_obs.image, 1, 3), 2, 3)
            pred_obs = torch.zeros_like(obs_batch, device=device)

            full_state_batch = preprocess_images(
                [obs[i]['state'] for i in range(len(obs))],
                device=device,
            )

            with torch.no_grad():
                if recurrent:
                    dist, value, memory = acmodel(preprocessed_obs,
                                                  memory * mask)
                else:
                    dist, value = acmodel(preprocessed_obs)

                action = dist.sample()
                crt_actions.zero_()
                crt_actions.scatter_(1,
                                     action.long().unsqueeze(1),
                                     1.)  # transform to one_hot

                # Agent world
                _, eval_ag_memory, new_agworld_emb = agworld_network(
                    obs_batch, eval_ag_memory * mask, prev_actions_a)

                pnext_obs1, pnext_obs2, pnext_obs3 = agworld_network.forward_state(
                    eval_ag_memory, crt_actions)

                pred_obs[:, 0] = pnext_obs1.argmax(1)
                pred_obs[:, 1] = pnext_obs2.argmax(1)
                pred_obs[:, 2] = pnext_obs3.argmax(1)

                pred_obs = torch.transpose(torch.transpose(pred_obs, 1, 3), 1,
                                           2)

            next_obs, reward, done, _ = env.step(action.cpu().numpy())

            mask = (1 - torch.tensor(done, device=device,
                                     dtype=torch.float)).unsqueeze(1)

            prev_actions_a.copy_(crt_actions)

            transitions.append(
                (obs, action.cpu(), reward, done, next_obs, dist.probs.cpu(),
                 pred_obs.cpu(), eval_ag_memory.cpu(), new_agworld_emb.cpu(),
                 obs_batch.cpu(), full_state_batch.cpu()))
            obs = next_obs

        if out_dir is not None:
            np.save(
                f"{out_dir}/eval_{updates_cnt}", {
                    "transitions":
                    transitions,
                    "columns": [
                        "obs", "action", "reward", "done", "next_obs", "probs",
                        "pred_obs", "eval_ag_memory", "new_agworld_emb",
                        "obs_batch", "full_state_batch"
                    ]
                })
        self.acmodel.train()

        return None