コード例 #1
0
ファイル: opt.py プロジェクト: cxiang26/mccp
class LinearCapsPro(nn.Module):
    def __init__(self, in_features, num_C, num_D, eps=0.0001):
        super(LinearCapsPro, self).__init__()
        self.in_features = in_features
        self.num_C = num_C
        self.num_D = num_D
        self.eps = eps
        self.weight = Parameter(torch.Tensor(num_C, num_D, self.in_features))
        self.eye = Parameter(torch.eye(num_D), requires_grad=False)
        self.count = 0
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(2))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x):
        weight_caps = torch.matmul(self.weight, self.weight.permute(0, 2, 1))
        sigma = torch.inverse(weight_caps + self.eps * self.eye)

        out = torch.matmul(x, torch.t(self.weight.view(-1, x.size(-1))))
        out = out.view(out.shape[0], self.num_C, 1, self.num_D)
        out = torch.matmul(out, sigma)
        out = torch.matmul(out, self.weight)
        out = torch.squeeze(out, dim=2)
        out = torch.matmul(out, x.unsqueeze(dim=2)).squeeze(dim=2)
        out = torch.sqrt(out)
        return out
コード例 #2
0
ファイル: net.py プロジェクト: small-yellow-duck/titanic-ede
class EmbeddingToIndex(nn.Module):
    def __init__(self, n_tokens, embdim, _weight=None):
        super(EmbeddingToIndex, self).__init__()
        self.weight = Parameter(_weight)
        self.n_tokens = self.weight.size(0)
        self.embdim = embdim

        assert list(_weight.shape) == [n_tokens, embdim], \
            'Shape of weight does not match num_embeddings and embedding_dim'
    
    def forward(self, X):  
        mb_size = X.size(0)
        adotb = torch.matmul(X, self.weight.permute(1, 0))

        if len(X.size()) == 3:
            seqlen = X.size(1)
            adota = torch.matmul(X.view(-1, seqlen, 1, self.embdim),
                                 X.view(-1, seqlen, self.embdim, 1))
            adota = adota.view(-1, seqlen, 1).repeat(1, 1, self.n_tokens)
        else: # (X.size()) == 2:
            adota = torch.matmul(X.view(-1, 1, self.embdim), X.view(-1, self.embdim, 1))
            adota = adota.view(-1, 1).repeat(1, self.n_tokens)

        bdotb = torch.bmm(self.weight.unsqueeze(-1).permute(0, 2, 1), self.weight.unsqueeze(-1)).permute(1, 2, 0)

        if len(X.size()) == 3:
            bdotb = bdotb.repeat(mb_size, seqlen, 1)
        else:
            bdotb = bdotb.reshape(1, self.n_tokens).repeat(mb_size, 1)

        dist = adota - 2 * adotb + bdotb

        return torch.min(dist, dim=len(dist.size()) - 1)[1]
