示例#1
0
    def __init__(self,
                 in_channel,
                 out_channel_list,
                 branch_list=[0, 1, 2, 3],
                 bias=True,
                 drop=None,
                 bn=False,
                 agg_first=True,
                 attn=False,
                 init='kaiming_uniform'
                 ):  # init = 'default', 'kaiming_uniform', 'xavier_uniform'
        super(PGception_Layer, self).__init__()
        # prepare adjacent matrixs
        keypoint_num = 17
        adj0 = torch.eye(keypoint_num)
        adj1_inds = {
            0: [0, 1, 2, 5, 6],
            1: [0, 1, 3],
            2: [0, 2, 4],
            3: [1, 3],
            4: [2, 4],
            5: [0, 5, 6, 7, 11],
            6: [0, 5, 6, 8, 12],
            7: [5, 7, 9],
            8: [6, 8, 10],
            9: [7, 9],
            10: [8, 10],
            11: [5, 11, 12, 13],
            12: [6, 11, 12, 14],
            13: [11, 13, 15],
            14: [12, 14, 16],
            15: [13, 15],
            16: [14, 16]
        }
        adj1 = adj_construction(inds=adj1_inds,
                                keypoint_num=keypoint_num,
                                symmetric=True)
        adj2_inds = {
            0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12],
            1: [0, 1, 2, 3, 5, 6],
            2: [0, 1, 2, 4, 5, 6],
            3: [0, 1, 3],
            4: [0, 2, 4],
            5: [0, 1, 2, 5, 6, 7, 8, 9, 11, 12, 13],
            6: [0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 14],
            7: [0, 5, 6, 7, 9, 11],
            8: [0, 5, 6, 8, 10, 12],
            9: [5, 7, 9],
            10: [6, 8, 10],
            11: [0, 5, 6, 7, 11, 12, 13, 14, 15],
            12: [0, 5, 6, 8, 11, 12, 13, 14, 16],
            13: [5, 11, 12, 13, 15],
            14: [6, 11, 12, 14, 16],
            15: [11, 13, 15],
            16: [12, 14, 16]
        }
        adj2 = adj_construction(inds=adj2_inds,
                                keypoint_num=keypoint_num,
                                symmetric=True)
        adj_all = adj_construction(inds=None)

        adj_part_inds = {
            0: [0, 1, 2, 3, 4],
            1: [1],
            2: [2],
            3: [3],
            4: [4],
            5: [5, 7, 9],
            6: [6, 8, 10],
            7: [7],
            8: [8],
            9: [9],
            10: [10],
            11: [11, 13, 15],
            12: [12, 14, 16],
            13: [13],
            14: [14],
            15: [15],
            16: [16]
        }
        adj_part = adj_construction(inds=adj_part_inds,
                                    keypoint_num=keypoint_num,
                                    symmetric=True)
        A = [adj0, adj1, adj2, adj_all, adj_part]

        # import ipdb; ipdb.set_trace()
        self.branch_list = branch_list
        if 0 in self.branch_list:
            self.branch_0 = GCN(A[0],
                                in_channel,
                                out_channel_list[0],
                                bias=bias,
                                drop=drop,
                                bn=bn,
                                init=init,
                                agg_first=agg_first)
        if 1 in self.branch_list:
            self.branch_1 = GCN(A[1],
                                in_channel,
                                out_channel_list[1],
                                bias=bias,
                                drop=drop,
                                bn=bn,
                                init=init,
                                agg_first=agg_first,
                                attn=attn)
        if 2 in self.branch_list:
            self.branch_2 = GCN(A[2],
                                in_channel,
                                out_channel_list[2],
                                bias=bias,
                                drop=drop,
                                bn=bn,
                                init=init,
                                agg_first=agg_first,
                                attn=attn)
        if 3 in self.branch_list:
            self.branch_all = GCN(A[3],
                                  in_channel,
                                  out_channel_list[3],
                                  bias=bias,
                                  drop=drop,
                                  bn=bn,
                                  init=init,
                                  agg_first=agg_first,
                                  attn=attn)
        if 4 in self.branch_list:
            self.branch_part = GCN(A[4],
                                   in_channel,
                                   out_channel_list[4],
                                   bias=bias,
                                   drop=drop,
                                   bn=bn,
                                   init=init,
                                   agg_first=agg_first,
                                   attn=attn,
                                   part=True)