Ejemplo n.º 1
0
    def __getitem__(self, idx):

        mol = self.mols[idx]
        NUM_ATOMS = mol.GetNumAtoms()

        #each value in pred_vals is a dictionary containing key value pairs of atom numbers and chem shift vals
        # pred_val returns the dictionary for the appropriate index (molecule)
        pred_val = self.pred_vals[idx]

        conf_idx = np.random.randint(
            mol.GetNumConformers()
        )  #returns random number between 0 and numconformers; this is used to randomly select a conformation

        #f_vect is a 2d tensor containing atom features; inner tensors represent one atom and contain features
        #shape of f_vect is num_atomsxnum_features
        # f_vect, atom_types = atom_features.feat_tensor_atom(mol, conf_idx=conf_idx, **self.feat_vert_args)
        f_vect, atom_types = atom_features.feat_tensor_atom(
            mol, conf_idx=conf_idx, **self.feat_vert_args)

        edge_index, edge_attr = molecule_features.get_edge_attr_and_ind(mol)

        #pred_val is a dictionary containing key value pairs of atom numbers and chem shift vals
        target = np.zeros((NUM_ATOMS, 1), dtype=np.float32)  #64x1
        mask = np.zeros((NUM_ATOMS, 1), dtype=np.float32)  #64x1
        for pn in range(self.PRED_N):
            for k, v in pred_val[pn].items():
                target[int(k), pn] = v
                mask[int(k), pn] = 1.0

        mask = torch.FloatTensor(mask).flatten()
        target = torch.FloatTensor(target).flatten()

        v = (f_vect, atom_types, edge_index, edge_attr, mask, target)

        # v = (f_vect, edge_index, edge_attr, mask, target)

        return v
Ejemplo n.º 2
0
    def __getitem__(self, idx):

        mol = self.mols[idx]
        pred_val = self.pred_vals[idx]

        conf_idx = np.random.randint(mol.GetNumConformers())

        if self.cache_key(idx, conf_idx) in self.cache:
            return self.mask_sel(self.cache[self.cache_key(idx, conf_idx)])

        f_vect = atom_features.feat_tensor_atom(mol,
                                                conf_idx=conf_idx,
                                                **self.feat_vert_args)

        DATA_N = f_vect.shape[0]

        vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32)
        vect_feat[:DATA_N] = f_vect

        f_mat = molecule_features.feat_tensor_mol(mol,
                                                  conf_idx=conf_idx,
                                                  **self.feat_edge_args)

        if self.combine_mat_vect:
            MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1]
        else:
            MAT_CHAN = f_mat.shape[2]

        mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN),
                            dtype=np.float32)
        # do the padding
        mat_feat[:DATA_N, :DATA_N, :f_mat.shape[2]] = f_mat

        if self.combine_mat_vect == 'row':
            # row-major
            for i in range(DATA_N):
                mat_feat[i, :DATA_N, f_mat.shape[2]:] = f_vect
        elif self.combine_mat_vect == 'col':
            # col-major
            for i in range(DATA_N):
                mat_feat[:DATA_N, i, f_mat.shape[2]:] = f_vect

        adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args)
        adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N))
        adj[:, :adj_nopad.shape[1], :adj_nopad.shape[2]] = adj_nopad

        # create mask and preds

        mask = np.zeros(self.MAX_N, dtype=np.float32)
        vals = np.zeros(self.MAX_N, dtype=np.float32)
        for k, v in pred_val.items():
            mask[k] = 1.0
            vals[k] = v

        v = (
            adj,
            vect_feat,
            mat_feat,
            vals,
            mask,
        )

        self.cache[self.cache_key(idx, conf_idx)] = v
        return self.mask_sel(v)
