Example #1
0
def learn(
        flags,
        actor_model,
        model,
        batch,
        initial_agent_state,
        optimizer,
        scheduler,
        lock=threading.Lock(),  # noqa: B008
):
    """Performs a learning (optimization) step."""
    with lock:
        """
        put a lock on the central learner,
        send the trajectories to it.
        Update the parameters of the central learner,
        copy the parameters of the central learner back to the actors
        """
        # print('RUNNING MAIN MODEL')
        # print('MODEL OUTOUT: ', model(batch, initial_agent_state))
        # print('batch size : ', batch['frame'].size())

        # TODO Next Step: think about chopping the sequences up for attending to long sequences
        # TODO Next Step: make the actors put only one sequence in one rollout. And then mask the remaining rollout
        #       positions. And then change the triu in line 461 in TXL
        # Here entire 81 length sequence is considered as a query, and is autoregressively being attended to the
        # keys and values of length 81. This sequence length when becomes very large is when we'll need memory to
        # kick in
        learner_outputs, unused_state, unused_mems, mem_padding, ind_first_done = model(
            batch, initial_agent_state, mems=None, mem_padding=None)
        # Here mem_padding is same as "batch" padding for this iteration so can use
        # for masking loss

        # Take final value function slice for bootstrapping.
        # this is the final value from this trajectory
        if ind_first_done is not None:
            # B dimensional tensor
            bootstrap_value = learner_outputs["baseline"][
                ind_first_done, range(flags.batch_size)]
        else:
            bootstrap_value = learner_outputs["baseline"][-1]

        # Move from obs[t] -> action[t] to action[t] -> obs[t].
        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }

        # Using learner_outputs to predict batch since batch is always one ahead of learner_outputs?

        rewards = batch["reward"]
        if flags.reward_clipping == "abs_one":
            clipped_rewards = torch.clamp(rewards, -1, 1)
        elif flags.reward_clipping == "none":
            clipped_rewards = rewards

        discounts = (~batch["done"]).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch["policy_logits"],
            target_policy_logits=learner_outputs[
                "policy_logits"],  # WHY IS THIS THE TARGET?
            actions=batch["action"],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs["baseline"],
            bootstrap_value=bootstrap_value,
        )

        # TODO Next Step: the losses also have to be computed with the padding, think on a structure of mask
        #                   to do this efficiently
        # Advantages are [rollout_len, batch_size]

        # First we mask out vtrace_returns.pg_advantages where there is padding which fixes pg_loss
        pad_mask = (~(mem_padding.squeeze(0)[1:])
                    ).float() if mem_padding is not None else None

        pg_loss = compute_policy_gradient_loss(
            learner_outputs["policy_logits"], batch["action"],
            vtrace_returns.pg_advantages, pad_mask)
        baseline_loss = flags.baseline_cost * compute_baseline_loss(
            vtrace_returns.vs - learner_outputs["baseline"], pad_mask)
        entropy_loss = flags.entropy_cost * compute_entropy_loss(
            learner_outputs["policy_logits"], pad_mask)

        total_loss = pg_loss + baseline_loss + entropy_loss

        episode_returns = batch["episode_return"][batch["done"]]
        stats = {
            "episode_returns": tuple(episode_returns.cpu().numpy()),
            "mean_episode_return": torch.mean(episode_returns).item(),
            "total_loss": total_loss.item(),
            "pg_loss": pg_loss.item(),
            "baseline_loss": baseline_loss.item(),
            "entropy_loss": entropy_loss.item(),
        }

        optimizer.zero_grad()
        total_loss.backward()
        if flags.fp16:
            optimizer.clip_master_grads(flags.grad_norm_clipping)
        else:
            nn.utils.clip_grad_norm_(model.parameters(),
                                     flags.grad_norm_clipping)
        optimizer.step()
        # scheduler is being stepped in the lock of batch_and_learn itself
        # scheduler.step()

        actor_model.load_state_dict(model.state_dict())
        return stats
