def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1): r""" Computes the gradient of conv1d with respect to the weight of the convolution. Args: input: input tensor of shape (minibatch x in_channels x iW) weight_size : Shape of the weight gradient tensor grad_output : output gradient tensor (minibatch x out_channels x oW) stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 Examples:: >>> input = torch.randn(1,1,3, requires_grad=True) >>> weight = torch.randn(1,1,1, requires_grad=True) >>> output = F.conv1d(input, weight) >>> grad_output = torch.randn(output.shape) >>> grad_weight = torch.autograd.grad(output, filter, grad_output) >>> F.grad.conv1d_weight(input, weight.shape, grad_output) """ stride = _single(stride) padding = _single(padding) dilation = _single(dilation) in_channels = input.shape[1] out_channels = grad_output.shape[1] min_batch = input.shape[0] input = input.detach() weight = torch.empty(weight_size, dtype=input.dtype, device=input.device, requires_grad=True) with torch.enable_grad(): result = torch.conv1d(input, weight, None, stride, padding, dilation, groups) result.backward(grad_output) return weight.grad
def forward(self, x): #bs = x.shape[0] x1 = torch.conv1d(x, self.wsin_var, stride=2)#.pow(2) x2 = torch.conv1d(x, self.wcos_var, stride=2)#.pow(2) x = x1 + x2 x = F.relu(x) x = F.relu(self.conv2(x)) x = self.conv2_drop(x) x = F.relu(self.conv3(x)) x = self.conv3_drop(x) x = x.view(-1, 64) x = F.relu(self.lin1(x)) x = self.lin1_drop(x) x = self.lin2(x) x = F.relu(x) #x = F.tanh(x) return(x)
def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None): r""" Computes the gradient of conv1d with respect to the weight of the convolution. Args: input: input tensor of shape (minibatch x in_channels x iW) weight_size : Shape of the weight gradient tensor grad_output : output gradient tensor (minibatch x out_channels x oW) stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias: optional bias tensor (out_channels). Default: None Examples:: >>> input = torch.randn(1,1,3, requires_grad=True) >>> weight = torch.randn(1,1,1, requires_grad=True) >>> output = F.conv1d(input, weight) >>> grad_output = torch.randn(output.shape) >>> grad_weight = torch.autograd.grad(output, filter, grad_output) >>> F.grad.conv1d_weight(input, weight.shape, grad_output) """ stride = _single(stride) padding = _single(padding) dilation = _single(dilation) in_channels = input.shape[1] out_channels = grad_output.shape[1] min_batch = input.shape[0] grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1) grad_output = grad_output.contiguous().view( grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2]) input = input.contiguous().view(1, input.shape[0] * input.shape[1], input.shape[2]) grad_weight = torch.conv1d(input, grad_output, bias, dilation, padding, stride, in_channels * min_batch) grad_weight = grad_weight.contiguous().view( min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2]) return grad_weight.sum(dim=0).view( in_channels // groups, out_channels, grad_weight.shape[2]).transpose( 0, 1).narrow(2, 0, weight_size[2])
def circular_convolve(weight, shift): r""" Perform a circular convolution. Taken from [1]. Args: weight (torch.Tensor): weight vector. shift (torch.Tensor): shift vector. Returns: torch.Tensor: convolved weight vector. References: [1] https://github.com/loudinthecloud/pytorch-ntm/blob/master/ntm/memory.py """ aug_weight = torch.cat([weight[-1:], weight, weight[:1]]) conv = torch.conv1d(aug_weight.view(1, 1, -1), shift.view(1, 1, -1)).view(-1) return conv
def classicalLoss(pred, targets): # Transform targets target_distribution = torch.cuda.FloatTensor(targets.size()[0], targets.size()[1], pred.size()[2]).fill_(0) for ii in range(target_distribution.size()[2]): target_distribution[:, :, ii] = (targets == ii).type(torch.LongTensor) # Loss computation target_smooth = torch.conv1d( target_distribution.permute(0, 2, 1).reshape(-1, 1, 64), kernel, padding=21 // 2).reshape(pred.size()[0], pred.size()[2], pred.size()[1]).permute(0, 2, 1) loss_smooth = torch.mean(torch.sum((pred - target_smooth)**2, 2)) return loss_smooth, 0 * loss_smooth
def forward(self, x): # Construct kernel x_shape = x.shape rel_pos = self.handle_rel_positions(x) conv_kernel = self.Kernel(rel_pos).view(-1, x_shape[1], *x_shape[2:]) # ---- Different samling rate -------- # If freq test > freq test, smooth out high-freq elements. if self.sigma is not None: from math import pi, sqrt, exp n = int(1 / self.sr_change) * 2 + 1 h = n // 2 G = (lambda x: 1 / (self.sigma * sqrt(2 * pi)) * exp(-float(x)**2 / (2 * self.sigma**2))) smoothing_ker = [G(x) for x in range(-h, h + 1)] smoothing_ker = torch.Tensor(smoothing_ker).cuda().unsqueeze( 0).unsqueeze(0) conv_kernel[:, :, h:-h] = torch.conv1d( conv_kernel.view(-1, 1, *x_shape[2:]), smoothing_ker, padding=0).view(*conv_kernel.shape[:-1], -1) # multiply by the sr_train / sr_test if self.sr_change != 1.0: conv_kernel *= self.sr_change # ------------------------------------ # For computation of "weight_decay" self.conv_kernel = conv_kernel # We have noticed that the results of fftconv become very noisy when the length of # the input is very small ( < 50 samples). As this might occur when we use subsampling, # we replace causal_fftconv by causal_conv in settings where this occurs. if x_shape[-1] < self.train_length.item(): # Use spatial convolution: return ckconv_f.causal_conv(x, conv_kernel, self.bias) else: # Otherwise use fft convolution: return ckconv_f.causal_fftconv(x, conv_kernel, self.bias)
def MonotonicAttention(self, dataInput, seqInput, hiddenNoduleNumbers): attentionNumeratorWeight = self.attentionWeightNumeratorLayer( input=dataInput).tanh() attentionDenominatorRawWeight = self.attentionWeightDenominatorLayer( input=dataInput).exp() padDenominatorZero = torch.zeros(size=[ attentionDenominatorRawWeight.size()[0], self.attentionScope - 1, attentionDenominatorRawWeight.size()[2] ]) if self.cudaFlag: padDenominatorZero = padDenominatorZero.cuda() self.sumKernel = self.sumKernel.float().cuda() attentionDenominatorSupplementWeight = torch.cat( [padDenominatorZero, attentionDenominatorRawWeight], dim=1) attentionDenominatorWeight = torch.conv1d( input=attentionDenominatorSupplementWeight.permute(0, 2, 1), weight=self.sumKernel, stride=1) attentionOriginWeight = torch.div(attentionNumeratorWeight.squeeze(), attentionDenominatorWeight.squeeze()) ######################################################### if seqInput is not None: attentionMaskWeight = attentionOriginWeight.min( self.AttentionMask(seqInput=seqInput)) else: attentionMaskWeight = attentionOriginWeight attentionWeight = torch.nn.functional.softmax( attentionMaskWeight, dim=-1).view([len(dataInput), -1, 1]) attentionSupplementWeight = attentionWeight.repeat( [1, 1, hiddenNoduleNumbers]) attentionSeparateResult = torch.mul(dataInput, attentionSupplementWeight) attentionResult = attentionSeparateResult.sum(dim=1) return attentionResult, attentionWeight
def forward(self, hidden_states, attention_mask): batch_size, seq_len, hidden_size = hidden_states.size() output_states = self.in_projection(hidden_states) output_states = output_states.permute(0, 2, 1) weight = torch.softmax(self.weight, dim=-1) weight = self.conv_weight_dropout(weight) if attention_mask: pivot = self.kernel_size // 2 + 1 weight[:, :, pivot:] = 0 output_states = output_states.contiguous().view(-1, self.num_heads, seq_len) output_states = torch.conv1d(output_states, weight, padding=self.kernel_size // 2, groups=self.num_heads) output_states = output_states.view(batch_size, hidden_size, seq_len) output_states = output_states.permute(0, 2, 1) # output projection output_states = self.out_projection(output_states) output_states = self.conv_layer_dropout(output_states) output_states = self.layer_norm(hidden_states + output_states) return output_states
def forward(self, data, edge_dropout=None, penalty_coefficient=0.25): x = data.x edge_index = data.edge_index batch = data.batch num_graphs = batch.max().item() + 1 row, col = edge_index total_num_edges = edge_index.shape[1] N_size = x.shape[0] if edge_dropout is not None: edge_index = dropout_adj( edge_index, edge_attr=(torch.ones(edge_index.shape[1], device=device)).long(), p=edge_dropout, force_undirected=True)[0] edge_index = add_remaining_self_loops(edge_index, num_nodes=batch.shape[0])[0] reduced_num_edges = edge_index.shape[1] current_edge_percentage = (reduced_num_edges / total_num_edges) no_loop_index, _ = remove_self_loops(edge_index) no_loop_row, no_loop_col = no_loop_index xinit = x.clone() x = x.unsqueeze(-1) mask = get_mask(x, edge_index, 1).to(x.dtype) x = F.leaky_relu(self.conv1(x, edge_index)) # +x x = x * mask x = self.gnorm(x) x = self.bn1(x) for conv, bn in zip(self.convs, self.bns): if (x.dim() > 1): x = x + F.leaky_relu(conv(x, edge_index)) mask = get_mask(mask, edge_index, 1).to(x.dtype) x = x * mask x = self.gnorm(x) x = bn(x) xpostconvs = x.detach() # x = F.leaky_relu(self.lin1(x)) x = x * mask xpostlin1 = x.detach() x = F.leaky_relu(self.lin2(x)) x = x * mask #calculate min and max batch_max = scatter_max(x, batch, 0, dim_size=N_size)[0] batch_max = torch.index_select(batch_max, 0, batch) batch_min = scatter_min(x, batch, 0, dim_size=N_size)[0] batch_min = torch.index_select(batch_min, 0, batch) #min-max normalize x = (x - batch_min) / (batch_max + 1e-6 - batch_min) probs = x #calculating the terms for the expected distance between clique and graph pairwise_prodsums = torch.zeros(num_graphs, device=device) for graph in range(num_graphs): batch_graph = (batch == graph) pairwise_prodsums[graph] = (torch.conv1d( probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum() / 2 ###calculate loss terms self_sums = scatter_add((probs * probs), batch, 0, dim_size=num_graphs) expected_weight_G = scatter_add( probs[no_loop_row] * probs[no_loop_col], batch[no_loop_row], 0, dim_size=num_graphs) / 2. expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums) / 1. expected_distance = (expected_clique_weight - expected_weight_G) ###calculate loss expected_loss = (penalty_coefficient ) * expected_distance * 0.5 - 0.5 * expected_weight_G loss = expected_loss retdict = {} retdict["output"] = [probs.squeeze(-1), "hist"] #output retdict["losses histogram"] = [loss.squeeze(-1), "hist"] retdict["Expected weight(G)"] = [expected_weight_G.mean(), "sequence"] retdict["Expected maximum weight"] = [ expected_clique_weight.mean(), "sequence" ] retdict["Expected distance"] = [expected_distance.mean(), "sequence"] retdict["loss"] = [loss.mean().squeeze(), "sequence"] #final loss return retdict
def update(self, x): # Prepare the inputs y = self.similarity(x, self.weight) t = self.teacher_signal if t is not None: t = t.unsqueeze(2).unsqueeze(3) * torch.ones_like(y, device=y.device) y = y.permute(0, 2, 3, 1).contiguous().view(-1, self.weight.size(0)) if t is not None: t = t.permute(0, 2, 3, 1).contiguous().view(-1, self.weight.size(0)) x_unf = unfold_map2d(x, self.weight.size(2), self.weight.size(3)) x_unf = x_unf.permute(0, 2, 3, 1, 4).contiguous().view(y.size(0), 1, -1) # Random abstention if self.random_abstention: abst_prob = self.victories_count / (self.victories_count.max() + y.size(0) / y.size(1)).clamp(1) scores = y * (torch.rand_like(abst_prob, device=y.device) >= abst_prob).float().unsqueeze(0) else: scores = y # Competition. The returned winner_mask is a bitmap telling where a neuron won and where one lost. if self.competitive: if t is not None: scores *= t winner_mask = (scores == scores.max(1, keepdim=True)[0]).float() if self.random_abstention: # Update statistics if using random abstension winner_mask_sum = winner_mask.sum( 0) # Number of inputs over which a neuron won self.victories_count += winner_mask_sum self.victories_count -= self.victories_count.min().item() else: winner_mask = torch.ones_like(y, device=y.device) # Lateral feedback if self.lfb_on: lfb_kernel = self.lfb_kernel if self.lfb_value == self.LFB_DoG or self.lfb_value == self.LFB_DoE: lfb_kernel = 2 * lfb_kernel - lfb_kernel.pow( 0.5 ) # Difference of Gaussians/Exponentials (mexican hat shaped function) lfb_in = F.pad(winner_mask.view(-1, *self.out_size), self.pad) if self.out_size.size(0) == 1: lfb_out = torch.conv1d(lfb_in.unsqueeze(1), lfb_kernel.unsqueeze(0).unsqueeze(1)) elif self.out_size.size(0) == 2: lfb_out = torch.conv2d(lfb_in.unsqueeze(1), lfb_kernel.unsqueeze(0).unsqueeze(1)) else: lfb_out = torch.conv3d(lfb_in.unsqueeze(1), lfb_kernel.unsqueeze(0).unsqueeze(1)) lfb_out = lfb_out.clamp(-1, 1).view_as(y) else: lfb_out = winner_mask if self.competitive: lfb_out[lfb_out == 0] = self.lfb_value elif t is not None: lfb_out = t # Compute step modulation coefficient r = lfb_out # RULE_BASE if self.weight_upd_rule == self.RULE_HEBB: r *= y # Compute delta r_abs = r.abs() r_sign = r.sign() delta_w = r_abs.unsqueeze(2) * ( r_sign.unsqueeze(2) * x_unf - self.weight.view(1, self.weight.size(0), -1)) # Since we use batches of inputs, we need to aggregate the different update steps of each kernel in a unique # update. We do this by taking the weighted average of teh steps, the weights being the r coefficients that # determine the length of each step r_sum = r_abs.sum(0) r_sum += (r_sum == 0).float() # Prevent divisions by zero delta_w_avg = (delta_w * r_abs.unsqueeze(2)).sum(0) / r_sum.unsqueeze(1) # Apply delta self.weight += self.eta * delta_w_avg.view_as(self.weight) # LFB kernel shrinking and LR schedule if self.lfb_on: self.lfb_kernel = self.lfb_kernel.pow(self.alpha) if self.lr_schedule is not None: self.eta = self.lr_schedule(self.eta)
def respond_to(model, sequences, training_run=True, extra_steps=0): responses = [[] for _ in range(len(sequences))] loss = 0 convolver, enc, dec, deconvolver = model hann_w = hann() if not config.use_gpu else hann().cuda() ihann_w = ihann() if not config.use_gpu else ihann().cuda() # with no_grad(): # #print(convolver[0].w.size(), hann().size()) # convolver[0].w *= hann_w # # deconvolver[0].w *= ihann(deconvolver[0].w) for i, sequence in enumerate(sequences): #print(f'seq{i}/{len(sequences)}') #print('in size:',sequence.size(),'conv_w size:',convolver[0].w.unsqueeze(1).size()) sequence = conv1d(sequence, (convolver[0].w * hann_w).unsqueeze(1), stride=config.frame_stride) sequence = transpose(sequence, 1, 2) sequence /= config.frame_len #print('conved size:',sequence.size()) # make key,query from all here.. => the transformer stuff for t in range(sequence.size(1) - 1): #curr_inp = sequence[:,t:t+1,:] prev_inps = sequence[:, :t + 1, :] lbl = sequence[:, t + 1:t + 2, :] positions = Tensor([[t + 1 / config.max_T, i / config.max_T] for i in range(t + 1)]).view(1, -1, 2) if config.use_gpu: positions = positions.cuda() #print(f'{t}/{sequence.size(1)}') # print('t:',t,',prev inps size:',prev_inps.size(),'curr inp size:',curr_inp.size()) #todo: hmmmm.. #inp = cat([prev_inps,curr_inp.repeat(1,t+1,1)], -1) inp = cat([prev_inps, positions], -1) # if config.seq_force_ratio != 1 and t>=2: # seq_force_ratio = config.seq_force_ratio**t # inp *= seq_force_ratio # inp += #print('inp size:',inp.size()) enced = prop_model(enc, inp) # print('enced size:', enced.size()) attn_inp = (softmax(enced, 1) * prev_inps).sum(1) # print('attnded size:', attn_inp.size()) deced = prop_model(dec, attn_inp) loss += sequence_loss(lbl, deced) responses[-1].append(deced) # input("halt here") #input('halt here..') if training_run: loss.backward() return float(loss) else: #print("seq size", sequence.size(1), 'hm resps', len(responses[-1])) if len(sequences) == 1: for t_extra in range(extra_steps): t = sequence.size(1) + t_extra - 1 #print(f't extra:{t}') curr_inp = responses[-1][t - 1] # print(sequence[:,:,:].size(), stack(responses[-1][sequence.size(1)-1-1:],1).size()) prev_inps = cat([ sequence[:, :-1, :], stack(responses[-1][sequence.size(1) - 1 - 1:], 1) ], 1) inp = cat([prev_inps, curr_inp.repeat(1, t + 1, 1)], -1) #print(inp.size()) enced = prop_model(enc, inp) # print('enced size:', enced.size()) attn_inp = (softmax(enced, 1) * prev_inps).sum(1) # print('attnded size:', attn_inp.size()) deced = prop_model(dec, attn_inp) responses[-1].append(deced) responses = responses[-1] responses = [(deconvolver[0].w * resp).sum(1) for resp in responses] responses = [resp * ihann_w for resp in responses] hm_windows = (len(sequence) - config.frame_len // config.frame_stride) + 1 responses = [] # todo: stitch together responses here.. responses = Tensor(responses).view(1, 1, -1) return float(loss), responses
def create_mask(n, d, f=torch.eq): a = torch.nonzero(torch.ones([n] * d)).view(-1, 1, d).float() w = torch.tensor([1, -1]).view(1, 1, -1).float() sel = (f(torch.conv1d(a, w, None, 1, 0, 1, 1), 0)).all(2).view([n] * d) return sel.byte()
import torch from onlinefy.marked_tensor import MarkedTensor from onlinefy.inject import marked_prop_wrapper torch.conv1d = marked_prop_wrapper(torch.conv1d) torch.sum = marked_prop_wrapper(torch.sum) a = torch.ones(1,3,4, requires_grad=True) b = MarkedTensor(a, marked_dim=1) w = torch.ones(3,3,3, requires_grad=True) c = torch.conv1d(b,w) d = MarkedTensor(c, marked_dim=1) e = torch.sum(d) e.backward()
def decode_clique_final(data, probabilities, draw=False, weight_factor=0.0, clique_number_bounds=None, fig=None, device='cpu'): row, col = data.edge_index sets = probabilities.detach().unsqueeze(-1) batch = data.batch no_loop_index, _ = remove_self_loops(data.edge_index) no_loop_row, no_loop_col = no_loop_index num_graphs = batch.max().item() + 1 total_index = 0 for graph in range(num_graphs): mark_edges = batch[no_loop_row] == graph nlr_graph, nlc_graph = no_loop_index[:, mark_edges] nlr_graph = nlr_graph - total_index nlc_graph = nlc_graph - total_index batch_graph = (batch == graph) graph_probs = sets[batch_graph].detach() sorted_inds = torch.argsort(graph_probs.squeeze(-1), descending=True) pairwise_prodsums = torch.zeros(1, device=device) pairwise_prodsums = (torch.conv1d(graph_probs.unsqueeze(-1), graph_probs.unsqueeze(-1))).sum() / 2 self_sums = (graph_probs * graph_probs).sum() num_nodes = batch_graph.float().sum().item() current_set_cardinality = 0 for node in range(int(num_nodes)): ind_i = total_index + sorted_inds[node] graph_probs_0 = sets[batch_graph].detach() graph_probs_1 = sets[batch_graph].detach() graph_probs_0[sorted_inds[node]] = 0 graph_probs_1[sorted_inds[node]] = 1 pairwise_prodsums_0 = torch.zeros(1, device=device) pairwise_prodsums_0 = (torch.conv1d( graph_probs_0.unsqueeze(-1), graph_probs_0.unsqueeze(-1))).sum() / 2 self_sums_0 = (graph_probs_0 * graph_probs_0).sum() expected_weight_G_0 = (graph_probs_0[nlr_graph] * graph_probs_0[nlc_graph]).sum() / 2 expected_clique_weight_0 = (pairwise_prodsums_0 - self_sums_0) clique_dist_0 = weight_factor * 0.5 * ( expected_clique_weight_0 - expected_weight_G_0) - expected_weight_G_0 pairwise_prodsums_1 = torch.zeros(1, device=device) pairwise_prodsums_1 = (torch.conv1d( graph_probs_1.unsqueeze(-1), graph_probs_1.unsqueeze(-1))).sum() / 2 self_sums_1 = (graph_probs_1 * graph_probs_1).sum() expected_weight_G_1 = (graph_probs_1[nlr_graph] * graph_probs_1[nlc_graph]).sum() / 2 expected_clique_weight_1 = (pairwise_prodsums_1 - self_sums_1) clique_dist_1 = weight_factor * 0.5 * ( expected_clique_weight_1 - expected_weight_G_1) - expected_weight_G_1 if clique_dist_0 >= clique_dist_1: decided = (graph_probs_1 == 1).float() current_set_cardinality = decided.sum().item() current_set_max_edges = (current_set_cardinality * (current_set_cardinality - 1)) / 2 current_set_edges = (decided[nlr_graph] * decided[nlc_graph]).sum() / 2 if (current_set_edges != current_set_max_edges): sets[ind_i] = 0 #IF NOT A CLIQUE else: sets[ind_i] = 1 #IF A CLIQUE else: sets[ind_i] = 0 if draw: dirac = data.locations[graph].item() - total_index if fig is None: f1 = plt.figure(graph, figsize=(16, 9)) else: f1 = fig ax1 = f1.add_subplot(121) g1, g2 = drawGraphFromData(data.to('cpu'), graph, vals=sets.squeeze(-1).detach().cpu(), dense=False, seed=dirac, nodecolor=True, edgecolor=False, seedhops=True, hoplabels=True, binarycut=False) ax2 = f1.add_subplot(122) g1, g2 = drawGraphFromData(data.to('cpu'), graph, vals=probabilities.detach().cpu(), dense=False, seed=dirac, nodecolor=True, edgecolor=False, seedhops=True, hoplabels=True, binarycut=False, clique=True) total_index += num_nodes expected_weight_G = scatter_add(sets[no_loop_col] * sets[no_loop_row], batch[no_loop_row], 0, dim_size=num_graphs) set_cardinality = scatter_add(sets, batch, 0, dim_size=num_graphs) return sets, expected_weight_G.detach(), set_cardinality
def forward(self, data, edge_dropout = None, penalty_coefficient = 0.25): x = data.x edge_index = data.edge_index batch = data.batch num_graphs = batch.max().item() + 1 row, col = edge_index total_num_edges = edge_index.shape[1] N_size = x.shape[0] if edge_dropout is not None: edge_index = dropout_adj(edge_index, edge_attr = (torch.ones(edge_index.shape[1], device=device)).long(), p = edge_dropout, force_undirected=True)[0] edge_index = add_remaining_self_loops(edge_index, num_nodes = batch.shape[0])[0] reduced_num_edges = edge_index.shape[1] current_edge_percentage = (reduced_num_edges/total_num_edges) no_loop_index,_ = remove_self_loops(edge_index) no_loop_row, no_loop_col = no_loop_index xinit= x.clone() x = x.unsqueeze(-1) mask = get_mask(x,edge_index,1).to(x.dtype) x = F.leaky_relu(self.conv1(x, edge_index))# +x x = x*mask x = self.gnorm(x) x = self.bn1(x) for conv, bn in zip(self.convs, self.bns): if(x.dim()>1): x = x+F.leaky_relu(conv(x, edge_index)) mask = get_mask(mask,edge_index,1).to(x.dtype) x = x*mask x = self.gnorm(x) x = bn(x) xpostconvs = x.detach() # x = F.leaky_relu(self.lin1(x)) x = x*mask xpostlin1 = x.detach() x = F.leaky_relu(self.lin2(x)) x = x*mask #calculate min and max batch_max = scatter_max(x, batch, 0, dim_size= N_size)[0] batch_max = torch.index_select(batch_max, 0, batch) batch_min = scatter_min(x, batch, 0, dim_size= N_size)[0] batch_min = torch.index_select(batch_min, 0, batch) #min-max normalize x = (x-batch_min)/(batch_max+1e-6-batch_min) probs=x x2 = x.detach() deg = degree(row).unsqueeze(-1) totalvol = scatter_add(deg.detach()*torch.ones_like(x, device=device), batch, 0)+1e-6 totalcard = scatter_add(torch.ones_like(x, device=device), batch, 0)+1e-6 x2 = ((x2 - torch.rand_like(x, device = device))>0).float() vol_1 = scatter_add(probs*deg, batch, 0)+1e-6 card_1 = scatter_add(probs, batch,0) set_size = scatter_add(x2, batch, 0) vol_hard = scatter_add(deg*x2, batch, 0, dim_size = batch.max().item()+1)+1e-6 total_vol_ratio = vol_hard/totalvol #calculating the terms for the expected distance between clique and graph pairwise_prodsums = torch.zeros(num_graphs, device = device) for graph in range(num_graphs): batch_graph = (batch==graph) pairwise_prodsums[graph] = (torch.conv1d(probs[batch_graph].unsqueeze(-1), probs[batch_graph].unsqueeze(-1))).sum()/2 ###calculate loss terms self_sums = scatter_add((probs*probs), batch, 0, dim_size = num_graphs) expected_weight_G = scatter_add(probs[no_loop_row]*probs[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2. expected_clique_weight = (pairwise_prodsums.unsqueeze(-1) - self_sums)/1. expected_distance = (expected_clique_weight - expected_weight_G) ###useful numbers max_set_weight = (scatter_add(torch.ones_like(x)[no_loop_row], batch[no_loop_row], 0, dim_size = num_graphs)/2).squeeze(-1) set_weight = (scatter_add(x2[no_loop_row]*x2[no_loop_col], batch[no_loop_row], 0, dim_size = num_graphs)/2)+1e-6 clique_edges_hard = (set_size*(set_size-1)/2) +1e-6 clique_dist_hard = set_weight/clique_edges_hard clique_check = ((clique_edges_hard != clique_edges_hard)) setedge_check = ((set_weight != set_weight)) assert ((clique_dist_hard>=1.1).sum())<=1e-6, "Invalid set vol/clique vol ratio." ###calculate loss expected_loss = (penalty_coefficient)*expected_distance*0.5 - 0.5*expected_weight_G loss = expected_loss retdict = {} retdict["output"] = [probs.squeeze(-1),"hist"] #output retdict["Expected_cardinality"] = [card_1.mean(),"sequence"] retdict["Expected_cardinality_hist"] = [card_1,"hist"] retdict["losses histogram"] = [loss.squeeze(-1),"hist"] retdict["Set sizes"] = [set_size.squeeze(-1),"hist"] retdict["volume_hard"] = [vol_hard.mean(),"aux"] #volume2 retdict["cardinality_hard"] = [set_size[0],"sequence"] #volumeq retdict["Expected weight(G)"]= [expected_weight_G.mean(), "sequence"] retdict["Expected maximum weight"] = [expected_clique_weight.mean(),"sequence"] retdict["Expected distance"]= [expected_distance.mean(), "sequence"] retdict["Currvol/Cliquevol"] = [clique_dist_hard.mean(),'sequence'] retdict["Currvol/Cliquevol all graphs in batch"] = [clique_dist_hard.squeeze(-1),'hist'] retdict["Average ratio of total volume"]= [total_vol_ratio.mean(),'sequence'] retdict["cardinalities"] = [cardinalities.squeeze(-1),"hist"] retdict["Current edge percentage"] = [torch.tensor(current_edge_percentage),'sequence'] retdict["loss"] = [loss.mean().squeeze(),"sequence"] #final loss return retdict
import torch from causal_conv1d import CausalConv1d if __name__ == '__main__': # Wie implementiert man so einen signal split in pytorch? # nvm. res1 = x x1 = torch.tanh(torch.conv1d(x, W)) * torch.sig(torch.conv1d(x, W)) x2 = torch.conv1d(x1) x3 = x2 + x # incremental forward?
def conv1d_side_effect(x, weights, bias, stride, **kwargs): return torch.conv1d(x, weights, bias, stride)
def conv1d(input, *args, **kwargs): return torch.conv1d(input.q, *args, **kwargs)