コード例 #3
0
class Maxout(nn.Module):
    __constants__ = ['bias']

    def __init__(self, in_features, out_features, pieces, bias=True):
        super(Maxout, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.pieces = pieces
        self.weight = Parameter(torch.Tensor(pieces, out_features,
                                             in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(pieces, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        output = input.matmul(self.weight.permute(0, 2, 1)).permute(
            (1, 0, 2)) + self.bias
        output = torch.max(output, dim=1)[0]
        return output
コード例 #4
0
class SymLinear(nn.Module):
    '''Linear with symmetric weight matrices'''
    def __init__(self, in_features, out_features, bias=True):
        super(SymLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        weight = self.weight + self.weight.permute(1, 0)
        return F.linear(input, weight, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)
コード例 #5
0
class MemoryUnit(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025):
        super(MemoryUnit, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.weight = Parameter(
            torch.Tensor(self.mem_dim, self.fea_dim)
        )  # (Rows in the memory matrix) M x (dimension of each row of the memory matrix) C
        self.bias = None
        self.shrink_thres = shrink_thres

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        att_weight = F.linear(
            input, self.weight
        )  # Features x Memory^T, the dimensions are : (BatchSize N  x FeatureDimension C) x (FeatureDimension C x Rows in the memory matrix M) = output matrix is of size N x M
        att_weight = F.softmax(
            att_weight, dim=1
        )  # Softmax on matix of size NxM along Mth Dimension, i.e. softmax along each row which results as row summing up to 1 for all the N rows of Matrix; Maxtix of size N x M is returned
        # ReLU based shrinkage, hard shrinkage for positive value
        if (self.shrink_thres > 0):
            att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
            att_weight = F.normalize(att_weight, p=1, dim=1)
        mem_trans = self.weight.permute(
            1, 0
        )  # Mem^T, matrix of size C x M is the result as  weights are of dimension M x C
        output = F.linear(
            att_weight, mem_trans
        )  # AttWeight x Mem^T^T = AW x Mem, attention weights (NxM) x (MxC) = NxC
        return {
            'output': output,
            'att': att_weight
        }  # output (N x C), att_weight (N X M)

    def extra_repr(self):
        return 'mem_dim={}, fea_dim={}'.format(self.mem_dim, self.fea_dim
                                               is not None)
コード例 #6
0
class MemoryUnit(nn.Module):
    def __init__(self, mem_dim, fea_dim, shrink_thres=0.0025):
        super(MemoryUnit, self).__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.weight = Parameter(torch.Tensor(self.mem_dim,
                                             self.fea_dim))  # M x C
        self.bias = None
        self.shrink_thres = shrink_thres
        # self.hard_sparse_shrink_opt = nn.Hardshrink(lambd=shrink_thres)

        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input):
        att_weight = F.linear(input,
                              self.weight)  # Fea x Mem^T, (TxC) x (CxM) = TxM
        att_weight = F.softmax(att_weight, dim=1)  # TxM
        # ReLU based shrinkage, hard shrinkage for positive value
        # print("att_weight",torch.sum(att_weight[0]))
        if (self.shrink_thres > 0):
            att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
            # att_weight = F.softshrink(att_weight, lambd=self.shrink_thres)
            # normalize???
            att_weight = F.normalize(att_weight, p=1, dim=1)
            print("att_weight", torch.sum(att_weight[0]))
            # att_weight = F.softmax(att_weight, dim=1)
            # att_weight = self.hard_sparse_shrink_opt(att_weight)
        mem_trans = self.weight.permute(1, 0)  # Mem^T, MxC
        output = F.linear(
            att_weight,
            mem_trans)  # AttWeight x Mem^T^T = AW x Mem, (TxM) x (MxC) = TxC
        return {'output': output, 'att': att_weight}  # output, att_weight

    def extra_repr(self):
        return 'mem_dim={}, fea_dim={}'.format(self.mem_dim, self.fea_dim
                                               is not None)
コード例 #7
0
class MemoryModule(nn.Module):
    ''' Memory Module '''
    def __init__(self, mem_dim, fea_dim, shrink_thres):
        super().__init__()
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        # attention
        self.weight = Parameter(torch.Tensor(self.mem_dim, self.fea_dim))   # [M, C]
        self.bias = None
        self.shrink_thres = shrink_thres
        self.reset_parameters()

    def reset_parameters(self):
        ''' init memory elements : Very Important !! '''
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, x):
        ''' x [B,C,H,W] : latent code Z'''
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).flatten(end_dim=2) # Fea : [NxC]  N=BxHxW
        # calculate attention weight
        att_weight = F.linear(x, self.weight)   # Fea*Mem^T : [NxC] x [CxM] = [N, M]
        att_weight = F.softmax(att_weight, dim=1)   # [N, M]

        if self.shrink_thres > 0:
            # hard shrink
            att_weight = hard_shrink_relu(att_weight, lambd=self.shrink_thres)
            # re-normalize
            att_weight = F.normalize(att_weight, p=1, dim=1)    # [N, M]
        
        # generate code z'
        mem_T = self.weight.permute(1, 0)
        output = F.linear(att_weight, mem_T) # Fea*Mem^T^T : [N, M] x [M, C] = [N, C]
        output = output.view(B,H,W,C).permute(0,3,1,2)  # [N,C,H,W]

        return att_weight, output
コード例 #8
0
class FilterStripe(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(FilterStripe, self).__init__(in_channels,
                                           out_channels,
                                           kernel_size,
                                           stride,
                                           kernel_size // 2,
                                           groups=1,
                                           bias=False)
        self.BrokenTarget = None
        self.FilterSkeleton = Parameter(torch.ones(self.out_channels,
                                                   self.kernel_size[0],
                                                   self.kernel_size[1]),
                                        requires_grad=True)

    def forward(self, x):
        if self.BrokenTarget is not None:
            out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0],
                              int(np.ceil(x.shape[2] / self.stride[0])),
                              int(np.ceil(x.shape[3] / self.stride[1])))
            if x.is_cuda:
                out = out.cuda()
            x = F.conv2d(x, self.weight)
            l, h = 0, 0
            for i in range(self.BrokenTarget.shape[0]):
                for j in range(self.BrokenTarget.shape[1]):
                    h += self.FilterSkeleton[:, i, j].sum().item()
                    out[:, self.FilterSkeleton[:, i, j]] += self.shift(
                        x[:, l:h], i,
                        j)[:, :, ::self.stride[0], ::self.stride[1]]
                    l += self.FilterSkeleton[:, i, j].sum().item()
            return out
        else:
            return F.conv2d(x,
                            self.weight * self.FilterSkeleton.unsqueeze(1),
                            stride=self.stride,
                            padding=self.padding,
                            groups=self.groups)

    def prune_in(self, in_mask=None):
        self.weight = Parameter(self.weight[:, in_mask])
        self.in_channels = in_mask.sum().item()

    def prune_out(self, threshold):
        out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0
        if out_mask.sum() == 0:
            out_mask[0] = True
        self.weight = Parameter(self.weight[out_mask])
        self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask],
                                        requires_grad=True)
        self.out_channels = out_mask.sum().item()
        return out_mask

    def _break(self, threshold):
        self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))
        self.FilterSkeleton = Parameter(
            (self.FilterSkeleton.abs() > threshold), requires_grad=False)
        if self.FilterSkeleton.sum() == 0:
            self.FilterSkeleton.data[0][0][0] = True
        self.out_channels = self.FilterSkeleton.sum().item()
        self.BrokenTarget = self.FilterSkeleton.sum(dim=0)
        self.kernel_size = (1, 1)
        self.weight = Parameter(
            self.weight.permute(2, 3, 0,
                                1).reshape(-1, self.in_channels, 1,
                                           1)[self.FilterSkeleton.permute(
                                               1, 2, 0).reshape(-1)])

    def update_skeleton(self, sr, threshold):
        self.FilterSkeleton.grad.data.add_(
            sr * torch.sign(self.FilterSkeleton.data))
        mask = self.FilterSkeleton.data.abs() > threshold
        self.FilterSkeleton.data.mul_(mask)
        self.FilterSkeleton.grad.data.mul_(mask)
        out_mask = mask.sum(dim=(1, 2)) != 0
        return out_mask

    def shift(self, x, i, j):
        return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j,
                         j - self.BrokenTarget.shape[0] // 2,
                         self.BrokenTarget.shape[0] // 2 - i,
                         i - self.BrokenTarget.shape[1] // 2), 'constant', 0)

    def extra_repr(self):
        s = (
            '{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}'
            ', stride={stride}')
        return s.format(**self.__dict__)
