예제 #1
0
    def __init__(self,
                 args,
                 data_loader,
                 C,
                 F,
                 num_classes,
                 graph_dict,
                 device='cuda:0'):

        self.args = args
        self.data_loader = data_loader
        self.num_classes = num_classes
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.device = device
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.model = classifier.Classifier(C, F, num_classes, graph_dict)
        self.model.cuda('cuda:0')
        self.model.apply(weights_init)
        self.loss = nn.CrossEntropyLoss()
        self.best_loss = math.inf
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_epoch = None
        self.best_accuracy = np.zeros((1, np.max(self.args.topk)))
        self.accuracy_updated = False

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr,
                                        weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
예제 #2
0
파일: processor.py 프로젝트: Tanmay-r/STEP
class Processor(object):
    """
        Processor for gait generation
    """

    def __init__(self, args, data_loader, C, num_classes, graph_dict, device='cuda:0', verbose=True):

        self.args = args
        self.data_loader = data_loader
        self.num_classes = num_classes
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.device = device
        self.verbose = verbose
        self.io = IO(
            self.args.work_dir,
            save_log=self.args.save_log,
            print_log=self.args.print_log)

        # model
        if not os.path.isdir(self.args.work_dir):
            os.mkdir(self.args.work_dir)
        self.model = classifier.Classifier(C, num_classes, graph_dict)
        self.model.cuda('cuda:0')
        self.model.apply(weights_init)
        self.loss = nn.CrossEntropyLoss()
        self.best_loss = math.inf
        self.step_epochs = [math.ceil(float(self.args.num_epoch * x)) for x in self.args.step]
        self.best_epoch = None
        self.best_accuracy = np.zeros((1, np.max(self.args.topk)))
        self.accuracy_updated = False

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.args.base_lr,
                momentum=0.9,
                nesterov=self.args.nesterov,
                weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.args.base_lr,
                weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr

    def adjust_lr(self):

        # if self.args.optimizer == 'SGD' and\
        if self.meta_info['epoch'] in self.step_epochs:
            lr = self.args.base_lr * (
                    0.1 ** np.sum(self.meta_info['epoch'] >= np.array(self.step_epochs)))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
            self.lr = lr

    def show_epoch_info(self):

        for k, v in self.epoch_info.items():
            if self.verbose:
                self.io.print_log('\t{}: {}'.format(k, v))
        if self.args.pavi_log:
            if self.verbose:
                self.io.log('train', self.meta_info['iter'], self.epoch_info)

    def show_iter_info(self):

        if self.meta_info['iter'] % self.args.log_interval == 0:
            info = '\tIter {} Done.'.format(self.meta_info['iter'])
            for k, v in self.iter_info.items():
                if isinstance(v, float):
                    info = info + ' | {}: {:.4f}'.format(k, v)
                else:
                    info = info + ' | {}: {}'.format(k, v)
            if self.verbose:
                self.io.print_log(info)

            if self.args.pavi_log:
                self.io.log('train', self.meta_info['iter'], self.iter_info)

    def show_topk(self, k):

        rank = self.result.argsort()
        hit_top_k = [l in rank[i, -k:] for i, l in enumerate(self.label)]
        accuracy = 100. * sum(hit_top_k) * 1.0 / len(hit_top_k)
        if accuracy > self.best_accuracy[0, k-1]:
            self.best_accuracy[0, k-1] = accuracy
            self.accuracy_updated = True
        else:
            self.accuracy_updated = False
        if self.verbose:
            print_epoch = self.best_epoch if self.best_epoch is not None else 0
            self.io.print_log('\tTop{}: {:.2f}%. Best so far: {:.2f}% (epoch: {:d}).'.
                              format(k, accuracy, self.best_accuracy[0, k-1], print_epoch))

    def per_train(self):

        self.model.train()
        self.adjust_lr()
        loader = self.data_loader['train']
        loss_value = []

        for data, label in loader:
            # get data
            data = data.float().to(self.device)
            label = label.long().to(self.device)

            # forward
            output, _ = self.model(data)
            loss = self.loss(output, label)

            # backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # statistics
            self.iter_info['loss'] = loss.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            loss_value.append(self.iter_info['loss'])
            self.show_iter_info()
            self.meta_info['iter'] += 1

        self.epoch_info['mean_loss'] = np.mean(loss_value)
        self.show_epoch_info()
        if self.verbose:
            self.io.print_timer()
        # for k in self.args.topk:
        #     self.calculate_topk(k, show=False)
        # if self.accuracy_updated:
            # self.model.extract_feature()

    def per_test(self, evaluation=True):

        self.model.eval()
        loader = self.data_loader['test']
        loss_value = []
        result_frag = []
        label_frag = []

        for data, label in loader:

            # get data
            data = data.float().to(self.device)
            label = label.long().to(self.device)

            # inference
            with torch.no_grad():
                output, _ = self.model(data)
            result_frag.append(output.data.cpu().numpy())

            # get loss
            if evaluation:
                loss = self.loss(output, label)
                loss_value.append(loss.item())
                label_frag.append(label.data.cpu().numpy())

        self.result = np.concatenate(result_frag)
        if evaluation:
            self.label = np.concatenate(label_frag)
            self.epoch_info['mean_loss'] = np.mean(loss_value)
            self.show_epoch_info()

            # show top-k accuracy
            for k in self.args.topk:
                self.show_topk(k)

    def train(self):

        for epoch in range(self.args.start_epoch, self.args.num_epoch):
            self.meta_info['epoch'] = epoch

            # training
            if self.verbose:
                self.io.print_log('Training epoch: {}'.format(epoch))
            self.per_train()
            if self.verbose:
                self.io.print_log('Done.')

            # evaluation
            if (epoch % self.args.eval_interval == 0) or (
                    epoch + 1 == self.args.num_epoch):
                if self.verbose:
                    self.io.print_log('Eval epoch: {}'.format(epoch))
                self.per_test()
                if self.verbose:
                    self.io.print_log('Done.')

            # save model and weights
            if self.accuracy_updated:
                torch.save(self.model.state_dict(),
                           os.path.join(self.args.work_dir,
                                        'epoch{}_acc{:.2f}_model.pth.tar'.format(epoch, self.best_accuracy.item())))
                if self.epoch_info['mean_loss'] < self.best_loss:
                    self.best_loss = self.epoch_info['mean_loss']
                self.best_epoch = epoch

    def test(self):

        # the path of weights must be appointed
        if self.args.weights is None:
            raise ValueError('Please appoint --weights.')
        if self.verbose:
            self.io.print_log('Model:   {}.'.format(self.args.model))
            self.io.print_log('Weights: {}.'.format(self.args.weights))

        # evaluation
        if self.verbose:
            self.io.print_log('Evaluation Start:')
        self.per_test()
        if self.verbose:
            self.io.print_log('Done.\n')

        # save the output of model
        if self.args.save_result:
            result_dict = dict(
                zip(self.data_loader['test'].dataset.sample_name,
                    self.result))
            self.io.save_pkl(result_dict, 'test_result.pkl')

    def smap(self):
        # self.model.eval()
        loader = self.data_loader['test']

        for data, label in loader:

            # get data
            data = data.float().to(self.device)
            label = label.long().to(self.device)

            GBP = GuidedBackprop(self.model)
            guided_grads = GBP.generate_gradients(data, label)

    def load_best_model(self):
        if self.best_epoch is None:
            self.best_epoch, best_accuracy = get_best_epoch_and_accuracy(self.args.work_dir)
        else:
            best_accuracy = self.best_accuracy.item()

        filename = os.path.join(self.args.work_dir,
                                'epoch{}_acc{:.2f}_model.pth.tar'.format(self.best_epoch, best_accuracy))
        self.model.load_state_dict(torch.load(filename))

    def generate_predictions(self, data, num_classes, joints, coords):
        # fin = h5py.File('../data/features'+ftype+'.h5', 'r')
        # fkeys = fin.keys()
        labels_pred = np.zeros(data.shape[0])
        output = np.zeros((data.shape[0], num_classes))
        for i, each_data in enumerate(zip(data)):
            # get data
            each_data = each_data[0]
            each_data = np.reshape(each_data, (1, each_data.shape[0], joints, coords, 1))
            each_data = np.moveaxis(each_data, [1, 2, 3], [2, 3, 1])
            each_data = torch.from_numpy(each_data).float().to(self.device)
            # get label
            with torch.no_grad():
                output[i], _ = self.model(each_data)
                labels_pred[i] = np.argmax(output[i])
        return labels_pred, output

    def generate_confusion_matrix(self, ftype, data, labels, num_classes, joints, coords):
        self.load_best_model()
        labels_pred = self.generate_predictions(data, num_classes, joints, coords)

        hit = np.nonzero(labels_pred == labels)
        miss = np.nonzero(labels_pred != labels)
        confusion_matrix = np.zeros((num_classes, num_classes))
        for hidx in np.arange(len(hit[0])):
            confusion_matrix[np.int(labels[hit[0][hidx]]), np.int(labels_pred[hit[0][hidx]])] += 1
        for midx in np.arange(len(miss[0])):
            confusion_matrix[np.int(labels[miss[0][midx]]), np.int(labels_pred[miss[0][midx]])] += 1
        confusion_matrix = confusion_matrix.transpose()
        plot_confusion_matrix(confusion_matrix)

    def save_best_feature(self, ftype, data, joints, coords):
        if self.best_epoch is None:
            self.best_epoch, best_accuracy = get_best_epoch_and_accuracy(self.args.work_dir)
        else:
            best_accuracy = self.best_accuracy.item()
        filename = os.path.join(self.args.work_dir,
                                'epoch{}_acc{:.2f}_model.pth.tar'.format(self.best_epoch, best_accuracy))
        self.model.load_state_dict(torch.load(filename))
        features = np.empty((0, 64))
        fCombined = h5py.File('../data/features'+ftype+'.h5', 'r')
        fkeys = fCombined.keys()
        dfCombined = h5py.File('../data/deepFeatures'+ftype+'.h5', 'w')
        for i, (each_data, each_key) in enumerate(zip(data, fkeys)):

            # get data
            each_data = np.reshape(each_data, (1, each_data.shape[0], joints, coords, 1))
            each_data = np.moveaxis(each_data, [1, 2, 3], [2, 3, 1])
            each_data = torch.from_numpy(each_data).float().to(self.device)

            # get feature
            with torch.no_grad():
                _, feature = self.model(each_data)
                fname = [each_key][0]
                dfCombined.create_dataset(fname, data=feature)
                features = np.append(features, np.array(feature).reshape((1, feature.shape[0])), axis=0)
        dfCombined.close()
        return features
예제 #3
0
    def __init__(self,
                 args,
                 data_path,
                 data_loader,
                 Z,
                 T,
                 A,
                 V,
                 C,
                 D,
                 tag_cats,
                 IE,
                 IP,
                 AT,
                 G,
                 AGE,
                 H,
                 NT,
                 joint_names,
                 joint_parents,
                 word2idx,
                 embedding_table,
                 lower_body_start=15,
                 fill=6,
                 min_train_epochs=20,
                 generate_while_train=False,
                 save_path=None,
                 device='cuda:0'):
        def get_quats_sos_and_eos():
            quats_sos_and_eos_file = os.path.join(data_path,
                                                  'quats_sos_and_eos.npz')
            keys = list(self.data_loader['train'].keys())
            num_samples = len(self.data_loader['train'])
            try:
                mean_quats_sos = np.load(quats_sos_and_eos_file,
                                         allow_pickle=True)['quats_sos']
                mean_quats_eos = np.load(quats_sos_and_eos_file,
                                         allow_pickle=True)['quats_eos']
            except FileNotFoundError:
                mean_quats_sos = np.zeros((self.V, self.D))
                mean_quats_eos = np.zeros((self.V, self.D))
                for j in range(self.V):
                    quats_sos = np.zeros((self.D, num_samples))
                    quats_eos = np.zeros((self.D, num_samples))
                    for s in range(num_samples):
                        quats_sos[:, s] = self.data_loader['train'][
                            keys[s]]['rotations'][0, j]
                        quats_eos[:, s] = self.data_loader['train'][
                            keys[s]]['rotations'][-1, j]
                    _, sos_eig_vectors = np.linalg.eig(
                        np.dot(quats_sos, quats_sos.T))
                    mean_quats_sos[j] = sos_eig_vectors[:, 0]
                    _, eos_eig_vectors = np.linalg.eig(
                        np.dot(quats_eos, quats_eos.T))
                    mean_quats_eos[j] = eos_eig_vectors[:, 0]
                np.savez_compressed(quats_sos_and_eos_file,
                                    quats_sos=mean_quats_sos,
                                    quats_eos=mean_quats_eos)
            mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0)
            mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0)
            for s in range(num_samples):
                pos_sos = \
                    MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0),
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]
                                                    ['positions'][0:1, 0]).double().unsqueeze(0),
                                                    self.joint_parents,
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                                    ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
                affs_sos = MocapDataset.get_mpi_affective_features(pos_sos)
                pos_eos = \
                    MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0),
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]
                                                    ['positions'][-1:, 0]).double().unsqueeze(0),
                                                    self.joint_parents,
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                                    ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
                affs_eos = MocapDataset.get_mpi_affective_features(pos_eos)
                self.data_loader['train'][keys[s]]['positions'] = \
                    np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0)
                self.data_loader['train'][keys[s]]['affective_features'] = \
                    np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos),
                                   axis=0)
            return mean_quats_sos, mean_quats_eos

        self.args = args
        self.dataset = args.dataset
        self.channel_map = {
            'Xrotation': 'x',
            'Yrotation': 'y',
            'Zrotation': 'z'
        }
        self.device = device
        self.data_loader = data_loader
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.T = T + 2
        self.A = A
        self.V = V
        self.C = C
        self.D = D
        self.O = 1
        self.tag_cats = tag_cats
        self.IE = IE
        self.IP = IP
        self.AT = AT
        self.G = G
        self.AGE = AGE
        self.H = H
        self.NT = NT
        self.joint_names = joint_names
        self.joint_parents = joint_parents
        self.lower_body_start = lower_body_start
        self.quats_sos, self.quats_eos = get_quats_sos_and_eos()
        self.recons_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.L1Loss()
        self.best_loss = np.inf
        self.loss_updated = False
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_loss_epoch = None
        self.min_train_epochs = min_train_epochs
        self.zfill = fill
        self.word2idx = word2idx
        self.text_sos = np.int64(self.word2idx['<SOS>'])
        self.text_eos = np.int64(self.word2idx['<EOS>'])
        num_tokens = len(self.word2idx)  # the size of vocabulary
        self.Z = Z  # embedding dimension
        num_hidden_units = 200  # the dimension of the feedforward network model in nn.TransformerEncoder
        num_layers = 2  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        num_heads = 2  # the number of heads in the multiheadattention models
        dropout = 0.2  # the dropout value
        self.model = T2GNet(num_tokens,
                            torch.from_numpy(embedding_table).cuda(),
                            self.T - 1, self.Z, self.V * self.D, self.D,
                            self.V - 1, self.IE, self.IP, self.AT, self.G,
                            self.AGE, self.H, self.NT, num_heads,
                            num_hidden_units, num_layers, dropout).to(device)

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr)
            # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr
