Example #1
0
 def load_data(self):
     Feeder = import_class(self.arg.feeder)
     if 'debug' not in self.arg.train_feeder_args:
         self.arg.train_feeder_args['debug'] = self.arg.debug
     self.data_loader = dict()
     if self.arg.phase == 'train':
         self.data_loader['train'] = torch.utils.data.DataLoader(
             dataset=Feeder(**self.arg.train_feeder_args),
             batch_size=self.arg.batch_size,
             shuffle=True,
             num_workers=self.arg.num_worker * torchlight.ngpu(
                 self.arg.device),
             drop_last=True)
     if self.arg.test_feeder_args:
         self.data_loader['test'] = torch.utils.data.DataLoader(
             dataset=Feeder(**self.arg.test_feeder_args),
             batch_size=self.arg.test_batch_size,
             shuffle=False,
             num_workers=self.arg.num_worker * torchlight.ngpu(
                 self.arg.device))
Example #2
0
import argparse
import sys

# torchlight
# import torchlight.torchlight as torchlight
# from torchlight.torchlight import import_class
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['recognition'] = import_class('processor.recognition.REC_Processor')
    processors['densenet'] = import_class('processor.skeleton.SK_Processor')
    processors['demo'] = import_class('processor.demo.Demo')
    #endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])
Example #3
0
#!/usr/bin/env python
import argparse
import sys

# torchlight
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['preprocessing'] = import_class('processor.sl.video_preprocessor.Video_Preprocessor')
    #endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])

    p.start()
Example #4
0
#!/usr/bin/env python
import argparse
import sys

# torchlight
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['generation_gcn'] = import_class('processor.generation.GEN_gcn_base_Processor')
    processors['generation_attention'] = import_class('processor.generation.GEN_gcn_attention_Processor')

    # TODO: next big step --> add the args of different similarity


    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])
Example #5
0
 def load_model(self):
     yolo = import_class(self.arg.model)
     self.model = yolo(self.classes, pretrained=True, **vars(self.arg))
     self.model.create_architecture()
Example #6
0
#!/usr/bin/env python
import argparse
import sys

# torchlight
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['recognition'] = import_class('processor.recognition.REC_Processor')
    processors['volleyball_demo'] = import_class('processor.volleyball_demo.Volleyball_Demo')
    #endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])

    p.start()
Example #7
0
# torchlight
import torchlight #轻量版本的torch
from torchlight import import_class

#python main.py recognition -c config/st_gcn/kinetics-skeleton/test.yaml
if __name__ == '__main__':
    # os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    # 创建一个解析器,ArgumentParser实例
    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict() #定义了一个字典对象
    #字典对象中的每个元组,包含一个key 和 value,即:
    #key:value
    # 构建一个字典,用不同的键来表示不同的值
    processors['recognition'] = import_class('processor.recognition.REC_Processor') # import_class意思就是从指定路径中导入一个类
    processors['demo'] = import_class('processor.demo.Demo')
    #endregion yapf: enable

    # add sub-parser
    # arg0 = parser.parse_args() #解析命令行参数
    # 构建一个子命令解析器,_SubParsersAction,注意这行语句不会改变parser本身的值
    # 并且parser只能调用一次add_subparsers方法,也就说只能有一个subparser
    #subparser用于不同子命令的处理,是一个从parser对象创建的特殊动作对象,该对象只有一个add_parser方法

    #这里的dest和之前argumentparser的dest参数不同,这里是指子命令保存到的属性名
    #这里processor是个属性变量,其中保存的就是子命令的识别名称,也就是recognition和demo
    subparsers = parser.add_subparsers(dest='processor' )##如果没有dest='processsor',那么在arg中就不会有processor='recognition',这个操作实际上就是为
    print(subparsers)

    for k, p in processors.items(): #.items()方法返回可遍历的(键, 值) 元组数组,k是键,p是值
 def load_model(self):
     FasterRCNN = import_class(self.arg.model)
     self.model = FasterRCNN(self.classes, pretrained=True, class_agnostic=self.arg.mix_args['class_agnostic'],
                             **vars(self.arg))
     self.model.create_architecture()
Example #9
0
#!/usr/bin/env python
import argparse
import sys

# torchlight
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['processor_siamese_gcn'] = import_class('utils.processor_siamese_gcn.SGCN_Processor')
    processors['processor_siamese_gcn_triplet'] = import_class('utils.processor_siamese_gcn_triplet.SGCN_Processor')
    processors['processor_siamese_naive'] = import_class('utils.processor_siamese_naive.Naive_Processor')
    #endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])
