class LinearCapsPro(nn.Module): def __init__(self, in_features, num_C, num_D, eps=0.0001): super(LinearCapsPro, self).__init__() self.in_features = in_features self.num_C = num_C self.num_D = num_D self.eps = eps self.weight = Parameter(torch.Tensor(num_C, num_D, self.in_features)) self.eye = Parameter(torch.eye(num_D), requires_grad=False) self.count = 0 self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(2)) self.weight.data.uniform_(-stdv, stdv) def forward(self, x): weight_caps = torch.matmul(self.weight, self.weight.permute(0, 2, 1)) sigma = torch.inverse(weight_caps + self.eps * self.eye) out = torch.matmul(x, torch.t(self.weight.view(-1, x.size(-1)))) out = out.view(out.shape[0], self.num_C, 1, self.num_D) out = torch.matmul(out, sigma) out = torch.matmul(out, self.weight) out = torch.squeeze(out, dim=2) out = torch.matmul(out, x.unsqueeze(dim=2)).squeeze(dim=2) out = torch.sqrt(out) return out
class EmbeddingToIndex(nn.Module): def __init__(self, n_tokens, embdim, _weight=None): super(EmbeddingToIndex, self).__init__() self.weight = Parameter(_weight) self.n_tokens = self.weight.size(0) self.embdim = embdim assert list(_weight.shape) == [n_tokens, embdim], \ 'Shape of weight does not match num_embeddings and embedding_dim' def forward(self, X): mb_size = X.size(0) adotb = torch.matmul(X, self.weight.permute(1, 0)) if len(X.size()) == 3: seqlen = X.size(1) adota = torch.matmul(X.view(-1, seqlen, 1, self.embdim), X.view(-1, seqlen, self.embdim, 1)) adota = adota.view(-1, seqlen, 1).repeat(1, 1, self.n_tokens) else: # (X.size()) == 2: adota = torch.matmul(X.view(-1, 1, self.embdim), X.view(-1, self.embdim, 1)) adota = adota.view(-1, 1).repeat(1, self.n_tokens) bdotb = torch.bmm(self.weight.unsqueeze(-1).permute(0, 2, 1), self.weight.unsqueeze(-1)).permute(1, 2, 0) if len(X.size()) == 3: bdotb = bdotb.repeat(mb_size, seqlen, 1) else: bdotb = bdotb.reshape(1, self.n_tokens).repeat(mb_size, 1) dist = adota - 2 * adotb + bdotb return torch.min(dist, dim=len(dist.size()) - 1)[1]
class Maxout(nn.Module): __constants__ = ['bias'] def __init__(self, in_features, out_features, pieces, bias=True): super(Maxout, self).__init__() self.in_features = in_features self.out_features = out_features self.pieces = pieces self.weight = Parameter(torch.Tensor(pieces, out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(pieces, out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, input): output = input.matmul(self.weight.permute(0, 2, 1)).permute( (1, 0, 2)) + self.bias output = torch.max(output, dim=1)[0] return output
class SymLinear(nn.Module): '''Linear with symmetric weight matrices''' def __init__(self, in_features, out_features, bias=True): super(SymLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) init.uniform_(self.bias, -bound, bound) def forward(self, input): weight = self.weight + self.weight.permute(1, 0) return F.linear(input, weight, self.bias) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None)
class MemoryUnit(nn.Module): def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025): super(MemoryUnit, self).__init__() self.mem_dim = mem_dim self.fea_dim = fea_dim self.weight = Parameter( torch.Tensor(self.mem_dim, self.fea_dim) ) # (Rows in the memory matrix) M x (dimension of each row of the memory matrix) C self.bias = None self.shrink_thres = shrink_thres self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input): att_weight = F.linear( input, self.weight ) # Features x Memory^T, the dimensions are : (BatchSize N x FeatureDimension C) x (FeatureDimension C x Rows in the memory matrix M) = output matrix is of size N x M att_weight = F.softmax( att_weight, dim=1 ) # Softmax on matix of size NxM along Mth Dimension, i.e. softmax along each row which results as row summing up to 1 for all the N rows of Matrix; Maxtix of size N x M is returned # ReLU based shrinkage, hard shrinkage for positive value if (self.shrink_thres > 0): att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres) att_weight = F.normalize(att_weight, p=1, dim=1) mem_trans = self.weight.permute( 1, 0 ) # Mem^T, matrix of size C x M is the result as weights are of dimension M x C output = F.linear( att_weight, mem_trans ) # AttWeight x Mem^T^T = AW x Mem, attention weights (NxM) x (MxC) = NxC return { 'output': output, 'att': att_weight } # output (N x C), att_weight (N X M) def extra_repr(self): return 'mem_dim={}, fea_dim={}'.format(self.mem_dim, self.fea_dim is not None)
class MemoryUnit(nn.Module): def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025): super(MemoryUnit, self).__init__() self.mem_dim = mem_dim self.fea_dim = fea_dim self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # M x C self.bias = None self.shrink_thres = shrink_thres # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input): att_weight = F.linear(input, self.weight) # Fea x Mem^T, (TxC) x (CxM) = TxM att_weight = F.softmax(att_weight, dim=1) # TxM # ReLU based shrinkage, hard shrinkage for positive value # print("att_weight",torch.sum(att_weight[0])) if (self.shrink_thres > 0): att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres) # att_weight = F.softshrink(att_weight, lambd=self.shrink_thres) # normalize??? att_weight = F.normalize(att_weight, p=1, dim=1) print("att_weight", torch.sum(att_weight[0])) # att_weight = F.softmax(att_weight, dim=1) # att_weight = self.hard_sparse_shrink_opt(att_weight) mem_trans = self.weight.permute(1, 0) # Mem^T, MxC output = F.linear( att_weight, mem_trans) # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC return {'output': output, 'att': att_weight} # output, att_weight def extra_repr(self): return 'mem_dim={}, fea_dim={}'.format(self.mem_dim, self.fea_dim is not None)
class MemoryModule(nn.Module): ''' Memory Module ''' def __init__(self, mem_dim, fea_dim, shrink_thres): super().__init__() self.mem_dim = mem_dim self.fea_dim = fea_dim # attention self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim)) # [M, C] self.bias = None self.shrink_thres = shrink_thres self.reset_parameters() def reset_parameters(self): ''' init memory elements : Very Important !! ''' stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, x): ''' x [B,C,H,W] : latent code Z''' B, C, H, W = x.shape x = x.permute(0, 2, 3, 1).flatten(end_dim=2) # Fea : [NxC] N=BxHxW # calculate attention weight att_weight = F.linear(x, self.weight) # Fea*Mem^T : [NxC] x [CxM] = [N, M] att_weight = F.softmax(att_weight, dim=1) # [N, M] if self.shrink_thres > 0: # hard shrink att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres) # re-normalize att_weight = F.normalize(att_weight, p=1, dim=1) # [N, M] # generate code z' mem_T = self.weight.permute(1, 0) output = F.linear(att_weight, mem_T) # Fea*Mem^T^T : [N, M] x [M, C] = [N, C] output = output.view(B,H,W,C).permute(0,3,1,2) # [N,C,H,W] return att_weight, output
class FilterStripe(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False) self.BrokenTarget = None self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True) def forward(self, x): if self.BrokenTarget is not None: out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1]))) if x.is_cuda: out = out.cuda() x = F.conv2d(x, self.weight) l, h = 0, 0 for i in range(self.BrokenTarget.shape[0]): for j in range(self.BrokenTarget.shape[1]): h += self.FilterSkeleton[:, i, j].sum().item() out[:, self.FilterSkeleton[:, i, j]] += self.shift( x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]] l += self.FilterSkeleton[:, i, j].sum().item() return out else: return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups) def prune_in(self, in_mask=None): self.weight = Parameter(self.weight[:, in_mask]) self.in_channels = in_mask.sum().item() def prune_out(self, threshold): out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0 if out_mask.sum() == 0: out_mask[0] = True self.weight = Parameter(self.weight[out_mask]) self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True) self.out_channels = out_mask.sum().item() return out_mask def _break(self, threshold): self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1)) self.FilterSkeleton = Parameter( (self.FilterSkeleton.abs() > threshold), requires_grad=False) if self.FilterSkeleton.sum() == 0: self.FilterSkeleton.data[0][0][0] = True self.out_channels = self.FilterSkeleton.sum().item() self.BrokenTarget = self.FilterSkeleton.sum(dim=0) self.kernel_size = (1, 1) self.weight = Parameter( self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute( 1, 2, 0).reshape(-1)]) def update_skeleton(self, sr, threshold): self.FilterSkeleton.grad.data.add_( sr * torch.sign(self.FilterSkeleton.data)) mask = self.FilterSkeleton.data.abs() > threshold self.FilterSkeleton.data.mul_(mask) self.FilterSkeleton.grad.data.mul_(mask) out_mask = mask.sum(dim=(1, 2)) != 0 return out_mask def shift(self, x, i, j): return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0) def extra_repr(self): s = ( '{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') return s.format(**self.__dict__)
class MCCP(nn.Module): def __init__(self, in_features, num_C, num_D, reciptive=512, strides=256, eps=0.0001): super(MCCP, self).__init__() self.in_features = in_features print('the dimension of input vector is:', in_features) print('mccp reciptive is:', reciptive) print('mccp strides is:', strides) self.reciptive = reciptive self.strides = strides self.num_C = num_C self.num_D = num_D self.eps = eps self.eye = Parameter(torch.eye(num_D), requires_grad=False) self.weight = Parameter(torch.Tensor(num_C, num_D, self.reciptive)) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(2)) self.weight.data.uniform_(-stdv, stdv) def forward(self, x): weight_caps = torch.matmul(self.weight, self.weight.permute(0, 2, 1)) sigma = torch.inverse(weight_caps + self.eps * self.eye) # implemented as a convolutional procedure # vec2mat = nn.functional.unfold(torch.cat((x, x[:, :self.reciptive]), dim=-1).unsqueeze(1).unsqueeze(1), (1, self.reciptive), stride=self.strides) # matrix = vec2mat.permute(0, 2, 1) # b, n, d = matrix.size() # out1 = torch.matmul(matrix, torch.t(self.weight.view(-1, d))) # out1 = out1.view(b, n, self.num_C, 1, self.num_D) # out1 = torch.matmul(out1, sigma) # (b, n, num_C, 1, num_D) # out2 = torch.matmul(self.weight.unsqueeze(dim=0).unsqueeze(dim=0), matrix.unsqueeze(dim=2).unsqueeze(dim=-1)) # out2 = torch.matmul(out1, out2) # (b, n, num_C, 1, 1) # return out2.sum(dim=1).sqrt().squeeze() # implemented column by column results = [] for i in range(0, x.shape[1], self.strides): # vec2mat if i + self.reciptive > x.shape[1]: inputs = torch.cat( (x[:, i:], x[:, :(self.reciptive - x.shape[1] + i)]), dim=-1) else: inputs = x[:, i:i + self.reciptive] # project out = torch.matmul(inputs, torch.t(self.weight.view(-1, inputs.size(-1)))) out = out.view(out.shape[0], self.num_C, 1, self.num_D) out = torch.matmul(out, sigma) out = torch.matmul( out, self.weight.view(self.num_C, self.num_D, self.reciptive)) out = torch.squeeze(out, dim=2) out = torch.matmul(out, torch.unsqueeze(inputs, dim=2)) results.append(torch.sqrt(out)) results = torch.cat(results, dim=-1) return torch.mean(results, dim=-1)
class TSL_Net(nn.Module): def __init__(self, arch_interaction_op='dot', arch_attention_mechanism='mlp', ln=None, model_type="tsl", tsl_inner="def", mha_num_heads=8, ln_top=""): super(TSL_Net, self).__init__() # save arguments self.arch_interaction_op = arch_interaction_op self.arch_attention_mechanism = arch_attention_mechanism self.model_type = model_type self.tsl_inner = tsl_inner # setup for mechanism type if self.arch_attention_mechanism == 'mlp': self.mlp = dlrm.DLRM_Net().create_mlp(ln, len(ln) - 2) # setup extra parameters for some of the models if self.model_type == "tsl" and self.tsl_inner in ["def", "ind"]: m = ln_top[-1] # dim of dlrm output mean = 0.0 std_dev = np.sqrt(2 / (m + m)) W = np.random.normal(mean, std_dev, size=(1, m, m)).astype(np.float32) self.A = Parameter(torch.tensor(W), requires_grad=True) elif self.model_type == "mha": m = ln_top[-1] # dlrm output dim self.nheads = mha_num_heads self.emb_m = self.nheads * m # mha emb dim mean = 0.0 std_dev = np.sqrt(2 / (m + m)) # np.sqrt(1 / m) # np.sqrt(1 / n) qm = np.random.normal(mean, std_dev, size=(1, m, self.emb_m)) \ .astype(np.float32) self.Q = Parameter(torch.tensor(qm), requires_grad=True) km = np.random.normal(mean, std_dev, size=(1, m, self.emb_m)) \ .astype(np.float32) self.K = Parameter(torch.tensor(km), requires_grad=True) vm = np.random.normal(mean, std_dev, size=(1, m, self.emb_m)) \ .astype(np.float32) self.V = Parameter(torch.tensor(vm), requires_grad=True) def forward(self, x=None, H=None): # adjust input shape (batchSize, vector_dim) = x.shape x = torch.reshape(x, (batchSize, 1, -1)) x = torch.transpose(x, 1, 2) # debug prints # print("shapes: ", self.A.shape, x.shape) # perform mode operation if self.model_type == "tsl": if self.tsl_inner == "def": ax = torch.matmul(self.A, x) x = torch.matmul(self.A.permute(0, 2, 1), ax) # debug prints # print("shapes: ", H.shape, ax.shape, x.shape) elif self.tsl_inner == "ind": x = torch.matmul(self.A, x) # perform interaction operation if self.arch_interaction_op == 'dot': if self.arch_attention_mechanism == 'mul': # coefficients a = torch.transpose(torch.bmm(H, x), 1, 2) # context c = torch.bmm(a, H) elif self.arch_attention_mechanism == 'mlp': # coefficients a = torch.transpose(torch.bmm(H, x), 1, 2) # MLP first/last layer dims are automatically adjusted to ts_length y = dlrm.DLRM_Net().apply_mlp(a, self.mlp) # context, y = mlp(a) c = torch.bmm(torch.reshape(y, (batchSize, 1, -1)), H) else: sys.exit('ERROR: --arch-attention-mechanism=' + self.arch_attention_mechanism + ' is not supported') else: sys.exit('ERROR: --arch-interaction-op=' + self.arch_interaction_op + ' is not supported') elif self.model_type == "mha": x = torch.transpose(x, 1, 2) Qx = torch.transpose(torch.matmul(x, self.Q), 0, 1) HK = torch.transpose(torch.matmul(H, self.K), 0, 1) HV = torch.transpose(torch.matmul(H, self.V), 0, 1) # multi-head attention (mha) multihead_attn = nn.MultiheadAttention(self.emb_m, self.nheads).to(x.device) attn_output, _ = multihead_attn(Qx, HK, HV) # context c = torch.squeeze(attn_output, dim=0) # debug prints # print("shapes:", c.shape, Qx.shape) return c
class FilterStripe(nn.Conv2d):#卷积层+FS层 def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False) self.BrokenTarget = None self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True)#FS层初始化 def forward(self, x):#forward()是自动调用的,x:[N,通道数,width,height] if self.BrokenTarget is not None: #out:[N,通道数,width,height] out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1])))#ceil() 函数返回数字的上入整数 if x.is_cuda: out = out.cuda() x = F.conv2d(x, self.weight)#卷积输出 l, h = 0, 0 for i in range(self.BrokenTarget.shape[0]): for j in range(self.BrokenTarget.shape[1]): h += self.FilterSkeleton[:, i, j].sum().item()#FS层每个通道对应的值相加 out[:, self.FilterSkeleton[:, i, j]] += self.shift(x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]]#获得每个通道对应索引的输出 l += self.FilterSkeleton[:, i, j].sum().item() return out#输出 else: #unsqueeze(1)在第二个维度增加一个维度 return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups) def prune_in(self, in_mask=None):#in_mask掩膜 #self.weight.shape:[out_channel,k,k,in_channel] print(self.weight.shape) self.weight = Parameter(self.weight[:, in_mask])#?????????? print(self.weight) self.in_channels = in_mask.sum().item() def prune_out(self, threshold):#threshold为阈值 out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0#获得掩膜 if out_mask.sum() == 0: print(out_mask.sum()) out_mask[0] = True self.weight = Parameter(self.weight[out_mask])#卷积核掩膜化 self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True)#FS层掩膜化 self.out_channels = out_mask.sum().item()#获取输出通道 return out_mask#掩膜 def _break(self, threshold): self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))#卷积核与FS层相乘 self.FilterSkeleton = Parameter((self.FilterSkeleton.abs() > threshold), requires_grad=False)#FS层大于阈值的为true if self.FilterSkeleton.sum() == 0: self.FilterSkeleton.data[0][0][0] = True self.out_channels = self.FilterSkeleton.sum().item() self.BrokenTarget = self.FilterSkeleton.sum(dim=0) self.kernel_size = (1, 1) #permute()将tensor的维度换位。 # print(self.FilterSkeleton.permute(1, 2, 0).reshape(-1)) # print(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)) self.weight = Parameter(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute(1, 2, 0).reshape(-1)])#掩膜化 # print(self.weight) def update_skeleton(self, sr, threshold): self.FilterSkeleton.grad.data.add_(sr * torch.sign(self.FilterSkeleton.data))#FS层的梯度更新,加入L1范数的导数 mask = self.FilterSkeleton.data.abs() > threshold self.FilterSkeleton.data.mul_(mask)#掩码化 self.FilterSkeleton.grad.data.mul_(mask)#掩码化 out_mask = mask.sum(dim=(1, 2)) != 0#???? return out_mask def shift(self, x, i, j): return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0) def extra_repr(self): s = ('{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}' ', stride={stride}') return s.format(**self.__dict__)
class RTN_CORE(nn.Module): def __init__(self, M, N, G, alpha, filtersize = 5, padding=0, init_kernels = None): super(RTN_CORE, self).__init__() # M: number of input channeles # N: number of output channels # G: number of translations to use. If odd, the original filter will also be used. If even, not. # alpha: There will be G filters rotated from -alpha to alpha. self.M=M self.N=N self.G=G self.a=alpha self.k = filtersize #self.conv1 = nn.Conv2d(self.M, self.M*self.N*self.G, filtersize, groups=self.M, padding=padding) # in, out, kernel size, groups as in paper # Lets follow the vocabluary of the paper and call: # w: the unrotated filters / "templates" (the ones we want to learn) # gw: the rotated filters / "transformed templates" (the G rotated versions of those filters) # There will be a total of: # M*N templates # M*N*G transformed templates # Create the "templates" as torch.nn.Parameter as described here: # http://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html # http://pytorch.org/docs/0.3.1/nn.html#parameters # Rondomly initialize weights: # TODO: Maybe there is a better way to initialize # maybe: #w = Parameter(torch.randn(M*N, k, k)) # size(w) = ( M*N x k x k ) if init_kernels is None: self.w = Parameter(torch.randn(M*N, 1, self.k, self.k)) # size(w) = ( M*N x 1 x k x k ) else: self.w = init_kernels # From every "template" derive G "transformed templates" by applying rotation: # Variable(torch.randn(8,4,3,3)) # Step 1 – Create the G transformation matrices we whish to rotate each "template" by: #g = get_rotated_kernels(w, self.G, -alpha, alpha) #Variable(torch.randn(G,2,3)) # size(g) = ( G x 2 x 3 ) # TODO #g = Variable(torch.randn(G,2,3)) # size(g) = ( G x 2 x 3 ) # TODO g1 = make_rotations(-alpha, alpha, G) # size(g) = ( G x 2 x 3 ) g2 = [torch.Tensor(f) for f in g1] g3 = torch.stack(g2) g = Variable( torch.Tensor( g3 )) # Step2 – Apply the transformations: # Step 2.1 – Create the flow fields, describing the transformations. s = torch.Size((G, M*N, self.k, self.k)) # Desired output size flow_field = torch.nn.functional.affine_grid(g, s) # size(flow_field) = (G, k, k, 2), one translation vector (or maybe coordinate; we don't know nor care) per each of the G rotation matrices. self.register_buffer("flow_field", flow_field) # Those two layers are the same as in the vanilla NPTN by ??? et.al. 20?? self.maxpool3d = nn.MaxPool3d((self.G, 1, 1)) self.meanpool3d = nn.AvgPool3d((self.M, 1, 1)) # Is that the right pooling? - AvgPool3d? self.permutation = make_permutation(self.M, self.N) # TODO: Should this also go on the GPU as a buffer? def forward(self, x): # Start from w ervery time and create the others as rotation of it #print('\nShape of x ', x.size()) # Step 2.2 – Apply the flow_fields. Each flow_field will be applied to each channle of the input / each "template". # Each flow_field is of (G x M*N, k, k) #a grid. For each rotation, there will be one flow field. # Repeat w along the first dimension: # Permute: (M*N, 1, k, k) to (1, M*N, k, k) w_rep = self.w.permute(1,0,2,3) # ( 1 x M*N, k, k) to ( G x M*N, k, k) w_rep2 = w_rep.expand(self.G, -1, -1, -1) # Repeat along the singular first dimension G times. # # size(w_rep2) = ( G x M*N, k, k) # w_rot = torch.nn.functional.grid_sample(w_rep2, self.flow_field) # size(w_rep2) = ( G x M*N x k x k ) # Go from ( G x M*N x k x k ) to ( M*N x G x k x k ), # i.e.( angle, template, x, y) to ( template, angle, x, y) w_perm = w_rot.permute(1,0,2,3) # But actually we need this unrolled: MNG = self.M * self.N * self.G w_unrolled = w_perm.resize(MNG, 1, self.k, self.k) # size(w_unrolled) = ( M*N*G x (M*N*G)/(M*N*G) x k x k ) #weight – filters of shape (out_channels×in_channelsgroups×kH×kW) # plot_kernels(w_unrolled, num_cols=G, title='RTN') # Convolution: # Use the functional (torch.nn.functional.conv2d) instead of Module # (torch.nn.conv2d), becuse we don't want trainable parameters exept the # template vector defined above: x = F.conv2d(x, w_unrolled, groups=self.M) # TODO: Add padding? #print('Shape after convolution', x.size()) x = self.maxpool3d(x) #print("Shape after MaxPool3d: ", x.size()) # dimension should be M*N #print('permutation ', permutation) x = x[:, self.permutation] # reorder channels #print("Shape after Channel reordering: ", x.size()) x = self.meanpool3d(x) #print('Shape after Mean Pooling: ', x.size()) return x
class Gumbel_Generator_Old(nn.Module): def __init__(self, sz=10, temp=10, temp_drop_frac=0.9999): super(Gumbel_Generator_Old, self).__init__() # 将类Gumbe_Generator_Old 对象转换为类 nn.Module 的对象 self.sz = sz # 将一个不可训练的类型Tensor 转换成可以训练的类型Parameter 经过这个类型转换这个self,***就成了模型的中一部分 # 成为了模型中根据训练可以改动的参数了 self.gen_matrix = Parameter(torch.rand( sz, sz, 2)) # torch.rand()返回一个张量,包含了从区间(0,1)随机抽取的一组随机数 self.new_matrix = Parameter(torch.zeros(5, 5, 2)) # gen_matrix 为邻接矩阵的概率 self.temperature = temp self.temp_drop_frac = temp_drop_frac def symmetry(self): matrix = self.gen_matrix.permute(2, 0, 1) temp_matrix = torch.triu(matrix, 1) + torch.triu(matrix, 1).permute( 0, 2, 1) self.gen_matrix.data = temp_matrix.permute(1, 2, 0) def drop_temp(self): # 降温过程 self.temperature = self.temperature * self.temp_drop_frac def sample_all(self, hard=False, epoch=1): # 采样——得到一个邻接矩阵 # self.symmetry() self.logp = self.gen_matrix.view( -1, 2) # view先变成一维的tensor,再按照参数转换成对应维度的tensor out = gumbel_softmax(self.logp, self.temperature, hard) if hard: hh = torch.zeros(self.gen_matrix.size()[0]**2, 2) for i in range(out.size()[0]): hh[i, out[i]] = 1 out = hh if use_cuda: out = out.cuda() out_matrix = out[:, 0].view(self.gen_matrix.size()[0], self.gen_matrix.size()[0]) return out_matrix def sample_small(self, list, hard=False): indices = np.ix_(list, list) self.logp = self.gen_matrix[indices].view( -1, 2) # view先变成一维的tensor,再按照参数转换成对应维度的tensor out = gumbel_softmax(self.logp, self.temperature, hard) # hard 干什么用的 if hard: hh = torch.zeros(self.gen_matrix[indices].size()[0]**2, 2) for i in range(out.size()[0]): hh[i, out[i]] = 1 out = hh if use_cuda: out = out.cuda() out_matrix = out[:, 0].view(len(list), len(list)) return out_matrix def sample_adj_ij(self, list, j, hard=False, sample_time=1): # self.logp = self.gen_matrix[:,i] self.logp = self.gen_matrix[list, j] out = gumbel_softmax(self.logp, self.temperature, hard=hard) if use_cuda: out = out.cuda() # print(out) if hard: out_matrix = out.float() else: out_matrix = out[:, 0] return out_matrix def sample_adj_i(self, i, hard=False, sample_time=1): # self.symmetry() self.logp = self.gen_matrix[:, i] out = gumbel_softmax(self.logp, self.temperature, hard=hard) if use_cuda: out = out.cuda() # print(out) if hard: out_matrix = out.float() else: out_matrix = out[:, 0] return out_matrix def get_temperature(self): return self.temperature def get_cross_entropy(self, obj_matrix): # 计算与目标矩阵的距离 logps = F.softmax(self.gen_matrix, 2) logps = torch.log(logps[:, :, 0] + 1e-10) * obj_matrix + torch.log( logps[:, :, 1] + 1e-10) * (1 - obj_matrix) result = -torch.sum(logps) result = result.cpu() if use_cuda else result return result.data.numpy() def get_entropy(self): logps = F.softmax(self.gen_matrix, 2) result = torch.mean(torch.sum(logps * torch.log(logps + 1e-10), 1)) result = result.cpu() if use_cuda else result return (-result.data.numpy()) def randomization(self, fraction): # 将gen_matrix重新随机初始化,fraction为重置比特的比例 sz = self.gen_matrix.size()[0] numbers = int(fraction * sz * sz) original = self.gen_matrix.cpu().data.numpy() for i in range(numbers): ii = np.random.choice(range(sz), (2, 1)) z = torch.rand(2).cuda() if use_cuda else torch.rand(2) self.gen_matrix.data[ii[0], ii[1], :] = z def init(self, mean, var): init.normal_(self.gen_matrix, mean=mean, std=var)