예제 #4
0
class Processor(object):
    """
        Processor for gait generation
    """
    def __init__(self,
                 args,
                 data_path,
                 data_loader,
                 Z,
                 T,
                 A,
                 V,
                 C,
                 D,
                 tag_cats,
                 IE,
                 IP,
                 AT,
                 G,
                 AGE,
                 H,
                 NT,
                 joint_names,
                 joint_parents,
                 word2idx,
                 embedding_table,
                 lower_body_start=15,
                 fill=6,
                 min_train_epochs=20,
                 generate_while_train=False,
                 save_path=None,
                 device='cuda:0'):
        def get_quats_sos_and_eos():
            quats_sos_and_eos_file = os.path.join(data_path,
                                                  'quats_sos_and_eos.npz')
            keys = list(self.data_loader['train'].keys())
            num_samples = len(self.data_loader['train'])
            try:
                mean_quats_sos = np.load(quats_sos_and_eos_file,
                                         allow_pickle=True)['quats_sos']
                mean_quats_eos = np.load(quats_sos_and_eos_file,
                                         allow_pickle=True)['quats_eos']
            except FileNotFoundError:
                mean_quats_sos = np.zeros((self.V, self.D))
                mean_quats_eos = np.zeros((self.V, self.D))
                for j in range(self.V):
                    quats_sos = np.zeros((self.D, num_samples))
                    quats_eos = np.zeros((self.D, num_samples))
                    for s in range(num_samples):
                        quats_sos[:, s] = self.data_loader['train'][
                            keys[s]]['rotations'][0, j]
                        quats_eos[:, s] = self.data_loader['train'][
                            keys[s]]['rotations'][-1, j]
                    _, sos_eig_vectors = np.linalg.eig(
                        np.dot(quats_sos, quats_sos.T))
                    mean_quats_sos[j] = sos_eig_vectors[:, 0]
                    _, eos_eig_vectors = np.linalg.eig(
                        np.dot(quats_eos, quats_eos.T))
                    mean_quats_eos[j] = eos_eig_vectors[:, 0]
                np.savez_compressed(quats_sos_and_eos_file,
                                    quats_sos=mean_quats_sos,
                                    quats_eos=mean_quats_eos)
            mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0)
            mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0)
            for s in range(num_samples):
                pos_sos = \
                    MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0),
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]
                                                    ['positions'][0:1, 0]).double().unsqueeze(0),
                                                    self.joint_parents,
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                                    ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
                affs_sos = MocapDataset.get_mpi_affective_features(pos_sos)
                pos_eos = \
                    MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0),
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]
                                                    ['positions'][-1:, 0]).double().unsqueeze(0),
                                                    self.joint_parents,
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                                    ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
                affs_eos = MocapDataset.get_mpi_affective_features(pos_eos)
                self.data_loader['train'][keys[s]]['positions'] = \
                    np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0)
                self.data_loader['train'][keys[s]]['affective_features'] = \
                    np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos),
                                   axis=0)
            return mean_quats_sos, mean_quats_eos

        self.args = args
        self.dataset = args.dataset
        self.channel_map = {
            'Xrotation': 'x',
            'Yrotation': 'y',
            'Zrotation': 'z'
        }
        self.device = device
        self.data_loader = data_loader
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.T = T + 2
        self.A = A
        self.V = V
        self.C = C
        self.D = D
        self.O = 1
        self.tag_cats = tag_cats
        self.IE = IE
        self.IP = IP
        self.AT = AT
        self.G = G
        self.AGE = AGE
        self.H = H
        self.NT = NT
        self.joint_names = joint_names
        self.joint_parents = joint_parents
        self.lower_body_start = lower_body_start
        self.quats_sos, self.quats_eos = get_quats_sos_and_eos()
        self.recons_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.L1Loss()
        self.best_loss = np.inf
        self.loss_updated = False
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_loss_epoch = None
        self.min_train_epochs = min_train_epochs
        self.zfill = fill
        self.word2idx = word2idx
        self.text_sos = np.int64(self.word2idx['<SOS>'])
        self.text_eos = np.int64(self.word2idx['<EOS>'])
        num_tokens = len(self.word2idx)  # the size of vocabulary
        self.Z = Z  # embedding dimension
        num_hidden_units = 200  # the dimension of the feedforward network model in nn.TransformerEncoder
        num_layers = 2  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        num_heads = 2  # the number of heads in the multiheadattention models
        dropout = 0.2  # the dropout value
        self.model = T2GNet(num_tokens,
                            torch.from_numpy(embedding_table).cuda(),
                            self.T - 1, self.Z, self.V * self.D, self.D,
                            self.V - 1, self.IE, self.IP, self.AT, self.G,
                            self.AGE, self.H, self.NT, num_heads,
                            num_hidden_units, num_layers, dropout).to(device)

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr)
            # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr

    def process_data(self, data, poses, quat, trans, affs):
        data = data.float().to(self.device)
        poses = poses.float().to(self.device)
        quat = quat.float().to(self.device)
        trans = trans.float().to(self.device)
        affs = affs.float().to(self.device)
        return data, poses, quat, trans, affs

    def load_model_at_epoch(self, epoch='best'):
        model_name, self.best_loss_epoch, self.best_loss =\
            get_epoch_and_loss(self.args.work_dir, epoch=epoch)
        model_found = False
        try:
            loaded_vars = torch.load(
                os.path.join(self.args.work_dir, model_name))
            self.model.load_state_dict(loaded_vars['model_dict'])
            model_found = True
        except (FileNotFoundError, IsADirectoryError):
            if epoch == 'best':
                print('Warning! No saved model found.')
            else:
                print('Warning! No saved model found at epoch {:d}.'.format(
                    epoch))
        return model_found

    def adjust_lr(self):
        self.lr = self.lr * self.args.lr_decay
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def adjust_tf(self):
        if self.meta_info['epoch'] > 20:
            self.tf = self.tf * self.args.tf_decay

    def show_epoch_info(self):

        print_epochs = [
            self.best_loss_epoch if self.best_loss_epoch is not None else 0
        ]
        best_metrics = [self.best_loss]
        i = 0
        for k, v in self.epoch_info.items():
            self.io.print_log(
                '\t{}: {}. Best so far: {} (epoch: {:d}).'.format(
                    k, v, best_metrics[i], print_epochs[i]))
            i += 1
        if self.args.pavi_log:
            self.io.log('train', self.meta_info['iter'], self.epoch_info)

    def show_iter_info(self):

        if self.meta_info['iter'] % self.args.log_interval == 0:
            info = '\tIter {} Done.'.format(self.meta_info['iter'])
            for k, v in self.iter_info.items():
                if isinstance(v, float):
                    info = info + ' | {}: {:.4f}'.format(k, v)
                else:
                    info = info + ' | {}: {}'.format(k, v)

            self.io.print_log(info)

            if self.args.pavi_log:
                self.io.log('train', self.meta_info['iter'], self.iter_info)

    def yield_batch(self, batch_size, dataset):
        batch_joint_offsets = torch.zeros(
            (batch_size, self.V - 1, self.C)).cuda()
        batch_pos = torch.zeros((batch_size, self.T, self.V, self.C)).cuda()
        batch_affs = torch.zeros((batch_size, self.T, self.A)).cuda()
        batch_quat = torch.zeros((batch_size, self.T, self.V * self.D)).cuda()
        batch_quat_valid_idx = torch.zeros((batch_size, self.T)).cuda()
        batch_text = torch.zeros((batch_size, self.Z)).cuda().long()
        batch_text_valid_idx = torch.zeros((batch_size, self.Z)).cuda()
        batch_intended_emotion = torch.zeros((batch_size, self.IE)).cuda()
        batch_intended_polarity = torch.zeros((batch_size, self.IP)).cuda()
        batch_acting_task = torch.zeros((batch_size, self.AT)).cuda()
        batch_gender = torch.zeros((batch_size, self.G)).cuda()
        batch_age = torch.zeros((batch_size, self.AGE)).cuda()
        batch_handedness = torch.zeros((batch_size, self.H)).cuda()
        batch_native_tongue = torch.zeros((batch_size, self.NT)).cuda()

        pseudo_passes = (len(dataset) + batch_size - 1) // batch_size

        probs = []
        for k in dataset.keys():
            probs.append(dataset[k]['positions'].shape[0])
        probs = np.array(probs) / np.sum(probs)

        for p in range(pseudo_passes):
            rand_keys = np.random.choice(len(dataset),
                                         size=batch_size,
                                         replace=True,
                                         p=probs)
            for i, k in enumerate(rand_keys):
                joint_offsets = torch.from_numpy(dataset[str(k).zfill(
                    self.zfill)]['joints_dict']['joints_offsets_all'][1:])
                pos = torch.from_numpy(dataset[str(k).zfill(
                    self.zfill)]['positions'])
                affs = torch.from_numpy(dataset[str(k).zfill(
                    self.zfill)]['affective_features'])
                quat = torch.cat(
                    (self.quats_sos,
                     torch.from_numpy(dataset[str(k).zfill(
                         self.zfill)]['rotations']), self.quats_eos),
                    dim=0)
                quat_length = quat.shape[0]
                quat_valid_idx = torch.zeros(self.T)
                quat_valid_idx[:quat_length] = 1
                text = torch.cat((torch.tensor([
                    self.word2idx[x] for x in [
                        e for e in str.split(dataset[str(k).zfill(self.zfill)]
                                             ['Text']) if e.isalnum()
                    ]
                ]), torch.from_numpy(np.array([self.text_eos]))))
                if text[0] != self.text_sos:
                    text = torch.cat(
                        (torch.from_numpy(np.array([self.text_sos])), text))
                text_length = text.shape[0]
                text_valid_idx = torch.zeros(self.Z)
                text_valid_idx[:text_length] = 1

                batch_joint_offsets[i] = joint_offsets
                batch_pos[i, :pos.shape[0]] = pos
                batch_pos[i, pos.shape[0]:] = pos[-1:].clone()
                batch_affs[i, :affs.shape[0]] = affs
                batch_affs[i, affs.shape[0]:] = affs[-1:].clone()
                batch_quat[i, :quat_length] = quat.view(quat_length, -1)
                batch_quat[i, quat_length:] = quat[-1:].view(1, -1).clone()
                batch_quat_valid_idx[i] = quat_valid_idx
                batch_text[i, :text_length] = text
                batch_text_valid_idx[i] = text_valid_idx
                batch_intended_emotion[i] = torch.from_numpy(
                    dataset[str(k).zfill(self.zfill)]['Intended emotion'])
                batch_intended_polarity[i] = torch.from_numpy(
                    dataset[str(k).zfill(self.zfill)]['Intended polarity'])
                batch_acting_task[i] = torch.from_numpy(dataset[str(k).zfill(
                    self.zfill)]['Acting task'])
                batch_gender[i] = torch.from_numpy(dataset[str(k).zfill(
                    self.zfill)]['Gender'])
                batch_age[i] = torch.tensor(dataset[str(k).zfill(
                    self.zfill)]['Age'])
                batch_handedness[i] = torch.from_numpy(dataset[str(k).zfill(
                    self.zfill)]['Handedness'])
                batch_native_tongue[i] = torch.from_numpy(dataset[str(k).zfill(
                    self.zfill)]['Native tongue'])

            yield batch_joint_offsets, batch_pos, batch_affs, batch_quat, \
                batch_quat_valid_idx, batch_text, batch_text_valid_idx, \
                batch_intended_emotion, batch_intended_polarity, batch_acting_task, \
                batch_gender, batch_age, batch_handedness, batch_native_tongue

    def return_batch(self, batch_size, dataset, randomized=True):
        if len(batch_size) > 1:
            rand_keys = np.copy(batch_size)
            batch_size = len(batch_size)
        else:
            batch_size = batch_size[0]
            probs = []
            for k in dataset.keys():
                probs.append(dataset[k]['positions'].shape[0])
            probs = np.array(probs) / np.sum(probs)
            if randomized:
                rand_keys = np.random.choice(len(dataset),
                                             size=batch_size,
                                             replace=False,
                                             p=probs)
            else:
                rand_keys = np.arange(batch_size)

        batch_joint_offsets = torch.zeros(
            (batch_size, self.V - 1, self.C)).cuda()
        batch_pos = torch.zeros((batch_size, self.T, self.V, self.C)).cuda()
        batch_affs = torch.zeros((batch_size, self.T, self.A)).cuda()
        batch_quat = torch.zeros((batch_size, self.T, self.V * self.D)).cuda()
        batch_quat_valid_idx = torch.zeros((batch_size, self.T)).cuda()
        batch_text = torch.zeros((batch_size, self.Z)).cuda().long()
        batch_text_valid_idx = torch.zeros((batch_size, self.Z)).cuda()
        batch_intended_emotion = torch.zeros((batch_size, self.IE)).cuda()
        batch_intended_polarity = torch.zeros((batch_size, self.IP)).cuda()
        batch_acting_task = torch.zeros((batch_size, self.AT)).cuda()
        batch_gender = torch.zeros((batch_size, self.G)).cuda()
        batch_age = torch.zeros((batch_size, self.AGE)).cuda()
        batch_handedness = torch.zeros((batch_size, self.H)).cuda()
        batch_native_tongue = torch.zeros((batch_size, self.NT)).cuda()

        for i, k in enumerate(rand_keys):
            joint_offsets = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['joints_dict']['joints_offsets_all'][1:])
            pos = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['positions'])
            affs = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['affective_features'])
            quat = torch.cat((self.quats_sos,
                              torch.from_numpy(dataset[str(k).zfill(
                                  self.zfill)]['rotations']), self.quats_eos),
                             dim=0)
            quat_length = quat.shape[0]
            quat_valid_idx = torch.zeros(self.T)
            quat_valid_idx[:quat_length] = 1
            text = torch.cat((torch.tensor([
                self.word2idx[x] for x in [
                    e for e in str.split(dataset[str(k).zfill(self.zfill)]
                                         ['Text']) if e.isalnum()
                ]
            ]), torch.from_numpy(np.array([self.text_eos]))))
            if text[0] != self.text_sos:
                text = torch.cat(
                    (torch.from_numpy(np.array([self.text_sos])), text))
            text_length = text.shape[0]
            text_valid_idx = torch.zeros(self.Z)
            text_valid_idx[:text_length] = 1

            batch_joint_offsets[i] = joint_offsets
            batch_pos[i, :pos.shape[0]] = pos
            batch_pos[i, pos.shape[0]:] = pos[-1:].clone()
            batch_affs[i, :affs.shape[0]] = affs
            batch_affs[i, affs.shape[0]:] = affs[-1:].clone()
            batch_quat[i, :quat_length] = quat.view(quat_length, -1)
            batch_quat[i, quat_length:] = quat[-1:].view(1, -1).clone()
            batch_quat_valid_idx[i] = quat_valid_idx
            batch_text[i, :text_length] = text
            batch_text_valid_idx[i] = text_valid_idx
            batch_intended_emotion[i] = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['Intended emotion'])
            batch_intended_polarity[i] = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['Intended polarity'])
            batch_acting_task[i] = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['Acting task'])
            batch_gender[i] = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['Gender'])
            batch_age[i] = torch.tensor(dataset[str(k).zfill(
                self.zfill)]['Age'])
            batch_handedness[i] = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['Handedness'])
            batch_native_tongue[i] = torch.from_numpy(dataset[str(k).zfill(
                self.zfill)]['Native tongue'])

        return batch_joint_offsets, batch_pos, batch_affs, batch_quat, \
                batch_quat_valid_idx, batch_text, batch_text_valid_idx, \
                batch_intended_emotion, batch_intended_polarity, batch_acting_task, \
                batch_gender, batch_age, batch_handedness, batch_native_tongue

    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for joint_offsets, pos, affs, quat, quat_valid_idx,\
            text, text_valid_idx, intended_emotion, intended_polarity,\
                acting_task, gender, age, handedness,\
                native_tongue in self.yield_batch(self.args.batch_size, train_loader):
            quat_prelude = self.quats_eos.view(1, -1).cuda() \
                .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
            quat_prelude[:, -1] = quat[:, 0].clone()

            self.optimizer.zero_grad()
            with torch.autograd.detect_anomaly():
                joint_lengths = torch.norm(joint_offsets, dim=-1)
                scales, _ = torch.max(joint_lengths, dim=-1)
                quat_pred, quat_pred_pre_norm = self.model(
                    text, intended_emotion, intended_polarity, acting_task,
                    gender, age, handedness, native_tongue, quat_prelude,
                    joint_lengths / scales[..., None])

                quat_pred_pre_norm = quat_pred_pre_norm.view(
                    quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1],
                    -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                                 torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V,
                    self.D, self.lower_body_start, self.args.upper_body_weight)
                quat_loss *= self.args.quat_reg

                root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                       self.C).cuda()
                pos_pred = MocapDataset.forward_kinematics(
                    quat_pred.contiguous().view(quat_pred.shape[0],
                                                quat_pred.shape[1], -1,
                                                self.D), root_pos,
                    self.joint_parents,
                    torch.cat((root_pos[:, 0:1], joint_offsets),
                              dim=1).unsqueeze(1))
                affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

                # row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V
                # row_sums[row_sums == 0.] = 1.

                shifted_pos = pos - pos[:, :, 0:1]
                shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

                recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                    shifted_pos[:, 1:])
                # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1)
                # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\
                #               recons_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_loss = self.args.recons_reg *\
                #               torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                #
                # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] -
                #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
                # recons_derv_loss = self.args.upper_body_weight *\
                #     (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\
                #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_derv_loss = 2. * self.args.recons_reg *\
                #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
                #
                # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1)
                # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])

                train_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss
                # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss
                train_loss.backward()
                # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
                self.optimizer.step()

            # animation_pred = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': joint_offsets,
            #     'joint_parents': self.joint_parents,
            #     'positions': pos_pred,
            #     'rotations': quat_pred
            # }
            # MocapDataset.save_as_bvh(animation_pred,
            #                          dataset_name=self.dataset,
            #                          subset_name='test')
            # animation = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': joint_offsets,
            #     'joint_parents': self.joint_parents,
            #     'positions': pos,
            #     'rotations': quat
            # }
            # MocapDataset.save_as_bvh(animation,
            #                          dataset_name=self.dataset,
            #                          subset_name='gt')

            # Compute statistics
            batch_loss += train_loss.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = train_loss.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()

        # pos_pred_np = np.swapaxes(np.reshape(pos_pred.detach().cpu().numpy(),
        #                                      (pos_pred.shape[0], self.T - 1, -1)), 2, 1)
        # display_animations(pos_pred_np, self.joint_parents,
        #                    save=True, dataset_name=self.dataset, subset_name='test', overwrite=True)

    def per_eval(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        eval_loss = 0.
        N = 0.

        for joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
                native_tongue in self.yield_batch(self.args.batch_size, test_loader):
            with torch.no_grad():
                joint_lengths = torch.norm(joint_offsets, dim=-1)
                scales, _ = torch.max(joint_lengths, dim=-1)
                quat_prelude = self.quats_eos.view(1, -1).cuda() \
                    .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
                quat_prelude[:, -1] = quat[:, 0].clone()
                quat_pred, quat_pred_pre_norm = self.model(
                    text, intended_emotion, intended_polarity, acting_task,
                    gender, age, handedness, native_tongue, quat_prelude,
                    joint_lengths / scales[..., None])
                quat_pred_pre_norm = quat_pred_pre_norm.view(
                    quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1],
                    -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                                 torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V,
                    self.D, self.lower_body_start, self.args.upper_body_weight)
                quat_loss *= self.args.quat_reg

                root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                       self.C).cuda()
                pos_pred = MocapDataset.forward_kinematics(
                    quat_pred.contiguous().view(quat_pred.shape[0],
                                                quat_pred.shape[1], -1,
                                                self.D), root_pos,
                    self.joint_parents,
                    torch.cat((root_pos[:, 0:1], joint_offsets),
                              dim=1).unsqueeze(1))
                affs_pred = MocapDataset.get_mpi_affective_features(pos_pred)

                row_sums = quat_valid_idx.sum(1,
                                              keepdim=True) * self.D * self.V
                row_sums[row_sums == 0.] = 1.

                shifted_pos = pos - pos[:, :, 0:1]
                shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1]

                recons_loss = self.recons_loss_func(shifted_pos_pred,
                                                    shifted_pos[:, 1:])
                # recons_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos[:, 1:]).sum(-1)
                # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) + \
                #               recons_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_loss = self.args.recons_reg * torch.mean(
                #     (recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                #
                # recons_derv_loss = torch.abs(shifted_pos_pred[:, 2:] - shifted_pos_pred[:, 1:-1] -
                #                              shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1)
                # recons_derv_loss = self.args.upper_body_weight * \
                #                    (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) + \
                #                    recons_derv_loss[:, :, self.lower_body_start:].sum(-1)
                # recons_derv_loss = 2. * self.args.recons_reg * \
                #                    torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums)
                #
                # affs_loss = torch.abs(affs[:, 1:] - affs_pred[:, 1:]).sum(-1)
                # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums)
                affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:])

                eval_loss += quat_norm_loss + quat_loss + recons_loss + affs_loss
                # eval_loss += quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss
                N += quat.shape[0]

        eval_loss /= N
        self.epoch_info['mean_loss'] = eval_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[
                'epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()

    def train(self):

        if self.args.load_last_best or (self.args.load_at_epoch is not None):
            model_found = self.load_model_at_epoch(
                epoch='best' if self.args.load_last_best else self.args.
                load_at_epoch)
            if not model_found and self.args.start_epoch is not 'best':
                print('Warning! Trying to load best known model: '.format(
                    self.args.start_epoch),
                      end='')
                model_found = self.load_model_at_epoch(epoch='best')
                self.args.start_epoch = self.best_loss_epoch if model_found else 0
                print('loaded.')
                if not model_found:
                    print('Warning! Starting at epoch 0')
                    self.args.start_epoch = 0
            else:
                self.args.start_epoch = self.best_loss_epoch
        else:
            self.args.start_epoch = 0
        for epoch in range(self.args.start_epoch, self.args.num_epoch):
            self.meta_info['epoch'] = epoch

            # training
            self.io.print_log('Training epoch: {}'.format(epoch))
            self.per_train()
            self.io.print_log('Done.')

            # evaluation
            if (epoch % self.args.eval_interval
                    == 0) or (epoch + 1 == self.args.num_epoch):
                self.io.print_log('Eval epoch: {}'.format(epoch))
                self.per_eval()
                self.io.print_log('Done.')

            # save model and weights
            if epoch > self.args.min_train_epochs and (
                    self.loss_updated or epoch % self.args.save_interval == 0):
                torch.save({'model_dict': self.model.state_dict()},
                           os.path.join(
                               self.args.work_dir,
                               'epoch_{}_loss_{:.4f}_model.pth.tar'.format(
                                   epoch, self.epoch_info['mean_loss'])))

                if self.generate_while_train:
                    self.generate_motion(load_saved_model=False,
                                         samples_to_generate=1,
                                         epoch=epoch)

    def copy_prefix(self, var, prefix_length=None):
        if prefix_length is None:
            prefix_length = self.prefix_length
        return [
            var[s, :prefix_length].unsqueeze(0) for s in range(var.shape[0])
        ]

    def generate_motion(self,
                        load_saved_model=True,
                        samples_to_generate=10,
                        epoch=None,
                        randomized=True,
                        animations_as_videos=False):

        if epoch is None:
            epoch = 'best'
        if load_saved_model:
            self.load_model_at_epoch(epoch=epoch)
        self.model.eval()
        test_loader = self.data_loader['test']

        joint_offsets, pos, affs, quat, quat_valid_idx, \
            text, text_valid_idx, intended_emotion, intended_polarity, \
            acting_task, gender, age, handedness, \
            native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        with torch.no_grad():
            joint_lengths = torch.norm(joint_offsets, dim=-1)
            scales, _ = torch.max(joint_lengths, dim=-1)
            quat_prelude = self.quats_eos.view(1, -1).cuda() \
                .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float()
            quat_prelude[:, -1] = quat[:, 0].clone()
            quat_pred, quat_pred_pre_norm = self.model(
                text, intended_emotion, intended_polarity, acting_task, gender,
                age, handedness, native_tongue, quat_prelude,
                joint_lengths / scales[..., None])
            for s in range(len(quat_pred)):
                quat_pred[s] = qfix(
                    quat_pred[s].view(quat_pred[s].shape[0], self.V,
                                      -1)).view(quat_pred[s].shape[0], -1)

            root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1],
                                   self.C).cuda()
            pos_pred = MocapDataset.forward_kinematics(
                quat_pred.contiguous().view(quat_pred.shape[0],
                                            quat_pred.shape[1], -1, self.D),
                root_pos, self.joint_parents,
                torch.cat((root_pos[:, 0:1], joint_offsets),
                          dim=1).unsqueeze(1))

        animation_pred = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos_pred,
            'rotations': quat_pred
        }
        MocapDataset.save_as_bvh(animation_pred,
                                 dataset_name=self.dataset + '_glove',
                                 subset_name='test')
        animation = {
            'joint_names': self.joint_names,
            'joint_offsets': joint_offsets,
            'joint_parents': self.joint_parents,
            'positions': pos,
            'rotations': quat
        }
        MocapDataset.save_as_bvh(animation,
                                 dataset_name=self.dataset + '_glove',
                                 subset_name='gt')

        if animations_as_videos:
            pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0],
                                                     pos_pred.shape[1], -1).permute(0, 2, 1).\
                detach().cpu().numpy()
            display_animations(pos_pred_np,
                               self.joint_parents,
                               save=True,
                               dataset_name=self.dataset,
                               subset_name='epoch_' +
                               str(self.best_loss_epoch),
                               overwrite=True)