Example #10
0
                srczip.write(
                    os.path.join(root, filename).replace(code_root, '.'))
    srczip.close()
    save_path = os.path.join(
        target_path,
        'src_%s.zip' % time.strftime("%Y-%m-%d_%H_%M_%S", time.localtime()))
    shutil.copy('./src.zip', save_path)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['linear_evaluation'] = import_class('processor.linear_evaluation.LE_Processor')
    processors['pretrain_crossclr_3views'] = import_class('processor.pretrain_crossclr_3views.CrosSCLR_3views_Processor')
    processors['pretrain_crossclr'] = import_class('processor.pretrain_crossclr.CrosSCLR_Processor')
    processors['pretrain_skeletonclr'] = import_class('processor.pretrain_skeletonclr.SkeletonCLR_Processor')
    # endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
Example #11
0
#!/usr/bin/env python
import argparse
import sys

# torchlight
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['recognition'] = import_class('processor.recognition.REC_Processor')
    processors['demo'] = import_class('processor.demo.Demo')
    #endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])

    p.start()
Example #12
0
#!/usr/bin/env python
import argparse
import sys

# torchlight
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['pxa573rec'] = import_class('processor.pxa573_rec.REC_Processor')
    processors['myrec'] = import_class('processor.myrec.REC_Processor')
    #endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])

    p.start()
Example #13
0
#!/usr/bin/env python
import argparse
import sys
# torchlight
import torchlight
from torchlight import import_class
import os
import torch

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')
    # region register processor yapf: disable
    processors = dict()
    processors['det_frcnn'] = import_class('processor.detection.frcnn.frcnn_Processor')
    processors['det_yolov2'] = import_class('processor.detection.yolov2.yolov2_Processor')
    processors['det_yolov2_v2'] = import_class('processor.detection.yolov2_v2.yolov2_Processor_v2')
    # processors['sg_imp'] = import_class('processor.sg.imp.imp_Porcessor')
    # processors['demo'] = import_class('processor.demo.Demo')
    # endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
Example #14
0
def test(data_path, label_path, vid=None, graph=None, is_3d=False):
    '''
    vis the samples using matplotlib
    :param data_path:
    :param label_path:
    :param vid: the id of sample
    :param graph:
    :param is_3d: when vis NTU, set it True
    :return:
    '''
    import matplotlib.pyplot as plt
    loader = torch.utils.data.DataLoader(dataset=Feeder(data_path, label_path),
                                         batch_size=64,
                                         shuffle=False,
                                         num_workers=2)
    print(loader.dataset.sample_name[1])
    if vid is not None:
        sample_name = loader.dataset.sample_name
        sample_id = [name.split('.')[0] for name in sample_name]
        print(sample_id[0])
        index = sample_id.index(vid)
        data, label, index = loader.dataset[index]
        data = data.reshape((1, ) + data.shape)

        # for batch_idx, (data, label) in enumerate(loader):
        N, C, T, V, M = data.shape

        plt.ion()
        fig = plt.figure()
        if is_3d:
            from mpl_toolkits.mplot3d import Axes3D
            ax = fig.add_subplot(111, projection='3d')
        else:
            ax = fig.add_subplot(111)

        if graph is None:
            p_type = [
                'b.', 'g.', 'r.', 'c.', 'm.', 'y.', 'k.', 'k.', 'k.', 'k.'
            ]
            pose = [
                ax.plot(np.zeros(V), np.zeros(V), p_type[m])[0]
                for m in range(M)
            ]
            ax.axis([-1, 1, -1, 1])
            for t in range(T):
                for m in range(M):
                    pose[m].set_xdata(data[0, 0, t, :, m])
                    pose[m].set_ydata(data[0, 1, t, :, m])
                fig.canvas.draw()
                plt.pause(0.001)
        else:
            p_type = [
                'b-', 'g-', 'r-', 'c-', 'm-', 'y-', 'k-', 'k-', 'k-', 'k-'
            ]
            import sys
            from os import path
            sys.path.append(
                path.dirname(path.dirname(path.dirname(
                    path.abspath(__file__)))))
            G = import_class(graph)()
            edge = G.inward
            pose = []
            for m in range(M):
                a = []
                for i in range(len(edge)):
                    if is_3d:
                        a.append(
                            ax.plot(np.zeros(3), np.zeros(3), p_type[m])[0])
                    else:
                        a.append(
                            ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0])
                pose.append(a)
            ax.axis([-1, 1, -1, 1])
            if is_3d:
                ax.set_zlim3d(-1, 1)
            for t in range(T):
                for m in range(M):
                    for i, (v1, v2) in enumerate(edge):
                        x1 = data[0, :2, t, v1, m]
                        x2 = data[0, :2, t, v2, m]
                        if (x1.sum() != 0
                                and x2.sum() != 0) or v1 == 1 or v2 == 1:
                            pose[m][i].set_xdata(data[0, 0, t, [v1, v2], m])
                            pose[m][i].set_ydata(data[0, 1, t, [v1, v2], m])
                            if is_3d:
                                pose[m][i].set_3d_properties(data[0, 2, t,
                                                                  [v1, v2], m])
                fig.canvas.draw()
                # plt.savefig('/home/lshi/Desktop/skeleton_sequence/' + str(t) + '.jpg')
                plt.pause(0.01)
