Esempio n. 1
0
    def forward(self, x, memory, xbatch, mbatch):

        x = self.input_proj(x)

        # add input vectors to perform self-attention
        # also compute the batch indices and the edge indices for the self-
        # attention computation
        mem_cat = torch.cat([memory, x], 0)
        batch_cat = torch.cat([mbatch, xbatch], 0)

        t_edge_compute, (ei_cat, log_time_get_ei_from) = timer(gnns.utils.get_ei_from)(mbatch, xbatch)

        # transformer block

        t_mhsa, (mem_tmp, log_time_mhsa) = timer(self.mhsa)(mem_cat, batch_cat, ei_cat)

        t_norm1, mem_tmp = timer(self.norm1)(memory + mem_tmp)

        t_mlp, mem_update = timer(self.mlp)(mem_tmp)

        t_norm2, mem_update = timer(self.norm2)(mem_tmp + mem_update)

        # compute forget and input gates
        f, i = self.proj(torch.cat([memory, mem_update], -1)).chunk(2, -1)

        # update memory
        # this mechanism may be refined
        memory = memory * torch.sigmoid(f) + mem_update * torch.sigmoid(i)

        # for now the output is the memory
        output = memory

        log_time = {'t_edge_compute': t_edge_compute, 'details_edge_compute': log_time_get_ei_from, 't_mhsa': t_mhsa,
                    'details_mhsa': log_time_mhsa, 't_norm1': t_norm1, 't_mlp': t_mlp, 't_norm2': t_norm2}
        return output, memory, log_time
Esempio n. 2
0
    def forward(self, obs, memory, obs_batch, m_batch, instr_embedding=None):
        # if self.use_instr and instr_embedding is None:
        #     instr_embedding = self._get_instr_embedding(obs.instr)
        # if self.use_instr and self.lang_model == "attgru":
        #     # outputs: B x L x D
        #     # memory: B x M
        #     mask = (obs.instr != 0).float()
        #     # The mask tensor has the same length as obs.instr, and
        #     # thus can be both shorter and longer than instr_embedding.
        #     # It can be longer if instr_embedding is computed
        #     # for a subbatch of obs.instr.
        #     # It can be shorter if obs.instr is a subbatch of
        #     # the batch that instr_embeddings was computed for.
        #     # Here, we make sure that mask and instr_embeddings
        #     # have equal length along dimension 1.
        #     mask = mask[:, :instr_embedding.shape[1]]
        #     instr_embedding = instr_embedding[:, :mask.shape[1]]
        #
        #     keys = self.memory2key(memory)
        #     pre_softmax = (keys[:, None, :] * instr_embedding).sum(2) + 1000 * mask
        #     attention = F.softmax(pre_softmax, dim=1)
        #     instr_embedding = (instr_embedding * attention[:, :, None]).sum(1)

        t_slot_memory_model, (output, memory, log_time_slot_memory_model) = timer(self.slot_memory_model)(obs, memory, obs_batch, m_batch)

        t_scatter_sum, embedding = timer(scatter_sum)(output, m_batch.type(torch.LongTensor))

        # if self.use_instr and not "filmcnn" in self.arch:
        #     embedding = torch.cat((embedding, instr_embedding), dim=1)

        if hasattr(self, 'aux_info') and self.aux_info:
            extra_predictions = {info: self.extra_heads[info](embedding) for info in self.extra_heads}
        else:
            extra_predictions = dict()

        t_actor, x = timer(self.actor)(embedding)
        dist = Categorical(logits=F.log_softmax(x))

        t_critic, x = timer(self.critic)(embedding)
        value = x

        log_time = {'t_slot_memory_model': t_slot_memory_model,
                    'details_memory_model': log_time_slot_memory_model, 't_scatter_sum': t_scatter_sum,
                    't_actor': t_actor, 't_critic': t_critic}

        return {'dist': dist, 'value': value, 'memory': memory, 'extra_predictions': extra_predictions,
                'log_time': log_time}
