def __init__(self, in_channel, out_channel_list, branch_list=[0, 1, 2, 3], bias=True, drop=None, bn=False, agg_first=True, attn=False, init='kaiming_uniform' ): # init = 'default', 'kaiming_uniform', 'xavier_uniform' super(PGception_Layer, self).__init__() # prepare adjacent matrixs keypoint_num = 17 adj0 = torch.eye(keypoint_num) adj1_inds = { 0: [0, 1, 2, 5, 6], 1: [0, 1, 3], 2: [0, 2, 4], 3: [1, 3], 4: [2, 4], 5: [0, 5, 6, 7, 11], 6: [0, 5, 6, 8, 12], 7: [5, 7, 9], 8: [6, 8, 10], 9: [7, 9], 10: [8, 10], 11: [5, 11, 12, 13], 12: [6, 11, 12, 14], 13: [11, 13, 15], 14: [12, 14, 16], 15: [13, 15], 16: [14, 16] } adj1 = adj_construction(inds=adj1_inds, keypoint_num=keypoint_num, symmetric=True) adj2_inds = { 0: [0, 1, 2, 3, 4, 5, 6, 7, 8, 11, 12], 1: [0, 1, 2, 3, 5, 6], 2: [0, 1, 2, 4, 5, 6], 3: [0, 1, 3], 4: [0, 2, 4], 5: [0, 1, 2, 5, 6, 7, 8, 9, 11, 12, 13], 6: [0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 14], 7: [0, 5, 6, 7, 9, 11], 8: [0, 5, 6, 8, 10, 12], 9: [5, 7, 9], 10: [6, 8, 10], 11: [0, 5, 6, 7, 11, 12, 13, 14, 15], 12: [0, 5, 6, 8, 11, 12, 13, 14, 16], 13: [5, 11, 12, 13, 15], 14: [6, 11, 12, 14, 16], 15: [11, 13, 15], 16: [12, 14, 16] } adj2 = adj_construction(inds=adj2_inds, keypoint_num=keypoint_num, symmetric=True) adj_all = adj_construction(inds=None) adj_part_inds = { 0: [0, 1, 2, 3, 4], 1: [1], 2: [2], 3: [3], 4: [4], 5: [5, 7, 9], 6: [6, 8, 10], 7: [7], 8: [8], 9: [9], 10: [10], 11: [11, 13, 15], 12: [12, 14, 16], 13: [13], 14: [14], 15: [15], 16: [16] } adj_part = adj_construction(inds=adj_part_inds, keypoint_num=keypoint_num, symmetric=True) A = [adj0, adj1, adj2, adj_all, adj_part] # import ipdb; ipdb.set_trace() self.branch_list = branch_list if 0 in self.branch_list: self.branch_0 = GCN(A[0], in_channel, out_channel_list[0], bias=bias, drop=drop, bn=bn, init=init, agg_first=agg_first) if 1 in self.branch_list: self.branch_1 = GCN(A[1], in_channel, out_channel_list[1], bias=bias, drop=drop, bn=bn, init=init, agg_first=agg_first, attn=attn) if 2 in self.branch_list: self.branch_2 = GCN(A[2], in_channel, out_channel_list[2], bias=bias, drop=drop, bn=bn, init=init, agg_first=agg_first, attn=attn) if 3 in self.branch_list: self.branch_all = GCN(A[3], in_channel, out_channel_list[3], bias=bias, drop=drop, bn=bn, init=init, agg_first=agg_first, attn=attn) if 4 in self.branch_list: self.branch_part = GCN(A[4], in_channel, out_channel_list[4], bias=bias, drop=drop, bn=bn, init=init, agg_first=agg_first, attn=attn, part=True)