Ejemplo n.º 1
0
def contract(indices, values, size, x, cuda=None):
    """
    Performs a contraction (generalized matrix multiplication) of a sparse tensor with and input x.

    The contraction is defined so that every element of the output is the sum of every element of the input multiplied
    once by a unique element from the tensor (that is, like a fully connected neural network layer). See the paper for
    details.

    :param indices: (b, k, r)-tensor describing indices of b sparse tensors of rank r
    :param values: (b, k)-tes=nsor with the corresponding values
    :param size:
    :param x:
    :return:
    """
    # translate tensor indices to matrix indices
    if cuda is None:
        cuda = indices.is_cuda

    b, k, r = indices.size()

    # size is equal to out_size + x.size()
    in_size = x.size()[1:]
    out_size = size[:-len(in_size)]

    assert len(out_size) + len(in_size) == r

    # Flatten into a matrix multiplication
    mindices, flat_size = flatten_indices_mat(indices, x.size()[1:], out_size)
    x_flat = x.view(b, -1, 1)

    # Prevent segfault
    assert mindices.min(
    ) >= 0, 'negative index in flattened indices: {} \n {} \n Original indices {} \n {}'.format(
        mindices.size(), mindices, indices.size(), indices)
    assert not util.contains_nan(
        values.data), 'NaN in values:\n {}'.format(values)

    y_flat = batchmm(mindices, values, flat_size, x_flat, cuda)

    return y_flat.view(b, *out_size)  # reshape y into a tensor