Example #2
0
def learn(
        flags,
        actor_model,
        model,
        batch,
        initial_agent_state,
        optimizer,
        scheduler,
        lock=threading.Lock(),  # noqa: B008
):
    """Performs a learning (optimization) step."""
    with lock:
        learner_outputs, unused_state = model(batch, initial_agent_state)

        # Take final value function slice for bootstrapping.
        bootstrap_value = learner_outputs["baseline"][-1]

        # Move from obs[t] -> action[t] to action[t] -> obs[t].
        batch = {key: tensor[1:] for key, tensor in batch.items()}
        learner_outputs = {
            key: tensor[:-1]
            for key, tensor in learner_outputs.items()
        }

        rewards = batch["reward"]
        if flags.reward_clipping == "abs_one":
            clipped_rewards = torch.clamp(rewards, -1, 1)
        elif flags.reward_clipping == "none":
            clipped_rewards = rewards

        discounts = (~batch["done"]).float() * flags.discounting

        vtrace_returns = vtrace.from_logits(
            behavior_policy_logits=batch["policy_logits"],
            target_policy_logits=learner_outputs["policy_logits"],
            actions=batch["action"],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs["baseline"],
            bootstrap_value=bootstrap_value,
        )

        pg_loss = compute_policy_gradient_loss(
            learner_outputs["policy_logits"],
            batch["action"],
            vtrace_returns.pg_advantages,
        )
        baseline_loss = flags.baseline_cost * compute_baseline_loss(
            vtrace_returns.vs - learner_outputs["baseline"])
        entropy_loss = flags.entropy_cost * compute_entropy_loss(
            learner_outputs["policy_logits"])

        total_loss = pg_loss + baseline_loss + entropy_loss

        episode_returns = batch["episode_return"][batch["done"]]
        stats = {
            "episode_returns": tuple(episode_returns.cpu().numpy()),
            "mean_episode_return": torch.mean(episode_returns).item(),
            "total_loss": total_loss.item(),
            "pg_loss": pg_loss.item(),
            "baseline_loss": baseline_loss.item(),
            "entropy_loss": entropy_loss.item(),
        }

        optimizer.zero_grad()
        total_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), flags.grad_norm_clipping)
        optimizer.step()
        scheduler.step()

        actor_model.load_state_dict(model.state_dict())
        return stats