예제 #5
0
    def __init__(self,
                 args,
                 data_path,
                 data_loader,
                 Z,
                 T,
                 A,
                 V,
                 C,
                 D,
                 tag_cats,
                 IE,
                 IP,
                 AT,
                 G,
                 AGE,
                 H,
                 NT,
                 joint_names,
                 joint_parents,
                 lower_body_start=15,
                 fill=6,
                 min_train_epochs=20,
                 generate_while_train=False,
                 save_path=None):
        def get_quats_sos_and_eos():
            quats_sos_and_eos_file = os.path.join(data_path,
                                                  'quats_sos_and_eos.npz')
            keys = list(self.data_loader['train'].keys())
            num_samples = len(self.data_loader['train'])
            try:
                mean_quats_sos = np.load(quats_sos_and_eos_file,
                                         allow_pickle=True)['quats_sos']
                mean_quats_eos = np.load(quats_sos_and_eos_file,
                                         allow_pickle=True)['quats_eos']
            except FileNotFoundError:
                mean_quats_sos = np.zeros((self.V, self.D))
                mean_quats_eos = np.zeros((self.V, self.D))
                for j in range(self.V):
                    quats_sos = np.zeros((self.D, num_samples))
                    quats_eos = np.zeros((self.D, num_samples))
                    for s in range(num_samples):
                        quats_sos[:, s] = self.data_loader['train'][
                            keys[s]]['rotations'][0, j]
                        quats_eos[:, s] = self.data_loader['train'][
                            keys[s]]['rotations'][-1, j]
                    _, sos_eig_vectors = np.linalg.eig(
                        np.dot(quats_sos, quats_sos.T))
                    mean_quats_sos[j] = sos_eig_vectors[:, 0]
                    _, eos_eig_vectors = np.linalg.eig(
                        np.dot(quats_eos, quats_eos.T))
                    mean_quats_eos[j] = eos_eig_vectors[:, 0]
                np.savez_compressed(quats_sos_and_eos_file,
                                    quats_sos=mean_quats_sos,
                                    quats_eos=mean_quats_eos)
            mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0)
            mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0)
            for s in range(num_samples):
                pos_sos = \
                    MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0),
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]
                                                    ['positions'][0:1, 0]).double().unsqueeze(0),
                                                    self.joint_parents,
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                                    ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
                affs_sos = MocapDataset.get_mpi_affective_features(pos_sos)
                pos_eos = \
                    MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0),
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]
                                                    ['positions'][-1:, 0]).double().unsqueeze(0),
                                                    self.joint_parents,
                                                    torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict']
                                                    ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy()
                affs_eos = MocapDataset.get_mpi_affective_features(pos_eos)
                self.data_loader['train'][keys[s]]['positions'] = \
                    np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0)
                self.data_loader['train'][keys[s]]['affective_features'] = \
                    np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos),
                                   axis=0)
            return mean_quats_sos, mean_quats_eos

        self.args = args
        self.dataset = args.dataset
        self.channel_map = {
            'Xrotation': 'x',
            'Yrotation': 'y',
            'Zrotation': 'z'
        }
        self.data_loader = data_loader
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.T = T + 2
        self.T_steps = 120
        self.A = A
        self.V = V
        self.C = C
        self.D = D
        self.O = 1
        self.tag_cats = tag_cats
        self.IE = IE
        self.IP = IP
        self.AT = AT
        self.G = G
        self.AGE = AGE
        self.H = H
        self.NT = NT
        self.joint_names = joint_names
        self.joint_parents = joint_parents
        self.lower_body_start = lower_body_start
        self.quats_sos, self.quats_eos = get_quats_sos_and_eos()
        # self.quats_sos = torch.from_numpy(Quaternions.id(self.V).qs).unsqueeze(0)
        # self.quats_eos = torch.from_numpy(Quaternions.from_euler(
        #     np.tile([np.pi / 2., 0, 0], (self.V, 1))).qs).unsqueeze(0)
        self.recons_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.MSELoss()
        self.best_loss = np.inf
        self.loss_updated = False
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_loss_epoch = None
        self.min_train_epochs = min_train_epochs
        self.zfill = fill
        try:
            self.text_processor = torch.load('text_processor.pt')
        except FileNotFoundError:
            self.text_processor = tt.data.Field(
                tokenize=get_tokenizer("basic_english"),
                init_token='<sos>',
                eos_token='<eos>',
                lower=True)
            train_text, eval_text, test_text = tt.datasets.WikiText2.splits(
                self.text_processor)
            self.text_processor.build_vocab(train_text, eval_text, test_text)
        self.text_sos = np.int64(self.text_processor.vocab.stoi['<sos>'])
        self.text_eos = np.int64(self.text_processor.vocab.stoi['<eos>'])
        num_tokens = len(
            self.text_processor.vocab.stoi)  # the size of vocabulary
        self.Z = Z + 2  # embedding dimension
        num_hidden_units_enc = 200  # the dimension of the feedforward network model in nn.TransformerEncoder
        num_hidden_units_dec = 200  # the dimension of the feedforward network model in nn.TransformerDecoder
        num_layers_enc = 2  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        num_layers_dec = 2  # the number of nn.TransformerEncoderLayer in nn.TransformerDecoder
        num_heads_enc = 2  # the number of heads in the multi-head attention in nn.TransformerEncoder
        num_heads_dec = 2  # the number of heads in the multi-head attention in nn.TransformerDecoder
        # num_hidden_units_enc = 216  # the dimension of the feedforward network model in nn.TransformerEncoder
        # num_hidden_units_dec = 512  # the dimension of the feedforward network model in nn.TransformerDecoder
        # num_layers_enc = 2  # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
        # num_layers_dec = 4  # the number of nn.TransformerEncoderLayer in nn.TransformerDecoder
        # num_heads_enc = 6  # the number of heads in the multi-head attention in nn.TransformerEncoder
        # num_heads_dec = self.V  # the number of heads in the multi-head attention in nn.TransformerDecoder
        dropout = 0.2  # the dropout value
        self.model = T2GNet(num_tokens, self.T - 1, self.Z, self.V * self.D,
                            self.D, self.V - 1, self.IE, self.IP, self.AT,
                            self.G, self.AGE, self.H, self.NT, num_heads_enc,
                            num_heads_dec, num_hidden_units_enc,
                            num_hidden_units_dec, num_layers_enc,
                            num_layers_dec, dropout)
        if self.args.use_multiple_gpus and torch.cuda.device_count() > 1:
            self.args.batch_size *= torch.cuda.device_count()
            self.model = nn.DataParallel(self.model)
        self.model.to(torch.cuda.current_device())
        print('Total training data:\t\t{}'.format(
            len(self.data_loader['train'])))
        print('Total validation data:\t\t{}'.format(
            len(self.data_loader['test'])))
        print('Training with batch size:\t{}'.format(self.args.batch_size))

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr)
            # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr
예제 #6
0
    def __init__(self,
                 args,
                 dataset,
                 data_loader,
                 T,
                 V,
                 C,
                 D,
                 A,
                 T_out,
                 joint_parents,
                 num_classes,
                 affs_max,
                 affs_min,
                 min_train_epochs=-1,
                 label_weights=None,
                 generate_while_train=False,
                 poses_mean=None,
                 poses_std=None,
                 save_path=None,
                 device='cuda:0'):

        self.args = args
        self.device = device
        self.dataset = dataset
        self.data_loader = data_loader
        self.num_classes = num_classes
        self.affs_max = affs_max
        self.affs_min = affs_min
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.T = T
        self.V = V
        self.C = C
        self.D = D
        self.A = A
        self.T_out = T_out
        self.P = int(0.9 * T)
        self.joint_parents = joint_parents
        self.model = hap.HAPPY(self.dataset,
                               T,
                               V,
                               C,
                               A,
                               T_out,
                               num_classes,
                               residual=self.args.residual)
        self.model.cuda(device)
        self.model.apply(weights_init)
        self.model_GRU_h_enc = None
        self.model_GRU_h_dec1 = None
        self.model_GRU_h_dec = None
        self.label_weights = torch.from_numpy(label_weights).cuda().float()
        self.loss = semisup_loss
        self.best_loss = math.inf
        self.best_mean_ap = 0.
        self.loss_updated = False
        self.mean_ap_updated = False
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_loss_epoch = None
        self.best_acc_epoch = None
        self.min_train_epochs = min_train_epochs
        self.beta = 0.1

        # generate
        self.generate_while_train = generate_while_train
        self.poses_mean = poses_mean
        self.poses_std = poses_std
        self.save_path = save_path
        self.dataset = dataset

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr)
            # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tr = self.args.base_tr
예제 #7
0
class Processor(object):
    """
        Processor for gait generation
    """
    def __init__(self,
                 args,
                 dataset,
                 data_loader,
                 T,
                 V,
                 C,
                 D,
                 A,
                 T_out,
                 joint_parents,
                 num_classes,
                 affs_max,
                 affs_min,
                 min_train_epochs=-1,
                 label_weights=None,
                 generate_while_train=False,
                 poses_mean=None,
                 poses_std=None,
                 save_path=None,
                 device='cuda:0'):

        self.args = args
        self.device = device
        self.dataset = dataset
        self.data_loader = data_loader
        self.num_classes = num_classes
        self.affs_max = affs_max
        self.affs_min = affs_min
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.T = T
        self.V = V
        self.C = C
        self.D = D
        self.A = A
        self.T_out = T_out
        self.P = int(0.9 * T)
        self.joint_parents = joint_parents
        self.model = hap.HAPPY(self.dataset,
                               T,
                               V,
                               C,
                               A,
                               T_out,
                               num_classes,
                               residual=self.args.residual)
        self.model.cuda(device)
        self.model.apply(weights_init)
        self.model_GRU_h_enc = None
        self.model_GRU_h_dec1 = None
        self.model_GRU_h_dec = None
        self.label_weights = torch.from_numpy(label_weights).cuda().float()
        self.loss = semisup_loss
        self.best_loss = math.inf
        self.best_mean_ap = 0.
        self.loss_updated = False
        self.mean_ap_updated = False
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_loss_epoch = None
        self.best_acc_epoch = None
        self.min_train_epochs = min_train_epochs
        self.beta = 0.1

        # generate
        self.generate_while_train = generate_while_train
        self.poses_mean = poses_mean
        self.poses_std = poses_std
        self.save_path = save_path
        self.dataset = dataset

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr)
            # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tr = self.args.base_tr

    def process_data(self, data, poses, diffs, affs):
        data = data.float().to(self.device)
        poses = poses.float().to(self.device)
        diffs = diffs.float().to(self.device)
        affs = affs.float().to(self.device)
        return data, poses, diffs, affs

    def load_best_model(self, ):
        loaded_vars = torch.load(
            os.path.join(self.args.work_dir, 'taew_weights.pth.tar'))
        self.model.load_state_dict(loaded_vars['model_dict'])
        self.model_GRU_h_enc = loaded_vars['h_enc']
        self.model_GRU_h_dec1 = loaded_vars['h_dec1']
        self.model_GRU_h_dec = loaded_vars['h_dec']

    def adjust_lr(self):
        self.lr = self.lr * self.args.lr_decay
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def adjust_tr(self):
        self.tr = self.tr * self.args.tr_decay

    def show_epoch_info(self, show_best=True):

        print_epochs = [
            self.best_loss_epoch if self.best_loss_epoch is not None else 0,
            self.best_acc_epoch if self.best_acc_epoch is not None else 0,
            self.best_acc_epoch if self.best_acc_epoch is not None else 0
        ]
        best_metrics = [self.best_loss, 0, self.best_mean_ap]
        i = 0
        for k, v in self.epoch_info.items():
            if show_best:
                self.io.print_log(
                    '\t{}: {}. Best so far: {} (epoch: {:d}).'.format(
                        k, v, best_metrics[i], print_epochs[i]))
            else:
                self.io.print_log('\t{}: {}.'.format(k, v))
            i += 1
        if self.args.pavi_log:
            self.io.log('train', self.meta_info['iter'], self.epoch_info)

    def show_iter_info(self):

        if self.meta_info['iter'] % self.args.log_interval == 0:
            info = '\tIter {} Done.'.format(self.meta_info['iter'])
            for k, v in self.iter_info.items():
                if isinstance(v, float):
                    info = info + ' | {}: {:.4f}'.format(k, v)
                else:
                    info = info + ' | {}: {}'.format(k, v)

            self.io.print_log(info)

            if self.args.pavi_log:
                self.io.log('train', self.meta_info['iter'], self.iter_info)

    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        loss_value = []
        ap_values = []
        mean_ap_value = []

        for data, poses, rots, affs, num_frames, labels in train_loader:
            # get data
            num_frames_actual, sort_idx = torch.sort(
                (num_frames * self.T).type(torch.IntTensor).cuda(),
                descending=True)
            seq_lens = num_frames_actual - 1
            data = data[sort_idx, :, :]
            poses = poses[sort_idx, :]
            rots = rots[sort_idx, :, :]
            affs = affs[sort_idx, :]
            data, poses, rots, affs = self.process_data(
                data, poses, rots, affs)

            # forward
            labels_pred, diffs_recons, diffs_recons_pre_norm, affs_pred,\
            self.model_GRU_h_enc, self.model_GRU_h_dec1, self.model_GRU_h_dec =\
                self.model(poses, rots[:, :, 1:], affs, teacher_steps=int(self.T * self.tr))
            loss = self.loss(labels,
                             labels_pred,
                             rots,
                             diffs_recons,
                             diffs_recons_pre_norm,
                             self.meta_info['epoch'],
                             affs,
                             affs_pred,
                             self.V,
                             self.D,
                             self.num_classes[0],
                             label_weights=self.label_weights)

            # backward
            self.optimizer.zero_grad()
            loss.backward()
            # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip)
            self.optimizer.step()

            # statistics
            self.iter_info['loss'] = loss.data.item()
            self.iter_info['aps'], self.iter_info['mean_ap'], self.iter_info['f1'], _ =\
                calculate_metrics(labels[0].detach().cpu().numpy(),
                                  labels_pred[0].detach().cpu().numpy())
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tr'] = '{:.6f}'.format(self.tr)
            loss_value.append(self.iter_info['loss'])
            ap_values.append(self.iter_info['aps'])
            mean_ap_value.append(self.iter_info['mean_ap'])
            self.show_iter_info()
            self.meta_info['iter'] += 1

        self.epoch_info['mean_loss'] = np.mean(loss_value)
        self.epoch_info['mean_aps'] = np.mean(ap_values, axis=0)
        self.epoch_info['mean_mean_ap'] = np.mean(mean_ap_value)
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tr()

    def per_test(self, epoch=None, evaluation=True):

        self.model.eval()
        test_loader = self.data_loader['test']
        loss_value = []
        mean_ap_value = []
        ap_values = []
        label_frag = []

        for data, poses, diffs, affs, num_frames, labels in test_loader:
            # get data
            num_frames_actual, sort_idx = torch.sort(
                (num_frames * self.T).type(torch.IntTensor).cuda(),
                descending=True)
            seq_lens = num_frames_actual - 1
            data = data[sort_idx, :]
            poses = poses[sort_idx, :]
            diffs = diffs[sort_idx, :, :]
            affs = affs[sort_idx, :]
            data, poses, diffs, affs = self.process_data(
                data, poses, diffs, affs)

            # inference
            with torch.no_grad():
                labels_pred, diffs_recons, diffs_recons_pre_norm, affs_pred, _, _, _ = \
                    self.model(poses, diffs[:, :, 1:], affs,
                               teacher_steps=int(self.T * self.tr))

            # get loss
            if evaluation:
                loss = self.loss(labels,
                                 labels_pred,
                                 diffs,
                                 diffs_recons,
                                 diffs_recons_pre_norm,
                                 self.meta_info['epoch'],
                                 affs,
                                 affs_pred,
                                 self.V,
                                 self.D,
                                 self.num_classes[0],
                                 label_weights=self.label_weights,
                                 eval_time=True)
                loss_value.append(loss.item())
                ap, mean_ap, _, _ = calculate_metrics(
                    labels[0].detach().cpu().numpy(),
                    labels_pred[0].detach().cpu().numpy(),
                    eval_time=True)
                ap_values.append(ap)
                mean_ap_value.append(mean_ap)

                label_frag.append(labels[0].data.cpu().numpy())

        if evaluation:
            self.epoch_info['mean_loss'] = np.mean(loss_value)
            self.epoch_info['mean_aps'] = np.mean(ap_values, axis=0)
            self.epoch_info['mean_mean_ap'] = np.mean(mean_ap_value)
            if self.epoch_info[
                    'mean_loss'] < self.best_loss and epoch > self.min_train_epochs:
                self.best_loss = self.epoch_info['mean_loss']
                self.best_loss_epoch = self.meta_info['epoch']
                self.loss_updated = True
            else:
                self.loss_updated = False
            if self.epoch_info[
                    'mean_mean_ap'] > self.best_mean_ap and epoch > self.min_train_epochs:
                self.best_mean_ap = self.epoch_info['mean_mean_ap']
                self.best_acc_epoch = self.meta_info['epoch']
                self.mean_ap_updated = True
            else:
                self.mean_ap_updated = False
            self.show_epoch_info()

    def train(self):

        if self.args.load_last_best:
            self.load_best_model()
            self.args.start_epoch = self.best_loss_epoch
        for epoch in range(self.args.start_epoch, self.args.num_epoch):
            self.meta_info['epoch'] = epoch

            # training
            self.io.print_log('Training epoch: {}'.format(epoch))
            self.per_train()
            self.io.print_log('Done.')

            # evaluation
            if (epoch % self.args.eval_interval
                    == 0) or (epoch + 1 == self.args.num_epoch):
                self.io.print_log('Eval epoch: {}'.format(epoch))
                self.per_test(epoch)
                self.io.print_log('Done.')

            # save model and weights
            if self.loss_updated or self.mean_ap_updated:
                torch.save(
                    {
                        'model_dict': self.model.state_dict(),
                        'h_enc': self.model_GRU_h_enc,
                        'h_dec1': self.model_GRU_h_dec1,
                        'h_dec': self.model_GRU_h_dec
                    },
                    os.path.join(
                        self.args.work_dir,
                        'epoch_{}_loss_{:.4f}_acc_{:.2f}_model.pth.tar'.format(
                            epoch, self.best_loss, self.best_mean_ap * 100.)))

                if self.generate_while_train:
                    self.generate(load_saved_model=False,
                                  samples_to_generate=1)

    def test(self):

        # the path of weights must be appointed
        if self.args.weights is None:
            raise ValueError('Please appoint --weights.')
        self.io.print_log('Model:   {}.'.format(self.args.model))
        self.io.print_log('Weights: {}.'.format(self.args.weights))

        # evaluation
        self.io.print_log('Evaluation Start:')
        self.per_test()
        self.io.print_log('Done.\n')

        # save the output of model
        if self.args.save_result:
            result_dict = dict(
                zip(self.data_loader['test'].dataset.sample_name, self.result))
            self.io.save_pkl(result_dict, 'test_result.pkl')

    def evaluate_model(self, load_saved_model=True):
        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']
        loss_value = []
        mean_ap_value = []
        ap_values = []
        label_frag = []

        for data, poses, diffs, affs, num_frames, labels in test_loader:
            # get data
            num_frames_actual, sort_idx = torch.sort(
                (num_frames * self.T).type(torch.IntTensor).cuda(),
                descending=True)
            seq_lens = num_frames_actual - 1
            data = data[sort_idx, :]
            poses = poses[sort_idx, :]
            diffs = diffs[sort_idx, :, :]
            affs = affs[sort_idx, :]
            data, poses, diffs, affs = self.process_data(
                data, poses, diffs, affs)

            # inference
            with torch.no_grad():
                labels_pred, diffs_recons, diffs_recons_pre_norm, affs_pred, _, _, _ = \
                    self.model(poses, diffs[:, :, 1:], affs,
                               teacher_steps=int(self.T * self.tr))

            # get loss
            loss = self.loss(labels,
                             labels_pred,
                             diffs,
                             diffs_recons,
                             diffs_recons_pre_norm,
                             self.meta_info['epoch'],
                             affs,
                             affs_pred,
                             self.V,
                             self.D,
                             self.num_classes[0],
                             label_weights=self.label_weights,
                             eval_time=True)
            loss_value.append(loss.item())
            ap, mean_ap, _, _ = calculate_metrics(
                labels[0].detach().cpu().numpy(),
                labels_pred[0].detach().cpu().numpy(),
                eval_time=True)
            ap_values.append(ap)
            mean_ap_value.append(mean_ap)

            label_frag.append(labels[0].data.cpu().numpy())

        self.epoch_info['mean_loss'] = np.mean(loss_value)
        self.epoch_info['aps'] = np.mean(ap_values, axis=0)
        self.epoch_info['mean_ap'] = np.mean(mean_ap_value)
        self.show_epoch_info(show_best=False)
예제 #8
0
    def __init__(self, args, dataset, data_loader, T, V, C, D, A, S,
                 joints_dict, joint_names, joint_offsets, joint_parents,
                 num_labels, prefix_length, target_length,
                 min_train_epochs=20, generate_while_train=False,
                 save_path=None, device='cuda:0'):

        self.args = args
        self.dataset = dataset
        self.mocap = MocapDataset(V, C, np.arange(V), joints_dict)
        self.joint_names = joint_names
        self.joint_offsets = joint_offsets
        self.joint_parents = joint_parents
        self.device = device
        self.data_loader = data_loader
        self.num_labels = num_labels
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(
            self.args.work_dir,
            save_log=self.args.save_log,
            print_log=self.args.print_log)

        # model
        self.T = T
        self.V = V
        self.C = C
        self.D = D
        self.A = A
        self.S = S
        self.O = 4
        self.Z = 1
        self.RS = 1
        self.o_scale = 10.
        self.prefix_length = prefix_length
        self.target_length = target_length
        self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O, self.Z, self.RS, num_labels[0])
        self.model.cuda(device)
        self.orient_h = None
        self.quat_h = None
        self.z_rs_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.L1Loss()
        self.spline_loss_func = nn.L1Loss()
        self.best_loss = math.inf
        self.loss_updated = False
        self.mean_ap_updated = False
        self.step_epochs = [math.ceil(float(self.args.num_epoch * x)) for x in self.args.step]
        self.best_loss_epoch = None
        self.min_train_epochs = min_train_epochs

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.args.base_lr,
                momentum=0.9,
                nesterov=self.args.nesterov,
                weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.args.base_lr)
                # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr
