Beispiel #1
0
    def loss_infomax(self, x_local, x, edge_index, batch):

        l_enc = x_local
        g_enc = x
        measure = 'JSD'

        num_graphs = g_enc.shape[0]
        num_nodes = l_enc.shape[0]

        pos_mask = torch.zeros((num_nodes, num_graphs)).cuda()
        neg_mask = torch.ones((num_nodes, num_graphs)).cuda()
        for nodeidx, graphidx in enumerate(batch):
            pos_mask[nodeidx][graphidx] = 1.
            neg_mask[nodeidx][graphidx] = 0.

        res = torch.mm(l_enc, g_enc.t())
        E_pos = get_positive_expectation(res * pos_mask,
                                         measure,
                                         average=False).sum()
        E_pos = E_pos / num_nodes
        E_neg = get_negative_expectation(res * neg_mask,
                                         measure,
                                         average=False).sum()
        E_neg = E_neg / (num_nodes * (num_graphs - 1))
        loss = E_neg - E_pos

        return loss
Beispiel #2
0
    def routine(self, Z_P, measure: str = 'KL'):
        '''

        Args:
            Z_P: Input unshuffled tensor.
            ndm_measure: Measure to compare representations with shuffled versions.

        '''
        Z_Q = random_permute(Z_P)
        E_pos, E_neg, P_samples, Q_samples = self.score(
            Z_P.detach(), Z_Q.detach(), measure)
        difference = E_pos - E_neg
        if measure == 'DV':
            ndm = E_pos - E_neg
        else:
            ndm = get_positive_expectation(
                P_samples, 'DV') - get_negative_expectation(Q_samples, 'DV')
            self.add_results(**{measure: difference.detach().item()})

        self.add_results(Scores={
            'E_P[D(x)]': P_samples.mean().detach().item(),
            'max(D(x))': P_samples.max().detach().item(),
            'E_Q[D(x)]': Q_samples.mean().detach().item()
        },
                         NDM=ndm.detach().item())

        self.add_losses(ndm=-difference)
Beispiel #3
0
def local_global_loss_(l_enc, g_enc, edge_index, batch, measure):
    '''
    Args:
        l: Local feature map.
        g: Global features.
        measure: Type of f-divergence. For use with mode `fd`
        mode: Loss mode. Fenchel-dual `fd`, NCE `nce`, or Donsker-Vadadhan `dv`.
    Returns:
        torch.Tensor: Loss.
    '''
    num_graphs = g_enc.shape[0]
    num_nodes = l_enc.shape[0]

    pos_mask = torch.zeros((num_nodes, num_graphs)).cuda()
    neg_mask = torch.ones((num_nodes, num_graphs)).cuda()
    for nodeidx, graphidx in enumerate(batch):
        pos_mask[nodeidx][graphidx] = 1.
        neg_mask[nodeidx][graphidx] = 0.

    res = torch.mm(l_enc, g_enc.t())

    E_pos = get_positive_expectation(res * pos_mask, measure,
                                     average=False).sum()
    E_pos = E_pos / num_nodes
    E_neg = get_negative_expectation(res * neg_mask, measure,
                                     average=False).sum()
    E_neg = E_neg / (num_nodes * (num_graphs - 1))

    return E_neg - E_pos
Beispiel #4
0
def fenchel_dual_loss(l, g, measure=None):
    '''Computes the f-divergence distance between positive and negative joint distributions.

    Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD),
    Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`.

    Args:
        l: Local feature map.
        g: Global features.
        measure: f-divergence measure.

    Returns:
        torch.Tensor: Loss.

    '''
    N, local_units, n_locs = l.size()
    l = l.permute(0, 2, 1)
    l = l.reshape(-1, local_units)

    u = torch.mm(g, l.t())
    u = u.reshape(N, N, -1)
    mask = torch.eye(N).cuda()
    n_mask = 1 - mask

    E_pos = get_positive_expectation(u, measure, average=False).mean(2)
    E_neg = get_negative_expectation(u, measure, average=False).mean(2)
    E_pos = (E_pos * mask).sum() / mask.sum()
    E_neg = (E_neg * n_mask).sum() / n_mask.sum()
    loss = E_neg - E_pos
    return loss
Beispiel #5
0
def fenchel_dual_loss(l, m, measure=None, neg=False):
    '''Computes the f-divergence distance between positive and negative joint distributions.

    Note that vectors should be sent as 1x1.

    Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD),
    Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`.

    Args:
        l: Local feature map.
        m: Multiple globals feature map.
        measure: f-divergence measure.

    Returns:
        torch.Tensor: Loss.

    '''
    N, units, n_locals = l.size()
    n_multis = m.size(2)

    # First we make the input tensors the right shape.
    l = l.view(N, units, n_locals)
    l = l.permute(0, 2, 1)
    l = l.reshape(-1, units)

    m = m.view(N, units, n_multis)
    m = m.permute(0, 2, 1)
    m = m.reshape(-1, units)

    # Outer product, we want a N x N x n_local x n_multi tensor.
    u = torch.mm(m, l.t())
    u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)

    # Since we have a big tensor with both positive and negative samples, we need to mask.
    mask = torch.eye(N).to(l.device)
    n_mask = 1 - mask

    # Compute the positive and negative score. Average the spatial locations.
    E_pos = get_positive_expectation(u, measure, average=False).mean(2).mean(2)
    E_neg = get_negative_expectation(u, measure, average=False).mean(2).mean(2)

    # Mask positive and negative terms for positive and negative parts of loss
    E_pos = (E_pos * mask).sum() / mask.sum()
    E_neg = (E_neg * n_mask).sum() / n_mask.sum()
    loss = E_neg - E_pos

    if neg:
        return - loss
    else:
        return loss
Beispiel #6
0
    def score(self, X_P, X_Q, measure):
        '''Score real and fake.

        Args:
            X_P: Real tensor (unshuffled).
            X_Q: Fake tensor (shuffled).
            measure: Comparison measure.

        Returns:
            tuple of torch.Tensor: (real expectation, fake expectation, real samples, fake samples)

        '''
        ndm = self.nets.ndm
        P_samples = ndm(X_P)
        Q_samples = ndm(X_Q)

        E_pos = get_positive_expectation(P_samples, measure)
        E_neg = get_negative_expectation(Q_samples, measure)

        return E_pos, E_neg, P_samples, Q_samples