def learn(
        flags,
        actor_model,
        model,
        batch,
        initial_agent_state,
        optimizer,
        scheduler,
        lock=threading.Lock(),  # noqa: B008
):
    """Performs a learning (optimization) step."""
    with lock:
        """
        put a lock on the central learner,
        send the trajectories to it.
        Update the parameters of the central learner,
        copy the parameters of the central learner back to the actors
        """

        # TODO: Chop up batch into smaller pieces to run through TXL one at a time (caching previous as memory)
        # TODO: Change batch function to look for trajectories of similar lengths
        # TODO: Add in adaptive attention (and think of how things change (for ex no memory))
        #print({key: batch[key].shape for key in batch})
        mems, mem_padding = None, None
        for i in range(0, flags.unroll_length + 1, flags.chunk_size):
            mini_batch = {
                key: batch[key][i:i + flags.chunk_size]
                for key in batch if key != 'len_traj'
            }
            #Note that initial agent state isn't used by transformer (I think this is hidden state)
            #Will need to change if want to use this with LSTM

            #TODO : Need to change batch->minibatch (batch name gets overwritten)
            tmp_mask = torch.zeros_like(mini_batch["done"]).bool()

            learner_outputs, unused_state, mems, mem_padding, ind_first_done = model(
                mini_batch,
                initial_agent_state,
                mems=mems,
                mem_padding=mem_padding)
            #Here mem_padding is same as "batch" padding for this iteration so can use
            #for masking loss

            #if mini_batch["done"].any().item():
            #    print('Indfirstdone: ',ind_first_done)
            #    print('miniBATCH DONE: ', mini_batch["done"])
            #    print('Mem padding: ', mem_padding)

            # Take final value function slice for bootstrapping.
            # this is the final value from this trajectory
            if ind_first_done is not None:
                # B dimensional tensor
                bootstrap_value = learner_outputs["baseline"][
                    ind_first_done, range(flags.batch_size)]
            else:
                bootstrap_value = learner_outputs["baseline"][-1]

            # Move from obs[t] -> action[t] to action[t] -> obs[t].
            mini_batch = {
                key: tensor[1:]
                for key, tensor in mini_batch.items()
            }
            learner_outputs = {
                key: tensor[:-1]
                for key, tensor in learner_outputs.items()
            }

            #Using learner_outputs to predict batch since batch is always one ahead of learner_outputs?

            rewards = mini_batch["reward"]
            if flags.reward_clipping == "abs_one":
                clipped_rewards = torch.clamp(rewards, -1, 1)
            elif flags.reward_clipping == "none":
                clipped_rewards = rewards

            discounts = (~mini_batch["done"]).float() * flags.discounting

            vtrace_returns = vtrace.from_logits(
                behavior_policy_logits=mini_batch["policy_logits"],
                target_policy_logits=learner_outputs[
                    "policy_logits"],  #WHY IS THIS THE TARGET?
                actions=mini_batch["action"],
                discounts=discounts,
                rewards=clipped_rewards,
                values=learner_outputs["baseline"],
                bootstrap_value=bootstrap_value,
            )

            # TODO Next Step: the losses also have to be computed with the padding, think on a structure of mask
            #                   to do this efficiently
            # Advantages are [rollout_len, batch_size]

            # First we mask out vtrace_returns.pg_advantages where there is padding which fixes pg_loss
            pad_mask = (~(mem_padding.squeeze(0)[1:])
                        ).float() if mem_padding is not None else None

            pg_loss = compute_policy_gradient_loss(
                learner_outputs["policy_logits"], mini_batch["action"],
                vtrace_returns.pg_advantages, pad_mask)
            baseline_loss = flags.baseline_cost * compute_baseline_loss(
                vtrace_returns.vs - learner_outputs["baseline"], pad_mask)
            entropy_loss = flags.entropy_cost * compute_entropy_loss(
                learner_outputs["policy_logits"], pad_mask)

            total_loss = pg_loss + baseline_loss + entropy_loss

            #tmp_mask is defined above
            if ind_first_done is not None:
                rows_to_use = []
                cols_to_use = []
                for i, val in enumerate(ind_first_done):
                    if val != -1:
                        rows_to_use.append(val)
                        cols_to_use.append(i)

                tmp_mask[
                    rows_to_use,
                    cols_to_use] = True  #NOT RIGHT FOR COLS THAT DIDNT FINISH
                tmp_mask = tmp_mask[
                    1:]  #This is how they initially had it so will keep like this
                #if mini_batch["done"].any().item():
                #    print('TMP MASK: ',tmp_mask)
                #    print('BATCH DONE: ', mini_batch["done"])
                #    print('shape1: {}, shape2: {}'.format(tmp_mask.shape, mini_batch['done'].shape))

            #episode_returns = mini_batch["episode_return"][mini_batch["done"]]
            episode_returns = mini_batch["episode_return"][tmp_mask]

            stats = {
                "episode_returns": tuple(episode_returns.cpu().numpy()),
                "mean_episode_return": torch.mean(episode_returns).item(),
                "total_loss": total_loss.item(),
                "pg_loss": pg_loss.item(),
                "baseline_loss": baseline_loss.item(),
                "entropy_loss": entropy_loss.item(),
            }

            optimizer.zero_grad()
            total_loss.backward()
            if flags.fp16:
                optimizer.clip_master_grads(flags.grad_norm_clipping)
            else:
                nn.utils.clip_grad_norm_(model.parameters(),
                                         flags.grad_norm_clipping)
            optimizer.step()
            # scheduler is being stepped in the lock of batch_and_learn itself
            # scheduler.step()

        actor_model.load_state_dict(model.state_dict())
        return stats