def per_eval(self): self.model.eval() test_loader = self.data_loader['test'] eval_loss = 0. N = 0. for joint_offsets, pos, affs, quat, quat_valid_idx, \ text, text_valid_idx, intended_emotion, intended_polarity, \ acting_task, gender, age, handedness, \ native_tongue in self.yield_batch(self.args.batch_size, test_loader): with torch.no_grad(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_prelude = self.quats_eos.view(1, -1).cuda() \ .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float() quat_prelude[:, -1] = quat[:, 0].clone() quat_pred, quat_pred_pre_norm = self.model( text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, quat_prelude, joint_lengths / scales[..., None]) quat_pred_pre_norm = quat_pred_pre_norm.view( quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V, self.D, self.lower_body_start, self.args.upper_body_weight) quat_loss *= self.args.quat_reg root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) affs_pred = MocapDataset.get_mpi_affective_features(pos_pred) row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V row_sums[row_sums == 0.] = 1. shifted_pos = pos - pos[:, :, 0:1] shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1] recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos[:, 1:]) # recons_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos[:, 1:]).sum(-1) # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) + \ # recons_loss[:, :, self.lower_body_start:].sum(-1) # recons_loss = self.args.recons_reg * torch.mean( # (recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # # recons_derv_loss = torch.abs(shifted_pos_pred[:, 2:] - shifted_pos_pred[:, 1:-1] - # shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1) # recons_derv_loss = self.args.upper_body_weight * \ # (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) + \ # recons_derv_loss[:, :, self.lower_body_start:].sum(-1) # recons_derv_loss = 2. * self.args.recons_reg * \ # torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums) # # affs_loss = torch.abs(affs[:, 1:] - affs_pred[:, 1:]).sum(-1) # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:]) eval_loss += quat_norm_loss + quat_loss + recons_loss + affs_loss # eval_loss += quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss N += quat.shape[0] eval_loss /= N self.epoch_info['mean_loss'] = eval_loss if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info[ 'epoch'] > self.min_train_epochs: self.best_loss = self.epoch_info['mean_loss'] self.best_loss_epoch = self.meta_info['epoch'] self.loss_updated = True else: self.loss_updated = False self.show_epoch_info()
def per_train(self): self.model.train() train_loader = self.data_loader['train'] batch_loss = 0. N = 0. for joint_offsets, pos, affs, quat, quat_valid_idx,\ text, text_valid_idx, intended_emotion, intended_polarity,\ acting_task, gender, age, handedness,\ native_tongue in self.yield_batch(self.args.batch_size, train_loader): quat_prelude = self.quats_eos.view(1, -1).cuda() \ .repeat(self.T - 1, 1).unsqueeze(0).repeat(quat.shape[0], 1, 1).float() quat_prelude[:, -1] = quat[:, 0].clone() self.optimizer.zero_grad() with torch.autograd.detect_anomaly(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) quat_pred, quat_pred_pre_norm = self.model( text, intended_emotion, intended_polarity, acting_task, gender, age, handedness, native_tongue, quat_prelude, joint_lengths / scales[..., None]) quat_pred_pre_norm = quat_pred_pre_norm.view( quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred, quat[:, 1:], quat_valid_idx[:, 1:], self.V, self.D, self.lower_body_start, self.args.upper_body_weight) quat_loss *= self.args.quat_reg root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) affs_pred = MocapDataset.get_mpi_affective_features(pos_pred) # row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V # row_sums[row_sums == 0.] = 1. shifted_pos = pos - pos[:, :, 0:1] shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1] recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos[:, 1:]) # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1) # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_loss[:, :, self.lower_body_start:].sum(-1) # recons_loss = self.args.recons_reg *\ # torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] - # shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1) # recons_derv_loss = self.args.upper_body_weight *\ # (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_derv_loss[:, :, self.lower_body_start:].sum(-1) # recons_derv_loss = 2. * self.args.recons_reg *\ # torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums) # # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1) # affs_loss = self.args.affs_reg * torch.mean((affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:]) train_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss train_loss.backward() # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) self.optimizer.step() # animation_pred = { # 'joint_names': self.joint_names, # 'joint_offsets': joint_offsets, # 'joint_parents': self.joint_parents, # 'positions': pos_pred, # 'rotations': quat_pred # } # MocapDataset.save_as_bvh(animation_pred, # dataset_name=self.dataset, # subset_name='test') # animation = { # 'joint_names': self.joint_names, # 'joint_offsets': joint_offsets, # 'joint_parents': self.joint_parents, # 'positions': pos, # 'rotations': quat # } # MocapDataset.save_as_bvh(animation, # dataset_name=self.dataset, # subset_name='gt') # Compute statistics batch_loss += train_loss.item() N += quat.shape[0] # statistics self.iter_info['loss'] = train_loss.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) self.iter_info['tf'] = '{:.6f}'.format(self.tf) self.show_iter_info() self.meta_info['iter'] += 1 batch_loss = batch_loss / N self.epoch_info['mean_loss'] = batch_loss self.show_epoch_info() self.io.print_timer() self.adjust_lr() self.adjust_tf()
def forward_pass(self, joint_offsets, pos, affs, quat, quat_sos, quat_valid_idx, text, text_valid_idx, perceived_emotion, perceived_polarity, acting_task, gender, age, handedness, native_tongue): self.optimizer.zero_grad() with torch.autograd.detect_anomaly(): joint_lengths = torch.norm(joint_offsets, dim=-1) scales, _ = torch.max(joint_lengths, dim=-1) # text_latent = self.model(text, perceived_emotion, perceived_polarity, # acting_task, gender, age, handedness, native_tongue, # only_encoder=True) # quat_pred = torch.zeros_like(quat) # quat_pred_pre_norm = torch.zeros_like(quat) # quat_in = quat_sos[:, :self.T_steps] # quat_valid_idx_max = torch.max(torch.sum(quat_valid_idx, dim=-1)) # for t in range(0, self.T, self.T_steps): # if t > quat_valid_idx_max: # break # quat_pred[:, t:min(self.T, t + self.T_steps)],\ # quat_pred_pre_norm[:, t:min(self.T, t + self.T_steps)] =\ # self.model(text_latent, quat=quat_in, offset_lengths= joint_lengths / scales[..., None], # only_decoder=True) # if torch.rand(1) > self.tf: # if t + self.T_steps * 2 >= self.T: # quat_in = quat_pred[:, -(self.T - t - self.T_steps):].clone() # else: # quat_in = quat_pred[:, t:t + self.T_steps].clone() # else: # if t + self.T_steps * 2 >= self.T: # quat_in = quat[:, -(self.T - t - self.T_steps):].clone() # else: # quat_in = quat[:, t:min(self.T, t + self.T_steps)].clone() quat_pred, quat_pred_pre_norm = self.model( text, perceived_emotion, perceived_polarity, acting_task, gender, age, handedness, native_tongue, quat_sos[:, :-1], joint_lengths / scales[..., None]) quat_fixed = qfix(quat.contiguous().view( quat.shape[0], quat.shape[1], -1, self.D)).contiguous().view(quat.shape[0], quat.shape[1], -1) quat_pred = qfix(quat_pred.contiguous().view( quat_pred.shape[0], quat_pred.shape[1], -1, self.D)).contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1) quat_pred_pre_norm = quat_pred_pre_norm.view( quat_pred_pre_norm.shape[0], quat_pred_pre_norm.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(quat_pred_pre_norm ** 2, dim=-1) - 1) ** 2) # quat_loss, quat_derv_loss = losses.quat_angle_loss(quat_pred, quat_fixed, # quat_valid_idx[:, 1:], # self.V, self.D, # self.lower_body_start, # self.args.upper_body_weight) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred, quat_fixed[:, 1:], quat_valid_idx[:, 1:], self.V, self.D, self.lower_body_start, self.args.upper_body_weight) quat_loss *= self.args.quat_reg root_pos = torch.zeros(quat_pred.shape[0], quat_pred.shape[1], self.C).cuda() pos_pred = MocapDataset.forward_kinematics( quat_pred.contiguous().view(quat_pred.shape[0], quat_pred.shape[1], -1, self.D), root_pos, self.joint_parents, torch.cat((root_pos[:, 0:1], joint_offsets), dim=1).unsqueeze(1)) affs_pred = MocapDataset.get_mpi_affective_features(pos_pred) row_sums = quat_valid_idx.sum(1, keepdim=True) * self.D * self.V row_sums[row_sums == 0.] = 1. shifted_pos = pos - pos[:, :, 0:1] shifted_pos_pred = pos_pred - pos_pred[:, :, 0:1] # recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos) # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, :, 7:15]) recons_loss = self.recons_loss_func(shifted_pos_pred, shifted_pos[:, 1:]) # recons_arms = self.recons_loss_func(shifted_pos_pred[:, :, 7:15], shifted_pos[:, 1:, 7:15]) # recons_loss = torch.abs(shifted_pos_pred - shifted_pos[:, 1:]).sum(-1) # recons_loss = self.args.upper_body_weight * (recons_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_loss[:, :, self.lower_body_start:].sum(-1) recons_loss = self.args.recons_reg *\ torch.mean((recons_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # # recons_derv_loss = torch.abs(shifted_pos_pred[:, 1:] - shifted_pos_pred[:, :-1] - # shifted_pos[:, 2:] + shifted_pos[:, 1:-1]).sum(-1) # recons_derv_loss = self.args.upper_body_weight *\ # (recons_derv_loss[:, :, :self.lower_body_start].sum(-1)) +\ # recons_derv_loss[:, :, self.lower_body_start:].sum(-1) # recons_derv_loss = 2. * self.args.recons_reg *\ # torch.mean((recons_derv_loss * quat_valid_idx[:, 2:]).sum(-1) / row_sums) # # affs_loss = torch.abs(affs[:, 1:] - affs_pred).sum(-1) affs_loss = self.affs_loss_func(affs_pred, affs[:, 1:]) affs_loss = self.args.affs_reg * torch.mean( (affs_loss * quat_valid_idx[:, 1:]).sum(-1) / row_sums) # affs_loss = self.affs_loss_func(affs_pred, affs) total_loss = quat_norm_loss + quat_loss + recons_loss + affs_loss # train_loss = quat_norm_loss + quat_loss + recons_loss + recons_derv_loss + affs_loss return total_loss
def per_test(self): self.model.eval() test_loader = self.data_loader['test'] valid_loss = 0. N = 0. for pos, quat, orient, z_mean, z_dev,\ root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, test_loader): with torch.no_grad(): pos = torch.from_numpy(pos).cuda() orient = torch.from_numpy(orient).cuda() quat = torch.from_numpy(quat).cuda() z_mean = torch.from_numpy(z_mean).cuda() z_dev = torch.from_numpy(z_dev).cuda() root_speed = torch.from_numpy(root_speed).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1) z_rs = torch.cat((z_dev, root_speed), dim=-1) quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1) pos_pred = pos.clone() orient_pred = orient.clone() quat_pred = quat.clone() z_rs_pred = z_rs.clone() affs_pred = affs.clone() spline_pred = spline.clone() orient_prenorm_terms = torch.zeros_like(orient_pred) quat_prenorm_terms = torch.zeros_like(quat_pred) # forward for t in range(self.target_length): orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\ quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\ z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\ self.orient_h, self.quat_h,\ orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\ quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \ self.model( orient_pred[:, t:self.prefix_length + t], quat_pred[:, t:self.prefix_length + t], z_rs_pred[:, t:self.prefix_length + t], affs_pred[:, t:self.prefix_length + t], spline[:, t:self.prefix_length + t], labels[:, t:self.prefix_length + t], orient_h=None if t == 0 else self.orient_h, quat_h=None if t == 0 else self.quat_h, return_prenorm=True) pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1], \ affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1],\ spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[:, :self.prefix_length + t], pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]], z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean, orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1], quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1]) prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1) prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg *\ torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( torch.cat((orient_pred[:, self.prefix_length - 1:], quat_pred[:, self.prefix_length - 1:]), dim=-1), quat_all, self.V, self.D) quat_loss *= self.args.quat_reg recons_loss = self.args.recons_reg *\ (pos_pred[:, self.prefix_length:] - pos_pred[:, self.prefix_length:, 0:1] - pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm() valid_loss += recons_loss N += quat.shape[0] valid_loss /= N # if self.meta_info['epoch'] > 5 and self.loss_updated: # pos_pred_np = pos_pred.contiguous().view(pos_pred.shape[0], pos_pred.shape[1], -1).permute(0, 2, 1).\ # detach().cpu().numpy() # display_animations(pos_pred_np, self.V, self.C, self.mocap.joint_parents, save=True, # dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch), # overwrite=True) # pos_in_np = pos_in.contiguous().view(pos_in.shape[0], pos_in.shape[1], -1).permute(0, 2, 1).\ # detach().cpu().numpy() # display_animations(pos_in_np, self.V, self.C, self.mocap.joint_parents, save=True, # dataset_name=self.dataset, subset_name='epoch_' + str(self.best_loss_epoch) + # '_gt', # overwrite=True) self.epoch_info['mean_loss'] = valid_loss if self.epoch_info['mean_loss'] < self.best_loss and self.meta_info['epoch'] > self.min_train_epochs: self.best_loss = self.epoch_info['mean_loss'] self.best_loss_epoch = self.meta_info['epoch'] self.loss_updated = True else: self.loss_updated = False self.show_epoch_info()
def per_train(self): self.model.train() train_loader = self.data_loader['train'] batch_loss = 0. N = 0. for pos, quat, orient, z_mean, z_dev,\ root_speed, affs, spline, labels in self.yield_batch(self.args.batch_size, train_loader): pos = torch.from_numpy(pos).cuda() orient = torch.from_numpy(orient).cuda() quat = torch.from_numpy(quat).cuda() z_mean = torch.from_numpy(z_mean).cuda() z_dev = torch.from_numpy(z_dev).cuda() root_speed = torch.from_numpy(root_speed).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() labels = torch.from_numpy(labels).cuda().repeat(1, quat.shape[1], 1) z_rs = torch.cat((z_dev, root_speed), dim=-1) quat_all = torch.cat((orient[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:]), dim=-1) pos_pred = pos.clone() orient_pred = orient.clone() quat_pred = quat.clone() z_rs_pred = z_rs.clone() affs_pred = affs.clone() spline_pred = spline.clone() pos_pred_all = pos.clone() orient_pred_all = orient.clone() quat_pred_all = quat.clone() z_rs_pred_all = z_rs.clone() affs_pred_all = affs.clone() spline_pred_all = spline.clone() orient_prenorm_terms = torch.zeros_like(orient_pred) quat_prenorm_terms = torch.zeros_like(quat_pred) # forward self.optimizer.zero_grad() for t in range(self.target_length): orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ self.orient_h, self.quat_h,\ orient_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1],\ quat_prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \ self.model( orient_pred[:, t:self.prefix_length + t], quat_pred[:, t:self.prefix_length + t], z_rs_pred[:, t:self.prefix_length + t], affs_pred[:, t:self.prefix_length + t], spline_pred[:, t:self.prefix_length + t], labels[:, t:self.prefix_length + t], orient_h=None if t == 0 else self.orient_h, quat_h=None if t == 0 else self.quat_h, return_prenorm=True) pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \ spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[:, :self.prefix_length + t], pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0, [0, 2]], z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0] + z_mean, orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1], quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1]) if np.random.uniform(size=1)[0] > self.tf: pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() orient_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ orient_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() z_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ z_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() spline_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ spline_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1].clone() prenorm_terms = torch.cat((orient_prenorm_terms, quat_prenorm_terms), dim=-1) prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg * torch.mean((torch.sum(prenorm_terms ** 2, dim=-1) - 1) ** 2) quat_loss, quat_derv_loss = losses.quat_angle_loss( torch.cat((orient_pred_all[:, self.prefix_length - 1:], quat_pred_all[:, self.prefix_length - 1:]), dim=-1), quat_all, self.V, self.D) quat_loss *= self.args.quat_reg z_rs_loss = self.z_rs_loss_func(z_rs_pred_all[:, self.prefix_length:], z_rs[:, self.prefix_length:]) spline_loss = self.spline_loss_func(spline_pred_all[:, self.prefix_length:], spline[:, self.prefix_length:]) fs_loss = losses.foot_speed_loss(pos_pred, pos) loss_total = quat_norm_loss + quat_loss + quat_derv_loss + z_rs_loss + spline_loss + fs_loss loss_total.backward() # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) self.optimizer.step() # animation_pred = { # 'joint_names': self.joint_names, # 'joint_offsets': torch.from_numpy(self.joint_offsets[1:]). # float().unsqueeze(0).repeat(pos_pred_all.shape[0], 1, 1), # 'joint_parents': self.joint_parents, # 'positions': pos_pred_all, # 'rotations': torch.cat((orient_pred_all, quat_pred_all), dim=-1) # } # MocapDataset.save_as_bvh(animation_pred, # dataset_name=self.dataset, # subset_name='test') # Compute statistics batch_loss += loss_total.item() N += quat.shape[0] # statistics self.iter_info['loss'] = loss_total.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) self.iter_info['tf'] = '{:.6f}'.format(self.tf) self.show_iter_info() self.meta_info['iter'] += 1 batch_loss = batch_loss / N self.epoch_info['mean_loss'] = batch_loss self.show_epoch_info() self.io.print_timer() self.adjust_lr() self.adjust_tf()
def per_train(self): self.model.train() train_loader = self.data_loader['train'] batch_loss = 0. N = 0. for pos, quat, orient, affs, spline, p_rs, labels in self.yield_batch( self.args.batch_size, train_loader): pos = torch.from_numpy(pos).cuda() quat = torch.from_numpy(quat).cuda() orient = torch.from_numpy(orient).cuda() affs = torch.from_numpy(affs).cuda() spline = torch.from_numpy(spline).cuda() p_rs = torch.from_numpy(p_rs).cuda() labels = torch.from_numpy(labels).cuda() pos_pred = pos.clone() quat_pred = quat.clone() p_rs_pred = p_rs.clone() affs_pred = affs.clone() pos_pred_all = pos.clone() quat_pred_all = quat.clone() p_rs_pred_all = p_rs.clone() affs_pred_all = affs.clone() prenorm_terms = torch.zeros_like(quat_pred) # forward self.optimizer.zero_grad() for t in range(self.target_length): quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \ p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], \ self.quat_h, prenorm_terms[:, self.prefix_length + t: self.prefix_length + t + 1] = \ self.model( quat_pred[:, t:self.prefix_length + t], p_rs_pred[:, t:self.prefix_length + t], affs_pred[:, t:self.prefix_length + t], spline[:, t:self.prefix_length + t], orient[:, t:self.prefix_length + t], labels, quat_h=None if t == 0 else self.quat_h, return_prenorm=True) pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1],\ affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] = \ self.mocap.get_predicted_features( pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1, 0], quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1], orient[:, self.prefix_length + t:self.prefix_length + t + 1]) if np.random.uniform(size=1)[0] > self.tf: pos_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ pos_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] quat_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ quat_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] p_rs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ p_rs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] affs_pred[:, self.prefix_length + t:self.prefix_length + t + 1] = \ affs_pred_all[:, self.prefix_length + t:self.prefix_length + t + 1] prenorm_terms = prenorm_terms.view(prenorm_terms.shape[0], prenorm_terms.shape[1], -1, self.D) quat_norm_loss = self.args.quat_norm_reg * torch.mean( (torch.sum(prenorm_terms**2, dim=-1) - 1)**2) quat_loss, quat_derv_loss = losses.quat_angle_loss( quat_pred_all[:, self.prefix_length - 1:], quat[:, self.prefix_length - 1:], self.V, self.D) quat_loss *= self.args.quat_reg p_rs_loss = self.p_rs_loss_func( p_rs_pred_all[:, self.prefix_length:], p_rs[:, self.prefix_length:]) affs_loss = self.affs_loss_func( affs_pred_all[:, self.prefix_length:], affs[:, self.prefix_length:]) # recons_loss = self.args.recons_reg *\ # (pos_pred_all[:, self.prefix_length:] - pos_pred_all[:, self.prefix_length:, 0:1] - # pos[:, self.prefix_length:] + pos[:, self.prefix_length:, 0:1]).norm() loss_total = quat_norm_loss + quat_loss + quat_derv_loss + p_rs_loss + affs_loss # + recons_loss loss_total.backward() # nn.utils.clip_grad_norm_(self.model.parameters(), self.args.gradient_clip) self.optimizer.step() # Compute statistics batch_loss += loss_total.item() N += quat.shape[0] # statistics self.iter_info['loss'] = loss_total.data.item() self.iter_info['lr'] = '{:.6f}'.format(self.lr) self.iter_info['tf'] = '{:.6f}'.format(self.tf) self.show_iter_info() self.meta_info['iter'] += 1 batch_loss = batch_loss / N self.epoch_info['mean_loss'] = batch_loss self.show_epoch_info() self.io.print_timer() self.adjust_lr() self.adjust_tf()