예제 #1
0
def slidingTopK(h, K, M, mask=None, stride=1):
    """ Performs KNN on each input pixel with a window of MxM.
	ONLY STRIDE==1 WORKS FOR NOW...
	"""
    if stride != 1:
        raise NotImplementedError
    # form index set that follows the reflection padding of input vector
    index = torch.arange(h.shape[-2] * h.shape[-1]).reshape(
        1, 1, h.shape[-2], h.shape[-1]).float()
    index = utils.conv_pad(index, M, mode='reflect')
    hp = utils.conv_pad(h, M, mode='reflect')
    hs = utils.stack(hp, M, stride)  # (B,I,J,C,M,M)
    B, I, J = hs.shape[:3]
    hbs = utils.batch_stack(hs)  # (BIJ, C, M, M)
    ibs = utils.batch_stack(utils.stack(index, M, stride))
    cpx = (M - 1) // 2
    pad = (int(np.floor((stride - 1) / 2)), int(np.ceil((stride - 1) / 2)))
    v = hbs[..., (cpx - pad[0]):(cpx + pad[1] + 1),
            (cpx - pad[0]):(cpx + pad[1] + 1)]
    S = v.shape[-1]
    print(f"forming adjacency matrix...")
    G = graphAdj(v, hbs, mask)  # (BIJ, SS, MM)
    ibs = ibs.reshape(B * I * J, 1, M * M)
    edge = torch.topk(G, K, largest=False).indices
    edge = edge + torch.arange(0, B * I * J, device=h.device).reshape(
        -1, 1, 1) * M * M
    edge = torch.index_select(ibs.reshape(-1, 1), 0, edge.flatten())
    edge = edge.reshape(B * I * J, S * S, K).permute(0, 2,
                                                     1).reshape(-1, K, S, S)
    edge = utils.unbatch_stack(edge, (I, J))
    edge = utils.unstack(edge)
    return edge.long()
예제 #2
0
def collate_lep(batch):
    """
    Collates LEP datapoints into the batch format for Cormorant.
    
    :param batch: The data to be collated.
    :type batch: list of datapoints

    :param batch: The collated data.
    :type batch: dict of Pytorch tensors

    """
    batch = {prop: batch_stack([mol[prop] for mol in batch]) for prop in batch[0].keys()}
    # Define which fields to keep 
    to_keep1 = (batch['charges_active'].sum(0) > 0)
    to_keep2 = (batch['charges_inactive'].sum(0) > 0)
    # Start building the new batch
    new_batch = {}
    # Copy label data. 
    new_batch['label'] = batch['label']
    # Split structural data and drop zeros
    for key in ['charges','positions','one_hot']:
        new_batch[key+'1'] = drop_zeros( batch[key+'_active'], key+'_active', to_keep1 )
        new_batch[key+'2'] = drop_zeros( batch[key+'_inactive'], key+'_inactive', to_keep2 )
    # Define the atom masks
    atom_mask1 = new_batch['charges1'] > 0
    atom_mask2 = new_batch['charges2'] > 0
    new_batch['atom_mask1'] = atom_mask1
    new_batch['atom_mask2'] = atom_mask2
    # Define the edge masks
    edge_mask1 = atom_mask1.unsqueeze(1) * atom_mask1.unsqueeze(2)
    edge_mask2 = atom_mask2.unsqueeze(1) * atom_mask2.unsqueeze(2)
    new_batch['edge_mask1'] = edge_mask1
    new_batch['edge_mask2'] = edge_mask2
    return new_batch
예제 #3
0
def collate_lba(batch):
    """
    Collates LBA datapoints into the batch format for Cormorant.
    
    :param batch: The data to be collated.
    :type batch: list of datapoints

    :param batch: The collated data.
    :type batch: dict of Pytorch tensors

    """
    batch = {prop: batch_stack([mol[prop] for mol in batch]) for prop in batch[0].keys()}
    # Define which fields to keep 
    to_keep = (batch['charges'].sum(0) > 0)
    # Start building the new batch
    new_batch = {}
    # Copy label data. 
    new_batch['neglog_aff'] = batch['neglog_aff']
    # Split structural data and drop zeros
    for key in ['charges','positions','one_hot']:
        new_batch[key] = drop_zeros( batch[key], key, to_keep )
    # Define the atom mask
    atom_mask = new_batch['charges'] > 0
    new_batch['atom_mask'] = atom_mask
    # Define the edge mask
    edge_mask = atom_mask.unsqueeze(1) * atom_mask.unsqueeze(2)
    new_batch['edge_mask'] = edge_mask
    return new_batch
예제 #4
0
def windowedTopK(h, K, M, mask):
    """ Returns top K feature vector indices for 
	h: (B, C, H, W) input feature
	M: window side-length
	mask: (H*W, H*W) Graph mask.
	output: (B, K, H, W) K edge indices (of flattened image) for each pixel
	"""
    # stack image windows
    hs = utils.stack(h, M, M)  # (B,I,J,C,M,M)
    I, J = hs.shape[1], hs.shape[2]
    # move stack to match dimension to build batched Graph Adjacency matrices
    hbs = utils.batch_stack(hs)  # (B*I*J,C,M,M)
    G = graphAdj(hbs, hbs, mask)  # (B*I*J, M*M, M*M)
    # find topK in each window, unbatch the stack, translate window-index to tile index
    # (B*I*J,M*M,K) -> (B*I*J,K,M*M) -> (B*I*J, K, M, M)
    edge = torch.topk(G, K,
                      largest=False).indices.permute(0, 2,
                                                     1).reshape(-1, K, M, M)
    edge = utils.unbatch_stack(edge, (I, J))  # (B,I,J,K,M,M)
    return utils.indexTranslate(edge, M)  # (B,K,H,W)