コード例 #1
0
ファイル: networks.py プロジェクト: jleesdev/meshcnn
    def __init__(self,
                 norm_layer,
                 nf0,
                 conv_res,
                 nclasses,
                 input_res,
                 pool_res,
                 fc_n,
                 nresblocks=3):
        super(MeshAutoEncoder, self).__init__()
        self.k = [nf0] + conv_res
        self.res = [input_res] + pool_res
        self.fc_n = fc_n
        norm_args = get_norm_args(norm_layer, self.k[1:])

        for i, ki in enumerate(self.k[:-1]):
            setattr(self, 'conv{}'.format(i),
                    MResConv(ki, self.k[i + 1], nresblocks))
            setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
            setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))

        # self.gp = torch.nn.AvgPool1d(self.res[-1])
        self.gp = torch.nn.MaxPool1d(self.res[-1])
        self.decoder_fc1 = nn.Linear(self.k[-1], fc_n[0])
        self.decoder_fc2 = nn.Linear(fc_n[0], fc_n[1])
        self.decoder_fc3 = nn.Linear(fc_n[1], 1402 * 3)
        '''
コード例 #2
0
ファイル: networks.py プロジェクト: kimmctim/BrainSurfaceTK
    def __init__(self,
                 norm_layer,
                 nf0,
                 conv_res,
                 nclasses,
                 input_res,
                 pool_res,
                 fc_n,
                 opt,
                 nresblocks=3,
                 num_features=0):
        super(MeshConvNet, self).__init__()
        self.k = [nf0] + conv_res
        self.res = [input_res] + pool_res
        self.opt = opt
        norm_args = get_norm_args(norm_layer, self.k[1:])

        for i, ki in enumerate(self.k[:-1]):
            setattr(self, 'conv{}'.format(i),
                    MResConv(ki, self.k[i + 1], nresblocks))
            setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
            setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))

        self.gp = torch.nn.AvgPool1d(self.res[-1])
        # self.gp = torch.nn.MaxPool1d(self.res[-1])
        if self.opt.dropout:
            self.d = nn.Dropout()
        self.fc1 = nn.Linear(self.k[-1] + num_features, fc_n)
        if self.opt.dataset_mode == 'binary_class':
            self.fc2 = nn.Linear(fc_n, 1)
        else:
            self.fc2 = nn.Linear(fc_n, nclasses)
コード例 #3
0
    def __init__(self,
                 norm_layer,
                 nf0,
                 conv_res,
                 nclasses,
                 input_res,
                 pool_res,
                 fc_n,
                 nresblocks=3,
                 stn=None):
        super(MeshConvNet, self).__init__()
        self.k = [nf0] + conv_res
        self.res = [input_res] + pool_res
        norm_args = get_norm_args(norm_layer, self.k[1:])
        self.trans_inp = None

        for i, ki in enumerate(self.k[:-1]):
            setattr(self, 'conv{}'.format(i),
                    MResConv(ki, self.k[i + 1], nresblocks))
            setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
            setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))

        self.gp = torch.nn.AvgPool1d(self.res[-1])
        # self.gp = torch.nn.MaxPool1d(self.res[-1])
        self.fc1 = nn.Linear(self.k[-1], fc_n)
        self.fc2 = nn.Linear(fc_n, nclasses)
        self.stn = stn
コード例 #4
0
    def __init__(
            self,
            norm_layer,
            nf0,
            conv_res,
            nclasses,
            input_res,
            pool_res,
            fc_n,
            attn_n_heads,  # d_k, d_v,
            nresblocks=3,
            attn_max_dist=None,
            attn_dropout=0.1,
            prioritize_with_attention=False,
            attn_use_values_as_is=False,
            double_attention=False,
            attn_use_positional_encoding=False,
            attn_max_relative_position=8):
        super(MeshAttentionNet, self).__init__()
        if double_attention:
            assert attn_use_values_as_is, (
                "must have attn_use_values_as_is=True if double_attention=True, "
                "since the attention layer works on its own outputs")
        self.k = [nf0] + conv_res
        self.res = [input_res] + pool_res
        self.prioritize_with_attention = prioritize_with_attention
        self.double_attention = double_attention
        self.use_values_as_is = attn_use_values_as_is
        norm_args = get_norm_args(norm_layer, self.k[1:])

        for i, ki in enumerate(self.k[:-1]):
            setattr(self, 'conv{}'.format(i),
                    MResConv(ki, self.k[i + 1], nresblocks))
            setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
            setattr(
                self, 'attention{}'.format(i),
                MeshAttention(
                    n_head=attn_n_heads,
                    d_model=self.k[i + 1],
                    d_k=int(self.k[i + 1] / attn_n_heads),
                    d_v=int(self.k[i + 1] / attn_n_heads),
                    attn_max_dist=attn_max_dist,
                    dropout=attn_dropout,
                    use_values_as_is=attn_use_values_as_is,
                    use_positional_encoding=attn_use_positional_encoding,
                    max_relative_position=attn_max_relative_position))
            setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))

        self.gp = torch.nn.AvgPool1d(self.res[-1])
        # self.gp = torch.nn.MaxPool1d(self.res[-1])
        self.fc1 = nn.Linear(self.k[-1], fc_n)
        self.fc2 = nn.Linear(fc_n, nclasses)
コード例 #5
0
ファイル: networks.py プロジェクト: jleesdev/meshcnn
 def __init__(self, in_channels, out_channels, blocks=0, pool=0):
     super(DownConv, self).__init__()
     self.bn = []
     self.pool = None
     self.conv1 = MeshConv(in_channels, out_channels)
     self.conv2 = []
     for _ in range(blocks):
         self.conv2.append(MeshConv(out_channels, out_channels))
         self.conv2 = nn.ModuleList(self.conv2)
     for _ in range(blocks + 1):
         self.bn.append(nn.InstanceNorm2d(out_channels))
         self.bn = nn.ModuleList(self.bn)
     if pool:
         self.pool = MeshPool(pool)
コード例 #6
0
ファイル: networks.py プロジェクト: s183983/MeshCNN_sparse
 def __init__(self, in_channels, out_channels, blocks=0, pool=0, nl_block = 0):
     super(DownConv, self).__init__()
     self.bn = []
     self.pool = None
     self.nl_block = None
     self.conv1 = MeshConv(in_channels, out_channels)
     self.conv2 = []
     for _ in range(blocks):
         self.conv2.append(MeshConv(out_channels, out_channels))
         self.conv2 = nn.ModuleList(self.conv2)
     for _ in range(blocks + 1):
         self.bn.append(nn.InstanceNorm2d(out_channels))
         self.bn = nn.ModuleList(self.bn)
     if pool:
         self.pool = MeshPool(pool)
     if nl_block:
         # Add non-local block before last downpool
         self.nl_block = NLBlock(out_channels, block_type = nl_block)
コード例 #7
0
ファイル: networks.py プロジェクト: jleesdev/meshcnn
    def __init__(self,
                 norm_layer,
                 nf0,
                 conv_res,
                 nclasses,
                 input_res,
                 pool_res,
                 fc_n,
                 nresblocks=3,
                 dropout_p=0):
        super(MeshConvNet, self).__init__()
        self.p = dropout_p
        self.k = [nf0] + conv_res
        self.res = [input_res] + pool_res
        norm_args = get_norm_args(norm_layer, self.k[1:])

        for i, ki in enumerate(self.k[:-1]):
            setattr(self, 'conv{}'.format(i),
                    MResConv(ki, self.k[i + 1], nresblocks))
            setattr(self, 'norm{}'.format(i), norm_layer(**norm_args[i]))
            setattr(self, 'pool{}'.format(i), MeshPool(self.res[i + 1]))

        self.gp = torch.nn.AvgPool1d(self.res[-1])
        #self.gp = torch.nn.MaxPool1d(self.res[-1])

        self.fcs = []
        self.fcs.append(nn.Linear(self.k[-1], fc_n[0]))
        for i in range(len(fc_n) - 1):
            self.fcs.append(nn.Linear(fc_n[i], fc_n[i + 1]))
        self.fcs = nn.ModuleList(self.fcs)
        self.last_fc = nn.Linear(fc_n[-1], nclasses)

        self.bns = []
        for i in range(len(fc_n)):
            self.bns.append(nn.BatchNorm1d(fc_n[i]))
        self.bns = nn.ModuleList(self.bns)
        if self.p != 0:
            print('dropout layers are applied (p = %f)' % (self.p))