Exemplo n.º 1
0
def learn(actor_model,
          model,
          batch,
          initial_agent_state,
          optimizer,
          scheduler,
          flags,
          lock=threading.Lock()):
    """Performs a learning (optimization) step."""
    with lock:
        intrinsic_rewards = torch.ones(
            (flags.unroll_length, flags.batch_size),
            dtype=torch.float32).to(device=flags.device)

        intrinsic_rewards = batch['train_state_count'][1:].float().to(
            device=flags.device)

        intrinsic_reward_coef = flags.intrinsic_reward_coef
        intrinsic_rewards *= intrinsic_reward_coef

        learner_outputs, unused_state = model(batch, initial_agent_state)

        bootstrap_value = learner_outputs['baseline'][-1]

        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }

        rewards = batch['reward']
        if flags.no_reward:
            total_rewards = intrinsic_rewards
        else:
            total_rewards = rewards + intrinsic_rewards
        clipped_rewards = torch.clamp(total_rewards, -1, 1)

        discounts = (~batch['done']).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch['policy_logits'],
            target_policy_logits=learner_outputs['policy_logits'],
            actions=batch['action'],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs['baseline'],
            bootstrap_value=bootstrap_value)

        pg_loss = losses.compute_policy_gradient_loss(
            learner_outputs['policy_logits'], batch['action'],
            vtrace_returns.pg_advantages)
        baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
            vtrace_returns.vs - learner_outputs['baseline'])
        entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
            learner_outputs['policy_logits'])

        total_loss = pg_loss + baseline_loss + entropy_loss

        episode_returns = batch['episode_return'][batch['done']]
        stats = {
            'mean_episode_return': torch.mean(episode_returns).item(),
            'total_loss': total_loss.item(),
            'pg_loss': pg_loss.item(),
            'baseline_loss': baseline_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'mean_rewards': torch.mean(rewards).item(),
            'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
            'mean_total_rewards': torch.mean(total_rewards).item(),
        }

        scheduler.step()
        optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
        optimizer.step()

        actor_model.load_state_dict(model.state_dict())
        return stats
Exemplo n.º 2
0
def learn(actor_model,
          model,
          batch,
          initial_agent_state,
          optimizer,
          scheduler,
          flags,
          action_hist,
          position_count,
          lock=threading.Lock()):
    """Performs a learning (optimization) step."""
    with lock:
        intrinsic_rewards = torch.ones(
            (flags.unroll_length, flags.batch_size),
            dtype=torch.float32).to(device=flags.device)

        # Store position id
        position_coord, position_counts = torch.unique(
            batch["agent_position"].view(-1, 2), return_counts=True, dim=0)

        for i, coord in enumerate(position_coord):
            coord = tuple(coord[:].cpu().numpy())
            position_count[coord] = position_count.get(
                coord, 0) + position_counts[i].item()

        # Saving action histogram, to visualize but NOT nudging it
        action_id, count_action = torch.unique(batch["action"].flatten(),
                                               return_counts=True)
        acted_id, action_acted = torch.unique(
            (batch["action"] + 1) * batch["action_acted"] - 1,
            return_counts=True)
        action_hist.add(action_id,
                        count_action.cpu().float(), acted_id[1:].cpu(),
                        action_acted[1:].cpu().float())

        intrinsic_rewards = batch['train_state_count'][1:].float().to(
            device=flags.device)

        intrinsic_reward_coef = flags.intrinsic_reward_coef
        intrinsic_rewards *= intrinsic_reward_coef

        learner_outputs, unused_state = model(batch, initial_agent_state)

        bootstrap_value = learner_outputs['baseline'][-1]

        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }

        rewards = batch['reward']
        if flags.no_reward:
            total_rewards = intrinsic_rewards
        else:
            total_rewards = rewards + intrinsic_rewards
        clipped_rewards = torch.clamp(total_rewards, -1, 1)

        discounts = (~batch['done']).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch['policy_logits'],
            target_policy_logits=learner_outputs['policy_logits'],
            actions=batch['action'],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs['baseline'],
            bootstrap_value=bootstrap_value)

        pg_loss = losses.compute_policy_gradient_loss(
            learner_outputs['policy_logits'], batch['action'],
            vtrace_returns.pg_advantages)
        baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
            vtrace_returns.vs - learner_outputs['baseline'])
        entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
            learner_outputs['policy_logits'])

        total_loss = pg_loss + baseline_loss + entropy_loss

        episode_returns = batch['episode_return'][batch['done']]
        stats = {
            'mean_episode_return': torch.mean(episode_returns).item(),
            'total_loss': total_loss.item(),
            'pg_loss': pg_loss.item(),
            'baseline_loss': baseline_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'mean_rewards': torch.mean(rewards).item(),
            'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
            'mean_total_rewards': torch.mean(total_rewards).item(),
        }

        scheduler.step()
        optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
        optimizer.step()

        actor_model.load_state_dict(model.state_dict())
        return stats