コード例 #9
0
ファイル: opt.py プロジェクト: cxiang26/mccp
class MCCP(nn.Module):
    def __init__(self,
                 in_features,
                 num_C,
                 num_D,
                 reciptive=512,
                 strides=256,
                 eps=0.0001):
        super(MCCP, self).__init__()
        self.in_features = in_features
        print('the dimension of input vector is:', in_features)
        print('mccp reciptive is:', reciptive)
        print('mccp strides is:', strides)
        self.reciptive = reciptive
        self.strides = strides
        self.num_C = num_C
        self.num_D = num_D
        self.eps = eps
        self.eye = Parameter(torch.eye(num_D), requires_grad=False)
        self.weight = Parameter(torch.Tensor(num_C, num_D, self.reciptive))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(2))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x):
        weight_caps = torch.matmul(self.weight, self.weight.permute(0, 2, 1))
        sigma = torch.inverse(weight_caps + self.eps * self.eye)
        # implemented as a convolutional procedure
        # vec2mat = nn.functional.unfold(torch.cat((x, x[:, :self.reciptive]), dim=-1).unsqueeze(1).unsqueeze(1), (1, self.reciptive), stride=self.strides)
        # matrix = vec2mat.permute(0, 2, 1)
        # b, n, d = matrix.size()
        # out1 = torch.matmul(matrix, torch.t(self.weight.view(-1, d)))
        # out1 = out1.view(b, n, self.num_C, 1, self.num_D)
        # out1 = torch.matmul(out1, sigma)  # (b, n, num_C, 1, num_D)
        # out2 = torch.matmul(self.weight.unsqueeze(dim=0).unsqueeze(dim=0), matrix.unsqueeze(dim=2).unsqueeze(dim=-1))
        # out2 = torch.matmul(out1, out2) # (b, n, num_C, 1, 1)
        # return out2.sum(dim=1).sqrt().squeeze()

        # implemented column by column
        results = []
        for i in range(0, x.shape[1], self.strides):
            # vec2mat
            if i + self.reciptive > x.shape[1]:
                inputs = torch.cat(
                    (x[:, i:], x[:, :(self.reciptive - x.shape[1] + i)]),
                    dim=-1)
            else:
                inputs = x[:, i:i + self.reciptive]
            # project
            out = torch.matmul(inputs,
                               torch.t(self.weight.view(-1, inputs.size(-1))))
            out = out.view(out.shape[0], self.num_C, 1, self.num_D)
            out = torch.matmul(out, sigma)
            out = torch.matmul(
                out, self.weight.view(self.num_C, self.num_D, self.reciptive))
            out = torch.squeeze(out, dim=2)
            out = torch.matmul(out, torch.unsqueeze(inputs, dim=2))
            results.append(torch.sqrt(out))
        results = torch.cat(results, dim=-1)
        return torch.mean(results, dim=-1)
