def conv1d_weight(input,
                  weight_size,
                  grad_output,
                  stride=1,
                  padding=0,
                  dilation=1,
                  groups=1):
    r"""
    Computes the gradient of conv1d with respect to the weight of the convolution.

    Args:
        input: input tensor of shape (minibatch x in_channels x iW)
        weight_size : Shape of the weight gradient tensor
        grad_output : output gradient tensor (minibatch x out_channels x oW)
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1

    Examples::

        >>> input = torch.randn(1,1,3, requires_grad=True)
        >>> weight = torch.randn(1,1,1, requires_grad=True)
        >>> output = F.conv1d(input, weight)
        >>> grad_output = torch.randn(output.shape)
        >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
        >>> F.grad.conv1d_weight(input, weight.shape, grad_output)

    """
    stride = _single(stride)
    padding = _single(padding)
    dilation = _single(dilation)
    in_channels = input.shape[1]
    out_channels = grad_output.shape[1]
    min_batch = input.shape[0]

    input = input.detach()
    weight = torch.empty(weight_size,
                         dtype=input.dtype,
                         device=input.device,
                         requires_grad=True)

    with torch.enable_grad():
        result = torch.conv1d(input, weight, None, stride, padding, dilation,
                              groups)

    result.backward(grad_output)

    return weight.grad
Пример #2
0
 def forward(self, x):
     #bs = x.shape[0]
     
     x1 = torch.conv1d(x, self.wsin_var, stride=2)#.pow(2)
     x2 = torch.conv1d(x, self.wcos_var, stride=2)#.pow(2)
     x = x1 + x2
     x = F.relu(x)
     
     x = F.relu(self.conv2(x))
     x = self.conv2_drop(x)
     
     x = F.relu(self.conv3(x))
     x = self.conv3_drop(x)
     
     x = x.view(-1, 64)
     
     x = F.relu(self.lin1(x))
     x = self.lin1_drop(x)
     
     x = self.lin2(x)
     x = F.relu(x)
     #x = F.tanh(x)
 
     return(x)
Пример #3
0
def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None):
    r"""
    Computes the gradient of conv1d with respect to the weight of the convolution.

    Args:
        input: input tensor of shape (minibatch x in_channels x iW)
        weight_size : Shape of the weight gradient tensor
        grad_output : output gradient tensor (minibatch x out_channels x oW)
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        bias: optional bias tensor (out_channels). Default: None

    Examples::

        >>> input = torch.randn(1,1,3, requires_grad=True)
        >>> weight = torch.randn(1,1,1, requires_grad=True)
        >>> output = F.conv1d(input, weight)
        >>> grad_output = torch.randn(output.shape)
        >>> grad_weight = torch.autograd.grad(output, filter, grad_output)
        >>> F.grad.conv1d_weight(input, weight.shape, grad_output)

    """
    stride = _single(stride)
    padding = _single(padding)
    dilation = _single(dilation)
    in_channels = input.shape[1]
    out_channels = grad_output.shape[1]
    min_batch = input.shape[0]

    grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1)
    grad_output = grad_output.contiguous().view(
        grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2])

    input = input.contiguous().view(1, input.shape[0] * input.shape[1],
                                    input.shape[2])

    grad_weight = torch.conv1d(input, grad_output, bias, dilation, padding,
                               stride, in_channels * min_batch)

    grad_weight = grad_weight.contiguous().view(
        min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2])

    return grad_weight.sum(dim=0).view(
        in_channels // groups, out_channels, grad_weight.shape[2]).transpose(
            0, 1).narrow(2, 0, weight_size[2])
Пример #4
0
def circular_convolve(weight, shift):
    r"""
    Perform a circular convolution. Taken from [1].

    Args:
        weight (torch.Tensor): weight vector.
        shift (torch.Tensor): shift vector.

    Returns:
        torch.Tensor: convolved weight vector.

    References:
        [1] https://github.com/loudinthecloud/pytorch-ntm/blob/master/ntm/memory.py
    """
    aug_weight = torch.cat([weight[-1:], weight, weight[:1]])
    conv = torch.conv1d(aug_weight.view(1, 1, -1), shift.view(1, 1,
                                                              -1)).view(-1)
    return conv