Ejemplo n.º 2
0
    def forward_inner(self, input, means, sigmas, values, bias):

        t0total = time.time()

        batchsize = input.size()[0]

        # NB: due to batching, real_indices has shape batchsize x K x rank(W)
        #     real_values has shape batchsize x K

        # print('--------------------------------')
        # for i in range(util.prod(sigmas.size())):
        #     print(sigmas.view(-1)[i].data[0])

        # turn the real values into integers in a differentiable way
        t0 = time.time()

        # max values allowed for each colum in the index matrix
        fullrange = self.out_size + input.size()[1:]
        subrange = [fullrange[r] for r in self.learn_cols]

        if self.subsample is None:
            indices, props, values = self.discretize(
                means,
                sigmas,
                values,
                rng=subrange,
                use_cuda=self.use_cuda,
                relative_range=self.region)
            b, l, r = indices.size()

            # pr = indices.view(-1, r)
            # if torch.sum(pr > torch.cuda.LongTensor(subrange).unsqueeze(0).expand_as(pr)) > 0:
            #     for i in range(b*l):
            #         print(pr[i, :])

            h, w = self.temp_indices.size()
            template = self.temp_indices.unsqueeze(0).unsqueeze(2).expand(
                b, h, (2**r + self.gadditional + self.radditional), w)
            template = template.contiguous().view(b, l, w)

            template[:, :, self.learn_cols] = indices
            indices = template

            values = values * props

        else:  # select a small proportion of the indices to learn over
            raise Exception('Not supported yet.')

            # b, k, r = means.size()
            #
            # prop = torch.cuda.FloatTensor([self.subsample]) if self.use_cuda else torch.FloatTensor([self.subsample])
            #
            # selection = None
            # while (selection is None) or (float(selection.sum()) < 1):
            #     selection = torch.bernoulli(prop.expand(k)).byte()
            #
            # mselection = selection.unsqueeze(0).unsqueeze(2).expand_as(means)
            # sselection = selection.unsqueeze(0).unsqueeze(2).expand_as(sigmas)
            # vselection = selection.unsqueeze(0).expand_as(values)
            #
            # means_in, means_out = means[mselection].view(b, -1, r), means[~ mselection].view(b, -1, r)
            # sigmas_in, sigmas_out = sigmas[sselection].view(b, -1, r), sigmas[~ sselection].view(b, -1, r)
            # values_in, values_out = values[vselection].view(b, -1), values[~ vselection].view(b, -1)
            #
            # means_out = means_out.detach()
            # values_out = values_out.detach()
            #
            # indices_in, props, values_in = self.discretize(means_in, sigmas_in, values_in, rng=rng, additional=self.additional, use_cuda=self.use_cuda)
            # values_in = values_in * props
            #
            # indices_out = means_out.data.round().long()
            #
            # indices = torch.cat([indices_in, indices_out], dim=1)
            # values = torch.cat([values_in, values_out], dim=1)

        logging.info('discretize: {} seconds'.format(time.time() - t0))

        if self.use_cuda:
            indices = indices.cuda()

        # translate tensor indices to matrix indices
        t0 = time.time()

        # mindices, flat_size = flatten_indices(indices, input.size()[1:], self.out_shape, self.use_cuda)
        mindices, flat_size = flatten_indices_mat(indices,
                                                  input.size()[1:],
                                                  self.out_size)

        logging.info('flatten: {} seconds'.format(time.time() - t0))

        # NB: mindices is not an autograd Variable. The error-signal for the indices passes to the hypernetwork
        #     through 'values', which are a function of both the real_indices and the real_values.

        ### Create the sparse weight tensor

        x_flat = input.view(batchsize, -1)

        sparsemult = util.sparsemult(self.use_cuda)

        t0 = time.time()

        # Prevent segfault
        assert mindices.min() >= 0
        assert not util.contains_nan(values.data)

        # Then we flatten the batch dimension as well
        bm = util.bmult(flat_size[1], flat_size[0],
                        mindices.size()[1], batchsize, self.use_cuda)
        bfsize = Variable(flat_size * batchsize)

        bfindices = mindices + bm
        bfindices = bfindices.view(1, -1, 2).squeeze(0)
        vindices = Variable(bfindices.t())

        #- bfindices is now a sparse representation of a big block-diagonal matrix (nxb times mxb), with the batches along the
        #  diagonal (and the rest zero). We flatten x over all batches, and multiply by this to get a flattened y.

        # print(bfindices.size(), flat_size)
        # print(bfindices)

        bfvalues = values.view(1, -1).squeeze(0)
        bfx = x_flat.view(1, -1).squeeze(0)

        bfy = sparsemult(vindices, bfvalues, bfsize, bfx)

        y_flat = bfy.unsqueeze(0).view(batchsize, -1)

        y_shape = [batchsize]
        y_shape.extend(self.out_size)

        y = y_flat.view(y_shape)  # reshape y into a tensor

        ### Handle the bias
        if self.bias_type == Bias.DENSE:
            y = y + bias
        if self.bias_type == Bias.SPARSE:
            raise Exception('Not implemented yet.')

        return y
Ejemplo n.º 3
0
def finish_episode(policy, update=True):
    policy_loss = []
    all_cum_rewards = []
    for i in range(args.n_rollouts):
        rewards = []
        R = np.zeros(len(policy.rewards[i][0]))
        for r in policy.rewards[i][::-1]:
            R = r + args.gamma * R
            rewards.insert(0, R)
        all_cum_rewards.extend(rewards)
        rewards = torch.Tensor(rewards)  # (length, batch_size)
        # logger.warning(f'original {rewards}')
        if args.baseline == 'avg':
            rewards -= policy.baseline_reward
        elif args.baseline == 'greedy':
            rewards_greedy = []
            R = np.zeros(len(policy.rewards_greedy[0]))
            for r in policy.rewards_greedy[::-1]:
                R = r + args.gamma * R
                rewards_greedy.insert(0, R)
            rewards_greedy = torch.Tensor(rewards_greedy)
            rewards -= rewards_greedy
        # logger.warning(f'after baseline {rewards}')
        if args.avg_reward_mode == 'batch':
            rewards = Variable(
                (rewards - rewards.mean()) /
                (rewards.std() + float(np.finfo(np.float32).eps)))
        elif args.avg_reward_mode == 'each':
            # mean/std is separate for each in the batch
            rewards = Variable(
                (rewards - rewards.mean(dim=0)) /
                (rewards.std(dim=0) + float(np.finfo(np.float32).eps)))
        else:
            rewards = Variable(rewards)
        if args.gpu:
            rewards = rewards.cuda()
        for log_prob, reward in zip(policy.saved_log_probs[i], rewards):
            policy_loss.append(-log_prob * reward)
        # logger.warning(f'after mean_std {rewards}')
    if update:
        tree.n_update += 1
        try:
            policy_loss = torch.cat(policy_loss).mean()
        except Exception as e:
            logger.error(e)
        entropy = torch.cat(policy.entropy_l).mean()
        writer.add_scalar('data/policy_loss', policy_loss, tree.n_update)
        writer.add_scalar('data/entropy_loss', policy.beta * entropy.data[0],
                          tree.n_update)
        policy_loss += policy.beta * entropy
        if args.sl_ratio > 0:
            policy_loss += args.sl_ratio * policy.sl_loss
        writer.add_scalar('data/sl_loss', args.sl_ratio * policy.sl_loss,
                          tree.n_update)
        writer.add_scalar('data/total_loss', policy_loss.data[0],
                          tree.n_update)
        optimizer.zero_grad()
        policy_loss.backward()
        if contains_nan(policy.class_embed.weight.grad):
            logger.error('nan in class_embed.weight.grad!')
        else:
            optimizer.step()
    policy.update_baseline(np.mean(np.concatenate(all_cum_rewards)))
    policy.finish_episode()