コード例 #10
0
ファイル: tbsm_pytorch.py プロジェクト: ravikucb/tbsm
class TSL_Net(nn.Module):
    def __init__(self,
                 arch_interaction_op='dot',
                 arch_attention_mechanism='mlp',
                 ln=None,
                 model_type="tsl",
                 tsl_inner="def",
                 mha_num_heads=8,
                 ln_top=""):
        super(TSL_Net, self).__init__()

        # save arguments
        self.arch_interaction_op = arch_interaction_op
        self.arch_attention_mechanism = arch_attention_mechanism
        self.model_type = model_type
        self.tsl_inner = tsl_inner

        # setup for mechanism type
        if self.arch_attention_mechanism == 'mlp':
            self.mlp = dlrm.DLRM_Net().create_mlp(ln, len(ln) - 2)

        # setup extra parameters for some of the models
        if self.model_type == "tsl" and self.tsl_inner in ["def", "ind"]:
            m = ln_top[-1]  # dim of dlrm output
            mean = 0.0
            std_dev = np.sqrt(2 / (m + m))
            W = np.random.normal(mean, std_dev,
                                 size=(1, m, m)).astype(np.float32)
            self.A = Parameter(torch.tensor(W), requires_grad=True)
        elif self.model_type == "mha":
            m = ln_top[-1]  # dlrm output dim
            self.nheads = mha_num_heads
            self.emb_m = self.nheads * m  # mha emb dim
            mean = 0.0
            std_dev = np.sqrt(2 / (m + m))  # np.sqrt(1 / m) # np.sqrt(1 / n)
            qm = np.random.normal(mean, std_dev, size=(1, m, self.emb_m)) \
                .astype(np.float32)
            self.Q = Parameter(torch.tensor(qm), requires_grad=True)
            km = np.random.normal(mean, std_dev, size=(1, m, self.emb_m))  \
                .astype(np.float32)
            self.K = Parameter(torch.tensor(km), requires_grad=True)
            vm = np.random.normal(mean, std_dev, size=(1, m, self.emb_m)) \
                .astype(np.float32)
            self.V = Parameter(torch.tensor(vm), requires_grad=True)

    def forward(self, x=None, H=None):
        # adjust input shape
        (batchSize, vector_dim) = x.shape
        x = torch.reshape(x, (batchSize, 1, -1))
        x = torch.transpose(x, 1, 2)
        # debug prints
        # print("shapes: ", self.A.shape, x.shape)

        # perform mode operation
        if self.model_type == "tsl":
            if self.tsl_inner == "def":
                ax = torch.matmul(self.A, x)
                x = torch.matmul(self.A.permute(0, 2, 1), ax)
                # debug prints
                # print("shapes: ", H.shape, ax.shape, x.shape)
            elif self.tsl_inner == "ind":
                x = torch.matmul(self.A, x)

            # perform interaction operation
            if self.arch_interaction_op == 'dot':
                if self.arch_attention_mechanism == 'mul':
                    # coefficients
                    a = torch.transpose(torch.bmm(H, x), 1, 2)
                    # context
                    c = torch.bmm(a, H)
                elif self.arch_attention_mechanism == 'mlp':
                    # coefficients
                    a = torch.transpose(torch.bmm(H, x), 1, 2)
                    # MLP first/last layer dims are automatically adjusted to ts_length
                    y = dlrm.DLRM_Net().apply_mlp(a, self.mlp)
                    # context, y = mlp(a)
                    c = torch.bmm(torch.reshape(y, (batchSize, 1, -1)), H)
                else:
                    sys.exit('ERROR: --arch-attention-mechanism=' +
                             self.arch_attention_mechanism +
                             ' is not supported')

            else:
                sys.exit('ERROR: --arch-interaction-op=' +
                         self.arch_interaction_op + ' is not supported')

        elif self.model_type == "mha":
            x = torch.transpose(x, 1, 2)
            Qx = torch.transpose(torch.matmul(x, self.Q), 0, 1)
            HK = torch.transpose(torch.matmul(H, self.K), 0, 1)
            HV = torch.transpose(torch.matmul(H, self.V), 0, 1)
            # multi-head attention (mha)
            multihead_attn = nn.MultiheadAttention(self.emb_m,
                                                   self.nheads).to(x.device)
            attn_output, _ = multihead_attn(Qx, HK, HV)
            # context
            c = torch.squeeze(attn_output, dim=0)
            # debug prints
            # print("shapes:", c.shape, Qx.shape)

        return c