예제 #9
0
class Processor(object):
    """
        Processor for emotive gait generation
    """

    def __init__(self, args, dataset, data_loader, T, V, C, D, A, S,
                 joints_dict, joint_names, joint_offsets, joint_parents,
                 num_labels, prefix_length, target_length,
                 min_train_epochs=20, generate_while_train=False,
                 save_path=None, device='cuda:0'):

        self.args = args
        self.dataset = dataset
        self.mocap = MocapDataset(V, C, np.arange(V), joints_dict)
        self.joint_names = joint_names
        self.joint_offsets = joint_offsets
        self.joint_parents = joint_parents
        self.device = device
        self.data_loader = data_loader
        self.num_labels = num_labels
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(
            self.args.work_dir,
            save_log=self.args.save_log,
            print_log=self.args.print_log)

        # model
        self.T = T
        self.V = V
        self.C = C
        self.D = D
        self.A = A
        self.S = S
        self.O = 4
        self.Z = 1
        self.RS = 1
        self.o_scale = 10.
        self.prefix_length = prefix_length
        self.target_length = target_length
        self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O, self.Z, self.RS, num_labels[0])
        self.model.cuda(device)
        self.orient_h = None
        self.quat_h = None
        self.z_rs_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.L1Loss()
        self.spline_loss_func = nn.L1Loss()
        self.best_loss = math.inf
        self.loss_updated = False
        self.mean_ap_updated = False
        self.step_epochs = [math.ceil(float(self.args.num_epoch * x)) for x in self.args.step]
        self.best_loss_epoch = None
        self.min_train_epochs = min_train_epochs

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(
                self.model.parameters(),
                lr=self.args.base_lr,
                momentum=0.9,
                nesterov=self.args.nesterov,
                weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(
                self.model.parameters(),
                lr=self.args.base_lr)
                # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr

    def process_data(self, data, poses, quat, trans, affs):
        data = data.float().to(self.device)
        poses = poses.float().to(self.device)
        quat = quat.float().to(self.device)
        trans = trans.float().to(self.device)
        affs = affs.float().to(self.device)
        return data, poses, quat, trans, affs

    def load_best_model(self, ):
        model_name, self.best_loss_epoch, self.best_loss =\
            get_best_epoch_and_loss(self.args.work_dir)
        best_model_found = False
        try:
            loaded_vars = torch.load(os.path.join(self.args.work_dir, model_name))
            self.model.load_state_dict(loaded_vars['model_dict'])
            self.orient_h = loaded_vars['orient_h']
            self.quat_h = loaded_vars['quat_h']
            best_model_found = True
        except (FileNotFoundError, IsADirectoryError):
            print('No saved model found.')
        return best_model_found

    def adjust_lr(self):
        self.lr = self.lr * self.args.lr_decay
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def adjust_tf(self):
        if self.meta_info['epoch'] > 20:
            self.tf = self.tf * self.args.tf_decay

    def show_epoch_info(self):

        print_epochs = [self.best_loss_epoch if self.best_loss_epoch is not None else 0]
        best_metrics = [self.best_loss]
        i = 0
        for k, v in self.epoch_info.items():
            self.io.print_log('\t{}: {}. Best so far: {} (epoch: {:d}).'.
                              format(k, v, best_metrics[i], print_epochs[i]))
            i += 1
        if self.args.pavi_log:
            self.io.log('train', self.meta_info['iter'], self.epoch_info)

    def show_iter_info(self):

        if self.meta_info['iter'] % self.args.log_interval == 0:
            info = '\tIter {} Done.'.format(self.meta_info['iter'])
            for k, v in self.iter_info.items():
                if isinstance(v, float):
                    info = info + ' | {}: {:.4f}'.format(k, v)
                else:
                    info = info + ' | {}: {}'.format(k, v)

            self.io.print_log(info)

            if self.args.pavi_log:
                self.io.log('train', self.meta_info['iter'], self.iter_info)

    def yield_batch(self, batch_size, dataset):
        batch_pos = np.zeros((batch_size, self.T, self.V, self.C), dtype='float32')
        batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D), dtype='float32')
        batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32')
        batch_z_mean = np.zeros((batch_size, self.Z), dtype='float32')
        batch_z_dev = np.zeros((batch_size, self.T, self.Z), dtype='float32')
        batch_root_speed = np.zeros((batch_size, self.T, self.RS), dtype='float32')
        batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32')
        batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32')
        batch_labels = np.zeros((batch_size, 1, self.num_labels[0]), dtype='float32')
        pseudo_passes = (len(dataset) + batch_size - 1) // batch_size

        probs = []
        for k in dataset.keys():
            if 'spline' not in dataset[k]:
                raise KeyError('No splines found. Perhaps you forgot to compute them?')
            probs.append(dataset[k]['spline'].size())
        probs = np.array(probs) / np.sum(probs)

        for p in range(pseudo_passes):
            rand_keys = np.random.choice(len(dataset), size=batch_size, replace=True, p=probs)
            for i, k in enumerate(rand_keys):
                pos = dataset[str(k)]['positions'][:self.T]
                quat = dataset[str(k)]['rotations'][:self.T, 1:]
                orient = dataset[str(k)]['rotations'][:self.T, 0]
                affs = dataset[str(k)]['affective_features'][:self.T]
                spline, phase = Spline.extract_spline_features(dataset[str(k)]['spline'])
                spline = spline[:self.T]
                phase = phase[:self.T]
                z = dataset[str(k)]['trans_and_controls'][:, 1][:self.T]
                z_mean = np.mean(z[:self.prefix_length])
                z_dev = z - z_mean
                root_speed = dataset[str(k)]['trans_and_controls'][:, -1][:self.T]
                labels = dataset[str(k)]['labels'][:self.num_labels[0]]

                batch_pos[i] = pos
                batch_quat[i] = quat.reshape(self.T, -1)
                batch_orient[i] = orient.reshape(self.T, -1)
                batch_z_mean[i] = z_mean.reshape(-1, 1)
                batch_z_dev[i] = z_dev.reshape(self.T, -1)
                batch_root_speed[i] = root_speed.reshape(self.T, 1)
                batch_affs[i] = affs
                batch_spline[i] = spline
                batch_labels[i] = np.expand_dims(labels, axis=0)
            yield batch_pos, batch_quat, batch_orient, batch_z_mean, batch_z_dev,\
                  batch_root_speed, batch_affs, batch_spline, batch_labels

    def return_batch(self, batch_size, dataset, randomized=True):
        if len(batch_size) > 1:
            rand_keys = np.copy(batch_size)
            batch_size = len(batch_size)
        else:
            batch_size = batch_size[0]
            probs = []
            for k in dataset.keys():
                if 'spline' not in dataset[k]:
                    raise KeyError('No splines found. Perhaps you forgot to compute them?')
                probs.append(dataset[k]['spline'].size())
            probs = np.array(probs) / np.sum(probs)
            if randomized:
                rand_keys = np.random.choice(len(dataset), size=batch_size, replace=False, p=probs)
            else:
                rand_keys = np.arange(batch_size)

        batch_pos = np.zeros((batch_size, self.T, self.V, self.C), dtype='float32')
        batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D), dtype='float32')
        batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32')
        batch_z_mean = np.zeros((batch_size, self.Z), dtype='float32')
        batch_z_dev = np.zeros((batch_size, self.T, self.Z), dtype='float32')
        batch_root_speed = np.zeros((batch_size, self.T, self.RS), dtype='float32')
        batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32')
        batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32')
        batch_labels = np.zeros((batch_size, 1, self.num_labels[0]), dtype='float32')
        pseudo_passes = (len(dataset) + batch_size - 1) // batch_size

        for i, k in enumerate(rand_keys):
            pos = dataset[str(k)]['positions'][:self.T]
            quat = dataset[str(k)]['rotations'][:self.T, 1:]
            orient = dataset[str(k)]['rotations'][:self.T, 0]
            affs = dataset[str(k)]['affective_features'][:self.T]
            spline, phase = Spline.extract_spline_features(dataset[str(k)]['spline'])
            spline = spline[:self.T]
            phase = phase[:self.T]
            z = dataset[str(k)]['trans_and_controls'][:, 1][:self.T]
            z_mean = np.mean(z[:self.prefix_length])
            z_dev = z - z_mean
            root_speed = dataset[str(k)]['trans_and_controls'][:, -1][:self.T]
            labels = dataset[str(k)]['labels'][:self.num_labels[0]]

            batch_pos[i] = pos
            batch_quat[i] = quat.reshape(self.T, -1)
            batch_orient[i] = orient.reshape(self.T, -1)
            batch_z_mean[i] = z_mean.reshape(-1, 1)
            batch_z_dev[i] = z_dev.reshape(self.T, -1)
            batch_root_speed[i] = root_speed.reshape(self.T, 1)
            batch_affs[i] = affs
            batch_spline[i] = spline
            batch_labels[i] = np.expand_dims(labels, axis=0)

        return batch_pos, batch_quat, batch_orient, batch_z_mean, batch_z_dev,\
            batch_root_speed, batch_affs, batch_spline, batch_labels

    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for pos, quat, orient, z_mean, z_dev,\
                root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, train_loader):

            pos = torch.from_numpy(pos).cuda()
            orient = torch.from_numpy(orient).cuda()
            quat = torch.from_numpy(quat).cuda()
            z_mean = torch.from_numpy(z_mean).cuda()
            z_dev = torch.from_numpy(z_dev).cuda()
            root_speed = torch.from_numpy(root_speed).cuda()
            affs = torch.from_numpy(affs).cuda()
            spline = torch.from_numpy(spline).cuda()
            labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1)
            z_rs = torch.cat((z_dev, root_speed), dim=-1)
            quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)

            pos_pred = pos.clone()
            orient_pred = orient.clone()
            quat_pred = quat.clone()
            z_rs_pred = z_rs.clone()
            affs_pred = affs.clone()
            spline_pred = spline.clone()
            pos_pred_all = pos.clone()
            orient_pred_all = orient.clone()
            quat_pred_all = quat.clone()
            z_rs_pred_all = z_rs.clone()
            affs_pred_all = affs.clone()
            spline_pred_all = spline.clone()
            orient_prenorm_terms = torch.zeros_like(orient_pred)
            quat_prenorm_terms = torch.zeros_like(quat_pred)

            # forward
            self.optimizer.zero_grad()
            for t in range(self.target_length):
                orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    self.orient_h, self.quat_h,\
                    orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\
                    quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                    self.model(
                        orient_pred[:, t:self.prefix_length + t],
                        quat_pred[:, t:self.prefix_length + t],
                        z_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline_pred[:, t:self.prefix_length + t],
                        labels[:, t:self.prefix_length + t],
                        orient_h=None if t == 0 else self.orient_h,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, :self.prefix_length + t],
                        pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                        z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean,
                        orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1])
                if np.random.uniform(size=1)[0] > self.tf:
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()
                    spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone()

            prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1)
            prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D)
            quat_norm_loss = self.args.quat_norm_reg * torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2)

            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                torch.cat((orient_pred_all[:, self.prefix_length - 1:],
                           quat_pred_all[:, self.prefix_length - 1:]), dim=-1),
                quat_all, self.V, self.D)
            quat_loss *= self.args.quat_reg

            z_rs_loss = self.z_rs_loss_func(z_rs_pred_all[:, self.prefix_length:],
                                            z_rs[:, self.prefix_length:])
            spline_loss = self.spline_loss_func(spline_pred_all[:, self.prefix_length:],
                                                spline[:, self.prefix_length:])
            fs_loss = losses.foot_speed_loss(pos_pred, pos)
            loss_total = quat_norm_loss + quat_loss + quat_derv_loss + z_rs_loss + spline_loss + fs_loss
            loss_total.backward()
            # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
            self.optimizer.step()

            # animation_pred = {
            #     'joint_names': self.joint_names,
            #     'joint_offsets': torch.from_numpy(self.joint_offsets[1:]).
            #         float().unsqueeze(0).repeat(pos_pred_all.shape[0], 1, 1),
            #     'joint_parents': self.joint_parents,
            #     'positions': pos_pred_all,
            #     'rotations': torch.cat((orient_pred_all, quat_pred_all), dim=-1)
            # }
            # MocapDataset.save_as_bvh(animation_pred,
            #                          dataset_name=self.dataset,
            #                          subset_name='test')

            # Compute statistics
            batch_loss += loss_total.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = loss_total.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()

    def per_test(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        valid_loss = 0.
        N = 0.

        for pos, quat, orient, z_mean, z_dev,\
                root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, test_loader):
            with torch.no_grad():
                pos = torch.from_numpy(pos).cuda()
                orient = torch.from_numpy(orient).cuda()
                quat = torch.from_numpy(quat).cuda()
                z_mean = torch.from_numpy(z_mean).cuda()
                z_dev = torch.from_numpy(z_dev).cuda()
                root_speed = torch.from_numpy(root_speed).cuda()
                affs = torch.from_numpy(affs).cuda()
                spline = torch.from_numpy(spline).cuda()
                labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1)
                z_rs = torch.cat((z_dev, root_speed), dim=-1)
                quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)

                pos_pred = pos.clone()
                orient_pred = orient.clone()
                quat_pred = quat.clone()
                z_rs_pred = z_rs.clone()
                affs_pred = affs.clone()
                spline_pred = spline.clone()
                orient_prenorm_terms = torch.zeros_like(orient_pred)
                quat_prenorm_terms = torch.zeros_like(quat_pred)

                # forward
                for t in range(self.target_length):
                    orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        self.orient_h, self.quat_h,\
                        orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\
                        quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                        self.model(
                            orient_pred[:, t:self.prefix_length + t],
                            quat_pred[:, t:self.prefix_length + t],
                            z_rs_pred[:, t:self.prefix_length + t],
                            affs_pred[:, t:self.prefix_length + t],
                            spline[:, t:self.prefix_length + t],
                            labels[:, t:self.prefix_length + t],
                            orient_h=None if t == 0 else self.orient_h,
                            quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                        affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\
                        spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[:, :self.prefix_length + t],
                            pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                            z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean,
                            orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1])

                prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1)
                prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D)
                quat_norm_loss = self.args.quat_norm_reg *\
                    torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2)

                quat_loss, quat_derv_loss = losses.quat_angle_loss(
                    torch.cat((orient_pred[:, self.prefix_length - 1:],
                               quat_pred[:, self.prefix_length - 1:]), dim=-1),
                    quat_all, self.V, self.D)
                quat_loss *= self.args.quat_reg

                recons_loss = self.args.recons_reg *\
                              (pos_pred[:, self.prefix_length:] - pos_pred[:, self.prefix_length:, 0:1] -
                               pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm()
                valid_loss += recons_loss
                N += quat.shape[0]

        valid_loss /= N
        # if self.meta_info['epoch'] > 5 and self.loss_updated:
        #     pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
        #                        overwrite=True)
        #     pos_in_np = pos_in.contiguous().view(pos_in.shape[0], pos_in.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_in_np, self.V, self.C, self.mocap.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) +
        #                                                               '_gt',
        #                        overwrite=True)

        self.epoch_info['mean_loss'] = valid_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info['epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()

    def train(self):

        if self.args.load_last_best:
            best_model_found = self.load_best_model()
            self.args.start_epoch = self.best_loss_epoch if best_model_found else 0
        for epoch in range(self.args.start_epoch, self.args.num_epoch):
            self.meta_info['epoch'] = epoch

            # training
            self.io.print_log('Training epoch: {}'.format(epoch))
            self.per_train()
            self.io.print_log('Done.')

            # evaluation
            if (epoch % self.args.eval_interval == 0) or (
                    epoch + 1 == self.args.num_epoch):
                self.io.print_log('Eval epoch: {}'.format(epoch))
                self.per_test()
                self.io.print_log('Done.')

            # save model and weights
            if self.loss_updated:
                torch.save({'model_dict': self.model.state_dict(),
                            'orient_h': self.orient_h,
                            'quat_h': self.quat_h},
                           os.path.join(self.args.work_dir, 'epoch_{}_loss_{:.4f}_model.pth.tar'.
                                        format(epoch, self.best_loss)))

                if self.generate_while_train:
                    self.generate_motion(load_saved_model=False, samples_to_generate=1)

    def copy_prefix(self, var, prefix_length=None):
        if prefix_length is None:
            prefix_length = self.prefix_length
        return [var[s, :prefix_length].unsqueeze(0) for s in range(var.shape[0])]

    def generate_linear_trajectory(self, traj, alpha=0.001):
        traj_markers = (traj[:, self.prefix_length - 2] +
                        (traj[:, self.prefix_length - 1] - traj[:, self.prefix_length - 2]) / alpha).unsqueeze(1)
        return traj_markers

    def generate_circular_trajectory(self, traj, alpha=5., num_segments=10):
        last_segment = alpha * traj[:, self.prefix_length - 1:self.prefix_length] -\
                       traj[:, self.prefix_length - 2:self.prefix_length - 1]
        last_marker = traj[:, self.prefix_length - 1:self.prefix_length]
        traj_markers = last_marker.clone()
        angle_per_segment = 2. * np.pi / num_segments
        for _ in range(num_segments):
            next_segment = qrot(expmap_to_quaternion(
                torch.tensor([0, -angle_per_segment, 0]).cuda().float().repeat(
                    last_segment.shape[0], last_segment.shape[1], 1)), torch.cat((
                last_segment[..., 0:1],
                torch.zeros_like(last_segment[..., 0:1]),
                last_segment[..., 1:]), dim=-1))[..., [0, 2]]
            next_marker = next_segment + last_marker
            traj_markers = torch.cat((traj_markers, next_marker), dim=1)
            last_segment = next_segment.clone()
            last_marker = next_marker.clone()
        traj_markers = traj_markers[:, 1:]
        return traj_markers

    def compute_next_traj_point(self, traj, traj_marker, rs_pred):
        tangent = traj_marker - traj
        tangent /= (torch.norm(tangent, dim=-1) + 1e-9)
        return tangent * rs_pred + traj

    def compute_next_traj_point_sans_markers(self, pos_last, quat_next, z_pred, rs_pred):
        # pos_next = torch.zeros_like(pos_last)
        offsets = torch.from_numpy(self.mocap.joint_offsets).cuda().float(). \
            unsqueeze(0).unsqueeze(0).repeat(pos_last.shape[0], pos_last.shape[1], 1, 1)
        pos_next = MocapDataset.forward_kinematics(quat_next.contiguous().view(quat_next.shape[0],
                                                                               quat_next.shape[1], -1, self.D),
                                                   pos_last[:, :, 0],
                                                   self.joint_parents,
                                                   torch.from_numpy(self.joint_offsets).float().cuda())
        # for joint in range(1, self.V):
        #     pos_next[:, :, joint] = qrot(quat_copy[:, :, joint - 1], offsets[:, :, joint]) \
        #                             + pos_next[:, :, self.mocap.joint_parents[joint]]
        root = pos_next[:, :, 0]
        l_shoulder = pos_next[:, :, 18]
        r_shoulder = pos_next[:, :, 25]
        facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]]
        facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9)
        return rs_pred * facing + pos_last[:, :, 0, [0, 2]]

    def get_diff_from_traj(self, pos_pred, traj_pred, s):
        root = pos_pred[s][:, :, 0]
        l_shoulder = pos_pred[s][:, :, 18]
        r_shoulder = pos_pred[s][:, :, 25]
        facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]]
        facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9)
        tangents = traj_pred[s][:, 1:] - traj_pred[s][:, :-1]
        tangent_norms = torch.norm(tangents, dim=-1)
        tangents /= (tangent_norms[..., None] + 1e-9)
        tangents = torch.cat((torch.zeros_like(tangents[:, 0:1]), tangents), dim=1)
        tangent_norms = torch.cat((torch.zeros_like(tangent_norms[:, 0:1]), tangent_norms), dim=1)
        axis_diff = torch.cross(torch.cat((facing[..., 0:1],
                                           torch.zeros_like(facing[..., 0:1]),
                                           facing[..., 1:]), dim=-1),
                                torch.cat((tangents[..., 0:1],
                                           torch.zeros_like(tangents[..., 0:1]),
                                           tangents[..., 1:]), dim=-1))
        axis_diff_norms = torch.norm(axis_diff, dim=-1)
        axis_diff /= (axis_diff_norms[..., None] + 1e-9)
        angle_diff = torch.acos(torch.einsum('ijk,ijk->ij', facing, tangents).clamp(min=-1., max=1.))
        angle_diff[tangent_norms < 1e-6] = 0.
        return axis_diff, angle_diff

    def rotate_gaits(self, orient_pred, quat_pred, quat_diff, head_tilt, l_shoulder_slouch, r_shoulder_slouch):
        quat_reshape = quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D).clone()
        quat_reshape[..., 14, :] = qmul(torch.from_numpy(head_tilt).cuda().float(),
                                        quat_reshape[..., 14, :])
        quat_reshape[..., 16, :] = qmul(torch.from_numpy(l_shoulder_slouch).cuda().float(),
                                        quat_reshape[..., 16, :])
        quat_reshape[..., 17, :] = qmul(torch.from_numpy(qinv(l_shoulder_slouch)).cuda().float(),
                                        quat_reshape[..., 17, :])
        quat_reshape[..., 23, :] = qmul(torch.from_numpy(r_shoulder_slouch).cuda().float(),
                                        quat_reshape[..., 23, :])
        quat_reshape[..., 24, :] = qmul(torch.from_numpy(qinv(r_shoulder_slouch)).cuda().float(),
                                        quat_reshape[..., 24, :])
        return qmul(quat_diff, orient_pred), quat_reshape.contiguous().view(quat_reshape.shape[0],
                                                                            quat_reshape.shape[1], -1)

    def generate_motion(self, load_saved_model=True, samples_to_generate=1534, max_steps=300, randomized=True):

        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']

        pos, quat, orient, z_mean, z_dev, \
        root_speed, affs, spline, labels = self.return_batch([samples_to_generate], test_loader, randomized=randomized)
        pos = torch.from_numpy(pos).cuda()
        traj = pos[:, :, 0, [0, 2]].clone()
        orient = torch.from_numpy(orient).cuda()
        quat = torch.from_numpy(quat).cuda()
        z_mean = torch.from_numpy(z_mean).cuda()
        z_dev = torch.from_numpy(z_dev).cuda()
        root_speed = torch.from_numpy(root_speed).cuda()
        affs = torch.from_numpy(affs).cuda()
        spline = torch.from_numpy(spline).cuda()
        z_rs = torch.cat((z_dev, root_speed), dim=-1)
        quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1)
        labels = np.tile(labels, (1, max_steps + self.prefix_length, 1))

        # Begin for transition
        # traj[:, self.prefix_length - 2] = torch.tensor([-0.208, 4.8]).cuda().float()
        # traj[:, self.prefix_length - 1] = torch.tensor([-0.204, 5.1]).cuda().float()
        # final_emo_idx = int(max_steps/2)
        # labels[:, final_emo_idx:] = np.array([1., 0., 0., 0.])
        # labels[:, :final_emo_idx + 1] = np.linspace(labels[:, 0], labels[:, final_emo_idx],
        #                                             num=final_emo_idx + 1, axis=1)
        # End for transition
        labels = torch.from_numpy(labels).cuda()

        # traj_np = traj_markers.detach().cpu().numpy()
        # import matplotlib.pyplot as plt
        # plt.plot(traj_np[6, :, 0], traj_np[6, :, 1])
        # plt.show()

        happy_idx = [25, 295, 390, 667, 1196]
        sad_idx = [169, 184, 258, 948, 974]
        angry_idx = [89, 93, 96, 112, 289, 290, 978]
        neutral_idx = [72, 106, 143, 237, 532, 747, 1177]
        sample_idx = np.squeeze(np.concatenate((happy_idx, sad_idx, angry_idx, neutral_idx)))

        ## CHANGE HERE
        # scene_corners = torch.tensor([[149.862, 50.833],
        #                               [149.862, 36.81],
        #                               [161.599, 36.81],
        #                               [161.599, 50.833]]).cuda().float()
        # character_heights = torch.tensor([0.95, 0.88, 0.86, 0.90, 0.95, 0.82]).cuda().float()
        # num_characters_per_side = torch.tensor([2, 3, 2, 3]).cuda().int()
        # traj_markers, traj_offsets, character_scale =\
        #     generate_trajectories(scene_corners, z_mean, character_heights,
        #                           num_characters_per_side, traj[:, :self.prefix_length])
        # num_characters_per_side = torch.tensor([4, 0, 0, 0]).cuda().int()
        # traj_markers, traj_offsets, character_scale =\
        #     generate_simple_trajectories(scene_corners, z_mean[:4], z_mean[:4],
        #                                  num_characters_per_side, traj[sample_idx, :self.prefix_length])
        # traj_markers, traj_offsets, character_scale =\
        #     generate_rvo_trajectories(scene_corners, z_mean[:4], z_mean[:4],
        #                               num_characters_per_side, traj[sample_idx, :self.prefix_length])

        # traj[sample_idx, :self.prefix_length] += traj_offsets
        # pos_sampled = pos[sample_idx].clone()
        # pos_sampled[:, :self.prefix_length, :, [0, 2]] += traj_offsets.unsqueeze(2).repeat(1, 1, self.V, 1)
        # pos[sample_idx] = pos_sampled
        # traj_markers = self.generate_linear_trajectory(traj)

        pos_pred = self.copy_prefix(pos)
        traj_pred = self.copy_prefix(traj)
        orient_pred = self.copy_prefix(orient)
        quat_pred = self.copy_prefix(quat)
        z_rs_pred = self.copy_prefix(z_rs)
        affs_pred = self.copy_prefix(affs)
        spline_pred = self.copy_prefix(spline)
        labels_pred = self.copy_prefix(labels, prefix_length=max_steps + self.prefix_length)

        # forward
        elapsed_time = np.zeros(len(sample_idx))
        for counter, s in enumerate(sample_idx):  # range(samples_to_generate):
            start_time = time.time()
            orient_h_copy = self.orient_h.clone()
            quat_h_copy = self.quat_h.clone()
            ## CHANGE HERE
            num_markers = max_steps + self.prefix_length + 1
            # num_markers = traj_markers[s].shape[0]
            marker_idx = 0
            t = -1
            with torch.no_grad():
                while marker_idx < num_markers:
                    t += 1
                    if t > max_steps:
                        print('Sample: {}. Did not reach end in {} steps.'.format(s, max_steps), end='')
                        break
                    pos_pred[s] = torch.cat((pos_pred[s], torch.zeros_like(pos_pred[s][:, -1:])), dim=1)
                    traj_pred[s] = torch.cat((traj_pred[s], torch.zeros_like(traj_pred[s][:, -1:])), dim=1)
                    orient_pred[s] = torch.cat((orient_pred[s], torch.zeros_like(orient_pred[s][:, -1:])), dim=1)
                    quat_pred[s] = torch.cat((quat_pred[s], torch.zeros_like(quat_pred[s][:, -1:])), dim=1)
                    z_rs_pred[s] = torch.cat((z_rs_pred[s], torch.zeros_like(z_rs_pred[s][:, -1:])), dim=1)
                    affs_pred[s] = torch.cat((affs_pred[s], torch.zeros_like(affs_pred[s][:, -1:])), dim=1)
                    spline_pred[s] = torch.cat((spline_pred[s], torch.zeros_like(spline_pred[s][:, -1:])), dim=1)

                    orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    orient_h_copy, quat_h_copy = \
                        self.model(
                            orient_pred[s][:, t:self.prefix_length + t],
                            quat_pred[s][:, t:self.prefix_length + t],
                            z_rs_pred[s][:, t:self.prefix_length + t],
                            affs_pred[s][:, t:self.prefix_length + t],
                            spline_pred[s][:, t:self.prefix_length + t],
                            labels_pred[s][:, t:self.prefix_length + t],
                            orient_h=None if t == 0 else orient_h_copy,
                            quat_h=None if t == 0 else quat_h_copy, return_prenorm=False)

                    traj_curr = traj_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t].clone()
                    # root_speed = torch.norm(
                    #     pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] - \
                    #     pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t, 0], dim=-1)

                    ## CHANGE HERE
                    # traj_next = \
                    #     self.compute_next_traj_point(
                    #         traj_curr,
                    #         traj_markers[s, marker_idx],
                    #         o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2])
                    try:
                        traj_next = traj[s, self.prefix_length + t]
                    except IndexError:
                        traj_next = \
                            self.compute_next_traj_point_sans_markers(
                                pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t],
                                quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                                z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                                z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 1])

                    pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[s][:, :self.prefix_length + t],
                            traj_next,
                            z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1],
                            orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1])

                    # min_speed_pred = torch.min(torch.cat((lf_speed_pred.unsqueeze(-1),
                    #                                        rf_speed_pred.unsqueeze(-1)), dim=-1), dim=-1)[0]
                    # if min_speed_pred - diff_speeds_mean[s] - diff_speeds_std[s] < 0.:
                    #     root_speed_pred = 0.
                    # else:
                    #     root_speed_pred = o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2]
                    #

                    ## CHANGE HERE
                    # traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    #     self.compute_next_traj_point(
                    #         traj_curr,
                    #         traj_markers[s, marker_idx],
                    #         root_speed_pred)
                    # if torch.norm(traj_next - traj_curr, dim=-1).squeeze() >= \
                    #         torch.norm(traj_markers[s, marker_idx] - traj_curr, dim=-1).squeeze():
                    #     marker_idx += 1
                    traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = traj_next
                    marker_idx += 1
                    pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \
                    spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        self.mocap.get_predicted_features(
                            pos_pred[s][:, :self.prefix_length + t],
                            pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]],
                            z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1],
                            orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1],
                            quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1])
                    print('Sample: {}. Steps: {}'.format(s, t), end='\r')
            print()

            # shift = torch.zeros((1, scene_corners.shape[1] + 1)).cuda().float()
            # shift[..., [0, 2]] = scene_corners[0]
            # pos_pred[s] = (pos_pred[s] - shift) / character_scale + shift
            # pos_pred_np = pos_pred[s].contiguous().view(pos_pred[s].shape[0],
            #                                             pos_pred[s].shape[1], -1).permute(0, 2, 1).\
            #     detach().cpu().numpy()
            # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
            #                    dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
            #                    save_file_names=[str(s).zfill(6)],
            #                    overwrite=True)

            # plt.cla()
            # fig, (ax1, ax2) = plt.subplots(2, 1)
            # ax1.plot(root_speeds[s])
            # ax1.plot(lf_speeds[s])
            # ax1.plot(rf_speeds[s])
            # ax1.plot(min_speeds[s] - root_speeds[s])
            # ax1.legend(['root', 'left', 'right', 'diff'])
            # ax2.plot(root_speeds_pred)
            # ax2.plot(lf_speeds_pred)
            # ax2.plot(rf_speeds_pred)
            # ax2.plot(min_speeds_pred - root_speeds_pred)
            # ax2.legend(['root', 'left', 'right', 'diff'])
            # plt.show()

            head_tilt = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            l_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            r_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1))
            head_tilt = Quaternions.from_euler(head_tilt, order='xyz').qs
            l_shoulder_slouch = Quaternions.from_euler(l_shoulder_slouch, order='xyz').qs
            r_shoulder_slouch = Quaternions.from_euler(r_shoulder_slouch, order='xyz').qs

            # Begin for aligning facing direction to trajectory
            axis_diff, angle_diff = self.get_diff_from_traj(pos_pred, traj_pred, s)
            angle_thres = 0.3
            # angle_thres = torch.max(angle_diff[:, 1:self.prefix_length])
            angle_diff[angle_diff <= angle_thres] = 0.
            angle_diff[:, self.prefix_length] = 0.
            # End for aligning facing direction to trajectory
            # pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff,
            #                                         head_tilt, l_shoulder_slouch, r_shoulder_slouch, s)
            # pos_pred[s] = pos_copy.clone()
            # angle_diff_intermediate = self.get_diff_from_traj(pos_pred, traj_pred, s)
            # if torch.max(angle_diff_intermediate[:, self.prefix_length:]) > np.pi / 2.:
            #     quat_diff = Quaternions.from_angle_axis(-angle_diff.cpu().numpy(), np.array([0, 1, 0])).qs
            #     pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff,
            #                                         head_tilt, l_shoulder_slouch, r_shoulder_slouch, s)
            # pos_pred[s] = pos_copy.clone()
            # axis_diff = torch.zeros_like(axis_diff)
            # axis_diff[..., 1] = 1.
            # angle_diff = torch.zeros_like(angle_diff)
            quat_diff = torch.from_numpy(Quaternions.from_angle_axis(
                angle_diff.cpu().numpy(), axis_diff.cpu().numpy()).qs).cuda().float()
            orient_pred[s], quat_pred[s] = self.rotate_gaits(orient_pred[s], quat_pred[s],
                                                             quat_diff, head_tilt,
                                                             l_shoulder_slouch, r_shoulder_slouch)

            if labels_pred[s][:, 0, 0] > 0.5:
                label_dir = 'happy'
            elif labels_pred[s][:, 0, 1] > 0.5:
                label_dir = 'sad'
            elif labels_pred[s][:, 0, 2] > 0.5:
                label_dir = 'angry'
            else:
                label_dir = 'neutral'

            ## CHANGE HERE
            # pos_pred[s] = pos_pred[s][:, self.prefix_length + 5:]
            # o_z_rs_pred[s] = o_z_rs_pred[s][:, self.prefix_length + 5:]
            # quat_pred[s] = quat_pred[s][:, self.prefix_length + 5:]

            traj_pred_np = pos_pred[s][0, :, 0].cpu().numpy()

            save_file_name = '{:06}_{:.2f}_{:.2f}_{:.2f}_{:.2f}'.format(s,
                                                                        labels_pred[s][0, 0, 0],
                                                                        labels_pred[s][0, 0, 1],
                                                                        labels_pred[s][0, 0, 2],
                                                                        labels_pred[s][0, 0, 3])

            animation_pred = {
                'joint_names': self.joint_names,
                'joint_offsets': torch.from_numpy(self.joint_offsets[1:]).float().unsqueeze(0).repeat(
                    len(pos_pred), 1, 1),
                'joint_parents': self.joint_parents,
                'positions': pos_pred[s],
                'rotations': torch.cat((orient_pred[s], quat_pred[s]), dim=-1)
            }
            self.mocap.save_as_bvh(animation_pred,
                                   dataset_name=self.dataset,
                                   # subset_name='epoch_' + str(self.best_loss_epoch),
                                   # save_file_names=[str(s).zfill(6)])
                                   subset_name=os.path.join('no_aff_epoch_' + str(self.best_loss_epoch),
                                                            str(counter).zfill(2) + '_' + label_dir),
                                   save_file_names=['root'])
            end_time = time.time()
            elapsed_time[counter] = end_time - start_time
            print('Elapsed Time: {}'.format(elapsed_time[counter]))

            # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True,
            #                    dataset_name=self.dataset,
            #                    # subset_name='epoch_' + str(self.best_loss_epoch),
            #                    # save_file_names=[str(s).zfill(6)],
            #                    subset_name=os.path.join('epoch_' + str(self.best_loss_epoch), label_dir),
            #                    save_file_names=[save_file_name],
            #                    overwrite=True)
        print('Mean Elapsed Time: {}'.format(np.mean(elapsed_time)))