Ejemplo n.º 4
0
    def forward_inner(self,
                      input,
                      means,
                      sigmas,
                      values,
                      bias,
                      mrange=None,
                      seed=None,
                      train=True):

        t0total = time.time()

        rng = tuple(self.out_size) + tuple(input.size()[1:])

        batchsize = input.size()[0]

        # NB: due to batching, real_indices has shape batchsize x K x rank(W)
        #     real_values has shape batchsize x K

        # turn the real values into integers in a differentiable way
        t0 = time.time()

        if train:
            if self.subsample is None:
                indices = self.generate_integer_tuples(
                    means,
                    rng=rng,
                    use_cuda=self.use_cuda,
                    relative_range=self.region)
                indfl = indices.float()

                # Mask for duplicate indices
                dups = self.duplicates(indices)

                props = densities(indfl, means, sigmas).clone(
                )  # result has size (b, indices.size(1), means.size(1))
                props[dups] = 0
                props = props / props.sum(dim=1, keepdim=True)

                values = values.unsqueeze(1).expand(batchsize, indices.size(1),
                                                    means.size(1))

                values = props * values
                values = values.sum(dim=2)

            else:
                # For large matrices we need to subsample the means we backpropagate for
                b, nm, rank = means.size()
                fr, to = mrange

                # sample = random.sample(range(nm), self.subsample) # the means we will learn for
                sample = range(fr, to)
                ids = torch.zeros((nm, ),
                                  dtype=torch.uint8,
                                  device='cuda' if self.use_cuda else 'cpu')
                ids[sample] = 1

                means_in, means_out = means[:, ids, :], means[:, ~ids, :]
                sigmas_in, sigmas_out = sigmas[:, ids, :], sigmas[:, ~ids, :]
                values_in, values_out = values[:, ids], values[:, ~ids]

                # These should not get a gradient, since their means aren't being sampled for
                # (their gradient will be computed in the next pass)
                means_out = means_out.detach()
                sigmas_out = sigmas_out.detach()
                values_out = values_out.detach()

                indices = self.generate_integer_tuples(
                    means,
                    rng=rng,
                    use_cuda=self.use_cuda,
                    relative_range=self.region,
                    seed=seed)
                indfl = indices.float()

                dups = self.duplicates(indices)

                props = densities(
                    indfl, means_in, sigmas_in
                )  # result has size (b, indices.size(1), means.size(1))
                props[dups] == 0
                props = props / props.sum(dim=1, keepdim=True)

                values_in = values_in.unsqueeze(1).expand(
                    batchsize, indices.size(1), means_in.size(1))

                values_in = props * values_in
                values_in = values_in.sum(dim=2)

                means_out = means_out.detach()
                values_out = values_out.detach()

                indices_out = means_out.data.round().long()

                indices = torch.cat([indices, indices_out], dim=1)
                values = torch.cat([values_in, values_out], dim=1)
        else:  # not train, just use the nearest indices
            indices = means.round().long()

        if self.use_cuda:
            indices = indices.cuda()

        # translate tensor indices to matrix indices so we can use matrix multiplication to perform the tensor contraction
        mindices, flat_size = gaussian.flatten_indices_mat(
            indices,
            input.size()[1:], self.out_size)

        ### Create the sparse weight tensor

        x_flat = input.view(batchsize, -1)

        sparsemult = util.sparsemult(self.use_cuda)

        # Prevent segfault
        assert not util.contains_nan(values.data)

        bm = self.bmult(flat_size[1], flat_size[0],
                        mindices.size()[1], batchsize, self.use_cuda)
        bfsize = Variable(flat_size * batchsize)

        bfindices = mindices + bm
        bfindices = bfindices.view(1, -1, 2).squeeze(0)
        vindices = Variable(bfindices.t())

        bfvalues = values.view(1, -1).squeeze(0)
        bfx = x_flat.view(1, -1).squeeze(0)

        # print(vindices.size(), bfvalues.size(), bfsize, bfx.size())
        bfy = sparsemult(vindices, bfvalues, bfsize, bfx)

        y_flat = bfy.unsqueeze(0).view(batchsize, -1)

        y_shape = [batchsize]
        y_shape.extend(self.out_size)

        y = y_flat.view(y_shape)  # reshape y into a tensor

        ### Handle the bias
        if self.bias_type == Bias.DENSE:
            y = y + bias
        if self.bias_type == Bias.SPARSE:  # untested!
            pass

        return y