Esempio n. 3
0
    def forward(self, x, batch, ei):
        # batch is the batch index tensor
        # TODO: implement for source and target tensors ?
        # that way we get rid of excess computation

        src, dest = ei

        B = batch[-1] + 1
        H = self.nheads
        Fh = self.Fqk // H
        Fhv = self.Fv // H

        scaling = float(Fh) ** -0.5
        q, k, v = self.proj(x).split([self.Fqk, self.Fqk, self.Fv], dim=-1)

        q = q * scaling
        q = q.reshape(-1, H, Fh)
        k = k.reshape(-1, H, Fh)
        v = v.reshape(-1, H, Fhv)

        # print(src, dest)

        qs, ks, vs = q[src], k[dest], v[dest]
        # dot product
        t0 = time.time()
        aw = qs.view(-1, H, 1, Fh) @ ks.view(-1, H, Fh, 1)
        t_mhsa_aw = time.time() - t0
        aw = aw.squeeze()
        # softmax reduction
        t_mhsa_scatter_softmax, aw = timer(scatter_softmax)(aw, src)
        out = aw.view([-1, H, 1]) * vs
        t_mhsa_scatter_sum, out = timer(scatter_sum)(out, src)
        out = out.reshape([-1, H * Fhv])

        log_time = {'t_mhsa_aw': t_mhsa_aw, 't_mhsa_scatter_softmax': t_mhsa_scatter_softmax,
                    't_mhsa_scatter_sum': t_mhsa_scatter_sum}

        return out, log_time