Ejemplo n.º 3
0
    def __getitem__(self, idx):

        mol = self.mols[idx]
        vect_pred_val = self.pred_vec_vals[idx]
        mat_pred_val = self.pred_mat_vals[idx]

        conf_idx = np.random.randint(mol.GetNumConformers())

        if self.cache_key(idx, conf_idx) in self.cache:
            return self.mask_sel(self.cache[self.cache_key(idx, conf_idx)])
        
        f_vect = atom_features.feat_tensor_atom(mol, conf_idx=conf_idx, 
                                                **self.feat_vert_args)
                                                
        DATA_N = f_vect.shape[0]
        
        vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32)
        vect_feat[:DATA_N] = f_vect

        f_mat = molecule_features.feat_tensor_mol(mol, conf_idx=conf_idx,
                                                  **self.feat_edge_args) 

        if self.combine_mat_vect:
            MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1]
        else:
            MAT_CHAN = f_mat.shape[2]
        if MAT_CHAN == 0: # Dataloader can't handle tensors with empty dimensions
            MAT_CHAN = 1
        mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN), dtype=np.float32)
        # do the padding
        mat_feat[:DATA_N, :DATA_N, :f_mat.shape[2]] = f_mat  
        
        if self.combine_mat_vect == 'row':
            # row-major
            for i in range(DATA_N):
                mat_feat[i, :DATA_N, f_mat.shape[2]:] = f_vect
        elif self.combine_mat_vect == 'col':
            # col-major
            for i in range(DATA_N):
                mat_feat[:DATA_N, i, f_mat.shape[2]:] = f_vect

        adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args)
        adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N))
        adj[:, :adj_nopad.shape[1], :adj_nopad.shape[2]] = adj_nopad
                        
        # create vect_mask and preds 
        
        vect_mask = np.zeros((self.MAX_N, self.PRED_N), 
                        dtype=np.float32)
        vect_vals = np.zeros((self.MAX_N, self.PRED_N), 
                             dtype=np.float32)
        #print(self.PRED_N, pred_val)
        for pn in range(self.PRED_N):
            for k, v in vect_pred_val[pn].items():
                vect_mask[int(k), pn] = 1.0
                vect_vals[int(k), pn] = v
        # create matrix mask and preds
        mat_mask = np.zeros((self.MAX_N, self.MAX_N, self.PRED_N), 
                        dtype=np.float32)
        mat_vals = np.zeros((self.MAX_N, self.MAX_N, self.PRED_N), 
                             dtype=np.float32)
        #print(self.PRED_N, pred_val)
        for pn in range(self.PRED_N):
            for (k1, k2), v in mat_pred_val[pn].items():
                mat_mask[int(k1), int(k2), pn] = 1.0
                mat_vals[int(k1), int(k2), pn] = v

                mat_mask[int(k2), int(k1), pn] = 1.0
                mat_vals[int(k2), int(k1), pn] = v
                
        # ADJ should be (N, N, features)
        adj = np.transpose(adj, axes=(1, 2, 0))
                

        v = (adj, vect_feat, mat_feat, 
             vect_vals, 
             vect_mask,
             mat_vals, 
             mat_mask)
        
        
        self.cache[self.cache_key(idx, conf_idx)] = v

        return self.mask_sel(v)