Ejemplo n.º 5
0
    def forward_inner(self, input, means, sigmas, values, bias, train=True):

        t0total = time.time()

        rng = tuple(self.out_size) + tuple(input.size()[1:])

        batchsize = input.size()[0]

        # NB: due to batching, real_indices has shape batchsize x K x rank(W)
        #     real_values has shape batchsize x K

        # print('--------------------------------')
        # for i in range(util.prod(sigmas.size())):
        #     print(sigmas.view(-1)[i].data[0])

        # turn the real values into integers in a differentiable way
        t0 = time.time()

        if train:
            if not self.reinforce:
                if self.subsample is None:
                    indices, props, values = self.discretize(
                        means,
                        sigmas,
                        values,
                        rng=rng,
                        additional=self.additional,
                        use_cuda=self.use_cuda,
                        relative_range=self.relative_range)

                    values = values * props
                else:  # select a small proportion of the indices to learn over

                    b, k, r = means.size()

                    prop = torch.cuda.FloatTensor([
                        self.subsample
                    ]) if self.use_cuda else torch.FloatTensor(
                        [self.subsample])

                    selection = None
                    while (selection is None) or (float(selection.sum()) < 1):
                        selection = torch.bernoulli(prop.expand(k)).byte()

                    mselection = selection.unsqueeze(0).unsqueeze(2).expand_as(
                        means)
                    sselection = selection.unsqueeze(0).unsqueeze(2).expand_as(
                        sigmas)
                    vselection = selection.unsqueeze(0).expand_as(values)

                    means_in, means_out = means[mselection].view(
                        b, -1, r), means[~mselection].view(b, -1, r)
                    sigmas_in, sigmas_out = sigmas[sselection].view(
                        b, -1, r), sigmas[~sselection].view(b, -1, r)
                    values_in, values_out = values[vselection].view(
                        b, -1), values[~vselection].view(b, -1)

                    means_out = means_out.detach()
                    values_out = values_out.detach()

                    indices_in, props, values_in = self.discretize(
                        means_in,
                        sigmas_in,
                        values_in,
                        rng=rng,
                        additional=self.additional,
                        use_cuda=self.use_cuda)
                    values_in = values_in * props

                    indices_out = means_out.data.round().long()

                    indices = torch.cat([indices_in, indices_out], dim=1)
                    values = torch.cat([values_in, values_out], dim=1)

            else:  # reinforce approach
                dists = torch.distributions.Normal(means, sigmas)
                samples = dists.sample()

                indices = samples.data.round().long()

                # if the sampling puts the indices out of bounds, we just clip to the min and max values
                indices[indices < 0] = 0

                rngt = torch.tensor(data=rng,
                                    device='cuda' if self.use_cuda else 'cpu')

                maxes = rngt.unsqueeze(0).unsqueeze(0).expand_as(means) - 1
                indices[indices > maxes] = maxes[indices > maxes]

        else:  # not train, just use the nearest indices
            indices = means.round().long()

        if self.use_cuda:
            indices = indices.cuda()

        # # Create bias for permutation matrices
        # TAU = 1
        # if SINKHORN_ITS is not None:
        #     values = values / TAU
        #     for _ in range(SINKHORN_ITS):
        #         values = util.normalize(indices, values, rng, row=True)
        #         values = util.normalize(indices, values, rng, row=False)

        # translate tensor indices to matrix indices
        t0 = time.time()

        # mindices, flat_size = flatten_indices(indices, input.size()[1:], self.out_shape, self.use_cuda)
        mindices, flat_size = flatten_indices_mat(indices,
                                                  input.size()[1:],
                                                  self.out_size)

        # NB: mindices is not an autograd Variable. The error-signal for the indices passes to the hypernetwork
        #     through 'values', which are a function of both the real_indices and the real_values.

        ### Create the sparse weight tensor

        # -- Turns out we don't have autograd over sparse tensors yet (let alone over the constructor arguments). For
        #    now, we'll do a slow, naive multiplication.

        x_flat = input.view(batchsize, -1)
        ly = prod(self.out_size)

        y_flat = torch.cuda.FloatTensor(
            batchsize, ly) if self.use_cuda else FloatTensor(batchsize, ly)
        y_flat.fill_(0.0)

        sparsemult = util.sparsemult(self.use_cuda)

        t0 = time.time()

        # Prevent segfault
        assert not util.contains_nan(values.data)

        bm = self.bmult(flat_size[1], flat_size[0],
                        mindices.size()[1], batchsize, self.use_cuda)
        bfsize = Variable(flat_size * batchsize)

        bfindices = mindices + bm
        bfindices = bfindices.view(1, -1, 2).squeeze(0)
        vindices = Variable(bfindices.t())

        bfvalues = values.view(1, -1).squeeze(0)
        bfx = x_flat.view(1, -1).squeeze(0)

        # print(vindices.size(), bfvalues.size(), bfsize, bfx.size())
        bfy = sparsemult(vindices, bfvalues, bfsize, bfx)

        y_flat = bfy.unsqueeze(0).view(batchsize, -1)

        y_shape = [batchsize]
        y_shape.extend(self.out_size)

        y = y_flat.view(y_shape)  # reshape y into a tensor

        ### Handle the bias
        if self.bias_type == Bias.DENSE:
            y = y + bias
        if self.bias_type == Bias.SPARSE:  # untested!
            pass

        if self.reinforce and train:
            return y, dists, samples
        else:
            return y