Esempio n. 4
0
    def update_parameters(self):
        # Collect experiences
        t_collect, (exps, logs) = timer(self.collect_experiences)()
        logs['t_collect'] = t_collect
        '''
        exps is a DictList with the following keys ['obs', 'memory', 'mask', 'action', 'value', 'reward',
         'advantage', 'returnn', 'log_prob'] and ['collected_info', 'extra_predictions'] if we use aux_info
        exps.obs is a DictList with the following keys ['image', 'instr']
        exps.obj.image is a (n_procs * n_frames_per_proc) x image_size 4D tensor
        exps.obs.instr is a (n_procs * n_frames_per_proc) x (max number of words in an instruction) 2D tensor
        exps.memory is a (n_procs * n_frames_per_proc) x (memory_size = 2*image_embedding_size) 2D tensor
        exps.mask is (n_procs * n_frames_per_proc) x 1 2D tensor
        if we use aux_info: exps.collected_info and exps.extra_predictions are DictLists with keys
        being the added information. They are either (n_procs * n_frames_per_proc) 1D tensors or
        (n_procs * n_frames_per_proc) x k 2D tensors where k is the number of classes for multiclass classification
        '''
        t0_train = time.time()
        t_details_train_forward_model = {}
        t_train_backward = 0
        # objs[torch.sum(torch.stack([idx == i for i in indices]), dim=0).nonzero().flatten()]
        n = 0
        for _ in range(self.epochs):
            n = n + 1
            # Initialize log values

            log_entropies = []
            log_values = []
            log_policy_losses = []
            log_value_losses = []
            log_grad_norms = []

            log_losses = []
            '''
            For each epoch, we create int(total_frames / batch_size + 1) batches, each of size batch_size (except
            maybe the last one. Each batch is divided into sub-batches of size recurrence (frames are contiguous in
            a sub-batch), but the position of each sub-batch in a batch and the position of each batch in the whole
            list of frames is random thanks to self._get_batches_starting_indexes().
            '''

            for inds in self._get_batches_starting_indexes():
                # inds is a numpy array of indices that correspond to the beginning of a sub-batch
                # there are as many inds as there are batches
                # Initialize batch values

                batch_entropy = 0
                batch_value = 0
                batch_policy_loss = 0
                batch_value_loss = 0
                batch_loss = 0

                # Initialize memory

                # Extract first memories
                inds_mem = [
                    item for sublist in [
                        list(
                            range(
                                self.acmodel.memory_dim[0] *
                                i, self.acmodel.memory_dim[0] * i +
                                self.acmodel.memory_dim[0])) for i in inds
                    ] for item in sublist
                ]
                memory = exps.memory[inds_mem]

                all_obs_inds = exps.obs[1].image
                sb = DictList()

                for i in range(self.recurrence):

                    # Extract scene level quantities:
                    sb.action = exps.action[inds + i]
                    sb.log_prob = exps.log_prob[inds + i]
                    sb.advantage = exps.advantage[inds + i]
                    sb.value = exps.value[inds + i]
                    sb.returnn = exps.returnn[inds + i]

                    m_batch = torch.IntTensor([
                        j + i for j in inds
                        for _ in range(self.acmodel.memory_dim[0])
                    ])

                    # Extract subatch of observation and observation batch indices
                    sb.obs = torch.zeros((0, self.acmodel.image_dim))
                    sb.obs_batch = torch.zeros(0).int()
                    for j in inds + i:
                        idx_j = all_obs_inds == j
                        sb.obs = torch.cat([sb.obs, exps.obs[0].image[idx_j]],
                                           dim=0)
                        sb.obs_batch = torch.cat(
                            [sb.obs_batch, exps.obs[1].image[idx_j]], dim=0)

                    # TODO rename obs[0] and obs[1] into obs.obs and obs.obs_batch

                    # Reshape mask
                    sb.mask = exps.mask[list(
                        numpy.array(inds_mem) + self.acmodel.memory_dim[0] *
                        (i))].flatten()

                    # Compute loss
                    model_results = self.acmodel(sb.obs,
                                                 sb.mask.unsqueeze(1) * memory,
                                                 sb.obs_batch, m_batch)
                    dist = model_results['dist']
                    value = model_results['value']
                    memory = model_results['memory']
                    extra_predictions = model_results['extra_predictions']

                    entropy = dist.entropy().mean()

                    t_details_train_forward_model = cumulate_value(
                        t_details_train_forward_model,
                        model_results['log_time'])

                    ratio = torch.exp(dist.log_prob(sb.action) - sb.log_prob)
                    surr1 = ratio * sb.advantage
                    surr2 = torch.clamp(ratio, 1.0 - self.clip_eps,
                                        1.0 + self.clip_eps) * sb.advantage
                    policy_loss = -torch.min(surr1, surr2).mean()

                    value_clipped = sb.value + torch.clamp(
                        value - sb.value, -self.clip_eps, self.clip_eps)
                    surr1 = (value - sb.returnn).pow(2)
                    surr2 = (value_clipped - sb.returnn).pow(2)
                    value_loss = torch.max(surr1, surr2).mean()

                    loss = policy_loss - self.entropy_coef * entropy + self.value_loss_coef * value_loss

                    # Update batch values

                    batch_entropy += entropy.item()
                    batch_value += value.mean().item()
                    batch_policy_loss += policy_loss.item()
                    batch_value_loss += value_loss.item()
                    batch_loss += loss

                    # Update memories for next epoch

                    if i < self.recurrence - 1:
                        exps.memory[list(
                            numpy.array(inds_mem) +
                            self.acmodel.memory_dim[0] *
                            (i + 1))] = memory.detach()

                # Update batch values

                batch_entropy /= self.recurrence
                batch_value /= self.recurrence
                batch_policy_loss /= self.recurrence
                batch_value_loss /= self.recurrence
                batch_loss /= self.recurrence

                # Update actor-critic
                t0_train_backward = time.time()
                self.optimizer.zero_grad()
                batch_loss.backward()
                grad_norm = sum(
                    p.grad.data.norm(2)**2 for p in self.acmodel.parameters()
                    if p.grad is not None)**0.5
                torch.nn.utils.clip_grad_norm_(self.acmodel.parameters(),
                                               self.max_grad_norm)
                self.optimizer.step()
                t_train_backward += time.time() - t0_train_backward
                # Update log values

                log_entropies.append(batch_entropy)
                log_values.append(batch_value)
                log_policy_losses.append(batch_policy_loss)
                log_value_losses.append(batch_value_loss)
                log_grad_norms.append(grad_norm.item())
                log_losses.append(batch_loss.item())

        t_train = time.time() - t0_train

        # Log some values

        logs["entropy"] = numpy.mean(log_entropies)
        logs["value"] = numpy.mean(log_values)
        logs["policy_loss"] = numpy.mean(log_policy_losses)
        logs["value_loss"] = numpy.mean(log_value_losses)
        logs["grad_norm"] = numpy.mean(log_grad_norms)
        logs["loss"] = numpy.mean(log_losses)
        logs['t_collect'] = t_collect
        logs['t_train'] = t_train
        logs['t_details_train_forward_mordel'] = t_details_train_forward_model
        logs['t_backward'] = t_train_backward
        return logs
