def __init__(self, args, data_loader, C, F, num_classes, graph_dict, device='cuda:0'): self.args = args self.data_loader = data_loader self.num_classes = num_classes self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.device = device self.io = IO(self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.model = classifier.Classifier(C, F, num_classes, graph_dict) self.model.cuda('cuda:0') self.model.apply(weights_init) self.loss = nn.CrossEntropyLoss() self.best_loss = math.inf self.step_epochs = [ math.ceil(float(self.args.num_epoch * x)) for x in self.args.step ] self.best_epoch = None self.best_accuracy = np.zeros((1, np.max(self.args.topk))) self.accuracy_updated = False # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.base_lr, weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr
class Processor(object): """ Processor for gait generation """ def __init__(self, args, data_loader, C, num_classes, graph_dict, device='cuda:0', verbose=True): self.args = args self.data_loader = data_loader self.num_classes = num_classes self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.device = device self.verbose = verbose self.io = IO( self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model if not os.path.isdir(self.args.work_dir): os.mkdir(self.args.work_dir) self.model = classifier.Classifier(C, num_classes, graph_dict) self.model.cuda('cuda:0') self.model.apply(weights_init) self.loss = nn.CrossEntropyLoss() self.best_loss = math.inf self.step_epochs = [math.ceil(float(self.args.num_epoch * x)) for x in self.args.step] self.best_epoch = None self.best_accuracy = np.zeros((1, np.max(self.args.topk))) self.accuracy_updated = False # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD( self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam( self.model.parameters(), lr=self.args.base_lr, weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr def adjust_lr(self): # if self.args.optimizer == 'SGD' and\ if self.meta_info['epoch'] in self.step_epochs: lr = self.args.base_lr * ( 0.1 ** np.sum(self.meta_info['epoch'] >= np.array(self.step_epochs))) for param_group in self.optimizer.param_groups: param_group['lr'] = lr self.lr = lr def show_epoch_info(self): for k, v in self.epoch_info.items(): if self.verbose: self.io.print_log('\t{}: {}'.format(k, v)) if self.args.pavi_log: if self.verbose: self.io.log('train', self.meta_info['iter'], self.epoch_info) def show_iter_info(self): if self.meta_info['iter'] % self.args.log_interval == 0: info = '\tIter {} Done.'.format(self.meta_info['iter']) for k, v in self.iter_info.items(): if isinstance(v, float): info = info + ' | {}: {:.4f}'.format(k, v) else: info = info + ' | {}: {}'.format(k, v) if self.verbose: self.io.print_log(info) if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.iter_info) def show_topk(self, k): rank = self.result.argsort() hit_top_k = [l in rank[i, -k:] for i, l in enumerate(self.label)] accuracy = 100. * sum(hit_top_k) * 1.0 / len(hit_top_k) if accuracy > self.best_accuracy[0, k-1]: self.best_accuracy[0, k-1] = accuracy self.accuracy_updated = True else: self.accuracy_updated = False if self.verbose: print_epoch = self.best_epoch if self.best_epoch is not None else 0 self.io.print_log('\tTop{}: {:.2f}%. Best so far: {:.2f}% (epoch: {:d}).'. format(k, accuracy, self.best_accuracy[0, k-1], print_epoch)) def per_train(self): self.model.train() self.adjust_lr() loader = self.data_loader['train'] loss_value = [] for data, label in loader: # get data data = data.float().to(self.device) label = label.long().to(self.device) # forward output, _ = self.model(data) loss = self.loss(output, label) # backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() # statistics self.iter_info['loss'] = loss.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) loss_value.append(self.iter_info['loss']) self.show_iter_info() self.meta_info['iter'] += 1 self.epoch_info['mean_loss'] = np.mean(loss_value) self.show_epoch_info() if self.verbose: self.io.print_timer() # for k in self.args.topk: # self.calculate_topk(k, show=False) # if self.accuracy_updated: # self.model.extract_feature() def per_test(self, evaluation=True): self.model.eval() loader = self.data_loader['test'] loss_value = [] result_frag = [] label_frag = [] for data, label in loader: # get data data = data.float().to(self.device) label = label.long().to(self.device) # inference with torch.no_grad(): output, _ = self.model(data) result_frag.append(output.data.cpu().numpy()) # get loss if evaluation: loss = self.loss(output, label) loss_value.append(loss.item()) label_frag.append(label.data.cpu().numpy()) self.result = np.concatenate(result_frag) if evaluation: self.label = np.concatenate(label_frag) self.epoch_info['mean_loss'] = np.mean(loss_value) self.show_epoch_info() # show top-k accuracy for k in self.args.topk: self.show_topk(k) def train(self): for epoch in range(self.args.start_epoch, self.args.num_epoch): self.meta_info['epoch'] = epoch # training if self.verbose: self.io.print_log('Training epoch: {}'.format(epoch)) self.per_train() if self.verbose: self.io.print_log('Done.') # evaluation if (epoch % self.args.eval_interval == 0) or ( epoch + 1 == self.args.num_epoch): if self.verbose: self.io.print_log('Eval epoch: {}'.format(epoch)) self.per_test() if self.verbose: self.io.print_log('Done.') # save model and weights if self.accuracy_updated: torch.save(self.model.state_dict(), os.path.join(self.args.work_dir, 'epoch{}_acc{:.2f}_model.pth.tar'.format(epoch, self.best_accuracy.item()))) if self.epoch_info['mean_loss'] < self.best_loss: self.best_loss = self.epoch_info['mean_loss'] self.best_epoch = epoch def test(self): # the path of weights must be appointed if self.args.weights is None: raise ValueError('Please appoint --weights.') if self.verbose: self.io.print_log('Model: {}.'.format(self.args.model)) self.io.print_log('Weights: {}.'.format(self.args.weights)) # evaluation if self.verbose: self.io.print_log('Evaluation Start:') self.per_test() if self.verbose: self.io.print_log('Done.\n') # save the output of model if self.args.save_result: result_dict = dict( zip(self.data_loader['test'].dataset.sample_name, self.result)) self.io.save_pkl(result_dict, 'test_result.pkl') def smap(self): # self.model.eval() loader = self.data_loader['test'] for data, label in loader: # get data data = data.float().to(self.device) label = label.long().to(self.device) GBP = GuidedBackprop(self.model) guided_grads = GBP.generate_gradients(data, label) def load_best_model(self): if self.best_epoch is None: self.best_epoch, best_accuracy = get_best_epoch_and_accuracy(self.args.work_dir) else: best_accuracy = self.best_accuracy.item() filename = os.path.join(self.args.work_dir, 'epoch{}_acc{:.2f}_model.pth.tar'.format(self.best_epoch, best_accuracy)) self.model.load_state_dict(torch.load(filename)) def generate_predictions(self, data, num_classes, joints, coords): # fin = h5py.File('../data/features'+ftype+'.h5', 'r') # fkeys = fin.keys() labels_pred = np.zeros(data.shape[0]) output = np.zeros((data.shape[0], num_classes)) for i, each_data in enumerate(zip(data)): # get data each_data = each_data[0] each_data = np.reshape(each_data, (1, each_data.shape[0], joints, coords, 1)) each_data = np.moveaxis(each_data, [1, 2, 3], [2, 3, 1]) each_data = torch.from_numpy(each_data).float().to(self.device) # get label with torch.no_grad(): output[i], _ = self.model(each_data) labels_pred[i] = np.argmax(output[i]) return labels_pred, output def generate_confusion_matrix(self, ftype, data, labels, num_classes, joints, coords): self.load_best_model() labels_pred = self.generate_predictions(data, num_classes, joints, coords) hit = np.nonzero(labels_pred == labels) miss = np.nonzero(labels_pred != labels) confusion_matrix = np.zeros((num_classes, num_classes)) for hidx in np.arange(len(hit[0])): confusion_matrix[np.int(labels[hit[0][hidx]]), np.int(labels_pred[hit[0][hidx]])] += 1 for midx in np.arange(len(miss[0])): confusion_matrix[np.int(labels[miss[0][midx]]), np.int(labels_pred[miss[0][midx]])] += 1 confusion_matrix = confusion_matrix.transpose() plot_confusion_matrix(confusion_matrix) def save_best_feature(self, ftype, data, joints, coords): if self.best_epoch is None: self.best_epoch, best_accuracy = get_best_epoch_and_accuracy(self.args.work_dir) else: best_accuracy = self.best_accuracy.item() filename = os.path.join(self.args.work_dir, 'epoch{}_acc{:.2f}_model.pth.tar'.format(self.best_epoch, best_accuracy)) self.model.load_state_dict(torch.load(filename)) features = np.empty((0, 64)) fCombined = h5py.File('../data/features'+ftype+'.h5', 'r') fkeys = fCombined.keys() dfCombined = h5py.File('../data/deepFeatures'+ftype+'.h5', 'w') for i, (each_data, each_key) in enumerate(zip(data, fkeys)): # get data each_data = np.reshape(each_data, (1, each_data.shape[0], joints, coords, 1)) each_data = np.moveaxis(each_data, [1, 2, 3], [2, 3, 1]) each_data = torch.from_numpy(each_data).float().to(self.device) # get feature with torch.no_grad(): _, feature = self.model(each_data) fname = [each_key][0] dfCombined.create_dataset(fname, data=feature) features = np.append(features, np.array(feature).reshape((1, feature.shape[0])), axis=0) dfCombined.close() return features
def __init__(self, args, data_path, data_loader, Z, T, A, V, C, D, tag_cats, IE, IP, AT, G, AGE, H, NT, joint_names, joint_parents, word2idx, embedding_table, lower_body_start=15, fill=6, min_train_epochs=20, generate_while_train=False, save_path=None, device='cuda:0'): def get_quats_sos_and_eos(): quats_sos_and_eos_file = os.path.join(data_path, 'quats_sos_and_eos.npz') keys = list(self.data_loader['train'].keys()) num_samples = len(self.data_loader['train']) try: mean_quats_sos = np.load(quats_sos_and_eos_file, allow_pickle=True)['quats_sos'] mean_quats_eos = np.load(quats_sos_and_eos_file, allow_pickle=True)['quats_eos'] except FileNotFoundError: mean_quats_sos = np.zeros((self.V, self.D)) mean_quats_eos = np.zeros((self.V, self.D)) for j in range(self.V): quats_sos = np.zeros((self.D, num_samples)) quats_eos = np.zeros((self.D, num_samples)) for s in range(num_samples): quats_sos[:, s] = self.data_loader['train'][ keys[s]]['rotations'][0, j] quats_eos[:, s] = self.data_loader['train'][ keys[s]]['rotations'][-1, j] _, sos_eig_vectors = np.linalg.eig( np.dot(quats_sos, quats_sos.T)) mean_quats_sos[j] = sos_eig_vectors[:, 0] _, eos_eig_vectors = np.linalg.eig( np.dot(quats_eos, quats_eos.T)) mean_quats_eos[j] = eos_eig_vectors[:, 0] np.savez_compressed(quats_sos_and_eos_file, quats_sos=mean_quats_sos, quats_eos=mean_quats_eos) mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0) mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0) for s in range(num_samples): pos_sos = \ MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0), torch.from_numpy(self.data_loader['train'][keys[s]] ['positions'][0:1, 0]).double().unsqueeze(0), self.joint_parents, torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict'] ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy() affs_sos = MocapDataset.get_mpi_affective_features(pos_sos) pos_eos = \ MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0), torch.from_numpy(self.data_loader['train'][keys[s]] ['positions'][-1:, 0]).double().unsqueeze(0), self.joint_parents, torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict'] ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy() affs_eos = MocapDataset.get_mpi_affective_features(pos_eos) self.data_loader['train'][keys[s]]['positions'] = \ np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0) self.data_loader['train'][keys[s]]['affective_features'] = \ np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos), axis=0) return mean_quats_sos, mean_quats_eos self.args = args self.dataset = args.dataset self.channel_map = { 'Xrotation': 'x', 'Yrotation': 'y', 'Zrotation': 'z' } self.device = device self.data_loader = data_loader self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.io = IO(self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.T = T + 2 self.A = A self.V = V self.C = C self.D = D self.O = 1 self.tag_cats = tag_cats self.IE = IE self.IP = IP self.AT = AT self.G = G self.AGE = AGE self.H = H self.NT = NT self.joint_names = joint_names self.joint_parents = joint_parents self.lower_body_start = lower_body_start self.quats_sos, self.quats_eos = get_quats_sos_and_eos() self.recons_loss_func = nn.L1Loss() self.affs_loss_func = nn.L1Loss() self.best_loss = np.inf self.loss_updated = False self.step_epochs = [ math.ceil(float(self.args.num_epoch * x)) for x in self.args.step ] self.best_loss_epoch = None self.min_train_epochs = min_train_epochs self.zfill = fill self.word2idx = word2idx self.text_sos = np.int64(self.word2idx['<SOS>']) self.text_eos = np.int64(self.word2idx['<EOS>']) num_tokens = len(self.word2idx) # the size of vocabulary self.Z = Z # embedding dimension num_hidden_units = 200 # the dimension of the feedforward network model in nn.TransformerEncoder num_layers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder num_heads = 2 # the number of heads in the multiheadattention models dropout = 0.2 # the dropout value self.model = T2GNet(num_tokens, torch.from_numpy(embedding_table).cuda(), self.T - 1, self.Z, self.V * self.D, self.D, self.V - 1, self.IE, self.IP, self.AT, self.G, self.AGE, self.H, self.NT, num_heads, num_hidden_units, num_layers, dropout).to(device) # generate self.generate_while_train = generate_while_train self.save_path = save_path # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.base_lr) # weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr self.tf = self.args.base_tr
class Processor(object): """ Processor for gait generation """ def __init__(self, args, data_path, data_loader, Z, T, A, V, C, D, tag_cats, IE, IP, AT, G, AGE, H, NT, joint_names, joint_parents, word2idx, embedding_table, lower_body_start=15, fill=6, min_train_epochs=20, generate_while_train=False, save_path=None, device='cuda:0'): def get_quats_sos_and_eos(): quats_sos_and_eos_file = os.path.join(data_path, 'quats_sos_and_eos.npz') keys = list(self.data_loader['train'].keys()) num_samples = len(self.data_loader['train']) try: mean_quats_sos = np.load(quats_sos_and_eos_file, allow_pickle=True)['quats_sos'] mean_quats_eos = np.load(quats_sos_and_eos_file, allow_pickle=True)['quats_eos'] except FileNotFoundError: mean_quats_sos = np.zeros((self.V, self.D)) mean_quats_eos = np.zeros((self.V, self.D)) for j in range(self.V): quats_sos = np.zeros((self.D, num_samples)) quats_eos = np.zeros((self.D, num_samples)) for s in range(num_samples): quats_sos[:, s] = self.data_loader['train'][ keys[s]]['rotations'][0, j] quats_eos[:, s] = self.data_loader['train'][ keys[s]]['rotations'][-1, j] _, sos_eig_vectors = np.linalg.eig( np.dot(quats_sos, quats_sos.T)) mean_quats_sos[j] = sos_eig_vectors[:, 0] _, eos_eig_vectors = np.linalg.eig( np.dot(quats_eos, quats_eos.T)) mean_quats_eos[j] = eos_eig_vectors[:, 0] np.savez_compressed(quats_sos_and_eos_file, quats_sos=mean_quats_sos, quats_eos=mean_quats_eos) mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0) mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0) for s in range(num_samples): pos_sos = \ MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0), torch.from_numpy(self.data_loader['train'][keys[s]] ['positions'][0:1, 0]).double().unsqueeze(0), self.joint_parents, torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict'] ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy() affs_sos = MocapDataset.get_mpi_affective_features(pos_sos) pos_eos = \ MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0), torch.from_numpy(self.data_loader['train'][keys[s]] ['positions'][-1:, 0]).double().unsqueeze(0), self.joint_parents, torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict'] ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy() affs_eos = MocapDataset.get_mpi_affective_features(pos_eos) self.data_loader['train'][keys[s]]['positions'] = \ np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0) self.data_loader['train'][keys[s]]['affective_features'] = \ np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos), axis=0) return mean_quats_sos, mean_quats_eos self.args = args self.dataset = args.dataset self.channel_map = { 'Xrotation': 'x', 'Yrotation': 'y', 'Zrotation': 'z' } self.device = device self.data_loader = data_loader self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.io = IO(self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.T = T + 2 self.A = A self.V = V self.C = C self.D = D self.O = 1 self.tag_cats = tag_cats self.IE = IE self.IP = IP self.AT = AT self.G = G self.AGE = AGE self.H = H self.NT = NT self.joint_names = joint_names self.joint_parents = joint_parents self.lower_body_start = lower_body_start self.quats_sos, self.quats_eos = get_quats_sos_and_eos() self.recons_loss_func = nn.L1Loss() self.affs_loss_func = nn.L1Loss() self.best_loss = np.inf self.loss_updated = False self.step_epochs = [ math.ceil(float(self.args.num_epoch * x)) for x in self.args.step ] self.best_loss_epoch = None self.min_train_epochs = min_train_epochs self.zfill = fill self.word2idx = word2idx self.text_sos = np.int64(self.word2idx['<SOS>']) self.text_eos = np.int64(self.word2idx['<EOS>']) num_tokens = len(self.word2idx) # the size of vocabulary self.Z = Z # embedding dimension num_hidden_units = 200 # the dimension of the feedforward network model in nn.TransformerEncoder num_layers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder num_heads = 2 # the number of heads in the multiheadattention models dropout = 0.2 # the dropout value self.model = T2GNet(num_tokens, torch.from_numpy(embedding_table).cuda(), self.T - 1, self.Z, self.V * self.D, self.D, self.V - 1, self.IE, self.IP, self.AT, self.G, self.AGE, self.H, self.NT, num_heads, num_hidden_units, num_layers, dropout).to(device) # generate self.generate_while_train = generate_while_train self.save_path = save_path # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.base_lr) # weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr self.tf = self.args.base_tr def process_data(self, data, poses, quat, trans, affs): data = data.float().to(self.device) poses = poses.float().to(self.device) quat = quat.float().to(self.device) trans = trans.float().to(self.device) affs = affs.float().to(self.device) return data, poses, quat, trans, affs def load_model_at_epoch(self, epoch='best'): model_name, self.best_loss_epoch, self.best_loss =\ get_epoch_and_loss(self.args.work_dir, epoch=epoch) model_found = False try: loaded_vars = torch.load( os.path.join(self.args.work_dir, model_name)) self.model.load_state_dict(loaded_vars['model_dict']) model_found = True except (FileNotFoundError, IsADirectoryError): if epoch == 'best': print('Warning! No saved model found.') else: print('Warning! No saved model found at epoch {:d}.'.format( epoch)) return model_found def adjust_lr(self): self.lr = self.lr * self.args.lr_decay for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def adjust_tf(self): if self.meta_info['epoch'] > 20: self.tf = self.tf * self.args.tf_decay def show_epoch_info(self): print_epochs = [ self.best_loss_epoch if self.best_loss_epoch is not None else 0 ] best_metrics = [self.best_loss] i = 0 for k, v in self.epoch_info.items(): self.io.print_log( '\t{}: {}. Best so far: {} (epoch: {:d}).'.format( k, v, best_metrics[i], print_epochs[i])) i += 1 if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.epoch_info) def show_iter_info(self): if self.meta_info['iter'] % self.args.log_interval == 0: info = '\tIter {} Done.'.format(self.meta_info['iter']) for k, v in self.iter_info.items(): if isinstance(v, float): info = info + ' | {}: {:.4f}'.format(k, v) else: info = info + ' | {}: {}'.format(k, v) self.io.print_log(info) if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.iter_info) def yield_batch(self, batch_size, dataset): batch_joint_offsets = torch.zeros( (batch_size, self.V - 1, self.C)).cuda() batch_pos = torch.zeros((batch_size, self.T, self.V, self.C)).cuda() batch_affs = torch.zeros((batch_size, self.T, self.A)).cuda() batch_quat = torch.zeros((batch_size, self.T, self.V * self.D)).cuda() batch_quat_valid_idx = torch.zeros((batch_size, self.T)).cuda() batch_text = torch.zeros((batch_size, self.Z)).cuda().long() batch_text_valid_idx = torch.zeros((batch_size, self.Z)).cuda() batch_intended_emotion = torch.zeros((batch_size, self.IE)).cuda() batch_intended_polarity = torch.zeros((batch_size, self.IP)).cuda() batch_acting_task = torch.zeros((batch_size, self.AT)).cuda() batch_gender = torch.zeros((batch_size, self.G)).cuda() batch_age = torch.zeros((batch_size, self.AGE)).cuda() batch_handedness = torch.zeros((batch_size, self.H)).cuda() batch_native_tongue = torch.zeros((batch_size, self.NT)).cuda() pseudo_passes = (len(dataset) + batch_size - 1) // batch_size probs = [] for k in dataset.keys(): probs.append(dataset[k]['positions'].shape[0]) probs = np.array(probs) / np.sum(probs) for p in range(pseudo_passes): rand_keys = np.random.choice(len(dataset), size=batch_size, replace=True, p=probs) for i, k in enumerate(rand_keys): joint_offsets = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['joints_dict']['joints_offsets_all'][1:]) pos = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['positions']) affs = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['affective_features']) quat = torch.cat( (self.quats_sos, torch.from_numpy(dataset[str(k).zfill( self.zfill)]['rotations']), self.quats_eos), dim=0) quat_length = quat.shape[0] quat_valid_idx = torch.zeros(self.T) quat_valid_idx[:quat_length] = 1 text = torch.cat((torch.tensor([ self.word2idx[x] for x in [ e for e in str.split(dataset[str(k).zfill(self.zfill)] ['Text']) if e.isalnum() ] ]), torch.from_numpy(np.array([self.text_eos])))) if text[0] != self.text_sos: text = torch.cat( (torch.from_numpy(np.array([self.text_sos])), text)) text_length = text.shape[0] text_valid_idx = torch.zeros(self.Z) text_valid_idx[:text_length] = 1 batch_joint_offsets[i] = joint_offsets batch_pos[i, :pos.shape[0]] = pos batch_pos[i, pos.shape[0]:] = pos[-1:].clone() batch_affs[i, :affs.shape[0]] = affs batch_affs[i, affs.shape[0]:] = affs[-1:].clone() batch_quat[i, :quat_length] = quat.view(quat_length, -1) batch_quat[i, quat_length:] = quat[-1:].view(1, -1).clone() batch_quat_valid_idx[i] = quat_valid_idx batch_text[i, :text_length] = text batch_text_valid_idx[i] = text_valid_idx batch_intended_emotion[i] = torch.from_numpy( dataset[str(k).zfill(self.zfill)]['Intended emotion']) batch_intended_polarity[i] = torch.from_numpy( dataset[str(k).zfill(self.zfill)]['Intended polarity']) batch_acting_task[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Acting task']) batch_gender[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Gender']) batch_age[i] = torch.tensor(dataset[str(k).zfill( self.zfill)]['Age']) batch_handedness[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Handedness']) batch_native_tongue[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Native tongue']) yield batch_joint_offsets, batch_pos, batch_affs, batch_quat, \ batch_quat_valid_idx, batch_text, batch_text_valid_idx, \ batch_intended_emotion, batch_intended_polarity, batch_acting_task, \ batch_gender, batch_age, batch_handedness, batch_native_tongue def return_batch(self, batch_size, dataset, randomized=True): if len(batch_size) > 1: rand_keys = np.copy(batch_size) batch_size = len(batch_size) else: batch_size = batch_size[0] probs = [] for k in dataset.keys(): probs.append(dataset[k]['positions'].shape[0]) probs = np.array(probs) / np.sum(probs) if randomized: rand_keys = np.random.choice(len(dataset), size=batch_size, replace=False, p=probs) else: rand_keys = np.arange(batch_size) batch_joint_offsets = torch.zeros( (batch_size, self.V - 1, self.C)).cuda() batch_pos = torch.zeros((batch_size, self.T, self.V, self.C)).cuda() batch_affs = torch.zeros((batch_size, self.T, self.A)).cuda() batch_quat = torch.zeros((batch_size, self.T, self.V * self.D)).cuda() batch_quat_valid_idx = torch.zeros((batch_size, self.T)).cuda() batch_text = torch.zeros((batch_size, self.Z)).cuda().long() batch_text_valid_idx = torch.zeros((batch_size, self.Z)).cuda() batch_intended_emotion = torch.zeros((batch_size, self.IE)).cuda() batch_intended_polarity = torch.zeros((batch_size, self.IP)).cuda() batch_acting_task = torch.zeros((batch_size, self.AT)).cuda() batch_gender = torch.zeros((batch_size, self.G)).cuda() batch_age = torch.zeros((batch_size, self.AGE)).cuda() batch_handedness = torch.zeros((batch_size, self.H)).cuda() batch_native_tongue = torch.zeros((batch_size, self.NT)).cuda() for i, k in enumerate(rand_keys): joint_offsets = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['joints_dict']['joints_offsets_all'][1:]) pos = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['positions']) affs = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['affective_features']) quat = torch.cat((self.quats_sos, torch.from_numpy(dataset[str(k).zfill( self.zfill)]['rotations']), self.quats_eos), dim=0) quat_length = quat.shape[0] quat_valid_idx = torch.zeros(self.T) quat_valid_idx[:quat_length] = 1 text = torch.cat((torch.tensor([ self.word2idx[x] for x in [ e for e in str.split(dataset[str(k).zfill(self.zfill)] ['Text']) if e.isalnum() ] ]), torch.from_numpy(np.array([self.text_eos])))) if text[0] != self.text_sos: text = torch.cat( (torch.from_numpy(np.array([self.text_sos])), text)) text_length = text.shape[0] text_valid_idx = torch.zeros(self.Z) text_valid_idx[:text_length] = 1 batch_joint_offsets[i] = joint_offsets batch_pos[i, :pos.shape[0]] = pos batch_pos[i, pos.shape[0]:] = pos[-1:].clone() batch_affs[i, :affs.shape[0]] = affs batch_affs[i, affs.shape[0]:] = affs[-1:].clone() batch_quat[i, :quat_length] = quat.view(quat_length, -1) batch_quat[i, quat_length:] = quat[-1:].view(1, -1).clone() batch_quat_valid_idx[i] = quat_valid_idx batch_text[i, :text_length] = text batch_text_valid_idx[i] = text_valid_idx batch_intended_emotion[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Intended emotion']) batch_intended_polarity[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Intended polarity']) batch_acting_task[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Acting task']) batch_gender[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Gender']) batch_age[i] = torch.tensor(dataset[str(k).zfill( self.zfill)]['Age']) batch_handedness[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Handedness']) batch_native_tongue[i] = torch.from_numpy(dataset[str(k).zfill( self.zfill)]['Native tongue']) return batch_joint_offsets, batch_pos, batch_affs, batch_quat, \ batch_quat_valid_idx, batch_text, batch_text_valid_idx, \ batch_intended_emotion, batch_intended_polarity, batch_acting_task, \ batch_gender, batch_age, batch_handedness, batch_native_tongue def per_train(self): self.model.train() train_loader = self.data_loader['train'] batch_loss = 0. N = 0. for joint_offsets, pos, affs, quat, quat_valid_idx,\ text, text_valid_idx, intended_emotion, intended_polarity,\ acting_task, gender, age, handedness,\ native_tongue in self.yield_batch(self.args.batch_size, train_loader): quat_prelude = self.quats_eos.view(1, -1).cuda() \ .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float() quat_prelude[:, -1] = quat[:, 0].clone() self.optimizer.zero_grad() with torch.autograd.detect_anomaly(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_pred, quat_pred_pre_norm = self.model( text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, quat_prelude, joint_lengths / scales[..., None]) quat_pred_pre_norm = quat_pred_pre_norm.view( quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V, self.D, self.lower_body_start, self.args.upper_body_weight) quat_loss *= self.args.quat_reg root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) affs_pred = MocapDataset.get_mpi_affective_features(pos_pred) # row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V # row_sums[row_sums == 0.] = 1. shifted_pos = pos - pos[:, :, 0:1] shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1] recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos[:, 1:]) # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1) # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_loss[:, :, self.lower_body_start:].sum(-1) # recons_loss = self.args.recons_reg *\ # torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] - # shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1) # recons_derv_loss = self.args.upper_body_weight *\ # (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_derv_loss[:, :, self.lower_body_start:].sum(-1) # recons_derv_loss = 2. * self.args.recons_reg *\ # torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums) # # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1) # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:]) train_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss train_loss.backward() # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) self.optimizer.step() # animation_pred = { # 'joint_names': self.joint_names, # 'joint_offsets': joint_offsets, # 'joint_parents': self.joint_parents, # 'positions': pos_pred, # 'rotations': quat_pred # } # MocapDataset.save_as_bvh(animation_pred, # dataset_name=self.dataset, # subset_name='test') # animation = { # 'joint_names': self.joint_names, # 'joint_offsets': joint_offsets, # 'joint_parents': self.joint_parents, # 'positions': pos, # 'rotations': quat # } # MocapDataset.save_as_bvh(animation, # dataset_name=self.dataset, # subset_name='gt') # Compute statistics batch_loss += train_loss.item() N += quat.shape[0] # statistics self.iter_info['loss'] = train_loss.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) self.iter_info['tf'] = '{:.6f}'.format(self.tf) self.show_iter_info() self.meta_info['iter'] += 1 batch_loss = batch_loss / N self.epoch_info['mean_loss'] = batch_loss self.show_epoch_info() self.io.print_timer() self.adjust_lr() self.adjust_tf() # pos_pred_np = np.swapaxes(np.reshape(pos_pred.detach().cpu().numpy(), # (pos_pred.shape[0], self.T - 1, -1)), 2, 1) # display_animations(pos_pred_np, self.joint_parents, # save=True, dataset_name=self.dataset, subset_name='test', overwrite=True) def per_eval(self): self.model.eval() test_loader = self.data_loader['test'] eval_loss = 0. N = 0. for joint_offsets, pos, affs, quat, quat_valid_idx, \ text, text_valid_idx, intended_emotion, intended_polarity, \ acting_task, gender, age, handedness, \ native_tongue in self.yield_batch(self.args.batch_size, test_loader): with torch.no_grad(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_prelude = self.quats_eos.view(1, -1).cuda() \ .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float() quat_prelude[:, -1] = quat[:, 0].clone() quat_pred, quat_pred_pre_norm = self.model( text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, quat_prelude, joint_lengths / scales[..., None]) quat_pred_pre_norm = quat_pred_pre_norm.view( quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V, self.D, self.lower_body_start, self.args.upper_body_weight) quat_loss *= self.args.quat_reg root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) affs_pred = MocapDataset.get_mpi_affective_features(pos_pred) row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V row_sums[row_sums == 0.] = 1. shifted_pos = pos - pos[:, :, 0:1] shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1] recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos[:, 1:]) # recons_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos[:, 1:]).sum(-1) # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) + \ # recons_loss[:, :, self.lower_body_start:].sum(-1) # recons_loss = self.args.recons_reg * torch.mean( # (recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # # recons_derv_loss = torch.abs(shifted_pos_pred[:, 2:] - shifted_pos_pred[:, 1:-1] - # shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1) # recons_derv_loss = self.args.upper_body_weight * \ # (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) + \ # recons_derv_loss[:, :, self.lower_body_start:].sum(-1) # recons_derv_loss = 2. * self.args.recons_reg * \ # torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums) # # affs_loss = torch.abs(affs[:, 1:] - affs_pred[:, 1:]).sum(-1) # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:]) eval_loss += quat_norm_loss + quat_loss + recons_loss + affs_loss # eval_loss += quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss N += quat.shape[0] eval_loss /= N self.epoch_info['mean_loss'] = eval_loss if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[ 'epoch'] > self.min_train_epochs: self.best_loss = self.epoch_info['mean_loss'] self.best_loss_epoch = self.meta_info['epoch'] self.loss_updated = True else: self.loss_updated = False self.show_epoch_info() def train(self): if self.args.load_last_best or (self.args.load_at_epoch is not None): model_found = self.load_model_at_epoch( epoch='best' if self.args.load_last_best else self.args. load_at_epoch) if not model_found and self.args.start_epoch is not 'best': print('Warning! Trying to load best known model: '.format( self.args.start_epoch), end='') model_found = self.load_model_at_epoch(epoch='best') self.args.start_epoch = self.best_loss_epoch if model_found else 0 print('loaded.') if not model_found: print('Warning! Starting at epoch 0') self.args.start_epoch = 0 else: self.args.start_epoch = self.best_loss_epoch else: self.args.start_epoch = 0 for epoch in range(self.args.start_epoch, self.args.num_epoch): self.meta_info['epoch'] = epoch # training self.io.print_log('Training epoch: {}'.format(epoch)) self.per_train() self.io.print_log('Done.') # evaluation if (epoch % self.args.eval_interval == 0) or (epoch + 1 == self.args.num_epoch): self.io.print_log('Eval epoch: {}'.format(epoch)) self.per_eval() self.io.print_log('Done.') # save model and weights if epoch > self.args.min_train_epochs and ( self.loss_updated or epoch % self.args.save_interval == 0): torch.save({'model_dict': self.model.state_dict()}, os.path.join( self.args.work_dir, 'epoch_{}_loss_{:.4f}_model.pth.tar'.format( epoch, self.epoch_info['mean_loss']))) if self.generate_while_train: self.generate_motion(load_saved_model=False, samples_to_generate=1, epoch=epoch) def copy_prefix(self, var, prefix_length=None): if prefix_length is None: prefix_length = self.prefix_length return [ var[s, :prefix_length].unsqueeze(0) for s in range(var.shape[0]) ] def generate_motion(self, load_saved_model=True, samples_to_generate=10, epoch=None, randomized=True, animations_as_videos=False): if epoch is None: epoch = 'best' if load_saved_model: self.load_model_at_epoch(epoch=epoch) self.model.eval() test_loader = self.data_loader['test'] joint_offsets, pos, affs, quat, quat_valid_idx, \ text, text_valid_idx, intended_emotion, intended_polarity, \ acting_task, gender, age, handedness, \ native_tongue = self.return_batch([samples_to_generate], test_loader, randomized=randomized) with torch.no_grad(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_prelude = self.quats_eos.view(1, -1).cuda() \ .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float() quat_prelude[:, -1] = quat[:, 0].clone() quat_pred, quat_pred_pre_norm = self.model( text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, quat_prelude, joint_lengths / scales[..., None]) for s in range(len(quat_pred)): quat_pred[s] = qfix( quat_pred[s].view(quat_pred[s].shape[0], self.V, -1)).view(quat_pred[s].shape[0], -1) root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) animation_pred = { 'joint_names': self.joint_names, 'joint_offsets': joint_offsets, 'joint_parents': self.joint_parents, 'positions': pos_pred, 'rotations': quat_pred } MocapDataset.save_as_bvh(animation_pred, dataset_name=self.dataset + '_glove', subset_name='test') animation = { 'joint_names': self.joint_names, 'joint_offsets': joint_offsets, 'joint_parents': self.joint_parents, 'positions': pos, 'rotations': quat } MocapDataset.save_as_bvh(animation, dataset_name=self.dataset + '_glove', subset_name='gt') if animations_as_videos: pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\ detach().cpu().numpy() display_animations(pos_pred_np, self.joint_parents, save=True, dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch), overwrite=True)
def __init__(self, args, data_path, data_loader, Z, T, A, V, C, D, tag_cats, IE, IP, AT, G, AGE, H, NT, joint_names, joint_parents, lower_body_start=15, fill=6, min_train_epochs=20, generate_while_train=False, save_path=None): def get_quats_sos_and_eos(): quats_sos_and_eos_file = os.path.join(data_path, 'quats_sos_and_eos.npz') keys = list(self.data_loader['train'].keys()) num_samples = len(self.data_loader['train']) try: mean_quats_sos = np.load(quats_sos_and_eos_file, allow_pickle=True)['quats_sos'] mean_quats_eos = np.load(quats_sos_and_eos_file, allow_pickle=True)['quats_eos'] except FileNotFoundError: mean_quats_sos = np.zeros((self.V, self.D)) mean_quats_eos = np.zeros((self.V, self.D)) for j in range(self.V): quats_sos = np.zeros((self.D, num_samples)) quats_eos = np.zeros((self.D, num_samples)) for s in range(num_samples): quats_sos[:, s] = self.data_loader['train'][ keys[s]]['rotations'][0, j] quats_eos[:, s] = self.data_loader['train'][ keys[s]]['rotations'][-1, j] _, sos_eig_vectors = np.linalg.eig( np.dot(quats_sos, quats_sos.T)) mean_quats_sos[j] = sos_eig_vectors[:, 0] _, eos_eig_vectors = np.linalg.eig( np.dot(quats_eos, quats_eos.T)) mean_quats_eos[j] = eos_eig_vectors[:, 0] np.savez_compressed(quats_sos_and_eos_file, quats_sos=mean_quats_sos, quats_eos=mean_quats_eos) mean_quats_sos = torch.from_numpy(mean_quats_sos).unsqueeze(0) mean_quats_eos = torch.from_numpy(mean_quats_eos).unsqueeze(0) for s in range(num_samples): pos_sos = \ MocapDataset.forward_kinematics(mean_quats_sos.unsqueeze(0), torch.from_numpy(self.data_loader['train'][keys[s]] ['positions'][0:1, 0]).double().unsqueeze(0), self.joint_parents, torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict'] ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy() affs_sos = MocapDataset.get_mpi_affective_features(pos_sos) pos_eos = \ MocapDataset.forward_kinematics(mean_quats_eos.unsqueeze(0), torch.from_numpy(self.data_loader['train'][keys[s]] ['positions'][-1:, 0]).double().unsqueeze(0), self.joint_parents, torch.from_numpy(self.data_loader['train'][keys[s]]['joints_dict'] ['joints_offsets_all']).unsqueeze(0)).squeeze(0).numpy() affs_eos = MocapDataset.get_mpi_affective_features(pos_eos) self.data_loader['train'][keys[s]]['positions'] = \ np.concatenate((pos_sos, self.data_loader['train'][keys[s]]['positions'], pos_eos), axis=0) self.data_loader['train'][keys[s]]['affective_features'] = \ np.concatenate((affs_sos, self.data_loader['train'][keys[s]]['affective_features'], affs_eos), axis=0) return mean_quats_sos, mean_quats_eos self.args = args self.dataset = args.dataset self.channel_map = { 'Xrotation': 'x', 'Yrotation': 'y', 'Zrotation': 'z' } self.data_loader = data_loader self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.io = IO(self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.T = T + 2 self.T_steps = 120 self.A = A self.V = V self.C = C self.D = D self.O = 1 self.tag_cats = tag_cats self.IE = IE self.IP = IP self.AT = AT self.G = G self.AGE = AGE self.H = H self.NT = NT self.joint_names = joint_names self.joint_parents = joint_parents self.lower_body_start = lower_body_start self.quats_sos, self.quats_eos = get_quats_sos_and_eos() # self.quats_sos = torch.from_numpy(Quaternions.id(self.V).qs).unsqueeze(0) # self.quats_eos = torch.from_numpy(Quaternions.from_euler( # np.tile([np.pi / 2., 0, 0], (self.V, 1))).qs).unsqueeze(0) self.recons_loss_func = nn.L1Loss() self.affs_loss_func = nn.MSELoss() self.best_loss = np.inf self.loss_updated = False self.step_epochs = [ math.ceil(float(self.args.num_epoch * x)) for x in self.args.step ] self.best_loss_epoch = None self.min_train_epochs = min_train_epochs self.zfill = fill try: self.text_processor = torch.load('text_processor.pt') except FileNotFoundError: self.text_processor = tt.data.Field( tokenize=get_tokenizer("basic_english"), init_token='<sos>', eos_token='<eos>', lower=True) train_text, eval_text, test_text = tt.datasets.WikiText2.splits( self.text_processor) self.text_processor.build_vocab(train_text, eval_text, test_text) self.text_sos = np.int64(self.text_processor.vocab.stoi['<sos>']) self.text_eos = np.int64(self.text_processor.vocab.stoi['<eos>']) num_tokens = len( self.text_processor.vocab.stoi) # the size of vocabulary self.Z = Z + 2 # embedding dimension num_hidden_units_enc = 200 # the dimension of the feedforward network model in nn.TransformerEncoder num_hidden_units_dec = 200 # the dimension of the feedforward network model in nn.TransformerDecoder num_layers_enc = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder num_layers_dec = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerDecoder num_heads_enc = 2 # the number of heads in the multi-head attention in nn.TransformerEncoder num_heads_dec = 2 # the number of heads in the multi-head attention in nn.TransformerDecoder # num_hidden_units_enc = 216 # the dimension of the feedforward network model in nn.TransformerEncoder # num_hidden_units_dec = 512 # the dimension of the feedforward network model in nn.TransformerDecoder # num_layers_enc = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder # num_layers_dec = 4 # the number of nn.TransformerEncoderLayer in nn.TransformerDecoder # num_heads_enc = 6 # the number of heads in the multi-head attention in nn.TransformerEncoder # num_heads_dec = self.V # the number of heads in the multi-head attention in nn.TransformerDecoder dropout = 0.2 # the dropout value self.model = T2GNet(num_tokens, self.T - 1, self.Z, self.V * self.D, self.D, self.V - 1, self.IE, self.IP, self.AT, self.G, self.AGE, self.H, self.NT, num_heads_enc, num_heads_dec, num_hidden_units_enc, num_hidden_units_dec, num_layers_enc, num_layers_dec, dropout) if self.args.use_multiple_gpus and torch.cuda.device_count() > 1: self.args.batch_size *= torch.cuda.device_count() self.model = nn.DataParallel(self.model) self.model.to(torch.cuda.current_device()) print('Total training data:\t\t{}'.format( len(self.data_loader['train']))) print('Total validation data:\t\t{}'.format( len(self.data_loader['test']))) print('Training with batch size:\t{}'.format(self.args.batch_size)) # generate self.generate_while_train = generate_while_train self.save_path = save_path # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.base_lr) # weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr self.tf = self.args.base_tr
def __init__(self, args, dataset, data_loader, T, V, C, D, A, T_out, joint_parents, num_classes, affs_max, affs_min, min_train_epochs=-1, label_weights=None, generate_while_train=False, poses_mean=None, poses_std=None, save_path=None, device='cuda:0'): self.args = args self.device = device self.dataset = dataset self.data_loader = data_loader self.num_classes = num_classes self.affs_max = affs_max self.affs_min = affs_min self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.io = IO(self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.T = T self.V = V self.C = C self.D = D self.A = A self.T_out = T_out self.P = int(0.9 * T) self.joint_parents = joint_parents self.model = hap.HAPPY(self.dataset, T, V, C, A, T_out, num_classes, residual=self.args.residual) self.model.cuda(device) self.model.apply(weights_init) self.model_GRU_h_enc = None self.model_GRU_h_dec1 = None self.model_GRU_h_dec = None self.label_weights = torch.from_numpy(label_weights).cuda().float() self.loss = semisup_loss self.best_loss = math.inf self.best_mean_ap = 0. self.loss_updated = False self.mean_ap_updated = False self.step_epochs = [ math.ceil(float(self.args.num_epoch * x)) for x in self.args.step ] self.best_loss_epoch = None self.best_acc_epoch = None self.min_train_epochs = min_train_epochs self.beta = 0.1 # generate self.generate_while_train = generate_while_train self.poses_mean = poses_mean self.poses_std = poses_std self.save_path = save_path self.dataset = dataset # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.base_lr) # weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr self.tr = self.args.base_tr
class Processor(object): """ Processor for gait generation """ def __init__(self, args, dataset, data_loader, T, V, C, D, A, T_out, joint_parents, num_classes, affs_max, affs_min, min_train_epochs=-1, label_weights=None, generate_while_train=False, poses_mean=None, poses_std=None, save_path=None, device='cuda:0'): self.args = args self.device = device self.dataset = dataset self.data_loader = data_loader self.num_classes = num_classes self.affs_max = affs_max self.affs_min = affs_min self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.io = IO(self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.T = T self.V = V self.C = C self.D = D self.A = A self.T_out = T_out self.P = int(0.9 * T) self.joint_parents = joint_parents self.model = hap.HAPPY(self.dataset, T, V, C, A, T_out, num_classes, residual=self.args.residual) self.model.cuda(device) self.model.apply(weights_init) self.model_GRU_h_enc = None self.model_GRU_h_dec1 = None self.model_GRU_h_dec = None self.label_weights = torch.from_numpy(label_weights).cuda().float() self.loss = semisup_loss self.best_loss = math.inf self.best_mean_ap = 0. self.loss_updated = False self.mean_ap_updated = False self.step_epochs = [ math.ceil(float(self.args.num_epoch * x)) for x in self.args.step ] self.best_loss_epoch = None self.best_acc_epoch = None self.min_train_epochs = min_train_epochs self.beta = 0.1 # generate self.generate_while_train = generate_while_train self.poses_mean = poses_mean self.poses_std = poses_std self.save_path = save_path self.dataset = dataset # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.base_lr) # weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr self.tr = self.args.base_tr def process_data(self, data, poses, diffs, affs): data = data.float().to(self.device) poses = poses.float().to(self.device) diffs = diffs.float().to(self.device) affs = affs.float().to(self.device) return data, poses, diffs, affs def load_best_model(self, ): loaded_vars = torch.load( os.path.join(self.args.work_dir, 'taew_weights.pth.tar')) self.model.load_state_dict(loaded_vars['model_dict']) self.model_GRU_h_enc = loaded_vars['h_enc'] self.model_GRU_h_dec1 = loaded_vars['h_dec1'] self.model_GRU_h_dec = loaded_vars['h_dec'] def adjust_lr(self): self.lr = self.lr * self.args.lr_decay for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def adjust_tr(self): self.tr = self.tr * self.args.tr_decay def show_epoch_info(self, show_best=True): print_epochs = [ self.best_loss_epoch if self.best_loss_epoch is not None else 0, self.best_acc_epoch if self.best_acc_epoch is not None else 0, self.best_acc_epoch if self.best_acc_epoch is not None else 0 ] best_metrics = [self.best_loss, 0, self.best_mean_ap] i = 0 for k, v in self.epoch_info.items(): if show_best: self.io.print_log( '\t{}: {}. Best so far: {} (epoch: {:d}).'.format( k, v, best_metrics[i], print_epochs[i])) else: self.io.print_log('\t{}: {}.'.format(k, v)) i += 1 if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.epoch_info) def show_iter_info(self): if self.meta_info['iter'] % self.args.log_interval == 0: info = '\tIter {} Done.'.format(self.meta_info['iter']) for k, v in self.iter_info.items(): if isinstance(v, float): info = info + ' | {}: {:.4f}'.format(k, v) else: info = info + ' | {}: {}'.format(k, v) self.io.print_log(info) if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.iter_info) def per_train(self): self.model.train() train_loader = self.data_loader['train'] loss_value = [] ap_values = [] mean_ap_value = [] for data, poses, rots, affs, num_frames, labels in train_loader: # get data num_frames_actual, sort_idx = torch.sort( (num_frames * self.T).type(torch.IntTensor).cuda(), descending=True) seq_lens = num_frames_actual - 1 data = data[sort_idx, :, :] poses = poses[sort_idx, :] rots = rots[sort_idx, :, :] affs = affs[sort_idx, :] data, poses, rots, affs = self.process_data( data, poses, rots, affs) # forward labels_pred, diffs_recons, diffs_recons_pre_norm, affs_pred,\ self.model_GRU_h_enc, self.model_GRU_h_dec1, self.model_GRU_h_dec =\ self.model(poses, rots[:, :, 1:], affs, teacher_steps=int(self.T * self.tr)) loss = self.loss(labels, labels_pred, rots, diffs_recons, diffs_recons_pre_norm, self.meta_info['epoch'], affs, affs_pred, self.V, self.D, self.num_classes[0], label_weights=self.label_weights) # backward self.optimizer.zero_grad() loss.backward() # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.grad_clip) self.optimizer.step() # statistics self.iter_info['loss'] = loss.data.item() self.iter_info['aps'], self.iter_info['mean_ap'], self.iter_info['f1'], _ =\ calculate_metrics(labels[0].detach().cpu().numpy(), labels_pred[0].detach().cpu().numpy()) self.iter_info['lr'] = '{:.6f}'.format(self.lr) self.iter_info['tr'] = '{:.6f}'.format(self.tr) loss_value.append(self.iter_info['loss']) ap_values.append(self.iter_info['aps']) mean_ap_value.append(self.iter_info['mean_ap']) self.show_iter_info() self.meta_info['iter'] += 1 self.epoch_info['mean_loss'] = np.mean(loss_value) self.epoch_info['mean_aps'] = np.mean(ap_values, axis=0) self.epoch_info['mean_mean_ap'] = np.mean(mean_ap_value) self.show_epoch_info() self.io.print_timer() self.adjust_lr() self.adjust_tr() def per_test(self, epoch=None, evaluation=True): self.model.eval() test_loader = self.data_loader['test'] loss_value = [] mean_ap_value = [] ap_values = [] label_frag = [] for data, poses, diffs, affs, num_frames, labels in test_loader: # get data num_frames_actual, sort_idx = torch.sort( (num_frames * self.T).type(torch.IntTensor).cuda(), descending=True) seq_lens = num_frames_actual - 1 data = data[sort_idx, :] poses = poses[sort_idx, :] diffs = diffs[sort_idx, :, :] affs = affs[sort_idx, :] data, poses, diffs, affs = self.process_data( data, poses, diffs, affs) # inference with torch.no_grad(): labels_pred, diffs_recons, diffs_recons_pre_norm, affs_pred, _, _, _ = \ self.model(poses, diffs[:, :, 1:], affs, teacher_steps=int(self.T * self.tr)) # get loss if evaluation: loss = self.loss(labels, labels_pred, diffs, diffs_recons, diffs_recons_pre_norm, self.meta_info['epoch'], affs, affs_pred, self.V, self.D, self.num_classes[0], label_weights=self.label_weights, eval_time=True) loss_value.append(loss.item()) ap, mean_ap, _, _ = calculate_metrics( labels[0].detach().cpu().numpy(), labels_pred[0].detach().cpu().numpy(), eval_time=True) ap_values.append(ap) mean_ap_value.append(mean_ap) label_frag.append(labels[0].data.cpu().numpy()) if evaluation: self.epoch_info['mean_loss'] = np.mean(loss_value) self.epoch_info['mean_aps'] = np.mean(ap_values, axis=0) self.epoch_info['mean_mean_ap'] = np.mean(mean_ap_value) if self.epoch_info[ 'mean_loss'] < self.best_loss and epoch > self.min_train_epochs: self.best_loss = self.epoch_info['mean_loss'] self.best_loss_epoch = self.meta_info['epoch'] self.loss_updated = True else: self.loss_updated = False if self.epoch_info[ 'mean_mean_ap'] > self.best_mean_ap and epoch > self.min_train_epochs: self.best_mean_ap = self.epoch_info['mean_mean_ap'] self.best_acc_epoch = self.meta_info['epoch'] self.mean_ap_updated = True else: self.mean_ap_updated = False self.show_epoch_info() def train(self): if self.args.load_last_best: self.load_best_model() self.args.start_epoch = self.best_loss_epoch for epoch in range(self.args.start_epoch, self.args.num_epoch): self.meta_info['epoch'] = epoch # training self.io.print_log('Training epoch: {}'.format(epoch)) self.per_train() self.io.print_log('Done.') # evaluation if (epoch % self.args.eval_interval == 0) or (epoch + 1 == self.args.num_epoch): self.io.print_log('Eval epoch: {}'.format(epoch)) self.per_test(epoch) self.io.print_log('Done.') # save model and weights if self.loss_updated or self.mean_ap_updated: torch.save( { 'model_dict': self.model.state_dict(), 'h_enc': self.model_GRU_h_enc, 'h_dec1': self.model_GRU_h_dec1, 'h_dec': self.model_GRU_h_dec }, os.path.join( self.args.work_dir, 'epoch_{}_loss_{:.4f}_acc_{:.2f}_model.pth.tar'.format( epoch, self.best_loss, self.best_mean_ap * 100.))) if self.generate_while_train: self.generate(load_saved_model=False, samples_to_generate=1) def test(self): # the path of weights must be appointed if self.args.weights is None: raise ValueError('Please appoint --weights.') self.io.print_log('Model: {}.'.format(self.args.model)) self.io.print_log('Weights: {}.'.format(self.args.weights)) # evaluation self.io.print_log('Evaluation Start:') self.per_test() self.io.print_log('Done.\n') # save the output of model if self.args.save_result: result_dict = dict( zip(self.data_loader['test'].dataset.sample_name, self.result)) self.io.save_pkl(result_dict, 'test_result.pkl') def evaluate_model(self, load_saved_model=True): if load_saved_model: self.load_best_model() self.model.eval() test_loader = self.data_loader['test'] loss_value = [] mean_ap_value = [] ap_values = [] label_frag = [] for data, poses, diffs, affs, num_frames, labels in test_loader: # get data num_frames_actual, sort_idx = torch.sort( (num_frames * self.T).type(torch.IntTensor).cuda(), descending=True) seq_lens = num_frames_actual - 1 data = data[sort_idx, :] poses = poses[sort_idx, :] diffs = diffs[sort_idx, :, :] affs = affs[sort_idx, :] data, poses, diffs, affs = self.process_data( data, poses, diffs, affs) # inference with torch.no_grad(): labels_pred, diffs_recons, diffs_recons_pre_norm, affs_pred, _, _, _ = \ self.model(poses, diffs[:, :, 1:], affs, teacher_steps=int(self.T * self.tr)) # get loss loss = self.loss(labels, labels_pred, diffs, diffs_recons, diffs_recons_pre_norm, self.meta_info['epoch'], affs, affs_pred, self.V, self.D, self.num_classes[0], label_weights=self.label_weights, eval_time=True) loss_value.append(loss.item()) ap, mean_ap, _, _ = calculate_metrics( labels[0].detach().cpu().numpy(), labels_pred[0].detach().cpu().numpy(), eval_time=True) ap_values.append(ap) mean_ap_value.append(mean_ap) label_frag.append(labels[0].data.cpu().numpy()) self.epoch_info['mean_loss'] = np.mean(loss_value) self.epoch_info['aps'] = np.mean(ap_values, axis=0) self.epoch_info['mean_ap'] = np.mean(mean_ap_value) self.show_epoch_info(show_best=False)
def __init__(self, args, dataset, data_loader, T, V, C, D, A, S, joints_dict, joint_names, joint_offsets, joint_parents, num_labels, prefix_length, target_length, min_train_epochs=20, generate_while_train=False, save_path=None, device='cuda:0'): self.args = args self.dataset = dataset self.mocap = MocapDataset(V, C, np.arange(V), joints_dict) self.joint_names = joint_names self.joint_offsets = joint_offsets self.joint_parents = joint_parents self.device = device self.data_loader = data_loader self.num_labels = num_labels self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.io = IO( self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.T = T self.V = V self.C = C self.D = D self.A = A self.S = S self.O = 4 self.Z = 1 self.RS = 1 self.o_scale = 10. self.prefix_length = prefix_length self.target_length = target_length self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O, self.Z, self.RS, num_labels[0]) self.model.cuda(device) self.orient_h = None self.quat_h = None self.z_rs_loss_func = nn.L1Loss() self.affs_loss_func = nn.L1Loss() self.spline_loss_func = nn.L1Loss() self.best_loss = math.inf self.loss_updated = False self.mean_ap_updated = False self.step_epochs = [math.ceil(float(self.args.num_epoch * x)) for x in self.args.step] self.best_loss_epoch = None self.min_train_epochs = min_train_epochs # generate self.generate_while_train = generate_while_train self.save_path = save_path # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD( self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam( self.model.parameters(), lr=self.args.base_lr) # weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr self.tf = self.args.base_tr
class Processor(object): """ Processor for emotive gait generation """ def __init__(self, args, dataset, data_loader, T, V, C, D, A, S, joints_dict, joint_names, joint_offsets, joint_parents, num_labels, prefix_length, target_length, min_train_epochs=20, generate_while_train=False, save_path=None, device='cuda:0'): self.args = args self.dataset = dataset self.mocap = MocapDataset(V, C, np.arange(V), joints_dict) self.joint_names = joint_names self.joint_offsets = joint_offsets self.joint_parents = joint_parents self.device = device self.data_loader = data_loader self.num_labels = num_labels self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.io = IO( self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.T = T self.V = V self.C = C self.D = D self.A = A self.S = S self.O = 4 self.Z = 1 self.RS = 1 self.o_scale = 10. self.prefix_length = prefix_length self.target_length = target_length self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O, self.Z, self.RS, num_labels[0]) self.model.cuda(device) self.orient_h = None self.quat_h = None self.z_rs_loss_func = nn.L1Loss() self.affs_loss_func = nn.L1Loss() self.spline_loss_func = nn.L1Loss() self.best_loss = math.inf self.loss_updated = False self.mean_ap_updated = False self.step_epochs = [math.ceil(float(self.args.num_epoch * x)) for x in self.args.step] self.best_loss_epoch = None self.min_train_epochs = min_train_epochs # generate self.generate_while_train = generate_while_train self.save_path = save_path # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD( self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam( self.model.parameters(), lr=self.args.base_lr) # weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr self.tf = self.args.base_tr def process_data(self, data, poses, quat, trans, affs): data = data.float().to(self.device) poses = poses.float().to(self.device) quat = quat.float().to(self.device) trans = trans.float().to(self.device) affs = affs.float().to(self.device) return data, poses, quat, trans, affs def load_best_model(self, ): model_name, self.best_loss_epoch, self.best_loss =\ get_best_epoch_and_loss(self.args.work_dir) best_model_found = False try: loaded_vars = torch.load(os.path.join(self.args.work_dir, model_name)) self.model.load_state_dict(loaded_vars['model_dict']) self.orient_h = loaded_vars['orient_h'] self.quat_h = loaded_vars['quat_h'] best_model_found = True except (FileNotFoundError, IsADirectoryError): print('No saved model found.') return best_model_found def adjust_lr(self): self.lr = self.lr * self.args.lr_decay for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def adjust_tf(self): if self.meta_info['epoch'] > 20: self.tf = self.tf * self.args.tf_decay def show_epoch_info(self): print_epochs = [self.best_loss_epoch if self.best_loss_epoch is not None else 0] best_metrics = [self.best_loss] i = 0 for k, v in self.epoch_info.items(): self.io.print_log('\t{}: {}. Best so far: {} (epoch: {:d}).'. format(k, v, best_metrics[i], print_epochs[i])) i += 1 if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.epoch_info) def show_iter_info(self): if self.meta_info['iter'] % self.args.log_interval == 0: info = '\tIter {} Done.'.format(self.meta_info['iter']) for k, v in self.iter_info.items(): if isinstance(v, float): info = info + ' | {}: {:.4f}'.format(k, v) else: info = info + ' | {}: {}'.format(k, v) self.io.print_log(info) if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.iter_info) def yield_batch(self, batch_size, dataset): batch_pos = np.zeros((batch_size, self.T, self.V, self.C), dtype='float32') batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D), dtype='float32') batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32') batch_z_mean = np.zeros((batch_size, self.Z), dtype='float32') batch_z_dev = np.zeros((batch_size, self.T, self.Z), dtype='float32') batch_root_speed = np.zeros((batch_size, self.T, self.RS), dtype='float32') batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32') batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32') batch_labels = np.zeros((batch_size, 1, self.num_labels[0]), dtype='float32') pseudo_passes = (len(dataset) + batch_size - 1) // batch_size probs = [] for k in dataset.keys(): if 'spline' not in dataset[k]: raise KeyError('No splines found. Perhaps you forgot to compute them?') probs.append(dataset[k]['spline'].size()) probs = np.array(probs) / np.sum(probs) for p in range(pseudo_passes): rand_keys = np.random.choice(len(dataset), size=batch_size, replace=True, p=probs) for i, k in enumerate(rand_keys): pos = dataset[str(k)]['positions'][:self.T] quat = dataset[str(k)]['rotations'][:self.T, 1:] orient = dataset[str(k)]['rotations'][:self.T, 0] affs = dataset[str(k)]['affective_features'][:self.T] spline, phase = Spline.extract_spline_features(dataset[str(k)]['spline']) spline = spline[:self.T] phase = phase[:self.T] z = dataset[str(k)]['trans_and_controls'][:, 1][:self.T] z_mean = np.mean(z[:self.prefix_length]) z_dev = z - z_mean root_speed = dataset[str(k)]['trans_and_controls'][:, -1][:self.T] labels = dataset[str(k)]['labels'][:self.num_labels[0]] batch_pos[i] = pos batch_quat[i] = quat.reshape(self.T, -1) batch_orient[i] = orient.reshape(self.T, -1) batch_z_mean[i] = z_mean.reshape(-1, 1) batch_z_dev[i] = z_dev.reshape(self.T, -1) batch_root_speed[i] = root_speed.reshape(self.T, 1) batch_affs[i] = affs batch_spline[i] = spline batch_labels[i] = np.expand_dims(labels, axis=0) yield batch_pos, batch_quat, batch_orient, batch_z_mean, batch_z_dev,\ batch_root_speed, batch_affs, batch_spline, batch_labels def return_batch(self, batch_size, dataset, randomized=True): if len(batch_size) > 1: rand_keys = np.copy(batch_size) batch_size = len(batch_size) else: batch_size = batch_size[0] probs = [] for k in dataset.keys(): if 'spline' not in dataset[k]: raise KeyError('No splines found. Perhaps you forgot to compute them?') probs.append(dataset[k]['spline'].size()) probs = np.array(probs) / np.sum(probs) if randomized: rand_keys = np.random.choice(len(dataset), size=batch_size, replace=False, p=probs) else: rand_keys = np.arange(batch_size) batch_pos = np.zeros((batch_size, self.T, self.V, self.C), dtype='float32') batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D), dtype='float32') batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32') batch_z_mean = np.zeros((batch_size, self.Z), dtype='float32') batch_z_dev = np.zeros((batch_size, self.T, self.Z), dtype='float32') batch_root_speed = np.zeros((batch_size, self.T, self.RS), dtype='float32') batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32') batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32') batch_labels = np.zeros((batch_size, 1, self.num_labels[0]), dtype='float32') pseudo_passes = (len(dataset) + batch_size - 1) // batch_size for i, k in enumerate(rand_keys): pos = dataset[str(k)]['positions'][:self.T] quat = dataset[str(k)]['rotations'][:self.T, 1:] orient = dataset[str(k)]['rotations'][:self.T, 0] affs = dataset[str(k)]['affective_features'][:self.T] spline, phase = Spline.extract_spline_features(dataset[str(k)]['spline']) spline = spline[:self.T] phase = phase[:self.T] z = dataset[str(k)]['trans_and_controls'][:, 1][:self.T] z_mean = np.mean(z[:self.prefix_length]) z_dev = z - z_mean root_speed = dataset[str(k)]['trans_and_controls'][:, -1][:self.T] labels = dataset[str(k)]['labels'][:self.num_labels[0]] batch_pos[i] = pos batch_quat[i] = quat.reshape(self.T, -1) batch_orient[i] = orient.reshape(self.T, -1) batch_z_mean[i] = z_mean.reshape(-1, 1) batch_z_dev[i] = z_dev.reshape(self.T, -1) batch_root_speed[i] = root_speed.reshape(self.T, 1) batch_affs[i] = affs batch_spline[i] = spline batch_labels[i] = np.expand_dims(labels, axis=0) return batch_pos, batch_quat, batch_orient, batch_z_mean, batch_z_dev,\ batch_root_speed, batch_affs, batch_spline, batch_labels def per_train(self): self.model.train() train_loader = self.data_loader['train'] batch_loss = 0. N = 0. for pos, quat, orient, z_mean, z_dev,\ root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, train_loader): pos = torch.from_numpy(pos).cuda() orient = torch.from_numpy(orient).cuda() quat = torch.from_numpy(quat).cuda() z_mean = torch.from_numpy(z_mean).cuda() z_dev = torch.from_numpy(z_dev).cuda() root_speed = torch.from_numpy(root_speed).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1) z_rs = torch.cat((z_dev, root_speed), dim=-1) quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1) pos_pred = pos.clone() orient_pred = orient.clone() quat_pred = quat.clone() z_rs_pred = z_rs.clone() affs_pred = affs.clone() spline_pred = spline.clone() pos_pred_all = pos.clone() orient_pred_all = orient.clone() quat_pred_all = quat.clone() z_rs_pred_all = z_rs.clone() affs_pred_all = affs.clone() spline_pred_all = spline.clone() orient_prenorm_terms = torch.zeros_like(orient_pred) quat_prenorm_terms = torch.zeros_like(quat_pred) # forward self.optimizer.zero_grad() for t in range(self.target_length): orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ self.orient_h, self.quat_h,\ orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\ quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \ self.model( orient_pred[:, t:self.prefix_length + t], quat_pred[:, t:self.prefix_length + t], z_rs_pred[:, t:self.prefix_length + t], affs_pred[:, t:self.prefix_length + t], spline_pred[:, t:self.prefix_length + t], labels[:, t:self.prefix_length + t], orient_h=None if t == 0 else self.orient_h, quat_h=None if t == 0 else self.quat_h, return_prenorm=True) pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \ spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[:, :self.prefix_length + t], pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]], z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean, orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1], quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1]) if np.random.uniform(size=1)[0] > self.tf: pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1) prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg * torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( torch.cat((orient_pred_all[:, self.prefix_length - 1:], quat_pred_all[:, self.prefix_length - 1:]), dim=-1), quat_all, self.V, self.D) quat_loss *= self.args.quat_reg z_rs_loss = self.z_rs_loss_func(z_rs_pred_all[:, self.prefix_length:], z_rs[:, self.prefix_length:]) spline_loss = self.spline_loss_func(spline_pred_all[:, self.prefix_length:], spline[:, self.prefix_length:]) fs_loss = losses.foot_speed_loss(pos_pred, pos) loss_total = quat_norm_loss + quat_loss + quat_derv_loss + z_rs_loss + spline_loss + fs_loss loss_total.backward() # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) self.optimizer.step() # animation_pred = { # 'joint_names': self.joint_names, # 'joint_offsets': torch.from_numpy(self.joint_offsets[1:]). # float().unsqueeze(0).repeat(pos_pred_all.shape[0], 1, 1), # 'joint_parents': self.joint_parents, # 'positions': pos_pred_all, # 'rotations': torch.cat((orient_pred_all, quat_pred_all), dim=-1) # } # MocapDataset.save_as_bvh(animation_pred, # dataset_name=self.dataset, # subset_name='test') # Compute statistics batch_loss += loss_total.item() N += quat.shape[0] # statistics self.iter_info['loss'] = loss_total.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) self.iter_info['tf'] = '{:.6f}'.format(self.tf) self.show_iter_info() self.meta_info['iter'] += 1 batch_loss = batch_loss / N self.epoch_info['mean_loss'] = batch_loss self.show_epoch_info() self.io.print_timer() self.adjust_lr() self.adjust_tf() def per_test(self): self.model.eval() test_loader = self.data_loader['test'] valid_loss = 0. N = 0. for pos, quat, orient, z_mean, z_dev,\ root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, test_loader): with torch.no_grad(): pos = torch.from_numpy(pos).cuda() orient = torch.from_numpy(orient).cuda() quat = torch.from_numpy(quat).cuda() z_mean = torch.from_numpy(z_mean).cuda() z_dev = torch.from_numpy(z_dev).cuda() root_speed = torch.from_numpy(root_speed).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1) z_rs = torch.cat((z_dev, root_speed), dim=-1) quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1) pos_pred = pos.clone() orient_pred = orient.clone() quat_pred = quat.clone() z_rs_pred = z_rs.clone() affs_pred = affs.clone() spline_pred = spline.clone() orient_prenorm_terms = torch.zeros_like(orient_pred) quat_prenorm_terms = torch.zeros_like(quat_pred) # forward for t in range(self.target_length): orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\ quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\ z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\ self.orient_h, self.quat_h,\ orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\ quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \ self.model( orient_pred[:, t:self.prefix_length + t], quat_pred[:, t:self.prefix_length + t], z_rs_pred[:, t:self.prefix_length + t], affs_pred[:, t:self.prefix_length + t], spline[:, t:self.prefix_length + t], labels[:, t:self.prefix_length + t], orient_h=None if t == 0 else self.orient_h, quat_h=None if t == 0 else self.quat_h, return_prenorm=True) pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \ affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\ spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[:, :self.prefix_length + t], pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]], z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean, orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1], quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1]) prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1) prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( torch.cat((orient_pred[:, self.prefix_length - 1:], quat_pred[:, self.prefix_length - 1:]), dim=-1), quat_all, self.V, self.D) quat_loss *= self.args.quat_reg recons_loss = self.args.recons_reg *\ (pos_pred[:, self.prefix_length:] - pos_pred[:, self.prefix_length:, 0:1] - pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm() valid_loss += recons_loss N += quat.shape[0] valid_loss /= N # if self.meta_info['epoch'] > 5 and self.loss_updated: # pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\ # detach().cpu().numpy() # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True, # dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch), # overwrite=True) # pos_in_np = pos_in.contiguous().view(pos_in.shape[0], pos_in.shape[1], -1).permute(0, 2, 1).\ # detach().cpu().numpy() # display_animations(pos_in_np, self.V, self.C, self.mocap.joint_parents, save=True, # dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) + # '_gt', # overwrite=True) self.epoch_info['mean_loss'] = valid_loss if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info['epoch'] > self.min_train_epochs: self.best_loss = self.epoch_info['mean_loss'] self.best_loss_epoch = self.meta_info['epoch'] self.loss_updated = True else: self.loss_updated = False self.show_epoch_info() def train(self): if self.args.load_last_best: best_model_found = self.load_best_model() self.args.start_epoch = self.best_loss_epoch if best_model_found else 0 for epoch in range(self.args.start_epoch, self.args.num_epoch): self.meta_info['epoch'] = epoch # training self.io.print_log('Training epoch: {}'.format(epoch)) self.per_train() self.io.print_log('Done.') # evaluation if (epoch % self.args.eval_interval == 0) or ( epoch + 1 == self.args.num_epoch): self.io.print_log('Eval epoch: {}'.format(epoch)) self.per_test() self.io.print_log('Done.') # save model and weights if self.loss_updated: torch.save({'model_dict': self.model.state_dict(), 'orient_h': self.orient_h, 'quat_h': self.quat_h}, os.path.join(self.args.work_dir, 'epoch_{}_loss_{:.4f}_model.pth.tar'. format(epoch, self.best_loss))) if self.generate_while_train: self.generate_motion(load_saved_model=False, samples_to_generate=1) def copy_prefix(self, var, prefix_length=None): if prefix_length is None: prefix_length = self.prefix_length return [var[s, :prefix_length].unsqueeze(0) for s in range(var.shape[0])] def generate_linear_trajectory(self, traj, alpha=0.001): traj_markers = (traj[:, self.prefix_length - 2] + (traj[:, self.prefix_length - 1] - traj[:, self.prefix_length - 2]) / alpha).unsqueeze(1) return traj_markers def generate_circular_trajectory(self, traj, alpha=5., num_segments=10): last_segment = alpha * traj[:, self.prefix_length - 1:self.prefix_length] -\ traj[:, self.prefix_length - 2:self.prefix_length - 1] last_marker = traj[:, self.prefix_length - 1:self.prefix_length] traj_markers = last_marker.clone() angle_per_segment = 2. * np.pi / num_segments for _ in range(num_segments): next_segment = qrot(expmap_to_quaternion( torch.tensor([0, -angle_per_segment, 0]).cuda().float().repeat( last_segment.shape[0], last_segment.shape[1], 1)), torch.cat(( last_segment[..., 0:1], torch.zeros_like(last_segment[..., 0:1]), last_segment[..., 1:]), dim=-1))[..., [0, 2]] next_marker = next_segment + last_marker traj_markers = torch.cat((traj_markers, next_marker), dim=1) last_segment = next_segment.clone() last_marker = next_marker.clone() traj_markers = traj_markers[:, 1:] return traj_markers def compute_next_traj_point(self, traj, traj_marker, rs_pred): tangent = traj_marker - traj tangent /= (torch.norm(tangent, dim=-1) + 1e-9) return tangent * rs_pred + traj def compute_next_traj_point_sans_markers(self, pos_last, quat_next, z_pred, rs_pred): # pos_next = torch.zeros_like(pos_last) offsets = torch.from_numpy(self.mocap.joint_offsets).cuda().float(). \ unsqueeze(0).unsqueeze(0).repeat(pos_last.shape[0], pos_last.shape[1], 1, 1) pos_next = MocapDataset.forward_kinematics(quat_next.contiguous().view(quat_next.shape[0], quat_next.shape[1], -1, self.D), pos_last[:, :, 0], self.joint_parents, torch.from_numpy(self.joint_offsets).float().cuda()) # for joint in range(1, self.V): # pos_next[:, :, joint] = qrot(quat_copy[:, :, joint - 1], offsets[:, :, joint]) \ # + pos_next[:, :, self.mocap.joint_parents[joint]] root = pos_next[:, :, 0] l_shoulder = pos_next[:, :, 18] r_shoulder = pos_next[:, :, 25] facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]] facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9) return rs_pred * facing + pos_last[:, :, 0, [0, 2]] def get_diff_from_traj(self, pos_pred, traj_pred, s): root = pos_pred[s][:, :, 0] l_shoulder = pos_pred[s][:, :, 18] r_shoulder = pos_pred[s][:, :, 25] facing = torch.cross(l_shoulder - root, r_shoulder - root, dim=-1)[..., [0, 2]] facing /= (torch.norm(facing, dim=-1)[..., None] + 1e-9) tangents = traj_pred[s][:, 1:] - traj_pred[s][:, :-1] tangent_norms = torch.norm(tangents, dim=-1) tangents /= (tangent_norms[..., None] + 1e-9) tangents = torch.cat((torch.zeros_like(tangents[:, 0:1]), tangents), dim=1) tangent_norms = torch.cat((torch.zeros_like(tangent_norms[:, 0:1]), tangent_norms), dim=1) axis_diff = torch.cross(torch.cat((facing[..., 0:1], torch.zeros_like(facing[..., 0:1]), facing[..., 1:]), dim=-1), torch.cat((tangents[..., 0:1], torch.zeros_like(tangents[..., 0:1]), tangents[..., 1:]), dim=-1)) axis_diff_norms = torch.norm(axis_diff, dim=-1) axis_diff /= (axis_diff_norms[..., None] + 1e-9) angle_diff = torch.acos(torch.einsum('ijk,ijk->ij', facing, tangents).clamp(min=-1., max=1.)) angle_diff[tangent_norms < 1e-6] = 0. return axis_diff, angle_diff def rotate_gaits(self, orient_pred, quat_pred, quat_diff, head_tilt, l_shoulder_slouch, r_shoulder_slouch): quat_reshape = quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D).clone() quat_reshape[..., 14, :] = qmul(torch.from_numpy(head_tilt).cuda().float(), quat_reshape[..., 14, :]) quat_reshape[..., 16, :] = qmul(torch.from_numpy(l_shoulder_slouch).cuda().float(), quat_reshape[..., 16, :]) quat_reshape[..., 17, :] = qmul(torch.from_numpy(qinv(l_shoulder_slouch)).cuda().float(), quat_reshape[..., 17, :]) quat_reshape[..., 23, :] = qmul(torch.from_numpy(r_shoulder_slouch).cuda().float(), quat_reshape[..., 23, :]) quat_reshape[..., 24, :] = qmul(torch.from_numpy(qinv(r_shoulder_slouch)).cuda().float(), quat_reshape[..., 24, :]) return qmul(quat_diff, orient_pred), quat_reshape.contiguous().view(quat_reshape.shape[0], quat_reshape.shape[1], -1) def generate_motion(self, load_saved_model=True, samples_to_generate=1534, max_steps=300, randomized=True): if load_saved_model: self.load_best_model() self.model.eval() test_loader = self.data_loader['test'] pos, quat, orient, z_mean, z_dev, \ root_speed, affs, spline, labels = self.return_batch([samples_to_generate], test_loader, randomized=randomized) pos = torch.from_numpy(pos).cuda() traj = pos[:, :, 0, [0, 2]].clone() orient = torch.from_numpy(orient).cuda() quat = torch.from_numpy(quat).cuda() z_mean = torch.from_numpy(z_mean).cuda() z_dev = torch.from_numpy(z_dev).cuda() root_speed = torch.from_numpy(root_speed).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() z_rs = torch.cat((z_dev, root_speed), dim=-1) quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1) labels = np.tile(labels, (1, max_steps + self.prefix_length, 1)) # Begin for transition # traj[:, self.prefix_length - 2] = torch.tensor([-0.208, 4.8]).cuda().float() # traj[:, self.prefix_length - 1] = torch.tensor([-0.204, 5.1]).cuda().float() # final_emo_idx = int(max_steps/2) # labels[:, final_emo_idx:] = np.array([1., 0., 0., 0.]) # labels[:, :final_emo_idx + 1] = np.linspace(labels[:, 0], labels[:, final_emo_idx], # num=final_emo_idx + 1, axis=1) # End for transition labels = torch.from_numpy(labels).cuda() # traj_np = traj_markers.detach().cpu().numpy() # import matplotlib.pyplot as plt # plt.plot(traj_np[6, :, 0], traj_np[6, :, 1]) # plt.show() happy_idx = [25, 295, 390, 667, 1196] sad_idx = [169, 184, 258, 948, 974] angry_idx = [89, 93, 96, 112, 289, 290, 978] neutral_idx = [72, 106, 143, 237, 532, 747, 1177] sample_idx = np.squeeze(np.concatenate((happy_idx, sad_idx, angry_idx, neutral_idx))) ## CHANGE HERE # scene_corners = torch.tensor([[149.862, 50.833], # [149.862, 36.81], # [161.599, 36.81], # [161.599, 50.833]]).cuda().float() # character_heights = torch.tensor([0.95, 0.88, 0.86, 0.90, 0.95, 0.82]).cuda().float() # num_characters_per_side = torch.tensor([2, 3, 2, 3]).cuda().int() # traj_markers, traj_offsets, character_scale =\ # generate_trajectories(scene_corners, z_mean, character_heights, # num_characters_per_side, traj[:, :self.prefix_length]) # num_characters_per_side = torch.tensor([4, 0, 0, 0]).cuda().int() # traj_markers, traj_offsets, character_scale =\ # generate_simple_trajectories(scene_corners, z_mean[:4], z_mean[:4], # num_characters_per_side, traj[sample_idx, :self.prefix_length]) # traj_markers, traj_offsets, character_scale =\ # generate_rvo_trajectories(scene_corners, z_mean[:4], z_mean[:4], # num_characters_per_side, traj[sample_idx, :self.prefix_length]) # traj[sample_idx, :self.prefix_length] += traj_offsets # pos_sampled = pos[sample_idx].clone() # pos_sampled[:, :self.prefix_length, :, [0, 2]] += traj_offsets.unsqueeze(2).repeat(1, 1, self.V, 1) # pos[sample_idx] = pos_sampled # traj_markers = self.generate_linear_trajectory(traj) pos_pred = self.copy_prefix(pos) traj_pred = self.copy_prefix(traj) orient_pred = self.copy_prefix(orient) quat_pred = self.copy_prefix(quat) z_rs_pred = self.copy_prefix(z_rs) affs_pred = self.copy_prefix(affs) spline_pred = self.copy_prefix(spline) labels_pred = self.copy_prefix(labels, prefix_length=max_steps + self.prefix_length) # forward elapsed_time = np.zeros(len(sample_idx)) for counter, s in enumerate(sample_idx): # range(samples_to_generate): start_time = time.time() orient_h_copy = self.orient_h.clone() quat_h_copy = self.quat_h.clone() ## CHANGE HERE num_markers = max_steps + self.prefix_length + 1 # num_markers = traj_markers[s].shape[0] marker_idx = 0 t = -1 with torch.no_grad(): while marker_idx < num_markers: t += 1 if t > max_steps: print('Sample: {}. Did not reach end in {} steps.'.format(s, max_steps), end='') break pos_pred[s] = torch.cat((pos_pred[s], torch.zeros_like(pos_pred[s][:, -1:])), dim=1) traj_pred[s] = torch.cat((traj_pred[s], torch.zeros_like(traj_pred[s][:, -1:])), dim=1) orient_pred[s] = torch.cat((orient_pred[s], torch.zeros_like(orient_pred[s][:, -1:])), dim=1) quat_pred[s] = torch.cat((quat_pred[s], torch.zeros_like(quat_pred[s][:, -1:])), dim=1) z_rs_pred[s] = torch.cat((z_rs_pred[s], torch.zeros_like(z_rs_pred[s][:, -1:])), dim=1) affs_pred[s] = torch.cat((affs_pred[s], torch.zeros_like(affs_pred[s][:, -1:])), dim=1) spline_pred[s] = torch.cat((spline_pred[s], torch.zeros_like(spline_pred[s][:, -1:])), dim=1) orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \ quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \ z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \ orient_h_copy, quat_h_copy = \ self.model( orient_pred[s][:, t:self.prefix_length + t], quat_pred[s][:, t:self.prefix_length + t], z_rs_pred[s][:, t:self.prefix_length + t], affs_pred[s][:, t:self.prefix_length + t], spline_pred[s][:, t:self.prefix_length + t], labels_pred[s][:, t:self.prefix_length + t], orient_h=None if t == 0 else orient_h_copy, quat_h=None if t == 0 else quat_h_copy, return_prenorm=False) traj_curr = traj_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t].clone() # root_speed = torch.norm( # pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] - \ # pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t, 0], dim=-1) ## CHANGE HERE # traj_next = \ # self.compute_next_traj_point( # traj_curr, # traj_markers[s, marker_idx], # o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2]) try: traj_next = traj[s, self.prefix_length + t] except IndexError: traj_next = \ self.compute_next_traj_point_sans_markers( pos_pred[s][:, self.prefix_length + t - 1:self.prefix_length + t], quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0], z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 1]) pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \ affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \ spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[s][:, :self.prefix_length + t], traj_next, z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1], orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1]) # min_speed_pred = torch.min(torch.cat((lf_speed_pred.unsqueeze(-1), # rf_speed_pred.unsqueeze(-1)), dim=-1), dim=-1)[0] # if min_speed_pred - diff_speeds_mean[s] - diff_speeds_std[s] < 0.: # root_speed_pred = 0. # else: # root_speed_pred = o_z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 2] # ## CHANGE HERE # traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \ # self.compute_next_traj_point( # traj_curr, # traj_markers[s, marker_idx], # root_speed_pred) # if torch.norm(traj_next - traj_curr, dim=-1).squeeze() >= \ # torch.norm(traj_markers[s, marker_idx] - traj_curr, dim=-1).squeeze(): # marker_idx += 1 traj_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = traj_next marker_idx += 1 pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \ affs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], \ spline_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[s][:, :self.prefix_length + t], pos_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]], z_rs_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean[s:s + 1], orient_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1], quat_pred[s][:, self.prefix_length + t:self.prefix_length + t + 1]) print('Sample: {}. Steps: {}'.format(s, t), end='\r') print() # shift = torch.zeros((1, scene_corners.shape[1] + 1)).cuda().float() # shift[..., [0, 2]] = scene_corners[0] # pos_pred[s] = (pos_pred[s] - shift) / character_scale + shift # pos_pred_np = pos_pred[s].contiguous().view(pos_pred[s].shape[0], # pos_pred[s].shape[1], -1).permute(0, 2, 1).\ # detach().cpu().numpy() # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True, # dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch), # save_file_names=[str(s).zfill(6)], # overwrite=True) # plt.cla() # fig, (ax1, ax2) = plt.subplots(2, 1) # ax1.plot(root_speeds[s]) # ax1.plot(lf_speeds[s]) # ax1.plot(rf_speeds[s]) # ax1.plot(min_speeds[s] - root_speeds[s]) # ax1.legend(['root', 'left', 'right', 'diff']) # ax2.plot(root_speeds_pred) # ax2.plot(lf_speeds_pred) # ax2.plot(rf_speeds_pred) # ax2.plot(min_speeds_pred - root_speeds_pred) # ax2.legend(['root', 'left', 'right', 'diff']) # plt.show() head_tilt = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1)) l_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1)) r_shoulder_slouch = np.tile(np.array([0., 0., 0.]), (1, quat_pred[s].shape[1], 1)) head_tilt = Quaternions.from_euler(head_tilt, order='xyz').qs l_shoulder_slouch = Quaternions.from_euler(l_shoulder_slouch, order='xyz').qs r_shoulder_slouch = Quaternions.from_euler(r_shoulder_slouch, order='xyz').qs # Begin for aligning facing direction to trajectory axis_diff, angle_diff = self.get_diff_from_traj(pos_pred, traj_pred, s) angle_thres = 0.3 # angle_thres = torch.max(angle_diff[:, 1:self.prefix_length]) angle_diff[angle_diff <= angle_thres] = 0. angle_diff[:, self.prefix_length] = 0. # End for aligning facing direction to trajectory # pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff, # head_tilt, l_shoulder_slouch, r_shoulder_slouch, s) # pos_pred[s] = pos_copy.clone() # angle_diff_intermediate = self.get_diff_from_traj(pos_pred, traj_pred, s) # if torch.max(angle_diff_intermediate[:, self.prefix_length:]) > np.pi / 2.: # quat_diff = Quaternions.from_angle_axis(-angle_diff.cpu().numpy(), np.array([0, 1, 0])).qs # pos_copy, quat_copy = self.rotate_gaits(pos_pred, quat_pred, quat_diff, # head_tilt, l_shoulder_slouch, r_shoulder_slouch, s) # pos_pred[s] = pos_copy.clone() # axis_diff = torch.zeros_like(axis_diff) # axis_diff[..., 1] = 1. # angle_diff = torch.zeros_like(angle_diff) quat_diff = torch.from_numpy(Quaternions.from_angle_axis( angle_diff.cpu().numpy(), axis_diff.cpu().numpy()).qs).cuda().float() orient_pred[s], quat_pred[s] = self.rotate_gaits(orient_pred[s], quat_pred[s], quat_diff, head_tilt, l_shoulder_slouch, r_shoulder_slouch) if labels_pred[s][:, 0, 0] > 0.5: label_dir = 'happy' elif labels_pred[s][:, 0, 1] > 0.5: label_dir = 'sad' elif labels_pred[s][:, 0, 2] > 0.5: label_dir = 'angry' else: label_dir = 'neutral' ## CHANGE HERE # pos_pred[s] = pos_pred[s][:, self.prefix_length + 5:] # o_z_rs_pred[s] = o_z_rs_pred[s][:, self.prefix_length + 5:] # quat_pred[s] = quat_pred[s][:, self.prefix_length + 5:] traj_pred_np = pos_pred[s][0, :, 0].cpu().numpy() save_file_name = '{:06}_{:.2f}_{:.2f}_{:.2f}_{:.2f}'.format(s, labels_pred[s][0, 0, 0], labels_pred[s][0, 0, 1], labels_pred[s][0, 0, 2], labels_pred[s][0, 0, 3]) animation_pred = { 'joint_names': self.joint_names, 'joint_offsets': torch.from_numpy(self.joint_offsets[1:]).float().unsqueeze(0).repeat( len(pos_pred), 1, 1), 'joint_parents': self.joint_parents, 'positions': pos_pred[s], 'rotations': torch.cat((orient_pred[s], quat_pred[s]), dim=-1) } self.mocap.save_as_bvh(animation_pred, dataset_name=self.dataset, # subset_name='epoch_' + str(self.best_loss_epoch), # save_file_names=[str(s).zfill(6)]) subset_name=os.path.join('no_aff_epoch_' + str(self.best_loss_epoch), str(counter).zfill(2) + '_' + label_dir), save_file_names=['root']) end_time = time.time() elapsed_time[counter] = end_time - start_time print('Elapsed Time: {}'.format(elapsed_time[counter])) # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True, # dataset_name=self.dataset, # # subset_name='epoch_' + str(self.best_loss_epoch), # # save_file_names=[str(s).zfill(6)], # subset_name=os.path.join('epoch_' + str(self.best_loss_epoch), label_dir), # save_file_names=[save_file_name], # overwrite=True) print('Mean Elapsed Time: {}'.format(np.mean(elapsed_time)))
class Processor(object): """ Processor for gait generation """ def __init__(self, args, dataset, data_loader, T, V, C, D, A, S, joint_parents, num_labels, prefix_length, target_length, min_train_epochs=-1, generate_while_train=False, save_path=None, device='cuda:0'): self.args = args self.dataset = dataset self.mocap = MocapDataset(V, C, joint_parents) self.device = device self.data_loader = data_loader self.num_labels = num_labels self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.io = IO(self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.T = T self.V = V self.C = C self.D = D self.A = A self.S = S self.O = 1 self.PRS = 2 self.prefix_length = prefix_length self.target_length = target_length self.joint_parents = joint_parents self.model = quater_emonet.QuaterEmoNet(V, D, S, A, self.O, num_labels[0], self.PRS) self.model.cuda(device) self.quat_h = None self.p_rs_loss_func = nn.L1Loss() self.affs_loss_func = nn.L1Loss() self.best_loss = math.inf self.best_mean_ap = 0. self.loss_updated = False self.mean_ap_updated = False self.step_epochs = [ math.ceil(float(self.args.num_epoch * x)) for x in self.args.step ] self.best_loss_epoch = None self.best_acc_epoch = None self.min_train_epochs = min_train_epochs # generate self.generate_while_train = generate_while_train self.save_path = save_path # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.base_lr) # weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr self.tf = self.args.base_tr def process_data(self, data, poses, quat, trans, affs): data = data.float().to(self.device) poses = poses.float().to(self.device) quat = quat.float().to(self.device) trans = trans.float().to(self.device) affs = affs.float().to(self.device) return data, poses, quat, trans, affs def load_best_model(self, ): if self.best_loss_epoch is None: model_name, self.best_loss_epoch, self.best_loss, self.best_mean_ap =\ get_best_epoch_and_loss(self.args.work_dir) # load model # if self.best_loss_epoch > 0: loaded_vars = torch.load(os.path.join(self.args.work_dir, model_name)) self.model.load_state_dict(loaded_vars['model_dict']) self.quat_h = loaded_vars['quat_h'] def adjust_lr(self): self.lr = self.lr * self.args.lr_decay for param_group in self.optimizer.param_groups: param_group['lr'] = self.lr def adjust_tf(self): if self.meta_info['epoch'] > 20: self.tf = self.tf * self.args.tf_decay def show_epoch_info(self): print_epochs = [ self.best_loss_epoch if self.best_loss_epoch is not None else 0, self.best_acc_epoch if self.best_acc_epoch is not None else 0, self.best_acc_epoch if self.best_acc_epoch is not None else 0 ] best_metrics = [self.best_loss, 0, self.best_mean_ap] i = 0 for k, v in self.epoch_info.items(): self.io.print_log( '\t{}: {}. Best so far: {} (epoch: {:d}).'.format( k, v, best_metrics[i], print_epochs[i])) i += 1 if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.epoch_info) def show_iter_info(self): if self.meta_info['iter'] % self.args.log_interval == 0: info = '\tIter {} Done.'.format(self.meta_info['iter']) for k, v in self.iter_info.items(): if isinstance(v, float): info = info + ' | {}: {:.4f}'.format(k, v) else: info = info + ' | {}: {}'.format(k, v) self.io.print_log(info) if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.iter_info) def yield_batch(self, batch_size, dataset): batch_pos = np.zeros((batch_size, self.T, self.V, self.C), dtype='float32') batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D), dtype='float32') batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32') batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32') batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32') batch_phase_and_root_speed = np.zeros((batch_size, self.T, self.PRS), dtype='float32') batch_labels = np.zeros((batch_size, 1, self.num_labels[0]), dtype='float32') pseudo_passes = (len(dataset) + batch_size - 1) // batch_size probs = [] for k in dataset.keys(): if 'spline' not in dataset[k]: raise KeyError( 'No splines found. Perhaps you forgot to compute them?') probs.append(dataset[k]['spline'].size()) probs = np.array(probs) / np.sum(probs) for p in range(pseudo_passes): rand_keys = np.random.choice(len(dataset), size=batch_size, replace=True, p=probs) for i, k in enumerate(rand_keys): pos = dataset[str(k)]['positions_world'] quat = dataset[str(k)]['rotations'] orient = dataset[str(k)]['orientations'] affs = dataset[str(k)]['affective_features'] spline, phase = Spline.extract_spline_features( dataset[str(k)]['spline']) root_speed = dataset[str(k)]['trans_and_controls'][:, -1].reshape( -1, 1) labels = dataset[str(k)]['labels'][:self.num_labels[0]] batch_pos[i] = pos batch_quat[i] = quat.reshape(self.T, -1) batch_orient[i] = orient.reshape(self.T, -1) batch_affs[i] = affs batch_spline[i] = spline batch_phase_and_root_speed[i] = np.concatenate( (phase, root_speed), axis=-1) batch_labels[i] = np.expand_dims(labels, axis=0) yield batch_pos, batch_quat, batch_orient, batch_affs, batch_spline,\ batch_phase_and_root_speed / np.pi, batch_labels def return_batch(self, batch_size, dataset): if len(batch_size) > 1: rand_keys = np.copy(batch_size) batch_size = len(batch_size) else: batch_size = batch_size[0] probs = [] for k in dataset.keys(): if 'spline' not in dataset[k]: raise KeyError( 'No splines found. Perhaps you forgot to compute them?' ) probs.append(dataset[k]['spline'].size()) probs = np.array(probs) / np.sum(probs) rand_keys = np.random.choice(len(dataset), size=batch_size, replace=False, p=probs) batch_pos = np.zeros((batch_size, self.T, self.V, self.C), dtype='float32') batch_traj = np.zeros((batch_size, self.T, self.C), dtype='float32') batch_quat = np.zeros((batch_size, self.T, (self.V - 1) * self.D), dtype='float32') batch_orient = np.zeros((batch_size, self.T, self.O), dtype='float32') batch_affs = np.zeros((batch_size, self.T, self.A), dtype='float32') batch_spline = np.zeros((batch_size, self.T, self.S), dtype='float32') batch_phase_and_root_speed = np.zeros((batch_size, self.T, self.PRS), dtype='float32') batch_labels = np.zeros((batch_size, 1, self.num_labels[0]), dtype='float32') for i, k in enumerate(rand_keys): pos = dataset[str(k)]['positions_world'] traj = dataset[str(k)]['trajectory'] quat = dataset[str(k)]['rotations'] orient = dataset[str(k)]['orientations'] affs = dataset[str(k)]['affective_features'] spline, phase = Spline.extract_spline_features( dataset[str(k)]['spline']) root_speed = dataset[str(k)]['trans_and_controls'][:, -1].reshape( -1, 1) labels = dataset[str(k)]['labels'][:self.num_labels[0]] batch_pos[i] = pos batch_traj[i] = traj batch_quat[i] = quat.reshape(self.T, -1) batch_orient[i] = orient.reshape(self.T, -1) batch_affs[i] = affs batch_spline[i] = spline batch_phase_and_root_speed[i] = np.concatenate((phase, root_speed), axis=-1) batch_labels[i] = np.expand_dims(labels, axis=0) return batch_pos, batch_traj, batch_quat, batch_orient, batch_affs, batch_spline,\ batch_phase_and_root_speed, batch_labels def per_train(self): self.model.train() train_loader = self.data_loader['train'] batch_loss = 0. N = 0. for pos, quat, orient, affs, spline, p_rs, labels in self.yield_batch( self.args.batch_size, train_loader): pos = torch.from_numpy(pos).cuda() quat = torch.from_numpy(quat).cuda() orient = torch.from_numpy(orient).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() p_rs = torch.from_numpy(p_rs).cuda() labels = torch.from_numpy(labels).cuda() pos_pred = pos.clone() quat_pred = quat.clone() p_rs_pred = p_rs.clone() affs_pred = affs.clone() pos_pred_all = pos.clone() quat_pred_all = quat.clone() p_rs_pred_all = p_rs.clone() affs_pred_all = affs.clone() prenorm_terms = torch.zeros_like(quat_pred) # forward self.optimizer.zero_grad() for t in range(self.target_length): quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \ p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \ self.quat_h, prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \ self.model( quat_pred[:, t:self.prefix_length + t], p_rs_pred[:, t:self.prefix_length + t], affs_pred[:, t:self.prefix_length + t], spline[:, t:self.prefix_length + t], orient[:, t:self.prefix_length + t], labels, quat_h=None if t == 0 else self.quat_h, return_prenorm=True) pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0], quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], orient[:, self.prefix_length + t:self.prefix_length + t + 1]) if np.random.uniform(size=1)[0] > self.tf: pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg * torch.mean( (torch.sum(prenorm_terms**2, dim=-1) - 1)**2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred_all[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:], self.V, self.D) quat_loss *= self.args.quat_reg p_rs_loss = self.p_rs_loss_func( p_rs_pred_all[:, self.prefix_length:], p_rs[:, self.prefix_length:]) affs_loss = self.affs_loss_func( affs_pred_all[:, self.prefix_length:], affs[:, self.prefix_length:]) # recons_loss = self.args.recons_reg *\ # (pos_pred_all[:, self.prefix_length:] - pos_pred_all[:, self.prefix_length:, 0:1] - # pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm() loss_total = quat_norm_loss + quat_loss + quat_derv_loss + p_rs_loss + affs_loss # + recons_loss loss_total.backward() # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) self.optimizer.step() # Compute statistics batch_loss += loss_total.item() N += quat.shape[0] # statistics self.iter_info['loss'] = loss_total.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) self.iter_info['tf'] = '{:.6f}'.format(self.tf) self.show_iter_info() self.meta_info['iter'] += 1 batch_loss = batch_loss / N self.epoch_info['mean_loss'] = batch_loss self.show_epoch_info() self.io.print_timer() self.adjust_lr() self.adjust_tf() def per_test(self): self.model.eval() test_loader = self.data_loader['test'] valid_loss = 0. N = 0. for pos, quat, orient, affs, spline, p_rs, labels in self.yield_batch( self.args.batch_size, test_loader): pos = torch.from_numpy(pos).cuda() quat = torch.from_numpy(quat).cuda() orient = torch.from_numpy(orient).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() p_rs = torch.from_numpy(p_rs).cuda() labels = torch.from_numpy(labels).cuda() pos_pred = pos.clone() quat_pred = quat.clone() p_rs_pred = p_rs.clone() affs_pred = affs.clone() prenorm_terms = torch.zeros_like(quat_pred) # forward self.optimizer.zero_grad() for t in range(self.target_length): quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \ p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \ self.quat_h, prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \ self.model( quat_pred[:, t:self.prefix_length + t], p_rs_pred[:, t:self.prefix_length + t], affs_pred[:, t:self.prefix_length + t], spline[:, t:self.prefix_length + t], orient[:, t:self.prefix_length + t], labels, quat_h=None if t == 0 else self.quat_h, return_prenorm=True) pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \ affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0], quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1], orient[:, self.prefix_length + t:self.prefix_length + t + 1]) prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg * torch.mean( (torch.sum(prenorm_terms**2, dim=-1) - 1)**2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:], self.V, self.D) quat_loss *= self.args.quat_reg recons_loss = self.args.recons_reg *\ (pos_pred[:, self.prefix_length:] - pos_pred[:, self.prefix_length:, 0:1] - pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm() valid_loss += recons_loss N += quat.shape[0] valid_loss /= N # if self.meta_info['epoch'] > 5 and self.loss_updated: # pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\ # detach().cpu().numpy() # display_animations(pos_pred_np, self.V, self.C, self.joint_parents, save=True, # dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch), # overwrite=True) # pos_in_np = pos_in.contiguous().view(pos_in.shape[0], pos_in.shape[1], -1).permute(0, 2, 1).\ # detach().cpu().numpy() # display_animations(pos_in_np, self.V, self.C, self.joint_parents, save=True, # dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) + # '_gt', # overwrite=True) self.epoch_info['mean_loss'] = valid_loss if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[ 'epoch'] > self.min_train_epochs: self.best_loss = self.epoch_info['mean_loss'] self.best_loss_epoch = self.meta_info['epoch'] self.loss_updated = True else: self.loss_updated = False self.show_epoch_info() def train(self): if self.args.load_last_best: self.load_best_model() self.args.start_epoch = self.best_loss_epoch for epoch in range(self.args.start_epoch, self.args.num_epoch): self.meta_info['epoch'] = epoch # training self.io.print_log('Training epoch: {}'.format(epoch)) self.per_train() self.io.print_log('Done.') # evaluation if (epoch % self.args.eval_interval == 0) or (epoch + 1 == self.args.num_epoch): self.io.print_log('Eval epoch: {}'.format(epoch)) self.per_test() self.io.print_log('Done.') # save model and weights if self.loss_updated: torch.save( { 'model_dict': self.model.state_dict(), 'quat_h': self.quat_h }, os.path.join( self.args.work_dir, 'epoch_{}_loss_{:.4f}_acc_{:.2f}_model.pth.tar'.format( epoch, self.best_loss, self.best_mean_ap * 100.))) if self.generate_while_train: self.generate_motion(load_saved_model=False, samples_to_generate=1) def copy_prefix(self, var, target_length): shape = list(var.shape) shape[1] = self.prefix_length + target_length var_pred = torch.zeros(torch.Size(shape)).cuda().float() var_pred[:, :self.prefix_length] = var[:, :self.prefix_length] return var_pred def flip_trajectory(self, traj, target_length): traj_flipped = traj[:, -(target_length - self.target_length):].flip( dims=[1]) orient_flipped = torch.zeros( (traj_flipped.shape[0], traj_flipped.shape[1], 1)).cuda().float() # orient_flipped[:, 0] = np.pi # traj_diff = traj_flipped[:, 1:, [0, 2]] - traj_flipped[:, :-1, [0, 2]] # traj_diff /= torch.norm(traj_diff, dim=-1)[..., None] # orient_flipped[:, 1:, 0] = torch.atan2(traj_diff[:, :, 1], traj_diff[:, :, 0]) return traj_flipped, orient_flipped def generate_motion(self, load_saved_model=True, target_length=100, samples_to_generate=10): if load_saved_model: self.load_best_model() self.model.eval() test_loader = self.data_loader['test'] pos, traj, quat, orient, affs, spline, p_rs, labels = self.return_batch( [samples_to_generate], test_loader) pos = torch.from_numpy(pos).cuda() traj = torch.from_numpy(traj).cuda() quat = torch.from_numpy(quat).cuda() orient = torch.from_numpy(orient).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() p_rs = torch.from_numpy(p_rs).cuda() labels = torch.from_numpy(labels).cuda() traj_flipped, orient_flipped = self.flip_trajectory( traj, target_length) traj = torch.cat((traj, traj_flipped), dim=1) orient = torch.cat((orient, orient_flipped), dim=1) pos_pred = self.copy_prefix(pos, target_length) quat_pred = self.copy_prefix(quat, target_length) p_rs_pred = self.copy_prefix(p_rs, target_length) affs_pred = self.copy_prefix(affs, target_length) spline_pred = self.copy_prefix(spline, target_length) # forward with torch.no_grad(): for t in range(target_length): quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \ p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \ self.quat_h = \ self.model( quat_pred[:, t:self.prefix_length + t], p_rs_pred[:, t:self.prefix_length + t], affs_pred[:, t:self.prefix_length + t], spline_pred[:, t:self.prefix_length + t], orient[:, t:self.prefix_length + t], labels, quat_h=None if t == 0 else self.quat_h, return_prenorm=False) data_pred = \ self.mocap.get_predicted_features( pos_pred[:, :self.prefix_length + t], orient[:, :self.prefix_length + t], traj[:, self.prefix_length + t:self.prefix_length + t + 1], quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1], orient[:, self.prefix_length + t:self.prefix_length + t + 1]) pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = data_pred['positions_world'] affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = data_pred['affective_features'] spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = data_pred['spline'] recons_loss = self.args.recons_reg *\ (pos_pred[:, self.prefix_length:self.T] - pos_pred[:, self.prefix_length:self.T, 0:1] - pos[:, self.prefix_length:self.T] + pos[:, self.prefix_length:self.T, 0:1]).norm() pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\ detach().cpu().numpy() pos_np = pos.contiguous().view(pos.shape[0], pos.shape[1], -1).permute(0, 2, 1).\ detach().cpu().numpy() display_animations(pos_pred_np, self.V, self.C, self.joint_parents, save=True, dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch), overwrite=True) display_animations(pos_np, self.V, self.C, self.joint_parents, save=True, dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) + '_gt', overwrite=True) self.mocap.save_as_bvh( traj.detach().cpu().numpy(), orient.detach().cpu().numpy(), np.reshape(quat_pred.detach().cpu().numpy(), (quat_pred.shape[0], quat_pred.shape[1], -1, self.D)), 'render/bvh')
class Processor(object): """ Processor for gait generation """ def __init__(self, args, data_loader, C, F, num_classes, graph_dict, device='cuda:0'): self.args = args self.data_loader = data_loader self.num_classes = num_classes self.result = dict() self.iter_info = dict() self.epoch_info = dict() self.meta_info = dict(epoch=0, iter=0) self.device = device self.io = IO(self.args.work_dir, save_log=self.args.save_log, print_log=self.args.print_log) # model self.model = classifier.Classifier(C, F, num_classes, graph_dict) self.model.cuda('cuda:0') self.model.apply(weights_init) self.loss = nn.CrossEntropyLoss() self.best_loss = math.inf self.step_epochs = [ math.ceil(float(self.args.num_epoch * x)) for x in self.args.step ] self.best_epoch = None self.best_accuracy = np.zeros((1, np.max(self.args.topk))) self.accuracy_updated = False # optimizer if self.args.optimizer == 'SGD': self.optimizer = optim.SGD(self.model.parameters(), lr=self.args.base_lr, momentum=0.9, nesterov=self.args.nesterov, weight_decay=self.args.weight_decay) elif self.args.optimizer == 'Adam': self.optimizer = optim.Adam(self.model.parameters(), lr=self.args.base_lr, weight_decay=self.args.weight_decay) else: raise ValueError() self.lr = self.args.base_lr def adjust_lr(self): # if self.args.optimizer == 'SGD' and \ if self.meta_info['epoch'] in self.step_epochs: lr = self.args.base_lr * (0.1**np.sum( self.meta_info['epoch'] >= np.array(self.step_epochs))) for param_group in self.optimizer.param_groups: param_group['lr'] = lr self.lr = lr def show_epoch_info(self): for k, v in self.epoch_info.items(): self.io.print_log('\t{}: {}'.format(k, v)) if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.epoch_info) def show_iter_info(self): if self.meta_info['iter'] % self.args.log_interval == 0: info = '\tIter {} Done.'.format(self.meta_info['iter']) for k, v in self.iter_info.items(): if isinstance(v, float): info = info + ' | {}: {:.4f}'.format(k, v) else: info = info + ' | {}: {}'.format(k, v) self.io.print_log(info) if self.args.pavi_log: self.io.log('train', self.meta_info['iter'], self.iter_info) def show_topk(self, k, epoch): rank = self.result.argsort() hit_top_k = [l in rank[i, -k:] for i, l in enumerate(self.label)] accuracy = 100. * sum(hit_top_k) * 1.0 / len(hit_top_k) if accuracy > self.best_accuracy[0, k - 1]: self.best_accuracy[0, k - 1] = accuracy self.accuracy_updated = True self.best_epoch = epoch else: self.accuracy_updated = False print_epoch = self.best_epoch if self.best_epoch is not None else 0 self.io.print_log( '\tTop{}: {:.2f}%. Best so far: {:.2f}% (epoch: {:d}).'.format( k, accuracy, self.best_accuracy[0, k - 1], print_epoch)) def per_train(self): self.model.train() self.adjust_lr() loader = self.data_loader['train'] loss_value = [] for aff, gait, label in loader: # get data aff = aff.float().to(self.device) gait = gait.float().to(self.device) label = label.long().to(self.device) # forward output, _ = self.model(aff, gait) loss = self.loss(output, label) # backward self.optimizer.zero_grad() loss.backward() self.optimizer.step() # statistics self.iter_info['loss'] = loss.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) loss_value.append(self.iter_info['loss']) self.show_iter_info() self.meta_info['iter'] += 1 self.epoch_info['mean_loss'] = np.mean(loss_value) self.show_epoch_info() self.io.print_timer() def per_test(self, epoch=None, evaluation=True): self.model.eval() loader = self.data_loader['test'] loss_value = [] result_frag = [] label_frag = [] for aff, gait, label in loader: # get data aff = aff.float().to(self.device) gait = gait.float().to(self.device) label = label.long().to(self.device) # inference with torch.no_grad(): output, _ = self.model(aff, gait) result_frag.append(output.data.cpu().numpy()) # get loss if evaluation: loss = self.loss(output, label) loss_value.append(loss.item()) label_frag.append(label.data.cpu().numpy()) self.result = np.concatenate(result_frag) if evaluation: self.label = np.concatenate(label_frag) self.epoch_info['mean_loss'] = np.mean(loss_value) self.show_epoch_info() # show top-k accuracy for k in self.args.topk: self.show_topk(k, epoch) def train(self): for epoch in range(self.args.start_epoch, self.args.num_epoch): self.meta_info['epoch'] = epoch # training self.io.print_log('Training epoch: {}'.format(epoch)) self.per_train() self.io.print_log('Done.') # evaluation if (epoch % self.args.eval_interval == 0) or (epoch + 1 == self.args.num_epoch): self.io.print_log('Eval epoch: {}'.format(epoch)) self.per_test(epoch=epoch) self.io.print_log('Done.') # save model and weights if self.accuracy_updated: torch.save( self.model.state_dict(), os.path.join( self.args.work_dir, 'epoch_{}_acc_{:.2f}_model.pth.tar'.format( epoch, self.best_accuracy.item()))) def test(self): # the path of weights must be appointed if self.args.weights is None: raise ValueError('Please appoint --weights.') self.io.print_log('Model: {}.'.format(self.args.model)) self.io.print_log('Weights: {}.'.format(self.args.weights)) # evaluation self.io.print_log('Evaluation Start:') self.per_test() self.io.print_log('Done.\n') # save the output of model if self.args.save_result: result_dict = dict( zip(self.data_loader['test'].dataset.sample_name, self.result)) self.io.save_pkl(result_dict, 'test_result.pkl') def extract_best_feature(self, data, joints, coords): if self.best_epoch is None: self.best_epoch, best_accuracy = get_best_epoch_and_accuracy( self.args.work_dir) else: best_accuracy = self.best_accuracy.item() if self.best_epoch is not None: filename = os.path.join( self.args.work_dir, 'epoch_{}_acc_{:.2f}_model.pth.tar'.format( self.best_epoch, best_accuracy)) self.model.load_state_dict(torch.load(filename)) label_preds = np.empty(len(data)) features = np.empty((0, 4)) for i, each_data in enumerate(data): # get data aff = each_data[0] aff = torch.from_numpy(aff).float().to(self.device) aff = aff.unsqueeze(0) gait = each_data[1] gait = np.reshape(gait, (1, gait.shape[0], joints, coords, 1)) gait = np.moveaxis(gait, [1, 2, 3], [2, 3, 1]) gait = torch.from_numpy(gait).float().to(self.device) # get feature self.model.eval() with torch.no_grad(): output, feature = self.model(aff, gait) label_preds[i] = np.argmax(output.cpu().numpy()) features = np.append(features, feature.cpu().numpy().reshape( (1, feature.shape[1])), axis=0) return features, label_preds