Ejemplo n.º 6
0
    def forward(self, input, train=True):

        ### Compute and unpack output of hypernetwork

        means, sigmas, values = self.hyper(input)
        nm = means.size(0)
        c = nm // self.k

        means = means.view(c, self.k, 1)
        sigmas = sigmas.view(c, self.k, 1)
        values = values.view(c, self.k)

        rng = (self.in_num, )

        assert input.size(0) == self.in_num

        if train:
            indices = self.generate_integer_tuples(means,
                                                   rng=rng,
                                                   use_cuda=self.use_cuda)
            indfl = indices.float()

            # Mask for duplicate indices
            dups = self.duplicates(indices)

            props = densities(indfl, means, sigmas).clone(
            )  # result has size (c, indices.size(1), means.size(1))
            props[dups] = 0
            props = props / props.sum(dim=1, keepdim=True)

            values = values.unsqueeze(1).expand(c, indices.size(1),
                                                means.size(1))

            values = props * values
            values = values.sum(dim=2)

            # unroll the batch dimension
            indices = indices.view(-1, 1)
            values = values.view(-1)

            indices = torch.cat([self.outs, indices.long()], dim=1)
        else:
            indices = means.round().long().view(-1, 1)
            values = values.squeeze().view(-1)

            indices = torch.cat([self.outs_inf, indices.long()], dim=1)

        if self.use_cuda:
            indices = indices.cuda()

        # Kill anything on the diagonal
        values[indices[:, 0] == indices[:, 1]] = 0.0

        # if self.symmetric:
        #     # Add reverse direction automatically
        #     flipped_indices = torch.cat([indices[:, 1].unsqueeze(1), indices[:, 0].unsqueeze(1)], dim=1)
        #     indices         = torch.cat([indices, flipped_indices], dim=0)
        #     values          = torch.cat([values, values], dim=0)

        ### Create the sparse weight tensor

        # Prevent segfault
        assert not util.contains_nan(values.data)

        vindices = Variable(indices.t())
        sz = Variable(torch.tensor((self.out_num, self.in_num)))

        spmm = sparsemm(self.use_cuda)
        output = spmm(vindices, values, sz, input)

        return output