Exemplo n.º 3
0
def learn(actor_model,
          model,
          state_embedding_model,
          forward_dynamics_model,
          inverse_dynamics_model,
          batch,
          initial_agent_state, 
          optimizer,
          state_embedding_optimizer, 
          forward_dynamics_optimizer, 
          inverse_dynamics_optimizer, 
          scheduler,
          flags,
          frames=None,
          lock=threading.Lock()):
    """Performs a learning (optimization) step."""
    with lock:
        state_emb = state_embedding_model(batch['frame'][:-1].to(device=flags.device))
        next_state_emb = state_embedding_model(batch['frame'][1:].to(device=flags.device))

        pred_next_state_emb = forward_dynamics_model(\
            state_emb, batch['action'][1:].to(device=flags.device))
        pred_actions = inverse_dynamics_model(state_emb, next_state_emb) 
        entropy_emb_actions = losses.compute_entropy_loss(pred_actions)

        intrinsic_rewards = torch.norm(pred_next_state_emb - next_state_emb, dim=2, p=2)
        
        intrinsic_reward_coef = flags.intrinsic_reward_coef
        intrinsic_rewards *= intrinsic_reward_coef 
        
        forward_dynamics_loss = flags.forward_loss_coef * \
            losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb)

        inverse_dynamics_loss = flags.inverse_loss_coef * \
            losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:])

        num_samples = flags.unroll_length * flags.batch_size
        actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy()
        intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy()

            
        learner_outputs, unused_state = model(batch, initial_agent_state)

        bootstrap_value = learner_outputs['baseline'][-1]

        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }
        
        actions = batch['action'].reshape(flags.unroll_length * flags.batch_size).cpu().numpy()
        action_percentage = [0 for _ in range(model.num_actions)]
        for i in range(model.num_actions):
            action_percentage[i] = np.sum([a == i for a in actions]) / len(actions)
        
        rewards = batch['reward']
            
        if flags.no_reward:
            total_rewards = intrinsic_rewards
        else:            
            total_rewards = rewards + intrinsic_rewards
        clipped_rewards = torch.clamp(total_rewards, -1, 1)
        
        discounts = (~batch['done']).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch['policy_logits'],
            target_policy_logits=learner_outputs['policy_logits'],
            actions=batch['action'],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs['baseline'],
            bootstrap_value=bootstrap_value)

        pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
                                               batch['action'],
                                               vtrace_returns.pg_advantages)
        baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
            vtrace_returns.vs - learner_outputs['baseline'])
        entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
            learner_outputs['policy_logits'])

        total_loss = pg_loss + baseline_loss + entropy_loss \
                + forward_dynamics_loss  + inverse_dynamics_loss
        
        episode_returns = batch['episode_return'][batch['done']]
        episode_lengths = batch['episode_step'][batch['done']]
        episode_wins = batch['episode_win'][batch['done']]
        stats = {
            'mean_episode_return': torch.mean(episode_returns).item(),
            'total_loss': total_loss.item(),
            'pg_loss': pg_loss.item(),
            'baseline_loss': baseline_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'forward_dynamics_loss': forward_dynamics_loss.item(),
            'inverse_dynamics_loss': inverse_dynamics_loss.item(),
            'mean_rewards': torch.mean(rewards).item(),
            'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
            'mean_total_rewards': torch.mean(total_rewards).item(),
        }
        
        scheduler.step()
        optimizer.zero_grad()
        state_embedding_optimizer.zero_grad()
        forward_dynamics_optimizer.zero_grad()
        inverse_dynamics_optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
        nn.utils.clip_grad_norm_(state_embedding_model.parameters(), flags.max_grad_norm)
        nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(), flags.max_grad_norm)
        nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(), flags.max_grad_norm)
        optimizer.step()
        state_embedding_optimizer.step()
        forward_dynamics_optimizer.step()
        inverse_dynamics_optimizer.step()

        actor_model.load_state_dict(model.state_dict())
        return stats