class Processor(object):
    """
        Processor for gait generation
    """
    def __init__(self,
                 args,
                 dataset,
                 data_loader,
                 T,
                 V,
                 C,
                 D,
                 A,
                 S,
                 joint_parents,
                 num_labels,
                 prefix_length,
                 target_length,
                 min_train_epochs=-1,
                 generate_while_train=False,
                 save_path=None,
                 device='cuda:0'):

        self.args = args
        self.dataset = dataset
        self.mocap = MocapDataset(V, C, joint_parents)
        self.device = device
        self.data_loader = data_loader
        self.num_labels = num_labels
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.T = T
        self.V = V
        self.C = C
        self.D = D
        self.A = A
        self.S = S
        self.O = 1
        self.PRS = 2
        self.prefix_length = prefix_length
        self.target_length = target_length
        self.joint_parents = joint_parents
        self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O,
                                                num_labels[0], self.PRS)
        self.model.cuda(device)
        self.quat_h = None
        self.p_rs_loss_func = nn.L1Loss()
        self.affs_loss_func = nn.L1Loss()
        self.best_loss = math.inf
        self.best_mean_ap = 0.
        self.loss_updated = False
        self.mean_ap_updated = False
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_loss_epoch = None
        self.best_acc_epoch = None
        self.min_train_epochs = min_train_epochs

        # generate
        self.generate_while_train = generate_while_train
        self.save_path = save_path

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr)
            # weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr
        self.tf = self.args.base_tr

    def process_data(self, data, poses, quat, trans, affs):
        data = data.float().to(self.device)
        poses = poses.float().to(self.device)
        quat = quat.float().to(self.device)
        trans = trans.float().to(self.device)
        affs = affs.float().to(self.device)
        return data, poses, quat, trans, affs

    def load_best_model(self, ):
        if self.best_loss_epoch is None:
            model_name, self.best_loss_epoch, self.best_loss, self.best_mean_ap =\
                get_best_epoch_and_loss(self.args.work_dir)
            # load model
            # if self.best_loss_epoch > 0:
        loaded_vars = torch.load(os.path.join(self.args.work_dir, model_name))
        self.model.load_state_dict(loaded_vars['model_dict'])
        self.quat_h = loaded_vars['quat_h']

    def adjust_lr(self):
        self.lr = self.lr * self.args.lr_decay
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = self.lr

    def adjust_tf(self):
        if self.meta_info['epoch'] > 20:
            self.tf = self.tf * self.args.tf_decay

    def show_epoch_info(self):

        print_epochs = [
            self.best_loss_epoch if self.best_loss_epoch is not None else 0,
            self.best_acc_epoch if self.best_acc_epoch is not None else 0,
            self.best_acc_epoch if self.best_acc_epoch is not None else 0
        ]
        best_metrics = [self.best_loss, 0, self.best_mean_ap]
        i = 0
        for k, v in self.epoch_info.items():
            self.io.print_log(
                '\t{}: {}. Best so far: {} (epoch: {:d}).'.format(
                    k, v, best_metrics[i], print_epochs[i]))
            i += 1
        if self.args.pavi_log:
            self.io.log('train', self.meta_info['iter'], self.epoch_info)

    def show_iter_info(self):

        if self.meta_info['iter'] % self.args.log_interval == 0:
            info = '\tIter {} Done.'.format(self.meta_info['iter'])
            for k, v in self.iter_info.items():
                if isinstance(v, float):
                    info = info + ' | {}: {:.4f}'.format(k, v)
                else:
                    info = info + ' | {}: {}'.format(k, v)

            self.io.print_log(info)

            if self.args.pavi_log:
                self.io.log('train', self.meta_info['iter'], self.iter_info)

    def yield_batch(self, batch_size, dataset):
        batch_pos = np.zeros((batch_size, self.T, self.V, self.C),
                             dtype='float32')
        batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D),
                              dtype='float32')
        batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32')
        batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32')
        batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32')
        batch_phase_and_root_speed = np.zeros((batch_size, self.T, self.PRS),
                                              dtype='float32')
        batch_labels = np.zeros((batch_size, 1, self.num_labels[0]),
                                dtype='float32')
        pseudo_passes = (len(dataset) + batch_size - 1) // batch_size

        probs = []
        for k in dataset.keys():
            if 'spline' not in dataset[k]:
                raise KeyError(
                    'No splines found. Perhaps you forgot to compute them?')
            probs.append(dataset[k]['spline'].size())
        probs = np.array(probs) / np.sum(probs)

        for p in range(pseudo_passes):
            rand_keys = np.random.choice(len(dataset),
                                         size=batch_size,
                                         replace=True,
                                         p=probs)
            for i, k in enumerate(rand_keys):
                pos = dataset[str(k)]['positions_world']
                quat = dataset[str(k)]['rotations']
                orient = dataset[str(k)]['orientations']
                affs = dataset[str(k)]['affective_features']
                spline, phase = Spline.extract_spline_features(
                    dataset[str(k)]['spline'])
                root_speed = dataset[str(k)]['trans_and_controls'][:,
                                                                   -1].reshape(
                                                                       -1, 1)
                labels = dataset[str(k)]['labels'][:self.num_labels[0]]

                batch_pos[i] = pos
                batch_quat[i] = quat.reshape(self.T, -1)
                batch_orient[i] = orient.reshape(self.T, -1)
                batch_affs[i] = affs
                batch_spline[i] = spline
                batch_phase_and_root_speed[i] = np.concatenate(
                    (phase, root_speed), axis=-1)
                batch_labels[i] = np.expand_dims(labels, axis=0)
            yield batch_pos, batch_quat, batch_orient, batch_affs, batch_spline,\
                  batch_phase_and_root_speed / np.pi, batch_labels

    def return_batch(self, batch_size, dataset):
        if len(batch_size) > 1:
            rand_keys = np.copy(batch_size)
            batch_size = len(batch_size)
        else:
            batch_size = batch_size[0]
            probs = []
            for k in dataset.keys():
                if 'spline' not in dataset[k]:
                    raise KeyError(
                        'No splines found. Perhaps you forgot to compute them?'
                    )
                probs.append(dataset[k]['spline'].size())
            probs = np.array(probs) / np.sum(probs)
            rand_keys = np.random.choice(len(dataset),
                                         size=batch_size,
                                         replace=False,
                                         p=probs)

        batch_pos = np.zeros((batch_size, self.T, self.V, self.C),
                             dtype='float32')
        batch_traj = np.zeros((batch_size, self.T, self.C), dtype='float32')
        batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D),
                              dtype='float32')
        batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32')
        batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32')
        batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32')
        batch_phase_and_root_speed = np.zeros((batch_size, self.T, self.PRS),
                                              dtype='float32')
        batch_labels = np.zeros((batch_size, 1, self.num_labels[0]),
                                dtype='float32')

        for i, k in enumerate(rand_keys):
            pos = dataset[str(k)]['positions_world']
            traj = dataset[str(k)]['trajectory']
            quat = dataset[str(k)]['rotations']
            orient = dataset[str(k)]['orientations']
            affs = dataset[str(k)]['affective_features']
            spline, phase = Spline.extract_spline_features(
                dataset[str(k)]['spline'])
            root_speed = dataset[str(k)]['trans_and_controls'][:, -1].reshape(
                -1, 1)
            labels = dataset[str(k)]['labels'][:self.num_labels[0]]

            batch_pos[i] = pos
            batch_traj[i] = traj
            batch_quat[i] = quat.reshape(self.T, -1)
            batch_orient[i] = orient.reshape(self.T, -1)
            batch_affs[i] = affs
            batch_spline[i] = spline
            batch_phase_and_root_speed[i] = np.concatenate((phase, root_speed),
                                                           axis=-1)
            batch_labels[i] = np.expand_dims(labels, axis=0)

        return batch_pos, batch_traj, batch_quat, batch_orient, batch_affs, batch_spline,\
               batch_phase_and_root_speed, batch_labels

    def per_train(self):

        self.model.train()
        train_loader = self.data_loader['train']
        batch_loss = 0.
        N = 0.

        for pos, quat, orient, affs, spline, p_rs, labels in self.yield_batch(
                self.args.batch_size, train_loader):

            pos = torch.from_numpy(pos).cuda()
            quat = torch.from_numpy(quat).cuda()
            orient = torch.from_numpy(orient).cuda()
            affs = torch.from_numpy(affs).cuda()
            spline = torch.from_numpy(spline).cuda()
            p_rs = torch.from_numpy(p_rs).cuda()
            labels = torch.from_numpy(labels).cuda()

            pos_pred = pos.clone()
            quat_pred = quat.clone()
            p_rs_pred = p_rs.clone()
            affs_pred = affs.clone()
            pos_pred_all = pos.clone()
            quat_pred_all = quat.clone()
            p_rs_pred_all = p_rs.clone()
            affs_pred_all = affs.clone()
            prenorm_terms = torch.zeros_like(quat_pred)

            # forward
            self.optimizer.zero_grad()
            for t in range(self.target_length):
                quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    self.quat_h, prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                    self.model(
                        quat_pred[:, t:self.prefix_length + t],
                        p_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline[:, t:self.prefix_length + t],
                        orient[:, t:self.prefix_length + t],
                        labels,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\
                    affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],
                        orient[:, self.prefix_length + t:self.prefix_length + t + 1])
                if np.random.uniform(size=1)[0] > self.tf:
                    pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]
                    affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                        affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1]

            prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0],
                                               prenorm_terms.shape[1], -1,
                                               self.D)
            quat_norm_loss = self.args.quat_norm_reg * torch.mean(
                (torch.sum(prenorm_terms**2, dim=-1) - 1)**2)

            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                quat_pred_all[:, self.prefix_length - 1:],
                quat[:, self.prefix_length - 1:], self.V, self.D)
            quat_loss *= self.args.quat_reg

            p_rs_loss = self.p_rs_loss_func(
                p_rs_pred_all[:, self.prefix_length:],
                p_rs[:, self.prefix_length:])
            affs_loss = self.affs_loss_func(
                affs_pred_all[:, self.prefix_length:],
                affs[:, self.prefix_length:])
            # recons_loss = self.args.recons_reg *\
            #               (pos_pred_all[:, self.prefix_length:] - pos_pred_all[:, self.prefix_length:, 0:1] -
            #                 pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm()

            loss_total = quat_norm_loss + quat_loss + quat_derv_loss + p_rs_loss + affs_loss  # + recons_loss
            loss_total.backward()
            # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip)
            self.optimizer.step()

            # Compute statistics
            batch_loss += loss_total.item()
            N += quat.shape[0]

            # statistics
            self.iter_info['loss'] = loss_total.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            self.iter_info['tf'] = '{:.6f}'.format(self.tf)
            self.show_iter_info()
            self.meta_info['iter'] += 1

        batch_loss = batch_loss / N
        self.epoch_info['mean_loss'] = batch_loss
        self.show_epoch_info()
        self.io.print_timer()
        self.adjust_lr()
        self.adjust_tf()

    def per_test(self):

        self.model.eval()
        test_loader = self.data_loader['test']
        valid_loss = 0.
        N = 0.

        for pos, quat, orient, affs, spline, p_rs, labels in self.yield_batch(
                self.args.batch_size, test_loader):
            pos = torch.from_numpy(pos).cuda()
            quat = torch.from_numpy(quat).cuda()
            orient = torch.from_numpy(orient).cuda()
            affs = torch.from_numpy(affs).cuda()
            spline = torch.from_numpy(spline).cuda()
            p_rs = torch.from_numpy(p_rs).cuda()
            labels = torch.from_numpy(labels).cuda()

            pos_pred = pos.clone()
            quat_pred = quat.clone()
            p_rs_pred = p_rs.clone()
            affs_pred = affs.clone()
            prenorm_terms = torch.zeros_like(quat_pred)

            # forward
            self.optimizer.zero_grad()
            for t in range(self.target_length):
                quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    self.quat_h, prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \
                    self.model(
                        quat_pred[:, t:self.prefix_length + t],
                        p_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline[:, t:self.prefix_length + t],
                        orient[:, t:self.prefix_length + t],
                        labels,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=True)
                pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0],
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                        orient[:, self.prefix_length + t:self.prefix_length + t + 1])

            prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0],
                                               prenorm_terms.shape[1], -1,
                                               self.D)
            quat_norm_loss = self.args.quat_norm_reg * torch.mean(
                (torch.sum(prenorm_terms**2, dim=-1) - 1)**2)

            quat_loss, quat_derv_loss = losses.quat_angle_loss(
                quat_pred[:, self.prefix_length - 1:],
                quat[:, self.prefix_length - 1:], self.V, self.D)
            quat_loss *= self.args.quat_reg

            recons_loss = self.args.recons_reg *\
                          (pos_pred[:, self.prefix_length:] - pos_pred[:, self.prefix_length:, 0:1] -
                           pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm()
            valid_loss += recons_loss
            N += quat.shape[0]

        valid_loss /= N
        # if self.meta_info['epoch'] > 5 and self.loss_updated:
        #     pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_pred_np, self.V, self.C, self.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch),
        #                        overwrite=True)
        #     pos_in_np = pos_in.contiguous().view(pos_in.shape[0], pos_in.shape[1], -1).permute(0, 2, 1).\
        #         detach().cpu().numpy()
        #     display_animations(pos_in_np, self.V, self.C, self.joint_parents, save=True,
        #                        dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) +
        #                                                               '_gt',
        #                        overwrite=True)

        self.epoch_info['mean_loss'] = valid_loss
        if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[
                'epoch'] > self.min_train_epochs:
            self.best_loss = self.epoch_info['mean_loss']
            self.best_loss_epoch = self.meta_info['epoch']
            self.loss_updated = True
        else:
            self.loss_updated = False
        self.show_epoch_info()

    def train(self):

        if self.args.load_last_best:
            self.load_best_model()
            self.args.start_epoch = self.best_loss_epoch
        for epoch in range(self.args.start_epoch, self.args.num_epoch):
            self.meta_info['epoch'] = epoch

            # training
            self.io.print_log('Training epoch: {}'.format(epoch))
            self.per_train()
            self.io.print_log('Done.')

            # evaluation
            if (epoch % self.args.eval_interval
                    == 0) or (epoch + 1 == self.args.num_epoch):
                self.io.print_log('Eval epoch: {}'.format(epoch))
                self.per_test()
                self.io.print_log('Done.')

            # save model and weights
            if self.loss_updated:
                torch.save(
                    {
                        'model_dict': self.model.state_dict(),
                        'quat_h': self.quat_h
                    },
                    os.path.join(
                        self.args.work_dir,
                        'epoch_{}_loss_{:.4f}_acc_{:.2f}_model.pth.tar'.format(
                            epoch, self.best_loss, self.best_mean_ap * 100.)))

                if self.generate_while_train:
                    self.generate_motion(load_saved_model=False,
                                         samples_to_generate=1)

    def copy_prefix(self, var, target_length):
        shape = list(var.shape)
        shape[1] = self.prefix_length + target_length
        var_pred = torch.zeros(torch.Size(shape)).cuda().float()
        var_pred[:, :self.prefix_length] = var[:, :self.prefix_length]
        return var_pred

    def flip_trajectory(self, traj, target_length):
        traj_flipped = traj[:, -(target_length - self.target_length):].flip(
            dims=[1])
        orient_flipped = torch.zeros(
            (traj_flipped.shape[0], traj_flipped.shape[1], 1)).cuda().float()
        # orient_flipped[:, 0] = np.pi
        # traj_diff = traj_flipped[:, 1:, [0, 2]] - traj_flipped[:, :-1, [0, 2]]
        # traj_diff /= torch.norm(traj_diff, dim=-1)[..., None]
        # orient_flipped[:, 1:, 0] = torch.atan2(traj_diff[:, :, 1], traj_diff[:, :, 0])
        return traj_flipped, orient_flipped

    def generate_motion(self,
                        load_saved_model=True,
                        target_length=100,
                        samples_to_generate=10):

        if load_saved_model:
            self.load_best_model()
        self.model.eval()
        test_loader = self.data_loader['test']

        pos, traj, quat, orient, affs, spline, p_rs, labels = self.return_batch(
            [samples_to_generate], test_loader)
        pos = torch.from_numpy(pos).cuda()
        traj = torch.from_numpy(traj).cuda()
        quat = torch.from_numpy(quat).cuda()
        orient = torch.from_numpy(orient).cuda()
        affs = torch.from_numpy(affs).cuda()
        spline = torch.from_numpy(spline).cuda()
        p_rs = torch.from_numpy(p_rs).cuda()
        labels = torch.from_numpy(labels).cuda()

        traj_flipped, orient_flipped = self.flip_trajectory(
            traj, target_length)
        traj = torch.cat((traj, traj_flipped), dim=1)
        orient = torch.cat((orient, orient_flipped), dim=1)

        pos_pred = self.copy_prefix(pos, target_length)
        quat_pred = self.copy_prefix(quat, target_length)
        p_rs_pred = self.copy_prefix(p_rs, target_length)
        affs_pred = self.copy_prefix(affs, target_length)
        spline_pred = self.copy_prefix(spline, target_length)

        # forward
        with torch.no_grad():
            for t in range(target_length):
                quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \
                    self.quat_h = \
                    self.model(
                        quat_pred[:, t:self.prefix_length + t],
                        p_rs_pred[:, t:self.prefix_length + t],
                        affs_pred[:, t:self.prefix_length + t],
                        spline_pred[:, t:self.prefix_length + t],
                        orient[:, t:self.prefix_length + t],
                        labels,
                        quat_h=None if t == 0 else self.quat_h, return_prenorm=False)
                data_pred = \
                    self.mocap.get_predicted_features(
                        pos_pred[:, :self.prefix_length + t],
                        orient[:, :self.prefix_length + t],
                        traj[:, self.prefix_length + t:self.prefix_length + t + 1],
                        quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],
                        orient[:, self.prefix_length + t:self.prefix_length + t + 1])
                pos_pred[:, self.prefix_length + t:self.prefix_length + t +
                         1] = data_pred['positions_world']
                affs_pred[:, self.prefix_length + t:self.prefix_length + t +
                          1] = data_pred['affective_features']
                spline_pred[:, self.prefix_length + t:self.prefix_length + t +
                            1] = data_pred['spline']


            recons_loss = self.args.recons_reg *\
                          (pos_pred[:, self.prefix_length:self.T] - pos_pred[:, self.prefix_length:self.T, 0:1] -
                           pos[:, self.prefix_length:self.T] + pos[:, self.prefix_length:self.T, 0:1]).norm()

        pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\
            detach().cpu().numpy()
        pos_np = pos.contiguous().view(pos.shape[0], pos.shape[1], -1).permute(0, 2, 1).\
            detach().cpu().numpy()
        display_animations(pos_pred_np,
                           self.V,
                           self.C,
                           self.joint_parents,
                           save=True,
                           dataset_name=self.dataset,
                           subset_name='epoch_' + str(self.best_loss_epoch),
                           overwrite=True)
        display_animations(pos_np,
                           self.V,
                           self.C,
                           self.joint_parents,
                           save=True,
                           dataset_name=self.dataset,
                           subset_name='epoch_' + str(self.best_loss_epoch) +
                           '_gt',
                           overwrite=True)
        self.mocap.save_as_bvh(
            traj.detach().cpu().numpy(),
            orient.detach().cpu().numpy(),
            np.reshape(quat_pred.detach().cpu().numpy(),
                       (quat_pred.shape[0], quat_pred.shape[1], -1, self.D)),
            'render/bvh')