Ejemplo n.º 7
0
    def forward_inner(self, input, means, sigmas, values, bias):

        b, n, r = means.size()

        k = n // self.chunk_size
        c = self.chunk_size
        means, sigmas, values = means.view(b, k, c, r), sigmas.view(
            b, k, c, r), values.view(b, k, c)

        batchsize = input.size()[0]

        # turn the real values into integers in a differentiable way
        # max values allowed for each colum in the index matrix
        fullrange = self.out_size + input.size()[1:]
        subrange = [fullrange[r] for r in self.learn_cols]

        indices = self.generate_integer_tuples(means,
                                               rng=subrange,
                                               use_cuda=self.use_cuda,
                                               relative_range=self.region)
        indfl = indices.float()

        # Mask for duplicate indices
        dups = self.duplicates(indices)

        props = densities(
            indfl, means,
            sigmas).clone()  # result has size (b, k, l, c), l = indices[2]
        props[dups, :] = 0
        props = props / props.sum(dim=2, keepdim=True)

        values = values[:, :, None, :].expand_as(props)

        values = props * values
        values = values.sum(dim=3)

        indices, values = indices.view(b, -1, r), values.view(b, -1)

        # stitch it into the template
        b, l, r = indices.size()
        h, w = self.temp_indices.size()
        template = self.temp_indices[None, :, None, :].expand(b, h, l // h, w)
        template = template.contiguous().view(b, l, w)

        template[:, :, self.learn_cols] = indices
        indices = template

        if self.use_cuda:
            indices = indices.cuda()

        # translate tensor indices to matrix indices

        # mindices, flat_size = flatten_indices(indices, input.size()[1:], self.out_shape, self.use_cuda)
        mindices, flat_size = flatten_indices_mat(indices,
                                                  input.size()[1:],
                                                  self.out_size)

        # NB: mindices is not an autograd Variable. The error-signal for the indices passes to the hypernetwork
        #     through 'values', which are a function of both the real_indices and the real_values.

        ### Create the sparse weight tensor

        x_flat = input.view(batchsize, -1)

        sparsemult = util.sparsemult(self.use_cuda)

        # Prevent segfault
        try:
            assert mindices.min() >= 0
            assert not util.contains_nan(values.data)
        except AssertionError as ae:
            print('Nan in values or negative index in mindices.')
            print('means', means)
            print('sigmas', sigmas)
            print('props', props)
            print('values', values)
            print('indices', indices)
            print('mindices', mindices)

            raise ae

        # Then we flatten the batch dimension as well
        bm = util.bmult(flat_size[1], flat_size[0],
                        mindices.size()[1], batchsize, self.use_cuda)
        bfsize = Variable(flat_size * batchsize)

        bfindices = mindices + bm
        bfindices = bfindices.view(1, -1, 2).squeeze(0)
        vindices = Variable(bfindices.t())

        #- bfindices is now a sparse representation of a big block-diagonal matrix (nxb times mxb), with the batches along the
        #  diagonal (and the rest zero). We flatten x over all batches, and multiply by this to get a flattened y.

        # print(bfindices.size(), flat_size)
        # print(bfindices)

        bfvalues = values.view(1, -1).squeeze(0)
        bfx = x_flat.view(1, -1).squeeze(0)

        bfy = sparsemult(vindices, bfvalues, bfsize, bfx)

        y_flat = bfy.unsqueeze(0).view(batchsize, -1)

        y_shape = [batchsize]
        y_shape.extend(self.out_size)

        y = y_flat.view(y_shape)  # reshape y into a tensor

        ### Handle the bias
        if self.bias_type == Bias.DENSE:
            y = y + bias
        if self.bias_type == Bias.SPARSE:
            raise Exception('Not implemented yet.')

        return y
Ejemplo n.º 8
0
    def forward(self, x, conditional=None):
        """
        :param x: E by N matrix of node embeddings.
        :return:
        """

        n, e = x.size()
        h, r = self.heads, self.relations
        s = e // h
        ed, _ = self.mindices.size() # nr of edges total

        if conditional is not None:
            x = x + conditional

        x = x[:, None, :].expand(n, r, e) # expand for relations
        x = x.view(n, r, h, s)            # cut for attention heads

        # multiply so that we have a length s vector for every head in every relation for every node
        keys    = torch.einsum('rhij, nrhj -> rnhi', self.tokeys, x)
        queries = torch.einsum('rhij, nrhj -> rnhi', self.toqueries, x)

        values  = torch.einsum('rhij, nrhj -> hrni', self.tovals, x).contiguous() # note order of indices
        # - h functions as batch dimension here

        # Select from r and n dimensions
        #      Fold h into i, and extract later.

        keys = keys.view(r, n, -1)
        queries = queries.view(r, n, -1)

        # select keys and queries
        skeys    = keys   [self.indices[:, 1], self.indices[:, 0], :]
        squeries = queries[self.indices[:, 1], self.indices[:, 2], :]

        skeys = skeys.view(-1, h, s)
        squeries = squeries.view(-1, h, s)

        # compute raw dot product
        dot = torch.einsum('ehi, ehi -> he', skeys, squeries) # e = nr of edges

        # row normalize dot products
        # print(dot.size(), self.indices.size(), self.mindices.size())

        mindices = self.mindices[None, :, :].expand(h, ed, 2).contiguous()

        assert not util.contains_inf(dot), f'dot contains inf (before softmax) {dot.min()}, {dot.mean()}, {dot.max()}'
        assert not util.contains_nan(dot), f'dot contains nan (before softmax) {dot.min()}, {dot.mean()}, {dot.max()}'

        if self.norm_method == 'softmax':
            dot = util.logsoftmax(mindices, dot, self.msize).exp()
        else:
            dot = util.simple_normalize(mindices, dot, self.msize, method=self.norm_method)

        assert not util.contains_inf(dot), f'dot contains inf (after softmax) {dot.min()}, {dot.mean()}, {dot.max()}'
        assert not util.contains_nan(dot), f'dot contains nan (after softmax) {dot.min()}, {dot.mean()}, {dot.max()}'

        if self.dropin:
            bern = ds.Bernoulli(dot)
            mask = bern.sample()
            dot = dot * mask

        values = values.view(h, r*n, s)
        output = util.batchmm(mindices, dot, self.msize, values)

        assert output.size() == (h, r*n, s)

        output = output.view(h, r, n, s).permute(2, 1, 0, 3).contiguous().view(n, r, h*s)

        # unify the heads
        output = torch.einsum('rij, nrj -> nri', self.unify, output)

        # unify the relations
        return self.unify_rels(output)