Exemplo n.º 4
0
def learn(actor_model,
          model,
          random_target_network,
          predictor_network,
          batch,
          initial_agent_state,
          optimizer,
          predictor_optimizer,
          scheduler,
          flags,
          action_hist,
          position_count,
          frames=None,
          lock=threading.Lock()):
    """Performs a learning (optimization) step."""
    with lock:
        if flags.use_fullobs_intrinsic:
            random_embedding = random_target_network(batch, next_state=True)\
                    .reshape(flags.unroll_length, flags.batch_size, 128)
            predicted_embedding = predictor_network(batch, next_state=True)\
                    .reshape(flags.unroll_length, flags.batch_size, 128)
        else:
            random_embedding = random_target_network(
                batch['partial_obs'][1:].to(device=flags.device))
            predicted_embedding = predictor_network(
                batch['partial_obs'][1:].to(device=flags.device))

        # Saving position count, to produce heatmaps later
        position_coord, position_counts = torch.unique(
            batch["agent_position"].view(-1, 2), return_counts=True, dim=0)
        for i, coord in enumerate(position_coord):
            coord = tuple(coord[:].cpu().numpy())
            position_count[coord] = position_count.get(
                coord, 0) + position_counts[i].item()

        # Saving action histogram, to visualize but NOT nudging it
        action_id, count_action = torch.unique(batch["action"].flatten(),
                                               return_counts=True)
        acted_id, action_acted = torch.unique(
            (batch["action"] + 1) * batch["action_acted"] - 1,
            return_counts=True)
        action_hist.add(action_id,
                        count_action.cpu().float(), acted_id[1:].cpu(),
                        action_acted[1:].cpu().float())

        intrinsic_rewards = torch.norm(predicted_embedding.detach() -
                                       random_embedding.detach(),
                                       dim=2,
                                       p=2)

        intrinsic_reward_coef = flags.intrinsic_reward_coef
        intrinsic_rewards *= intrinsic_reward_coef

        num_samples = flags.unroll_length * flags.batch_size
        actions_flat = batch['action'][1:].reshape(
            num_samples).cpu().detach().numpy()
        intrinsic_rewards_flat = intrinsic_rewards.reshape(
            num_samples).cpu().detach().numpy()

        rnd_loss = flags.rnd_loss_coef * \
                losses.compute_forward_dynamics_loss(predicted_embedding, random_embedding.detach())

        learner_outputs, unused_state = model(batch, initial_agent_state)

        bootstrap_value = learner_outputs['baseline'][-1]

        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }

        rewards = batch['reward']

        if flags.no_reward:
            total_rewards = intrinsic_rewards
        else:
            total_rewards = rewards + intrinsic_rewards
        clipped_rewards = torch.clamp(total_rewards, -1, 1)

        discounts = (~batch['done']).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch['policy_logits'],
            target_policy_logits=learner_outputs['policy_logits'],
            actions=batch['action'],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs['baseline'],
            bootstrap_value=bootstrap_value)

        pg_loss = losses.compute_policy_gradient_loss(
            learner_outputs['policy_logits'], batch['action'],
            vtrace_returns.pg_advantages)
        baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
            vtrace_returns.vs - learner_outputs['baseline'])
        entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
            learner_outputs['policy_logits'])

        total_loss = pg_loss + baseline_loss + entropy_loss + rnd_loss

        episode_returns = batch['episode_return'][batch['done']]
        stats = {
            'mean_episode_return': torch.mean(episode_returns).item(),
            'total_loss': total_loss.item(),
            'pg_loss': pg_loss.item(),
            'baseline_loss': baseline_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'rnd_loss': rnd_loss.item(),
            'mean_rewards': torch.mean(rewards).item(),
            'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
            'mean_total_rewards': torch.mean(total_rewards).item(),
        }

        scheduler.step()
        optimizer.zero_grad()
        predictor_optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
        nn.utils.clip_grad_norm_(predictor_network.parameters(),
                                 flags.max_grad_norm)
        optimizer.step()
        predictor_optimizer.step()

        actor_model.load_state_dict(model.state_dict())
        return stats