Example #15
0
#!/usr/bin/env python
import argparse
import sys

# torchlight
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['recognition'] = import_class('processor.recognition.REC_Processor')
    processors['demo_old'] = import_class('processor.demo_old.Demo')
    processors['demo'] = import_class('processor.demo_realtime.DemoRealtime')
    processors['demo_offline'] = import_class('processor.demo_offline.DemoOffline')
    #endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])
#!/usr/bin/env python
import argparse
import sys

import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    processors = dict()
    processors['recognition'] = import_class(
        'processor.recognition.REC_Processor')
    processors['demo'] = import_class('processor.demo.Demo')

    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    arg = parser.parse_args()
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])

    p.start()
#!/usr/bin/env python
import argparse
import sys

# torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    # region register processor yapf: disable
    processors = dict()
    processors['processor_siamese_gcn'] = import_class('gcn_utils.processor_siamese_gcn.SGCN_Processor')
    # processors['processor_siamese_gcn_triplet'] = import_class('utils.processor_siamese_gcn_triplet.SGCN_Processor')
    # processors['processor_siamese_naive'] = import_class('utils.processor_siamese_naive.Naive_Processor')
    # endregion yapf: enable

    # add sub-parser
    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    # read arguments
    arg = parser.parse_args()

    # start
    Processor = processors[arg.processor]
    p = Processor(sys.argv[2:])

    p.start()
Example #18
0
import argparse
import sys
import torchlight
from torchlight import import_class

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Processor collection')

    processors = dict()
    processors['recognition'] = import_class(
        'processor.recognition.REC_Processor')
    # processors['demo'] = import_class('processor.demo.Demo')

    subparsers = parser.add_subparsers(dest='processor')
    for k, p in processors.items():
        subparsers.add_parser(k, parents=[p.get_parser()])

    arg = parser.parse_args()

    config_path = 'config/as_gcn/ntu-xsub/permutation_test/'
    for method in ['permutate_by_frame', 'permutate_by_clip']:
        if method == 'permutate_by_frame':
            for num_frames_per_clip in [10, 20, 30, 40, 50]:
                print(method, num_frames_per_clip)
                sys.argv[3] = config_path + method + '_' + str(
                    num_frames_per_clip) + '.yaml'
                Processor = processors[arg.processor]
                p = Processor(sys.argv[2:])
                p.start()
        elif method == 'permutate_by_clip':