コード例 #11
0
class FilterStripe(nn.Conv2d):#卷积层+FS层
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(FilterStripe, self).__init__(in_channels, out_channels, kernel_size, stride, kernel_size // 2, groups=1, bias=False)
        self.BrokenTarget = None
        self.FilterSkeleton = Parameter(torch.ones(self.out_channels, self.kernel_size[0], self.kernel_size[1]), requires_grad=True)#FS层初始化

    def forward(self, x):#forward()是自动调用的,x:[N,通道数,width,height]
        if self.BrokenTarget is not None:
            #out:[N,通道数,width,height]
            out = torch.zeros(x.shape[0], self.FilterSkeleton.shape[0], int(np.ceil(x.shape[2] / self.stride[0])), int(np.ceil(x.shape[3] / self.stride[1])))#ceil() 函数返回数字的上入整数
            if x.is_cuda:
                out = out.cuda()
            x = F.conv2d(x, self.weight)#卷积输出
            l, h = 0, 0
            for i in range(self.BrokenTarget.shape[0]):
                for j in range(self.BrokenTarget.shape[1]):
                    h += self.FilterSkeleton[:, i, j].sum().item()#FS层每个通道对应的值相加
                    out[:, self.FilterSkeleton[:, i, j]] += self.shift(x[:, l:h], i, j)[:, :, ::self.stride[0], ::self.stride[1]]#获得每个通道对应索引的输出
                    l += self.FilterSkeleton[:, i, j].sum().item()
            return out#输出
        else:
            #unsqueeze(1)在第二个维度增加一个维度
            return F.conv2d(x, self.weight * self.FilterSkeleton.unsqueeze(1), stride=self.stride, padding=self.padding, groups=self.groups)

    def prune_in(self, in_mask=None):#in_mask掩膜
        #self.weight.shape:[out_channel,k,k,in_channel]
        print(self.weight.shape)
        self.weight = Parameter(self.weight[:, in_mask])#??????????
        print(self.weight)
        self.in_channels = in_mask.sum().item()

    def prune_out(self, threshold):#threshold为阈值
        out_mask = (self.FilterSkeleton.abs() > threshold).sum(dim=(1, 2)) != 0#获得掩膜
        if out_mask.sum() == 0:
            print(out_mask.sum())
            out_mask[0] = True
        self.weight = Parameter(self.weight[out_mask])#卷积核掩膜化
        self.FilterSkeleton = Parameter(self.FilterSkeleton[out_mask], requires_grad=True)#FS层掩膜化
        self.out_channels = out_mask.sum().item()#获取输出通道
        return out_mask#掩膜

    def _break(self, threshold):
        self.weight = Parameter(self.weight * self.FilterSkeleton.unsqueeze(1))#卷积核与FS层相乘
        self.FilterSkeleton = Parameter((self.FilterSkeleton.abs() > threshold), requires_grad=False)#FS层大于阈值的为true
        if self.FilterSkeleton.sum() == 0:
            self.FilterSkeleton.data[0][0][0] = True
        self.out_channels = self.FilterSkeleton.sum().item()
        self.BrokenTarget = self.FilterSkeleton.sum(dim=0)
        self.kernel_size = (1, 1)
        #permute()将tensor的维度换位。
        # print(self.FilterSkeleton.permute(1, 2, 0).reshape(-1))
        # print(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1))
        self.weight = Parameter(self.weight.permute(2, 3, 0, 1).reshape(-1, self.in_channels, 1, 1)[self.FilterSkeleton.permute(1, 2, 0).reshape(-1)])#掩膜化
        # print(self.weight)

    def update_skeleton(self, sr, threshold):
        self.FilterSkeleton.grad.data.add_(sr * torch.sign(self.FilterSkeleton.data))#FS层的梯度更新,加入L1范数的导数
        mask = self.FilterSkeleton.data.abs() > threshold
        self.FilterSkeleton.data.mul_(mask)#掩码化
        self.FilterSkeleton.grad.data.mul_(mask)#掩码化
        out_mask = mask.sum(dim=(1, 2)) != 0#????
        return out_mask

    def shift(self, x, i, j):
        return F.pad(x, (self.BrokenTarget.shape[0] // 2 - j, j - self.BrokenTarget.shape[0] // 2, self.BrokenTarget.shape[0] // 2 - i, i - self.BrokenTarget.shape[1] // 2), 'constant', 0)

    def extra_repr(self):
        s = ('{BrokenTarget},{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        return s.format(**self.__dict__)
コード例 #12
0
class RTN_CORE(nn.Module):
    def __init__(self, M, N, G, alpha, filtersize = 5, padding=0, init_kernels = None):
        super(RTN_CORE, self).__init__()
        
        # M: number of input channeles
        # N: number of output channels
        # G: number of translations to use. If odd, the original filter will also be used. If even, not.
        # alpha: There will be G filters rotated from -alpha to alpha.
        self.M=M
        self.N=N 
        self.G=G
        self.a=alpha
        self.k = filtersize
        
        #self.conv1 = nn.Conv2d(self.M, self.M*self.N*self.G, filtersize, groups=self.M, padding=padding) # in, out, kernel size, groups as in paper
        
        # Lets follow the vocabluary of the paper and call:
        # w: the unrotated filters / "templates" (the ones we want to learn)
        # gw: the rotated filters / "transformed templates" (the G rotated versions of those filters)
        # There will be a total of:
        #   M*N templates
        #   M*N*G transformed templates
 
        # Create the "templates" as torch.nn.Parameter as described here:
        # http://pytorch.org/tutorials/advanced/numpy_extensions_tutorial.html
        # http://pytorch.org/docs/0.3.1/nn.html#parameters

        # Rondomly initialize weights: # TODO: Maybe there is a better way to initialize
        # maybe: #w = Parameter(torch.randn(M*N, k, k))  # size(w) = ( M*N x k x k )

        if init_kernels is None:
            self.w = Parameter(torch.randn(M*N, 1, self.k, self.k))  # size(w) = ( M*N x 1 x k x k )
        else:
            self.w = init_kernels
            
        # From every "template" derive G "transformed templates" by applying rotation:
        # Variable(torch.randn(8,4,3,3))
        # Step 1 – Create the G transformation matrices we whish to rotate each "template" by:
        #g = get_rotated_kernels(w, self.G, -alpha, alpha) #Variable(torch.randn(G,2,3)) # size(g) = ( G x 2 x 3 ) # TODO

        #g = Variable(torch.randn(G,2,3)) # size(g) = ( G x 2 x 3 ) # TODO
        g1 = make_rotations(-alpha, alpha, G) # size(g) = ( G x 2 x 3 )
        g2 = [torch.Tensor(f) for f in g1]
        g3 = torch.stack(g2)
        g = Variable( torch.Tensor(  g3  ))
        
        # Step2 – Apply the transformations:
        # Step 2.1 – Create the flow fields, describing the transformations.
        s = torch.Size((G, M*N, self.k, self.k)) # Desired output size
        flow_field  = torch.nn.functional.affine_grid(g, s) # size(flow_field) = (G, k, k, 2), one translation vector (or maybe coordinate; we don't know nor care) per each of the G rotation matrices.
        self.register_buffer("flow_field", flow_field)
        
        # Those two layers are the same as in the vanilla NPTN by ??? et.al. 20??
        self.maxpool3d = nn.MaxPool3d((self.G, 1, 1)) 
        self.meanpool3d = nn.AvgPool3d((self.M, 1, 1)) # Is that the right pooling? - AvgPool3d?
        
        self.permutation = make_permutation(self.M, self.N) # TODO: Should this also go on the GPU as a buffer?

    def forward(self, x):
        # Start from w ervery time and create the others as rotation of it
        
        #print('\nShape of x ', x.size())
       
        # Step 2.2 – Apply the flow_fields. Each flow_field will be applied to each channle of the input / each "template".
        # Each flow_field is of (G x M*N, k, k)
        #a grid. For each rotation, there will be one flow field.
        # Repeat w along the first dimension:
        
        # Permute:  (M*N, 1, k, k) to (1, M*N, k, k)
        w_rep = self.w.permute(1,0,2,3)
        # ( 1 x M*N, k, k) to ( G x M*N, k, k)
        w_rep2 = w_rep.expand(self.G, -1, -1, -1) # Repeat along the singular first dimension G times.
        #                                    # size(w_rep2) = ( G x M*N, k, k)
        # 
        w_rot = torch.nn.functional.grid_sample(w_rep2, self.flow_field) # size(w_rep2) = ( G x M*N x k x k )
        
        # Go from ( G x M*N x k x k ) to ( M*N x G x k x k ),
        #     i.e.( angle, template, x, y) to ( template, angle, x, y)
        w_perm = w_rot.permute(1,0,2,3)
        
        # But actually we need this unrolled:
        MNG = self.M * self.N * self.G
        
        w_unrolled = w_perm.resize(MNG, 1, self.k, self.k) # size(w_unrolled) = ( M*N*G x (M*N*G)/(M*N*G) x k x k ) #weight – filters of shape (out_channels×in_channelsgroups×kH×kW)
        # plot_kernels(w_unrolled, num_cols=G, title='RTN')
        
        # Convolution:
        # Use the functional (torch.nn.functional.conv2d) instead of Module
        # (torch.nn.conv2d), becuse we don't want trainable parameters exept the
        # template vector defined above:
        x = F.conv2d(x, w_unrolled, groups=self.M)  # TODO: Add padding?        
                
        #print('Shape after convolution', x.size())
        x = self.maxpool3d(x)
        #print("Shape after MaxPool3d: ", x.size()) # dimension should be M*N
        
        #print('permutation ', permutation)
        x = x[:, self.permutation] # reorder channels
        #print("Shape after Channel reordering: ", x.size())
        x = self.meanpool3d(x)
        #print('Shape after Mean Pooling: ', x.size())
        return x
コード例 #13
0
class Gumbel_Generator_Old(nn.Module):
    def __init__(self, sz=10, temp=10, temp_drop_frac=0.9999):
        super(Gumbel_Generator_Old,
              self).__init__()  # 将类Gumbe_Generator_Old 对象转换为类 nn.Module 的对象
        self.sz = sz
        # 将一个不可训练的类型Tensor 转换成可以训练的类型Parameter 经过这个类型转换这个self,***就成了模型的中一部分
        # 成为了模型中根据训练可以改动的参数了
        self.gen_matrix = Parameter(torch.rand(
            sz, sz, 2))  # torch.rand()返回一个张量,包含了从区间(0,1)随机抽取的一组随机数
        self.new_matrix = Parameter(torch.zeros(5, 5, 2))
        # gen_matrix 为邻接矩阵的概率
        self.temperature = temp
        self.temp_drop_frac = temp_drop_frac

    def symmetry(self):
        matrix = self.gen_matrix.permute(2, 0, 1)
        temp_matrix = torch.triu(matrix, 1) + torch.triu(matrix, 1).permute(
            0, 2, 1)
        self.gen_matrix.data = temp_matrix.permute(1, 2, 0)

    def drop_temp(self):
        # 降温过程
        self.temperature = self.temperature * self.temp_drop_frac

    def sample_all(self, hard=False, epoch=1):
        # 采样——得到一个邻接矩阵
        # self.symmetry()
        self.logp = self.gen_matrix.view(
            -1, 2)  # view先变成一维的tensor,再按照参数转换成对应维度的tensor
        out = gumbel_softmax(self.logp, self.temperature, hard)
        if hard:
            hh = torch.zeros(self.gen_matrix.size()[0]**2, 2)
            for i in range(out.size()[0]):
                hh[i, out[i]] = 1
            out = hh
        if use_cuda:
            out = out.cuda()
        out_matrix = out[:, 0].view(self.gen_matrix.size()[0],
                                    self.gen_matrix.size()[0])
        return out_matrix

    def sample_small(self, list, hard=False):
        indices = np.ix_(list, list)
        self.logp = self.gen_matrix[indices].view(
            -1, 2)  # view先变成一维的tensor,再按照参数转换成对应维度的tensor

        out = gumbel_softmax(self.logp, self.temperature, hard)

        # hard 干什么用的
        if hard:
            hh = torch.zeros(self.gen_matrix[indices].size()[0]**2, 2)
            for i in range(out.size()[0]):
                hh[i, out[i]] = 1
            out = hh
        if use_cuda:
            out = out.cuda()
        out_matrix = out[:, 0].view(len(list), len(list))
        return out_matrix

    def sample_adj_ij(self, list, j, hard=False, sample_time=1):
        # self.logp = self.gen_matrix[:,i]
        self.logp = self.gen_matrix[list, j]

        out = gumbel_softmax(self.logp, self.temperature, hard=hard)
        if use_cuda:
            out = out.cuda()
        # print(out)
        if hard:
            out_matrix = out.float()
        else:
            out_matrix = out[:, 0]
        return out_matrix

    def sample_adj_i(self, i, hard=False, sample_time=1):
        # self.symmetry()
        self.logp = self.gen_matrix[:, i]
        out = gumbel_softmax(self.logp, self.temperature, hard=hard)
        if use_cuda:
            out = out.cuda()
        # print(out)
        if hard:
            out_matrix = out.float()
        else:
            out_matrix = out[:, 0]
        return out_matrix

    def get_temperature(self):
        return self.temperature

    def get_cross_entropy(self, obj_matrix):
        # 计算与目标矩阵的距离
        logps = F.softmax(self.gen_matrix, 2)
        logps = torch.log(logps[:, :, 0] + 1e-10) * obj_matrix + torch.log(
            logps[:, :, 1] + 1e-10) * (1 - obj_matrix)
        result = -torch.sum(logps)
        result = result.cpu() if use_cuda else result
        return result.data.numpy()

    def get_entropy(self):
        logps = F.softmax(self.gen_matrix, 2)
        result = torch.mean(torch.sum(logps * torch.log(logps + 1e-10), 1))
        result = result.cpu() if use_cuda else result
        return (-result.data.numpy())

    def randomization(self, fraction):
        # 将gen_matrix重新随机初始化,fraction为重置比特的比例
        sz = self.gen_matrix.size()[0]
        numbers = int(fraction * sz * sz)
        original = self.gen_matrix.cpu().data.numpy()

        for i in range(numbers):
            ii = np.random.choice(range(sz), (2, 1))
            z = torch.rand(2).cuda() if use_cuda else torch.rand(2)
            self.gen_matrix.data[ii[0], ii[1], :] = z

    def init(self, mean, var):
        init.normal_(self.gen_matrix, mean=mean, std=var)