def learn(actor_model,
          model,
          random_target_network,
          predictor_network,
          batch,
          initial_agent_state, 
          optimizer,
          predictor_optimizer,
          scheduler,
          flags,
          frames=None,
          lock=threading.Lock()):
    """Performs a learning (optimization) step."""
    with lock:
        random_embedding = random_target_network(batch['frame'][1:].to(device=flags.device))
        predicted_embedding = predictor_network(batch['frame'][1:].to(device=flags.device))

        intrinsic_rewards = torch.norm(predicted_embedding.detach() - random_embedding.detach(), dim=2, p=2)

        intrinsic_reward_coef = flags.intrinsic_reward_coef
        intrinsic_rewards *= intrinsic_reward_coef 
        
        num_samples = flags.unroll_length * flags.batch_size
        actions_flat = batch['action'][1:].reshape(num_samples).cpu().detach().numpy()
        intrinsic_rewards_flat = intrinsic_rewards.reshape(num_samples).cpu().detach().numpy()

        rnd_loss = flags.rnd_loss_coef * \
                losses.compute_forward_dynamics_loss(predicted_embedding, random_embedding.detach()) 
            
        learner_outputs, unused_state = model(batch, initial_agent_state)

        bootstrap_value = learner_outputs['baseline'][-1]

        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }
        
        rewards = batch['reward']
            
        if flags.no_reward:
            total_rewards = intrinsic_rewards
        else:            
            total_rewards = rewards + intrinsic_rewards
        clipped_rewards = torch.clamp(total_rewards, -1, 1)
        
        discounts = (~batch['done']).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch['policy_logits'],
            target_policy_logits=learner_outputs['policy_logits'],
            actions=batch['action'],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs['baseline'],
            bootstrap_value=bootstrap_value)

        pg_loss = losses.compute_policy_gradient_loss(learner_outputs['policy_logits'],
                                               batch['action'],
                                               vtrace_returns.pg_advantages)
        baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
            vtrace_returns.vs - learner_outputs['baseline'])
        entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
            learner_outputs['policy_logits'])

        total_loss = pg_loss + baseline_loss + entropy_loss + rnd_loss

        episode_returns = batch['episode_return'][batch['done']]
        stats = {
            'mean_episode_return': torch.mean(episode_returns).item(),
            'total_loss': total_loss.item(),
            'pg_loss': pg_loss.item(),
            'baseline_loss': baseline_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'rnd_loss': rnd_loss.item(),
            'mean_rewards': torch.mean(rewards).item(),
            'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
            'mean_total_rewards': torch.mean(total_rewards).item(),
        }
        
        scheduler.step()
        optimizer.zero_grad()
        predictor_optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
        nn.utils.clip_grad_norm_(predictor_network.parameters(), flags.max_grad_norm)
        optimizer.step()
        predictor_optimizer.step()

        actor_model.load_state_dict(model.state_dict())
        return stats
