Exemplo n.º 1
0
    def get_z(self, data, return_mask=False):
        patch_len = self.cpc.module.cfg.patch_len if self.parallel else self.cpc.cfg.patch_len

        if isinstance(data, dict):
            x = data['primary'].to(self.device)  # (N, max_L, H_enc)
            prot_len = data['protein_length'].to(self.device)
            num_patches = (prot_len / patch_len).floor()
            mask = torch.tensor(
                pad_sequences([np.ones(int(i)) for i in num_patches]))
        elif isinstance(data, torch.Tensor):
            x = data.to(self.device)
            # not masking anything out in this case
            mask = torch.ones((x.shape[0], x.shape[1] // patch_len))
        else:
            print(
                "Input to CPCProt model must be a Torch tensor or the dictionary returned by training dataloaders."
            )
            raise

        z = self.cpc(x, return_early='z')
        mask = mask.to(dtype=torch.int, device=self.device)

        if return_mask:
            return (z, mask)
        else:
            return z
Exemplo n.º 2
0
    def collate_fn(self, batch: List[Tuple[Any, ...]]) -> Dict[str, torch.Tensor]:
        """ Define a collate_fn to convert the variable length sequences into
            a batch of torch tensors. token ids and mask should be padded with
            zeros. Labels for classification should be padded with -1.
            This takes in a list of outputs from the dataset's __getitem__
            method. You can use the `pad_sequences` helper function to pad
            a list of numpy arrays.
        """
        input_ids, input_mask, scl_label = tuple(zip(*batch))
        input_ids = torch.from_numpy(pad_sequences(input_ids, 0))
        input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
        scl_label = torch.LongTensor(scl_label)

        output = {'input_ids': input_ids,
                  'input_mask': input_mask,
                  'targets': scl_label}

        return output