Ejemplo n.º 4
0
    def __getitem__(self, idx):
        if self.frac_per_epoch < 1.0:
            # randomly get an index each time
            idx = np.random.randint(len(self.mols))

        mol = self.mols[idx]
        pred_val = self.pred_vals[idx]
        whole_record = self.whole_records[idx]

        conf_idx = 0

        if self.cache is not None and self.cache_key(idx,
                                                     conf_idx) in self.cache:
            return self.cache[self.cache_key(idx, conf_idx)]

        # mol features
        f_mol = molecule_features.whole_molecule_features(
            whole_record, **self.mol_args)

        f_vect = atom_features.feat_tensor_atom(mol,
                                                conf_idx=conf_idx,
                                                **self.feat_vert_args)
        if self.combine_mol_vect:
            f_vect = torch.cat(
                [f_vect,
                 f_mol.reshape(1, -1).expand(f_vect.shape[0], -1)], -1)
        # process extra data arguments
        for extra_data_config in self.extra_npy_filenames:
            filename = extra_data_config['filenames'][idx]
            combine_with = extra_data_config.get('combine_with', None)
            if combine_with == 'vert':
                npy_data = np.load(filename)
                npy_data_flatter = npy_data.reshape(f_vect.shape[0], -1)
                f_vect = torch.cat(
                    [f_vect, torch.Tensor(npy_data_flatter)], dim=-1)
            elif combine_with is None:
                continue
            else:

                raise NotImplementedError(
                    f"the combinewith {combine_with} not working yet")
        DATA_N = f_vect.shape[0]

        vect_feat = np.zeros((self.MAX_N, f_vect.shape[1]), dtype=np.float32)
        vect_feat[:DATA_N] = f_vect

        f_mat = molecule_features.feat_tensor_mol(mol,
                                                  conf_idx=conf_idx,
                                                  **self.feat_edge_args)

        if self.combine_mat_vect:
            MAT_CHAN = f_mat.shape[2] + vect_feat.shape[1]
        else:
            MAT_CHAN = f_mat.shape[2]
        if MAT_CHAN == 0:  # Dataloader can't handle tensors with empty dimensions
            MAT_CHAN = 1
        mat_feat = np.zeros((self.MAX_N, self.MAX_N, MAT_CHAN),
                            dtype=np.float32)
        # do the padding
        mat_feat[:DATA_N, :DATA_N, :f_mat.shape[2]] = f_mat

        if self.combine_mat_vect == 'row':
            # row-major
            for i in range(DATA_N):
                mat_feat[i, :DATA_N, f_mat.shape[2]:] = f_vect
        elif self.combine_mat_vect == 'col':
            # col-major
            for i in range(DATA_N):
                mat_feat[:DATA_N, i, f_mat.shape[2]:] = f_vect

        adj_nopad = molecule_features.feat_mol_adj(mol, **self.adj_args)
        adj = torch.zeros((adj_nopad.shape[0], self.MAX_N, self.MAX_N))
        adj[:, :adj_nopad.shape[1], :adj_nopad.shape[2]] = adj_nopad

        if self.combine_mat_feat_adj:
            adj = torch.cat([adj, torch.Tensor(mat_feat).permute(2, 0, 1)], 0)

        ### Simple one-hot encoding for reconstruction
        adj_oh_nopad = molecule_features.feat_mol_adj(
            mol,
            split_weights=[1.0, 1.5, 2.0, 3.0],
            edge_weighted=False,
            norm_adj=False,
            add_identity=False)

        adj_oh = torch.zeros((adj_oh_nopad.shape[0], self.MAX_N, self.MAX_N))
        adj_oh[:, :adj_oh_nopad.shape[1], :adj_oh_nopad.
               shape[2]] = adj_oh_nopad

        ## per-edge features
        feat_edge_dict = edge_features.feat_edges(mol, )

        # pad each of these
        edge_edge_nopad = feat_edge_dict['edge_edge']
        edge_edge = torch.zeros(
            (edge_edge_nopad.shape[0], self.MAX_N, self.MAX_N))

        # edge_edge[:, :edge_edge_nopad.shape[1],
        #              :edge_edge_nopad.shape[2]] = torch.Tensor(edge_edge_nopad)

        edge_feat_nopad = feat_edge_dict['edge_feat']
        edge_feat = torch.zeros((self.MAX_N, edge_feat_nopad.shape[1]))
        # edge_feat[:edge_feat_nopad.shape[0]] = torch.Tensor(edge_feat_nopad)

        edge_vert_nopad = feat_edge_dict['edge_vert']
        edge_vert = torch.zeros(
            (edge_vert_nopad.shape[0], self.MAX_N, self.MAX_N))
        # edge_vert[:, :edge_vert_nopad.shape[1],
        #              :edge_vert_nopad.shape[2]] = torch.Tensor(edge_vert_nopad)

        atomicnos, coords = util.get_nos_coords(mol, conf_idx)
        coords_t = torch.zeros((self.MAX_N, 3))
        coords_t[:len(coords), :] = torch.Tensor(coords)

        # create mask and preds

        pred_mask = np.zeros((self.MAX_N, self.PRED_N), dtype=np.float32)
        vals = np.ones((self.MAX_N, self.PRED_N),
                       dtype=np.float32) * util.PERM_MISSING_VALUE
        #print(self.PRED_N, pred_val)
        if self.spect_assign:
            for pn in range(self.PRED_N):
                if len(pred_val) > 0:  # when empty, there's nothing to predict
                    atom_idx = [int(k) for k in pred_val[pn].keys()]
                    obs_vals = [pred_val[pn][i] for i in atom_idx]
                    # if self.shuffle_observations:
                    #     obs_vals = np.random.permutation(obs_vals)
                    for k, v in zip(atom_idx, obs_vals):
                        pred_mask[k, pn] = 1.0
                        vals[k, pn] = v
        else:
            if self.PRED_N > 1:
                raise NotImplementedError()
            vals[:] = util.PERM_MISSING_VALUE  # sentinel value
            for k in pred_val[0][0]:
                pred_mask[k, 0] = 1
            for vi, v in enumerate(pred_val[0][1]):
                vals[vi] = v

        # input mask
        input_mask = torch.zeros(self.MAX_N)
        input_mask[:DATA_N] = 1.0

        v = {
            'adj': adj,
            'vect_feat': vect_feat,
            'mat_feat': mat_feat,
            'mol_feat': f_mol,
            'vals': vals,
            'adj_oh': adj_oh,
            'pred_mask': pred_mask,
            'coords': coords_t,
            'input_mask': input_mask,
            'input_idx': idx,
            'edge_edge': edge_edge,
            'edge_vert': edge_vert,
            'edge_feat': edge_feat
        }

        ## add on extra args
        for ei, extra_data_config in enumerate(self.extra_npy_filenames):
            filename = extra_data_config['filenames'][idx]
            combine_with = extra_data_config.get('combine_with', None)
            if combine_with is None:
                ## this is an extra arg
                npy_data = np.load(filename)

                ## Zero pad
                npy_shape = list(npy_data.shape)
                npy_shape[0] = self.MAX_N
                t_pad = torch.zeros(npy_shape)
                t_pad[:npy_data.shape[0]] = torch.Tensor(npy_data)

                v[f'extra_data_{ei}'] = t_pad

        for k, kv in v.items():
            assert np.isfinite(kv).all()
        if self.cache is not None:
            self.cache[self.cache_key(idx, conf_idx)] = v

        return v