Exemplo n.º 6
0
def learn(actor_model,
          model,
          state_embedding_model,
          forward_dynamics_model,
          inverse_dynamics_model,
          batch,
          initial_agent_state,
          optimizer,
          state_embedding_optimizer,
          forward_dynamics_optimizer,
          inverse_dynamics_optimizer,
          scheduler,
          flags,
          frames=None,
          lock=threading.Lock()):
    """Performs a learning (optimization) step."""
    with lock:
        count_rewards = torch.ones((flags.unroll_length, flags.batch_size),
                                   dtype=torch.float32).to(device=flags.device)
        count_rewards = batch['episode_state_count'][1:].float().to(
            device=flags.device)

        if flags.use_fullobs_intrinsic:
            state_emb = state_embedding_model(batch, next_state=False)\
                    .reshape(flags.unroll_length, flags.batch_size, 128)
            next_state_emb = state_embedding_model(batch, next_state=True)\
                    .reshape(flags.unroll_length, flags.batch_size, 128)
        else:
            state_emb = state_embedding_model(
                batch['partial_obs'][:-1].to(device=flags.device))
            next_state_emb = state_embedding_model(
                batch['partial_obs'][1:].to(device=flags.device))

        pred_next_state_emb = forward_dynamics_model(
            state_emb, batch['action'][1:].to(device=flags.device))
        pred_actions = inverse_dynamics_model(state_emb, next_state_emb)

        control_rewards = torch.norm(next_state_emb - state_emb, dim=2, p=2)

        intrinsic_rewards = count_rewards * control_rewards

        intrinsic_reward_coef = flags.intrinsic_reward_coef
        intrinsic_rewards *= intrinsic_reward_coef

        forward_dynamics_loss = flags.forward_loss_coef * \
            losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_emb)

        inverse_dynamics_loss = flags.inverse_loss_coef * \
            losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:])

        learner_outputs, unused_state = model(batch, initial_agent_state)

        bootstrap_value = learner_outputs['baseline'][-1]

        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }

        rewards = batch['reward']
        if flags.no_reward:
            total_rewards = intrinsic_rewards
        else:
            total_rewards = rewards + intrinsic_rewards
        clipped_rewards = torch.clamp(total_rewards, -1, 1)

        discounts = (~batch['done']).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch['policy_logits'],
            target_policy_logits=learner_outputs['policy_logits'],
            actions=batch['action'],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs['baseline'],
            bootstrap_value=bootstrap_value)

        pg_loss = losses.compute_policy_gradient_loss(
            learner_outputs['policy_logits'], batch['action'],
            vtrace_returns.pg_advantages)
        baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
            vtrace_returns.vs - learner_outputs['baseline'])
        entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
            learner_outputs['policy_logits'])

        total_loss = pg_loss + baseline_loss + entropy_loss + \
                    forward_dynamics_loss + inverse_dynamics_loss

        episode_returns = batch['episode_return'][batch['done']]
        stats = {
            'mean_episode_return': torch.mean(episode_returns).item(),
            'total_loss': total_loss.item(),
            'pg_loss': pg_loss.item(),
            'baseline_loss': baseline_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'mean_rewards': torch.mean(rewards).item(),
            'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
            'mean_total_rewards': torch.mean(total_rewards).item(),
            'mean_control_rewards': torch.mean(control_rewards).item(),
            'mean_count_rewards': torch.mean(count_rewards).item(),
            'forward_dynamics_loss': forward_dynamics_loss.item(),
            'inverse_dynamics_loss': inverse_dynamics_loss.item(),
        }

        scheduler.step()
        optimizer.zero_grad()
        state_embedding_optimizer.zero_grad()
        forward_dynamics_optimizer.zero_grad()
        inverse_dynamics_optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)
        nn.utils.clip_grad_norm_(state_embedding_model.parameters(),
                                 flags.max_grad_norm)
        nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(),
                                 flags.max_grad_norm)
        nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(),
                                 flags.max_grad_norm)
        optimizer.step()
        state_embedding_optimizer.step()
        forward_dynamics_optimizer.step()
        inverse_dynamics_optimizer.step()

        actor_model.load_state_dict(model.state_dict())
        return stats