def classicalLoss(pred, targets):

    # Transform targets
    target_distribution = torch.cuda.FloatTensor(targets.size()[0],
                                                 targets.size()[1],
                                                 pred.size()[2]).fill_(0)
    for ii in range(target_distribution.size()[2]):
        target_distribution[:, :, ii] = (targets == ii).type(torch.LongTensor)

    # Loss computation
    target_smooth = torch.conv1d(
        target_distribution.permute(0, 2, 1).reshape(-1, 1, 64),
        kernel,
        padding=21 // 2).reshape(pred.size()[0],
                                 pred.size()[2],
                                 pred.size()[1]).permute(0, 2, 1)
    loss_smooth = torch.mean(torch.sum((pred - target_smooth)**2, 2))

    return loss_smooth, 0 * loss_smooth
Пример #6
0
    def forward(self, x):
        # Construct kernel
        x_shape = x.shape

        rel_pos = self.handle_rel_positions(x)
        conv_kernel = self.Kernel(rel_pos).view(-1, x_shape[1], *x_shape[2:])

        # ---- Different samling rate --------
        # If freq test > freq test, smooth out high-freq elements.
        if self.sigma is not None:
            from math import pi, sqrt, exp

            n = int(1 / self.sr_change) * 2 + 1
            h = n // 2
            G = (lambda x: 1 /
                 (self.sigma * sqrt(2 * pi)) * exp(-float(x)**2 /
                                                   (2 * self.sigma**2)))

            smoothing_ker = [G(x) for x in range(-h, h + 1)]
            smoothing_ker = torch.Tensor(smoothing_ker).cuda().unsqueeze(
                0).unsqueeze(0)
            conv_kernel[:, :, h:-h] = torch.conv1d(
                conv_kernel.view(-1, 1, *x_shape[2:]),
                smoothing_ker,
                padding=0).view(*conv_kernel.shape[:-1], -1)
        # multiply by the sr_train / sr_test
        if self.sr_change != 1.0:
            conv_kernel *= self.sr_change
        # ------------------------------------

        # For computation of "weight_decay"
        self.conv_kernel = conv_kernel

        # We have noticed that the results of fftconv become very noisy when the length of
        # the input is very small ( < 50 samples). As this might occur when we use subsampling,
        # we replace causal_fftconv by causal_conv in settings where this occurs.
        if x_shape[-1] < self.train_length.item():
            # Use spatial convolution:
            return ckconv_f.causal_conv(x, conv_kernel, self.bias)
        else:
            # Otherwise use fft convolution:
            return ckconv_f.causal_fftconv(x, conv_kernel, self.bias)
Пример #7
0
    def MonotonicAttention(self, dataInput, seqInput, hiddenNoduleNumbers):
        attentionNumeratorWeight = self.attentionWeightNumeratorLayer(
            input=dataInput).tanh()
        attentionDenominatorRawWeight = self.attentionWeightDenominatorLayer(
            input=dataInput).exp()
        padDenominatorZero = torch.zeros(size=[
            attentionDenominatorRawWeight.size()[0], self.attentionScope - 1,
            attentionDenominatorRawWeight.size()[2]
        ])
        if self.cudaFlag:
            padDenominatorZero = padDenominatorZero.cuda()
            self.sumKernel = self.sumKernel.float().cuda()

        attentionDenominatorSupplementWeight = torch.cat(
            [padDenominatorZero, attentionDenominatorRawWeight], dim=1)

        attentionDenominatorWeight = torch.conv1d(
            input=attentionDenominatorSupplementWeight.permute(0, 2, 1),
            weight=self.sumKernel,
            stride=1)
        attentionOriginWeight = torch.div(attentionNumeratorWeight.squeeze(),
                                          attentionDenominatorWeight.squeeze())

        #########################################################

        if seqInput is not None:
            attentionMaskWeight = attentionOriginWeight.min(
                self.AttentionMask(seqInput=seqInput))
        else:
            attentionMaskWeight = attentionOriginWeight
        attentionWeight = torch.nn.functional.softmax(
            attentionMaskWeight, dim=-1).view([len(dataInput), -1, 1])
        attentionSupplementWeight = attentionWeight.repeat(
            [1, 1, hiddenNoduleNumbers])
        attentionSeparateResult = torch.mul(dataInput,
                                            attentionSupplementWeight)
        attentionResult = attentionSeparateResult.sum(dim=1)
        return attentionResult, attentionWeight