예제 #11
0
class Processor(object):
    """
        Processor for gait generation
    """
    def __init__(self,
                 args,
                 data_loader,
                 C,
                 F,
                 num_classes,
                 graph_dict,
                 device='cuda:0'):

        self.args = args
        self.data_loader = data_loader
        self.num_classes = num_classes
        self.result = dict()
        self.iter_info = dict()
        self.epoch_info = dict()
        self.meta_info = dict(epoch=0, iter=0)
        self.device = device
        self.io = IO(self.args.work_dir,
                     save_log=self.args.save_log,
                     print_log=self.args.print_log)

        # model
        self.model = classifier.Classifier(C, F, num_classes, graph_dict)
        self.model.cuda('cuda:0')
        self.model.apply(weights_init)
        self.loss = nn.CrossEntropyLoss()
        self.best_loss = math.inf
        self.step_epochs = [
            math.ceil(float(self.args.num_epoch * x)) for x in self.args.step
        ]
        self.best_epoch = None
        self.best_accuracy = np.zeros((1, np.max(self.args.topk)))
        self.accuracy_updated = False

        # optimizer
        if self.args.optimizer == 'SGD':
            self.optimizer = optim.SGD(self.model.parameters(),
                                       lr=self.args.base_lr,
                                       momentum=0.9,
                                       nesterov=self.args.nesterov,
                                       weight_decay=self.args.weight_decay)
        elif self.args.optimizer == 'Adam':
            self.optimizer = optim.Adam(self.model.parameters(),
                                        lr=self.args.base_lr,
                                        weight_decay=self.args.weight_decay)
        else:
            raise ValueError()
        self.lr = self.args.base_lr

    def adjust_lr(self):

        # if self.args.optimizer == 'SGD' and \
        if self.meta_info['epoch'] in self.step_epochs:
            lr = self.args.base_lr * (0.1**np.sum(
                self.meta_info['epoch'] >= np.array(self.step_epochs)))
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr
            self.lr = lr

    def show_epoch_info(self):

        for k, v in self.epoch_info.items():
            self.io.print_log('\t{}: {}'.format(k, v))
        if self.args.pavi_log:
            self.io.log('train', self.meta_info['iter'], self.epoch_info)

    def show_iter_info(self):

        if self.meta_info['iter'] % self.args.log_interval == 0:
            info = '\tIter {} Done.'.format(self.meta_info['iter'])
            for k, v in self.iter_info.items():
                if isinstance(v, float):
                    info = info + ' | {}: {:.4f}'.format(k, v)
                else:
                    info = info + ' | {}: {}'.format(k, v)

            self.io.print_log(info)

            if self.args.pavi_log:
                self.io.log('train', self.meta_info['iter'], self.iter_info)

    def show_topk(self, k, epoch):

        rank = self.result.argsort()
        hit_top_k = [l in rank[i, -k:] for i, l in enumerate(self.label)]
        accuracy = 100. * sum(hit_top_k) * 1.0 / len(hit_top_k)
        if accuracy > self.best_accuracy[0, k - 1]:
            self.best_accuracy[0, k - 1] = accuracy
            self.accuracy_updated = True
            self.best_epoch = epoch
        else:
            self.accuracy_updated = False
        print_epoch = self.best_epoch if self.best_epoch is not None else 0
        self.io.print_log(
            '\tTop{}: {:.2f}%. Best so far: {:.2f}% (epoch: {:d}).'.format(
                k, accuracy, self.best_accuracy[0, k - 1], print_epoch))

    def per_train(self):

        self.model.train()
        self.adjust_lr()
        loader = self.data_loader['train']
        loss_value = []

        for aff, gait, label in loader:
            # get data
            aff = aff.float().to(self.device)
            gait = gait.float().to(self.device)
            label = label.long().to(self.device)

            # forward
            output, _ = self.model(aff, gait)
            loss = self.loss(output, label)

            # backward
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # statistics
            self.iter_info['loss'] = loss.data.item()
            self.iter_info['lr'] = '{:.6f}'.format(self.lr)
            loss_value.append(self.iter_info['loss'])
            self.show_iter_info()
            self.meta_info['iter'] += 1

        self.epoch_info['mean_loss'] = np.mean(loss_value)
        self.show_epoch_info()
        self.io.print_timer()

    def per_test(self, epoch=None, evaluation=True):

        self.model.eval()
        loader = self.data_loader['test']
        loss_value = []
        result_frag = []
        label_frag = []

        for aff, gait, label in loader:

            # get data
            aff = aff.float().to(self.device)
            gait = gait.float().to(self.device)
            label = label.long().to(self.device)

            # inference
            with torch.no_grad():
                output, _ = self.model(aff, gait)
            result_frag.append(output.data.cpu().numpy())

            # get loss
            if evaluation:
                loss = self.loss(output, label)
                loss_value.append(loss.item())
                label_frag.append(label.data.cpu().numpy())

        self.result = np.concatenate(result_frag)
        if evaluation:
            self.label = np.concatenate(label_frag)
            self.epoch_info['mean_loss'] = np.mean(loss_value)
            self.show_epoch_info()

            # show top-k accuracy
            for k in self.args.topk:
                self.show_topk(k, epoch)

    def train(self):

        for epoch in range(self.args.start_epoch, self.args.num_epoch):
            self.meta_info['epoch'] = epoch

            # training
            self.io.print_log('Training epoch: {}'.format(epoch))
            self.per_train()
            self.io.print_log('Done.')

            # evaluation
            if (epoch % self.args.eval_interval
                    == 0) or (epoch + 1 == self.args.num_epoch):
                self.io.print_log('Eval epoch: {}'.format(epoch))
                self.per_test(epoch=epoch)
                self.io.print_log('Done.')

            # save model and weights
            if self.accuracy_updated:
                torch.save(
                    self.model.state_dict(),
                    os.path.join(
                        self.args.work_dir,
                        'epoch_{}_acc_{:.2f}_model.pth.tar'.format(
                            epoch, self.best_accuracy.item())))

    def test(self):

        # the path of weights must be appointed
        if self.args.weights is None:
            raise ValueError('Please appoint --weights.')
        self.io.print_log('Model:   {}.'.format(self.args.model))
        self.io.print_log('Weights: {}.'.format(self.args.weights))

        # evaluation
        self.io.print_log('Evaluation Start:')
        self.per_test()
        self.io.print_log('Done.\n')

        # save the output of model
        if self.args.save_result:
            result_dict = dict(
                zip(self.data_loader['test'].dataset.sample_name, self.result))
            self.io.save_pkl(result_dict, 'test_result.pkl')

    def extract_best_feature(self, data, joints, coords):
        if self.best_epoch is None:
            self.best_epoch, best_accuracy = get_best_epoch_and_accuracy(
                self.args.work_dir)
        else:
            best_accuracy = self.best_accuracy.item()
        if self.best_epoch is not None:
            filename = os.path.join(
                self.args.work_dir, 'epoch_{}_acc_{:.2f}_model.pth.tar'.format(
                    self.best_epoch, best_accuracy))
            self.model.load_state_dict(torch.load(filename))
        label_preds = np.empty(len(data))
        features = np.empty((0, 4))

        for i, each_data in enumerate(data):
            # get data
            aff = each_data[0]
            aff = torch.from_numpy(aff).float().to(self.device)
            aff = aff.unsqueeze(0)
            gait = each_data[1]
            gait = np.reshape(gait, (1, gait.shape[0], joints, coords, 1))
            gait = np.moveaxis(gait, [1, 2, 3], [2, 3, 1])
            gait = torch.from_numpy(gait).float().to(self.device)

            # get feature
            self.model.eval()
            with torch.no_grad():
                output, feature = self.model(aff, gait)
                label_preds[i] = np.argmax(output.cpu().numpy())
                features = np.append(features,
                                     feature.cpu().numpy().reshape(
                                         (1, feature.shape[1])),
                                     axis=0)

        return features, label_preds