Example #19
0
    def load_data(self):
        if 'debug' not in self.arg.train_feeder_args:
            self.arg.train_feeder_args['debug'] = self.arg.debug
        self.need_val = False
        self.need_bg = not ('yolo' in self.arg.model)  ## yolo: False
        imdb_name = self.arg.train_dataset
        imdbval_name = self.arg.val_dataset
        imdbtest_name = self.arg.test_dataset
        self.train_size = 0
        self.val_size = 0
        self.test_size = 0
        self.imdb, self.imdb_val, self.imdb_test = None, None, None

        Feeder = import_class(self.arg.feeder)
        sampler = import_class(self.arg.sampler) if self.arg.sampler else None
        self.data_loader = dict()
        self.classes = None

        if self.arg.phase == 'train':
            assert imdb_name is not None, print(
                'Training data is not provoided.')
            imdb, roidb, ratio_list, ratio_index = combined_roidb(
                imdb_name, **vars(self.arg))
            self.classes = imdb.classes if self.need_bg else imdb.classes[1:]
            num_classes = imdb.num_classes if self.need_bg else imdb.num_classes - 1
            self.imdb = imdb
            self.roidb = roidb
            train_size = len(roidb)
            self.train_size = train_size
            self.data_loader['train'] = torch.utils.data.DataLoader(
                dataset=Feeder(roidb,
                               num_classes,
                               need_bg=self.need_bg,
                               **vars(self.arg)),
                batch_size=self.arg.train_args['ims_per_batch'],
                num_workers=self.arg.num_worker *
                torchlight.ngpu(self.arg.device),
                drop_last=True)

            if imdbval_name:
                imdb_val, roidb_val, ratio_list_val, ratio_index_val = combined_roidb(
                    imdbval_name, training=False, **vars(self.arg))
                self.need_val = True
                self.val_size = len(roidb_val)
                self.imdb_val = imdb_val
                self.roidb_val = roidb_val
                self.data_loader['val'] = torch.utils.data.DataLoader(
                    dataset=Feeder(roidb_val,
                                   num_classes,
                                   need_bg=self.need_bg,
                                   training=False,
                                   **vars(self.arg)),
                    batch_size=self.arg.test_args['ims_per_batch'],
                    shuffle=False,
                    num_workers=self.arg.num_worker *
                    torchlight.ngpu(self.arg.device))
        else:
            assert imdbtest_name is not None, print(
                'Test data is not provoided.')
            imdb_test, roidb_test, ratio_list_test, ratio_index_test = combined_roidb(
                imdbtest_name, training=False, **vars(self.arg))
            self.classes = imdb_test.classes if self.need_bg else imdb_test.classes[
                1:]
            num_classes = imdb_test.num_classes if self.need_bg else imdb_test.num_classes - 1
            self.test_size = len(roidb_test)
            self.imdb_test = imdb_test
            self.roidb_test = roidb_test
            self.data_loader['test'] = torch.utils.data.DataLoader(
                dataset=Feeder(roidb_test,
                               num_classes,
                               need_bg=self.need_bg,
                               training=False,
                               **vars(self.arg)),
                batch_size=self.arg.test_args['ims_per_batch'],
                shuffle=False,
                num_workers=0,
                pin_memory=True)
Example #20
0
    def __init__(self,
                 base_encoder=None,
                 pretrain=True,
                 feature_dim=128,
                 queue_size=32768,
                 momentum=0.999,
                 Temperature=0.07,
                 mlp=True,
                 in_channels=3,
                 hidden_channels=64,
                 hidden_dim=256,
                 num_class=60,
                 dropout=0.5,
                 graph_args={
                     'layout': 'ntu-rgb+d',
                     'strategy': 'spatial'
                 },
                 edge_importance_weighting=True,
                 **kwargs):
        """
        K: queue size; number of negative keys (default: 32768)
        m: momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """

        super().__init__()
        base_encoder = import_class(base_encoder)
        self.pretrain = pretrain

        if not self.pretrain:
            self.encoder_q = base_encoder(
                in_channels=in_channels,
                hidden_channels=hidden_channels,
                hidden_dim=hidden_dim,
                num_class=num_class,
                dropout=dropout,
                graph_args=graph_args,
                edge_importance_weighting=edge_importance_weighting,
                **kwargs)
        else:
            self.K = queue_size
            self.m = momentum
            self.T = Temperature

            self.encoder_q = base_encoder(
                in_channels=in_channels,
                hidden_channels=hidden_channels,
                hidden_dim=hidden_dim,
                num_class=feature_dim,
                dropout=dropout,
                graph_args=graph_args,
                edge_importance_weighting=edge_importance_weighting,
                **kwargs)
            self.encoder_k = base_encoder(
                in_channels=in_channels,
                hidden_channels=hidden_channels,
                hidden_dim=hidden_dim,
                num_class=feature_dim,
                dropout=dropout,
                graph_args=graph_args,
                edge_importance_weighting=edge_importance_weighting,
                **kwargs)

            if mlp:  # hack: brute-force replacement
                dim_mlp = self.encoder_q.fc.weight.shape[1]
                self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                                  nn.ReLU(), self.encoder_q.fc)
                self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                                  nn.ReLU(), self.encoder_k.fc)

            for param_q, param_k in zip(self.encoder_q.parameters(),
                                        self.encoder_k.parameters()):
                param_k.data.copy_(param_q.data)  # initialize
                param_k.requires_grad = False  # not update by gradient

            # create the queue
            self.register_buffer("queue", torch.randn(feature_dim, queue_size))
            self.queue = F.normalize(self.queue, dim=0)
            self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