def learn(
    actor_model,
    model,
    action_hist,
    batch,
    initial_agent_state,
    optimizer,
    # action_distribution_optimizer,
    scheduler,
    flags,
    frames=None,
    position_count=None,
    state_embedding_model=None,
    forward_dynamics_model=None,
    inverse_dynamics_model=None,
    state_embedding_optimizer=None,
    forward_dynamics_optimizer=None,
    inverse_dynamics_optimizer=None,
    lock=threading.Lock()):
    """Performs a learning (optimization) step."""
    with lock:
        count_rewards = torch.ones((flags.unroll_length, flags.batch_size),
                                   dtype=torch.float32).to(device=flags.device)
        count_rewards = batch['episode_state_count'][1:].float().to(
            device=flags.device)

        action_id, count_action = torch.unique(batch["action"].flatten(),
                                               return_counts=True)

        # Store position id
        position_coord, position_counts = torch.unique(
            batch["agent_position"].view(-1, 2), return_counts=True, dim=0)

        for i, coord in enumerate(position_coord):
            coord = tuple(coord[:].cpu().numpy())
            position_count[coord] = position_count.get(
                coord, 0) + position_counts[i].item()

        if state_embedding_model:
            current_state_embedding = state_embedding_model(
                batch['partial_obs'][:-1].to(device=flags.device))
            next_state_embedding = state_embedding_model(
                batch['partial_obs'][1:].to(device=flags.device))

            batch_action_acted = torch.abs(current_state_embedding -
                                           next_state_embedding).sum(
                                               dim=2) > flags.change_treshold
            acted_id, action_acted = torch.unique(
                (batch["action"][:-1] + 1) * batch_action_acted.long() - 1,
                return_counts=True)

            pred_next_state_emb = forward_dynamics_model(
                current_state_embedding,
                batch['action'][1:].to(device=flags.device))
            pred_actions = inverse_dynamics_model(current_state_embedding,
                                                  next_state_embedding)

            forward_dynamics_loss = flags.forward_loss_coef * \
                                    losses.compute_forward_dynamics_loss(pred_next_state_emb, next_state_embedding)

            inverse_dynamics_loss = flags.inverse_loss_coef * \
                                    losses.compute_inverse_dynamics_loss(pred_actions, batch['action'][1:])

        else:
            acted_id, action_acted = torch.unique(
                (batch["action"] + 1) * batch["action_acted"] - 1,
                return_counts=True)
            batch_action_acted = batch["action_acted"][:-1].byte()

            forward_dynamics_loss = torch.zeros(1)
            inverse_dynamics_loss = torch.zeros(1)

        action_hist.add(action_id,
                        count_action.cpu().float(), acted_id[1:].cpu(),
                        action_acted[1:].cpu().float())

        action_rewards = torch.zeros_like(batch["action"][:-1]).float()
        acted_ratio = action_hist.usage_ratio()

        if flags.action_dist_decay_coef == 0:
            reward_for_an_action = 1 - acted_ratio
        else:
            reward_for_an_action = torch.exp(-acted_ratio *
                                             flags.action_dist_decay_coef)
        reward_for_an_action[acted_ratio == 1] = 0

        reward_for_an_action[torch.isnan(reward_for_an_action)] = 0
        assert torch.all(reward_for_an_action >= 0
                         ), "Problem, reward should only be positive"

        for id in action_id:
            action_rewards[(batch["action"][:-1] == id.item())
                           & batch_action_acted] = reward_for_an_action[id]

        intrinsic_rewards = count_rewards * action_rewards

        intrinsic_reward_coef = flags.intrinsic_reward_coef
        intrinsic_rewards *= intrinsic_reward_coef

        learner_outputs, unused_state = model(batch, initial_agent_state)

        bootstrap_value = learner_outputs['baseline'][-1]

        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }

        rewards = batch['reward']
        if flags.no_reward:
            total_rewards = intrinsic_rewards
        else:
            total_rewards = rewards + intrinsic_rewards
        clipped_rewards = torch.clamp(total_rewards, -1, 1)

        discounts = (~batch['done']).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch['policy_logits'],
            target_policy_logits=learner_outputs['policy_logits'],
            actions=batch['action'],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs['baseline'],
            bootstrap_value=bootstrap_value)

        pg_loss = losses.compute_policy_gradient_loss(
            learner_outputs['policy_logits'], batch['action'],
            vtrace_returns.pg_advantages)
        baseline_loss = flags.baseline_cost * losses.compute_baseline_loss(
            vtrace_returns.vs - learner_outputs['baseline'])
        entropy_loss = flags.entropy_cost * losses.compute_entropy_loss(
            learner_outputs['policy_logits'])

        total_loss = pg_loss + baseline_loss + entropy_loss  #+ act_distrib_loss

        episode_returns = batch['episode_return'][batch['done']]
        stats = {
            'mean_episode_return': torch.mean(episode_returns).item(),
            'total_loss': total_loss.item(),
            'pg_loss': pg_loss.item(),
            'baseline_loss': baseline_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'mean_rewards': torch.mean(rewards).item(),
            'mean_intrinsic_rewards': torch.mean(intrinsic_rewards).item(),
            'mean_total_rewards': torch.mean(total_rewards).item(),
            'mean_action_rewards': torch.mean(action_rewards).item(),
            'mean_count_rewards': torch.mean(count_rewards).item(),
            'forward_dynamics_loss': forward_dynamics_loss.item(),
            'inverse_dynamics_loss': inverse_dynamics_loss.item(),
        }

        scheduler.step()
        optimizer.zero_grad()

        if state_embedding_model:
            state_embedding_optimizer.zero_grad()
            forward_dynamics_optimizer.zero_grad()
            inverse_dynamics_optimizer.zero_grad()

        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.max_grad_norm)

        if state_embedding_model:
            nn.utils.clip_grad_norm_(state_embedding_model.parameters(),
                                     flags.max_grad_norm)
            nn.utils.clip_grad_norm_(forward_dynamics_model.parameters(),
                                     flags.max_grad_norm)
            nn.utils.clip_grad_norm_(inverse_dynamics_model.parameters(),
                                     flags.max_grad_norm)

        optimizer.step()

        if state_embedding_model:
            state_embedding_optimizer.step()
            forward_dynamics_optimizer.step()
            inverse_dynamics_optimizer.step()

        actor_model.load_state_dict(model.state_dict())
        return stats