Esempio n. 5
0
    def collect_experiences(self):
        """Collects rollouts and computes advantages.

        Runs several environments concurrently. The next actions are computed
        in a batch mode for all environments at the same time. The rollouts
        and advantages from all environments are concatenated together.

        Returns
        -------
        exps : DictList
            Contains actions, rewards, advantages etc as attributes.
            Each attribute, e.g. `exps.reward` has a shape
            (self.num_frames_per_proc * num_envs, ...). k-th block
            of consecutive `self.num_frames_per_proc` frames contains
            data obtained from the k-th environment. Be careful not to mix
            data from different environments!
        logs : dict
            Useful stats about the training process, including the average
            reward, policy loss, value loss, etc.

        """

        t0 = time.time()
        t_forward_process = 0
        t_forward_step = 0
        t_details_forward_model = {}
        for i in range(self.num_frames_per_proc):
            # Do one agent-environment interaction
            tt_process, preprocessed_obs = timer(self.preprocess_obss)(
                self.obs, device=self.device)
            t_forward_process += tt_process
            obs_flat = preprocessed_obs.image[0]
            obs_batch = preprocessed_obs.image[1]

            with torch.no_grad():
                model_results = self.acmodel(
                    obs_flat,
                    self.mask.unsqueeze(1) * self.memory, obs_batch,
                    self.m_batch)
                dist = model_results['dist']
                value = model_results['value'].flatten()
                memory = model_results['memory']
                extra_predictions = model_results['extra_predictions']

            t_details_forward_model = cumulate_value(t_details_forward_model,
                                                     model_results['log_time'])

            action = dist.sample()

            tt_step, (obs, reward, done,
                      env_info) = timer(self.env.step)(action.cpu().numpy())
            t_forward_step += tt_step

            if self.aux_info:
                env_info = self.aux_info_collector.process(env_info)
                # env_info = self.process_aux_info(env_info)

            # Update experiences values

            self.obss[i] = self.obs
            self.obs = obs

            self.memories[i] = self.memory
            self.memory = memory

            self.masks[i] = self.mask
            done_as_int = torch.tensor(done,
                                       device=self.device,
                                       dtype=torch.float).unsqueeze(1)
            self.mask = 1 - done_as_int.expand(
                done_as_int.shape[0], self.acmodel.memory_size[0]).flatten()

            self.actions[i] = action
            self.values[i] = value
            if self.reshape_reward is not None:
                self.rewards[i] = torch.tensor([
                    self.reshape_reward(obs_, action_, reward_, done_)
                    for obs_, action_, reward_, done_ in zip(
                        obs, action, reward, done)
                ],
                                               device=self.device)
            else:
                self.rewards[i] = torch.tensor(reward, device=self.device)
            self.log_probs[i] = dist.log_prob(action)

            if self.aux_info:
                self.aux_info_collector.fill_dictionaries(
                    i, env_info, extra_predictions)

            # Update log values

            self.log_episode_return += torch.tensor(reward,
                                                    device=self.device,
                                                    dtype=torch.float)
            self.log_episode_reshaped_return += self.rewards[i]
            self.log_episode_num_frames += torch.ones(self.num_procs,
                                                      device=self.device)

            for i, done_ in enumerate(done):
                if done_:
                    self.log_done_counter += 1
                    self.log_return.append(self.log_episode_return[i].item())
                    self.log_reshaped_return.append(
                        self.log_episode_reshaped_return[i].item())
                    self.log_num_frames.append(
                        self.log_episode_num_frames[i].item())

            episode_mask = torch.tensor([
                self.mask[i * self.acmodel.memory_size[0]]
                for i in range(self.num_procs)
            ])
            self.log_episode_return *= episode_mask
            self.log_episode_reshaped_return *= episode_mask
            self.log_episode_num_frames *= episode_mask

        t_collect_forward = time.time() - t0

        # Add advantage and return to experiences
        t0 = time.time()
        preprocessed_obs = self.preprocess_obss(self.obs, device=self.device)
        with torch.no_grad():
            # TODO: Add split obs_flat, obs_batch in preprocess_obss ?
            obs_flat = preprocessed_obs.image[0]
            obs_batch = preprocessed_obs.image[1]
            next_value = self.acmodel(obs_flat,
                                      self.mask.unsqueeze(1) * self.memory,
                                      obs_batch,
                                      self.m_batch)['value'].flatten()

        for i in reversed(range(self.num_frames_per_proc)):
            next_mask = torch.tensor([
                self.masks[i + 1][j * self.acmodel.memory_size[0]]
                for j in range(self.num_procs)
            ]) if i < self.num_frames_per_proc - 1 else torch.tensor([
                self.mask[j * self.acmodel.memory_size[0]]
                for j in range(self.num_procs)
            ])

            next_value = self.values[
                i + 1] if i < self.num_frames_per_proc - 1 else next_value
            next_advantage = self.advantages[
                i + 1] if i < self.num_frames_per_proc - 1 else 0
            delta = self.rewards[
                i] + self.discount * next_value * next_mask - self.values[i]
            self.advantages[
                i] = delta + self.discount * self.gae_lambda * next_advantage * next_mask

        t_collect_backward = time.time() - t0

        # Flatten the data correctly, making sure that
        # each episode's data is a continuous chunk
        t0 = time.time()
        exps = DictList()
        exps.obs = [
            self.obss[i][j] for j in range(self.num_procs)
            for i in range(self.num_frames_per_proc)
        ]

        # In commments below T is self.num_frames_per_proc, P is self.num_procs,
        # D is the dimensionality and M the number of memory slots

        # T x (P * M) x D -> T x P x M x D -> P x T x M x D -> (P * T * M) x D
        exps.memory = self.memories.reshape(
            (self.num_frames_per_proc, self.num_procs,
             self.acmodel.memory_size[0],
             self.acmodel.memory_size[1])).transpose(0, 1).reshape(
                 -1, *self.memories.shape[2:])

        # T x (P * M) -> T x P x M -> P x T x M -> (P * T * M) x 1
        exps.mask = self.masks.reshape(self.num_frames_per_proc,
                                       self.num_procs,
                                       self.acmodel.memory_size[0]).transpose(
                                           0, 1).reshape(-1).unsqueeze(1)

        # for all tensors below, T x P -> P x T -> P * T
        exps.action = self.actions.transpose(0, 1).reshape(-1)
        exps.value = self.values.transpose(0, 1).reshape(-1)
        exps.reward = self.rewards.transpose(0, 1).reshape(-1)
        exps.advantage = self.advantages.transpose(0, 1).reshape(-1)
        exps.returnn = exps.value + exps.advantage
        exps.log_prob = self.log_probs.transpose(0, 1).reshape(-1)

        t_organize_exp = time.time() - t0
        if self.aux_info:
            exps = self.aux_info_collector.end_collection(exps)

        # Preprocess experiences

        exps.obs = self.preprocess_obss(exps.obs, device=self.device)

        # Log some values

        keep = max(self.log_done_counter, self.num_procs)

        log = {
            "return_per_episode": self.log_return[-keep:],
            "reshaped_return_per_episode": self.log_reshaped_return[-keep:],
            "num_frames_per_episode": self.log_num_frames[-keep:],
            "num_frames": self.num_frames,
            "episodes_done": self.log_done_counter,
            "t_collect_forward": t_collect_forward,
            "t_details_forward_model": t_details_forward_model,
            "t_forward_process": t_forward_process,
            "t_forward_step": t_forward_step,
            "t_collect_backward": t_collect_backward,
            "t_collect_organize": t_organize_exp
        }

        self.log_done_counter = 0
        self.log_return = self.log_return[-self.num_procs:]
        self.log_reshaped_return = self.log_reshaped_return[-self.num_procs:]
        self.log_num_frames = self.log_num_frames[-self.num_procs:]

        return exps, log