Example #21
0
    def __init__(self, base_encoder=None, pretrain=True, feature_dim=128, queue_size=32768,
                 momentum=0.999, Temperature=0.07, mlp=True, in_channels=3, hidden_channels=64,
                 hidden_dim=256, num_class=60, dropout=0.5,
                 graph_args={'layout': 'ntu-rgb+d', 'strategy': 'spatial'},
                 edge_importance_weighting=True, **kwargs):
        """
        K: queue size; number of negative keys (default: 32768)
        m: momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """

        super().__init__()
        base_encoder = import_class(base_encoder)
        self.pretrain = pretrain
        self.Bone = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6), (8, 7), (9, 21),
                     (10, 9), (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), (16, 15), (17, 1),
                     (18, 17), (19, 18), (20, 19), (21, 21), (22, 23), (23, 8), (24, 25), (25, 12)]

        if not self.pretrain:
            self.encoder_q = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                          hidden_dim=hidden_dim, num_class=num_class,
                                          dropout=dropout, graph_args=graph_args,
                                          edge_importance_weighting=edge_importance_weighting,
                                          **kwargs)
            self.encoder_q_motion = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                                 hidden_dim=hidden_dim, num_class=num_class,
                                                 dropout=dropout, graph_args=graph_args,
                                                 edge_importance_weighting=edge_importance_weighting,
                                                 **kwargs)
            self.encoder_q_bone = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                               hidden_dim=hidden_dim, num_class=num_class,
                                               dropout=dropout, graph_args=graph_args,
                                               edge_importance_weighting=edge_importance_weighting,
                                               **kwargs)
        else:
            self.K = queue_size
            self.m = momentum
            self.T = Temperature

            self.encoder_q = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                          hidden_dim=hidden_dim, num_class=feature_dim,
                                          dropout=dropout, graph_args=graph_args,
                                          edge_importance_weighting=edge_importance_weighting,
                                          **kwargs)
            self.encoder_k = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                          hidden_dim=hidden_dim, num_class=feature_dim,
                                          dropout=dropout, graph_args=graph_args,
                                          edge_importance_weighting=edge_importance_weighting,
                                          **kwargs)
            self.encoder_q_motion = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                                 hidden_dim=hidden_dim, num_class=feature_dim,
                                                 dropout=dropout, graph_args=graph_args,
                                                 edge_importance_weighting=edge_importance_weighting,
                                                 **kwargs)
            self.encoder_k_motion = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                                 hidden_dim=hidden_dim, num_class=feature_dim,
                                                 dropout=dropout, graph_args=graph_args,
                                                 edge_importance_weighting=edge_importance_weighting,
                                                 **kwargs)
            self.encoder_q_bone = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                               hidden_dim=hidden_dim, num_class=feature_dim,
                                               dropout=dropout, graph_args=graph_args,
                                               edge_importance_weighting=edge_importance_weighting,
                                               **kwargs)
            self.encoder_k_bone = base_encoder(in_channels=in_channels, hidden_channels=hidden_channels,
                                               hidden_dim=hidden_dim, num_class=feature_dim,
                                               dropout=dropout, graph_args=graph_args,
                                               edge_importance_weighting=edge_importance_weighting,
                                               **kwargs)

            if mlp:  # hack: brute-force replacement
                dim_mlp = self.encoder_q.fc.weight.shape[1]
                self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                                  nn.ReLU(),
                                                  self.encoder_q.fc)
                self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                                  nn.ReLU(),
                                                  self.encoder_k.fc)
                self.encoder_q_motion.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                                         nn.ReLU(),
                                                         self.encoder_q.fc)
                self.encoder_k_motion.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                                         nn.ReLU(),
                                                         self.encoder_k.fc)
                self.encoder_q_bone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                                       nn.ReLU(),
                                                       self.encoder_q.fc)
                self.encoder_k_bone.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                                                       nn.ReLU(),
                                                       self.encoder_k.fc)

            for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
                param_k.data.copy_(param_q.data)    # initialize
                param_k.requires_grad = False       # not update by gradient
            for param_q, param_k in zip(self.encoder_q_motion.parameters(), self.encoder_k_motion.parameters()):
                param_k.data.copy_(param_q.data)
                param_k.requires_grad = False
            for param_q, param_k in zip(self.encoder_q_bone.parameters(), self.encoder_k_bone.parameters()):
                param_k.data.copy_(param_q.data)
                param_k.requires_grad = False

            # create the queue
            self.register_buffer("queue", torch.randn(feature_dim, self.K))
            self.queue = F.normalize(self.queue, dim=0)
            self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

            self.register_buffer("queue_motion", torch.randn(feature_dim, self.K))
            self.queue_motion = F.normalize(self.queue_motion, dim=0)
            self.register_buffer("queue_ptr_motion", torch.zeros(1, dtype=torch.long))

            self.register_buffer("queue_bone", torch.randn(feature_dim, self.K))
            self.queue_bone = F.normalize(self.queue_bone, dim=0)
            self.register_buffer("queue_ptr_bone", torch.zeros(1, dtype=torch.long))