Пример #8
0
    def forward(self, hidden_states, attention_mask):
        batch_size, seq_len, hidden_size = hidden_states.size()
        output_states = self.in_projection(hidden_states)
        output_states = output_states.permute(0, 2, 1)

        weight = torch.softmax(self.weight, dim=-1)
        weight = self.conv_weight_dropout(weight)

        if attention_mask:
            pivot = self.kernel_size // 2 + 1
            weight[:, :, pivot:] = 0

        output_states = output_states.contiguous().view(-1, self.num_heads, seq_len)
        output_states = torch.conv1d(output_states, weight, padding=self.kernel_size // 2, groups=self.num_heads)
        output_states = output_states.view(batch_size, hidden_size, seq_len)
        output_states = output_states.permute(0, 2, 1)

        # output projection
        output_states = self.out_projection(output_states)
        output_states = self.conv_layer_dropout(output_states)
        output_states = self.layer_norm(hidden_states + output_states)

        return output_states
Пример #9
0
    def forward(self, data, edge_dropout=None, penalty_coefficient=0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        if edge_dropout is not None:
            edge_index = dropout_adj(
                edge_index,
                edge_attr=(torch.ones(edge_index.shape[1],
                                      device=device)).long(),
                p=edge_dropout,
                force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index,
                                                  num_nodes=batch.shape[0])[0]

        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges / total_num_edges)
        no_loop_index, _ = remove_self_loops(edge_index)
        no_loop_row, no_loop_col = no_loop_index

        xinit = x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x, edge_index, 1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))  # +x
        x = x * mask
        x = self.gnorm(x)
        x = self.bn1(x)

        for conv, bn in zip(self.convs, self.bns):
            if (x.dim() > 1):
                x = x + F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask, edge_index, 1).to(x.dtype)
                x = x * mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x))
        x = x * mask

        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x))
        x = x * mask

        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size=N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)
        batch_min = scatter_min(x, batch, 0, dim_size=N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x - batch_min) / (batch_max + 1e-6 - batch_min)
        probs = x

        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device=device)
        for graph in range(num_graphs):
            batch_graph = (batch == graph)
            pairwise_prodsums[graph] = (torch.conv1d(
                probs[batch_graph].unsqueeze(-1),
                probs[batch_graph].unsqueeze(-1))).sum() / 2

        ###calculate loss terms
        self_sums = scatter_add((probs * probs), batch, 0, dim_size=num_graphs)
        expected_weight_G = scatter_add(
            probs[no_loop_row] * probs[no_loop_col],
            batch[no_loop_row],
            0,
            dim_size=num_graphs) / 2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) -
                                  self_sums) / 1.
        expected_distance = (expected_clique_weight - expected_weight_G)

        ###calculate loss
        expected_loss = (penalty_coefficient
                         ) * expected_distance * 0.5 - 0.5 * expected_weight_G

        loss = expected_loss

        retdict = {}

        retdict["output"] = [probs.squeeze(-1), "hist"]  #output
        retdict["losses histogram"] = [loss.squeeze(-1), "hist"]
        retdict["Expected weight(G)"] = [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [
            expected_clique_weight.mean(), "sequence"
        ]
        retdict["Expected distance"] = [expected_distance.mean(), "sequence"]
        retdict["loss"] = [loss.mean().squeeze(), "sequence"]  #final loss

        return retdict
Пример #10
0
    def update(self, x):
        # Prepare the inputs
        y = self.similarity(x, self.weight)
        t = self.teacher_signal
        if t is not None:
            t = t.unsqueeze(2).unsqueeze(3) * torch.ones_like(y,
                                                              device=y.device)
        y = y.permute(0, 2, 3, 1).contiguous().view(-1, self.weight.size(0))
        if t is not None:
            t = t.permute(0, 2, 3,
                          1).contiguous().view(-1, self.weight.size(0))
        x_unf = unfold_map2d(x, self.weight.size(2), self.weight.size(3))
        x_unf = x_unf.permute(0, 2, 3, 1,
                              4).contiguous().view(y.size(0), 1, -1)

        # Random abstention
        if self.random_abstention:
            abst_prob = self.victories_count / (self.victories_count.max() +
                                                y.size(0) / y.size(1)).clamp(1)
            scores = y * (torch.rand_like(abst_prob, device=y.device) >=
                          abst_prob).float().unsqueeze(0)
        else:
            scores = y

        # Competition. The returned winner_mask is a bitmap telling where a neuron won and where one lost.
        if self.competitive:
            if t is not None: scores *= t
            winner_mask = (scores == scores.max(1, keepdim=True)[0]).float()
            if self.random_abstention:  # Update statistics if using random abstension
                winner_mask_sum = winner_mask.sum(
                    0)  # Number of inputs over which a neuron won
                self.victories_count += winner_mask_sum
                self.victories_count -= self.victories_count.min().item()
        else:
            winner_mask = torch.ones_like(y, device=y.device)

        # Lateral feedback
        if self.lfb_on:
            lfb_kernel = self.lfb_kernel
            if self.lfb_value == self.LFB_DoG or self.lfb_value == self.LFB_DoE:
                lfb_kernel = 2 * lfb_kernel - lfb_kernel.pow(
                    0.5
                )  # Difference of Gaussians/Exponentials (mexican hat shaped function)
            lfb_in = F.pad(winner_mask.view(-1, *self.out_size), self.pad)
            if self.out_size.size(0) == 1:
                lfb_out = torch.conv1d(lfb_in.unsqueeze(1),
                                       lfb_kernel.unsqueeze(0).unsqueeze(1))
            elif self.out_size.size(0) == 2:
                lfb_out = torch.conv2d(lfb_in.unsqueeze(1),
                                       lfb_kernel.unsqueeze(0).unsqueeze(1))
            else:
                lfb_out = torch.conv3d(lfb_in.unsqueeze(1),
                                       lfb_kernel.unsqueeze(0).unsqueeze(1))
            lfb_out = lfb_out.clamp(-1, 1).view_as(y)
        else:
            lfb_out = winner_mask
            if self.competitive: lfb_out[lfb_out == 0] = self.lfb_value
            elif t is not None: lfb_out = t

        # Compute step modulation coefficient
        r = lfb_out  # RULE_BASE
        if self.weight_upd_rule == self.RULE_HEBB: r *= y

        # Compute delta
        r_abs = r.abs()
        r_sign = r.sign()
        delta_w = r_abs.unsqueeze(2) * (
            r_sign.unsqueeze(2) * x_unf -
            self.weight.view(1, self.weight.size(0), -1))

        # Since we use batches of inputs, we need to aggregate the different update steps of each kernel in a unique
        # update. We do this by taking the weighted average of teh steps, the weights being the r coefficients that
        # determine the length of each step
        r_sum = r_abs.sum(0)
        r_sum += (r_sum == 0).float()  # Prevent divisions by zero
        delta_w_avg = (delta_w *
                       r_abs.unsqueeze(2)).sum(0) / r_sum.unsqueeze(1)

        # Apply delta
        self.weight += self.eta * delta_w_avg.view_as(self.weight)

        # LFB kernel shrinking and LR schedule
        if self.lfb_on: self.lfb_kernel = self.lfb_kernel.pow(self.alpha)
        if self.lr_schedule is not None: self.eta = self.lr_schedule(self.eta)
Пример #11
0
def respond_to(model, sequences, training_run=True, extra_steps=0):

    responses = [[] for _ in range(len(sequences))]
    loss = 0

    convolver, enc, dec, deconvolver = model

    hann_w = hann() if not config.use_gpu else hann().cuda()
    ihann_w = ihann() if not config.use_gpu else ihann().cuda()

    # with no_grad():
    #     #print(convolver[0].w.size(), hann().size())
    #     convolver[0].w *= hann_w
    # #     deconvolver[0].w *= ihann(deconvolver[0].w)

    for i, sequence in enumerate(sequences):

        #print(f'seq{i}/{len(sequences)}')

        #print('in size:',sequence.size(),'conv_w size:',convolver[0].w.unsqueeze(1).size())

        sequence = conv1d(sequence, (convolver[0].w * hann_w).unsqueeze(1),
                          stride=config.frame_stride)
        sequence = transpose(sequence, 1, 2)
        sequence /= config.frame_len

        #print('conved size:',sequence.size())

        # make key,query from all here.. => the transformer stuff

        for t in range(sequence.size(1) - 1):

            #curr_inp = sequence[:,t:t+1,:]
            prev_inps = sequence[:, :t + 1, :]
            lbl = sequence[:, t + 1:t + 2, :]

            positions = Tensor([[t + 1 / config.max_T, i / config.max_T]
                                for i in range(t + 1)]).view(1, -1, 2)
            if config.use_gpu: positions = positions.cuda()

            #print(f'{t}/{sequence.size(1)}')

            # print('t:',t,',prev inps size:',prev_inps.size(),'curr inp size:',curr_inp.size())

            #todo: hmmmm..
            #inp = cat([prev_inps,curr_inp.repeat(1,t+1,1)], -1)
            inp = cat([prev_inps, positions], -1)

            # if config.seq_force_ratio != 1 and t>=2:
            #     seq_force_ratio = config.seq_force_ratio**t
            #     inp *= seq_force_ratio
            #     inp +=

            #print('inp size:',inp.size())

            enced = prop_model(enc, inp)

            # print('enced size:', enced.size())

            attn_inp = (softmax(enced, 1) * prev_inps).sum(1)

            # print('attnded size:', attn_inp.size())

            deced = prop_model(dec, attn_inp)

            loss += sequence_loss(lbl, deced)

            responses[-1].append(deced)
            # input("halt here")

        #input('halt here..')

    if training_run:
        loss.backward()
        return float(loss)

    else:

        #print("seq size", sequence.size(1), 'hm resps', len(responses[-1]))

        if len(sequences) == 1:

            for t_extra in range(extra_steps):
                t = sequence.size(1) + t_extra - 1

                #print(f't extra:{t}')

                curr_inp = responses[-1][t - 1]

                # print(sequence[:,:,:].size(), stack(responses[-1][sequence.size(1)-1-1:],1).size())

                prev_inps = cat([
                    sequence[:, :-1, :],
                    stack(responses[-1][sequence.size(1) - 1 - 1:], 1)
                ], 1)

                inp = cat([prev_inps, curr_inp.repeat(1, t + 1, 1)], -1)

                #print(inp.size())

                enced = prop_model(enc, inp)

                # print('enced size:', enced.size())

                attn_inp = (softmax(enced, 1) * prev_inps).sum(1)

                # print('attnded size:', attn_inp.size())

                deced = prop_model(dec, attn_inp)

                responses[-1].append(deced)

            responses = responses[-1]
            responses = [(deconvolver[0].w * resp).sum(1)
                         for resp in responses]
            responses = [resp * ihann_w for resp in responses]
            hm_windows = (len(sequence) -
                          config.frame_len // config.frame_stride) + 1

            responses = []  # todo: stitch together responses here..
            responses = Tensor(responses).view(1, 1, -1)

        return float(loss), responses
Пример #12
0
def create_mask(n, d, f=torch.eq):
    a = torch.nonzero(torch.ones([n] * d)).view(-1, 1, d).float()
    w = torch.tensor([1, -1]).view(1, 1, -1).float()
    sel = (f(torch.conv1d(a, w, None, 1, 0, 1, 1), 0)).all(2).view([n] * d)
    return sel.byte()
Пример #13
0
import torch
from onlinefy.marked_tensor import MarkedTensor
from onlinefy.inject import marked_prop_wrapper

torch.conv1d = marked_prop_wrapper(torch.conv1d)
torch.sum = marked_prop_wrapper(torch.sum)
    

a = torch.ones(1,3,4, requires_grad=True)
b = MarkedTensor(a, marked_dim=1)
w = torch.ones(3,3,3, requires_grad=True)
c = torch.conv1d(b,w)
d = MarkedTensor(c, marked_dim=1)
e = torch.sum(d)
e.backward()
Пример #14
0
def decode_clique_final(data,
                        probabilities,
                        draw=False,
                        weight_factor=0.0,
                        clique_number_bounds=None,
                        fig=None,
                        device='cpu'):
    row, col = data.edge_index
    sets = probabilities.detach().unsqueeze(-1)
    batch = data.batch
    no_loop_index, _ = remove_self_loops(data.edge_index)
    no_loop_row, no_loop_col = no_loop_index
    num_graphs = batch.max().item() + 1
    total_index = 0

    for graph in range(num_graphs):
        mark_edges = batch[no_loop_row] == graph
        nlr_graph, nlc_graph = no_loop_index[:, mark_edges]
        nlr_graph = nlr_graph - total_index
        nlc_graph = nlc_graph - total_index
        batch_graph = (batch == graph)
        graph_probs = sets[batch_graph].detach()
        sorted_inds = torch.argsort(graph_probs.squeeze(-1), descending=True)
        pairwise_prodsums = torch.zeros(1, device=device)
        pairwise_prodsums = (torch.conv1d(graph_probs.unsqueeze(-1),
                                          graph_probs.unsqueeze(-1))).sum() / 2
        self_sums = (graph_probs * graph_probs).sum()
        num_nodes = batch_graph.float().sum().item()

        current_set_cardinality = 0

        for node in range(int(num_nodes)):
            ind_i = total_index + sorted_inds[node]
            graph_probs_0 = sets[batch_graph].detach()
            graph_probs_1 = sets[batch_graph].detach()

            graph_probs_0[sorted_inds[node]] = 0
            graph_probs_1[sorted_inds[node]] = 1

            pairwise_prodsums_0 = torch.zeros(1, device=device)
            pairwise_prodsums_0 = (torch.conv1d(
                graph_probs_0.unsqueeze(-1),
                graph_probs_0.unsqueeze(-1))).sum() / 2

            self_sums_0 = (graph_probs_0 * graph_probs_0).sum()

            expected_weight_G_0 = (graph_probs_0[nlr_graph] *
                                   graph_probs_0[nlc_graph]).sum() / 2
            expected_clique_weight_0 = (pairwise_prodsums_0 - self_sums_0)
            clique_dist_0 = weight_factor * 0.5 * (
                expected_clique_weight_0 -
                expected_weight_G_0) - expected_weight_G_0

            pairwise_prodsums_1 = torch.zeros(1, device=device)
            pairwise_prodsums_1 = (torch.conv1d(
                graph_probs_1.unsqueeze(-1),
                graph_probs_1.unsqueeze(-1))).sum() / 2

            self_sums_1 = (graph_probs_1 * graph_probs_1).sum()

            expected_weight_G_1 = (graph_probs_1[nlr_graph] *
                                   graph_probs_1[nlc_graph]).sum() / 2
            expected_clique_weight_1 = (pairwise_prodsums_1 - self_sums_1)
            clique_dist_1 = weight_factor * 0.5 * (
                expected_clique_weight_1 -
                expected_weight_G_1) - expected_weight_G_1

            if clique_dist_0 >= clique_dist_1:
                decided = (graph_probs_1 == 1).float()
                current_set_cardinality = decided.sum().item()
                current_set_max_edges = (current_set_cardinality *
                                         (current_set_cardinality - 1)) / 2
                current_set_edges = (decided[nlr_graph] *
                                     decided[nlc_graph]).sum() / 2

                if (current_set_edges != current_set_max_edges):
                    sets[ind_i] = 0  #IF NOT A CLIQUE
                else:
                    sets[ind_i] = 1  #IF A CLIQUE

            else:
                sets[ind_i] = 0

        if draw:
            dirac = data.locations[graph].item() - total_index
            if fig is None:
                f1 = plt.figure(graph, figsize=(16, 9))
            else:
                f1 = fig
            ax1 = f1.add_subplot(121)
            g1, g2 = drawGraphFromData(data.to('cpu'),
                                       graph,
                                       vals=sets.squeeze(-1).detach().cpu(),
                                       dense=False,
                                       seed=dirac,
                                       nodecolor=True,
                                       edgecolor=False,
                                       seedhops=True,
                                       hoplabels=True,
                                       binarycut=False)
            ax2 = f1.add_subplot(122)
            g1, g2 = drawGraphFromData(data.to('cpu'),
                                       graph,
                                       vals=probabilities.detach().cpu(),
                                       dense=False,
                                       seed=dirac,
                                       nodecolor=True,
                                       edgecolor=False,
                                       seedhops=True,
                                       hoplabels=True,
                                       binarycut=False,
                                       clique=True)
        total_index += num_nodes

    expected_weight_G = scatter_add(sets[no_loop_col] * sets[no_loop_row],
                                    batch[no_loop_row],
                                    0,
                                    dim_size=num_graphs)
    set_cardinality = scatter_add(sets, batch, 0, dim_size=num_graphs)
    return sets, expected_weight_G.detach(), set_cardinality
Пример #15
0
    def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch
        num_graphs = batch.max().item() + 1
        row, col = edge_index     
        total_num_edges = edge_index.shape[1]
        N_size = x.shape[0]

        
        if edge_dropout is not None:
            edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=device)).long(), p = edge_dropout, force_undirected=True)[0]
            edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0]
                
        reduced_num_edges = edge_index.shape[1]
        current_edge_percentage = (reduced_num_edges/total_num_edges)
        no_loop_index,_ = remove_self_loops(edge_index)  
        no_loop_row, no_loop_col = no_loop_index

        xinit= x.clone()
        x = x.unsqueeze(-1)
        mask = get_mask(x,edge_index,1).to(x.dtype)
        x = F.leaky_relu(self.conv1(x, edge_index))# +x
        x = x*mask
        x = self.gnorm(x)
        x = self.bn1(x)
        
            
        for conv, bn in zip(self.convs, self.bns):
            if(x.dim()>1):
                x =  x+F.leaky_relu(conv(x, edge_index))
                mask = get_mask(mask,edge_index,1).to(x.dtype)
                x = x*mask
                x = self.gnorm(x)
                x = bn(x)

        xpostconvs = x.detach()
        #
        x = F.leaky_relu(self.lin1(x)) 
        x = x*mask


        xpostlin1 = x.detach()
        x = F.leaky_relu(self.lin2(x)) 
        x = x*mask


        #calculate min and max
        batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0]
        batch_max = torch.index_select(batch_max, 0, batch)        
        batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0]
        batch_min = torch.index_select(batch_min, 0, batch)

        #min-max normalize
        x = (x-batch_min)/(batch_max+1e-6-batch_min)
        probs=x
           
        x2 = x.detach()              
        deg = degree(row).unsqueeze(-1) 
        totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6
        totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6               
        x2 =  ((x2 - torch.rand_like(x, device = device))>0).float()    
        vol_1 = scatter_add(probs*deg, batch, 0)+1e-6
        card_1 = scatter_add(probs, batch,0)            
        set_size = scatter_add(x2, batch, 0)
        vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 
        total_vol_ratio = vol_hard/totalvol
        
        
        #calculating the terms for the expected distance between clique and graph
        pairwise_prodsums = torch.zeros(num_graphs, device = device)
        for graph in range(num_graphs):
            batch_graph = (batch==graph)
            pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2
        
        
        ###calculate loss terms
        self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs)
        expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2.
        expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1.
        expected_distance = (expected_clique_weight - expected_weight_G)        
        
        
        ###useful numbers 
        max_set_weight = (scatter_add(torch.ones_like(x)[no_loop_row], batch[no_loop_row], 0, dim_size = num_graphs)/2).squeeze(-1)                
        set_weight = (scatter_add(x2[no_loop_row]*x2[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2)+1e-6
        clique_edges_hard = (set_size*(set_size-1)/2) +1e-6
        clique_dist_hard = set_weight/clique_edges_hard
        clique_check = ((clique_edges_hard != clique_edges_hard))
        setedge_check  = ((set_weight != set_weight))      
        
        assert ((clique_dist_hard>=1.1).sum())<=1e-6, "Invalid set vol/clique vol ratio."

        ###calculate loss
        expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G  
        

        loss = expected_loss


        retdict = {}
        
        retdict["output"] = [probs.squeeze(-1),"hist"]   #output
        retdict["Expected_cardinality"] = [card_1.mean(),"sequence"]
        retdict["Expected_cardinality_hist"] = [card_1,"hist"]
        retdict["losses histogram"] = [loss.squeeze(-1),"hist"]
        retdict["Set sizes"] = [set_size.squeeze(-1),"hist"]
        retdict["volume_hard"] = [vol_hard.mean(),"aux"] #volume2
        retdict["cardinality_hard"] = [set_size[0],"sequence"] #volumeq
        retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"]
        retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"]
        retdict["Expected distance"]= [expected_distance.mean(), "sequence"]
        retdict["Currvol/Cliquevol"] = [clique_dist_hard.mean(),'sequence']
        retdict["Currvol/Cliquevol all graphs in batch"] = [clique_dist_hard.squeeze(-1),'hist']
        retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence']
        retdict["cardinalities"] = [cardinalities.squeeze(-1),"hist"]
        retdict["Current edge percentage"] = [torch.tensor(current_edge_percentage),'sequence']
        retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss

        return retdict
Пример #16
0
import torch
from causal_conv1d import CausalConv1d

if __name__ == '__main__':
    # Wie implementiert man so einen signal split in pytorch?
    # nvm.
    res1 = x
    x1 = torch.tanh(torch.conv1d(x, W)) * torch.sig(torch.conv1d(x, W))
    x2 = torch.conv1d(x1)
    x3 = x2 + x

    # incremental forward?
Пример #17
0
 def conv1d_side_effect(x, weights, bias, stride, **kwargs):
     return torch.conv1d(x, weights, bias, stride)
Пример #18
0
def conv1d(input, *args, **kwargs):
    return torch.conv1d(input.q, *args, **kwargs)