def __init__(self, config, subset_name, data_path=None, extra_data_dir=None):
        super(MotionNorm, self).__init__()

        np.random.seed(2020)
        self.skel = Skel()  # TD: add config

        if data_path is None:
            data_path = config.data_path
        dataset = np.load(data_path, allow_pickle=True)[subset_name].item()
        motions, labels, metas = dataset["motion"], dataset["style"], dataset["meta"]

        self.label_i = labels
        self.len = len(self.label_i)
        self.metas = [{key: metas[key][i] for key in metas.keys()} for i in range(self.len)]
        self.motion_i, self.foot_i = [], []
        content, style3d, style2d = [], [], []

        self.labels = []
        self.data_dict = {}
        self.diff_labels_dict = {}

        for i, motion in enumerate(motions):
            label = labels[i]
            anim = AnimationData(motion, skel=self.skel)
            if label not in self.labels:
                self.labels.append(label)
                self.data_dict[label] = []
            self.data_dict[label].append(i)
            self.motion_i.append(anim)
            self.foot_i.append(anim.get_foot_contact(transpose=True))  # [4, T]
            content.append(anim.get_content_input())
            style3d.append(anim.get_style3d_input())
            view_angles, scales = [], []
            for v in range(10):
                view_angles.append(self.random_view_angle())
                scales.append(self.random_scale())
            style2d.append(anim.get_projections(view_angles, scales))

        # calc diff labels
        for x in self.labels:
            self.diff_labels_dict[x] = [y for y in self.labels if y != x]

        if extra_data_dir is None:
            extra_data_dir = config.extra_data_dir

        norm_cfg = config.dataset_norm_config
        norm_data = []
        for key, raw in zip(["content", "style3d", "style2d"], [content, style3d, style2d]):
            prefix = norm_cfg[subset_name][key]
            pre_computed = prefix is not None
            if prefix is None:
                prefix = subset_name
            norm_data.append(NormData(prefix + "_" + key, pre_computed, raw,
                                      config, extra_data_dir, keep_raw=(key != "style2d")))
        self.content, self.style3d, self.style2d = norm_data
        self.device = config.device
        self.rand = random.SystemRandom()
    def __init__(self,
                 config,
                 subset_name,
                 data_path=None,
                 extra_data_dir=None,
                 panda=False):
        super(MotionNorm, self).__init__()

        np.random.seed(2020)
        self.skel = Skel()  # TD: add config
        if panda:
            self.skel = PandaSkel()

        if data_path is None:
            data_path = config.data_path
        dataset = np.load(data_path, allow_pickle=True)[subset_name].item()
        '''
        motions: arrays of T x 132
        labels: array of integer values to denote the 'style' of the motion
        metas:
            - style: array of string labels e.g. 'angry', 'childlike'
            - content: array of string labels e.g. 'walk'
            - phase: array of floats
        '''
        motions, labels, metas = dataset["motion"], dataset["style"], dataset[
            "meta"]

        self.label_i = labels
        self.len = len(self.label_i)
        self.metas = [{key: metas[key][i]
                       for key in metas.keys()} for i in range(self.len)]
        self.motion_i = []
        # self.foot_i = []
        content, style3d, style2d = [], [], []

        self.labels = []
        '''data_dict contains mapping of style label to indices belonging to this label'''
        self.data_dict = {}
        self.diff_labels_dict = {}

        for i, motion in enumerate(motions):
            label = labels[i]
            anim = AnimationData(motion, skel=self.skel, panda=panda)
            if label not in self.labels:
                self.labels.append(label)
                self.data_dict[label] = []
            self.data_dict[label].append(i)
            self.motion_i.append(anim)
            # self.foot_i.append(anim.get_foot_contact(transpose=True))  # [4, T]
            content.append(anim.get_content_input())
            style3d.append(anim.get_style3d_input())
            # TODO: FIGURE OUT HOW TO CALCULATE ROOT ROTATION THEN CAN DO STYLE_2D
            # view_angles, scales = [], []
            # for v in range(10):
            #     view_angles.append(self.random_view_angle())
            #     scales.append(self.random_scale())
            # style2d.append(anim.get_projections(view_angles, scales))

        # calc diff labels
        for x in self.labels:
            self.diff_labels_dict[x] = [y for y in self.labels if y != x]

        if extra_data_dir is None:
            extra_data_dir = config.extra_data_dir

        norm_cfg = config.dataset_norm_config
        norm_data = []
        # for key, raw in zip(["content", "style3d", "style2d"], [content, style3d, style2d]):
        #     prefix = norm_cfg[subset_name][key]
        #     pre_computed = prefix is not None
        #     if prefix is None:
        #         prefix = subset_name
        #     norm_data.append(NormData(prefix + "_" + key, pre_computed, raw,
        #                               config, extra_data_dir, keep_raw=(key != "style2d")))
        # self.content, self.style3d, self.style2d = norm_data
        for key, raw in zip(["content", "style3d"], [content, style3d]):
            prefix = norm_cfg[subset_name][key]
            pre_computed = prefix is not None
            if prefix is None:
                prefix = subset_name
            norm_data.append(
                NormData(prefix + "_" + key,
                         pre_computed,
                         raw,
                         config,
                         extra_data_dir,
                         keep_raw=(key != "style2d")))
        self.content, self.style3d = norm_data

        self.device = config.device
        self.rand = random.SystemRandom()