示例#1
0
 def __init__(self, prefix_length, num_joints, num_outputs, num_controls, model_velocities):
     self.translations_size = num_outputs
     self.controls_size = num_controls
     self.model_velocities = model_velocities
     
     self.model = QuaterNet(num_joints, num_outputs, num_controls, model_velocities)
     self.use_cuda = False
     
     self.prefix_length = prefix_length
         
     dec_params = 0
     for parameter in self.model.parameters():
         dec_params += parameter.numel()
     print('# parameters:', dec_params)
示例#2
0
class PoseNetwork:
    def __init__(self, prefix_length, num_joints, num_outputs, num_controls, model_velocities):
        self.translations_size = num_outputs
        self.controls_size = num_controls
        self.model_velocities = model_velocities
        
        self.model = QuaterNet(num_joints, num_outputs, num_controls, model_velocities)
        self.use_cuda = False
        
        self.prefix_length = prefix_length
            
        dec_params = 0
        for parameter in self.model.parameters():
            dec_params += parameter.numel()
        print('# parameters:', dec_params)
        
    def cuda(self):
        self.use_cuda = True
        self.model.cuda()
        return self
    
    def eval(self):
        self.model.eval()
        return self
            
    def _prepare_next_batch_impl(self, batch_size, dataset, target_length, sequences):
        # This method must be implemented by the subclass
        pass
    
    def _loss_impl(self, predicted, expected):
        # This method must be implemented by the subclass
        pass
    
    def train(self, dataset, target_length, sequences_train, sequences_valid, batch_size, n_epochs=3000, rot_reg=0.01):
        np.random.seed(1234)
        self.model.train()
        
        lr = 0.001
        batch_size_valid = 30
        lr_decay = 0.999
        teacher_forcing_ratio = 1 # Start by forcing the ground truth
        tf_decay = 0.995 # Teacher forcing decay rate
        gradient_clip = 0.1
        
        optimizer = optim.Adam(self.model.parameters(), lr=lr)

        if len(sequences_valid) > 0:
            batch_in_valid, batch_out_valid = next(self._prepare_next_batch_impl(
                batch_size_valid, dataset, target_length, sequences_valid))
            inputs_valid = torch.from_numpy(batch_in_valid)
            outputs_valid = torch.from_numpy(batch_out_valid)
            if self.use_cuda:
                inputs_valid = inputs_valid.cuda()
                outputs_valid = outputs_valid.cuda()
        
        losses = []
        valid_losses = []
        gradient_norms = []
        print('Training for %d epochs' % (n_epochs))
        start_time = time()
        start_epoch = 0
        try:
            for epoch in range(n_epochs):
                batch_loss = 0.0
                N = 0
                for batch_in, batch_out in self._prepare_next_batch_impl(batch_size, dataset, target_length, sequences_train):
                    # Pick a random chunk from each sequence
                    inputs = torch.from_numpy(batch_in)
                    outputs = torch.from_numpy(batch_out)
                    
                    if self.use_cuda:
                        inputs = inputs.cuda()
                        outputs = outputs.cuda()

                    optimizer.zero_grad()

                    terms = []
                    predictions = []
                    # Initialize with prefix
                    predicted, hidden, term = self.model(inputs[:, :self.prefix_length], None, True)
                    terms.append(term)
                    predictions.append(predicted)

                    tf_mask = np.random.uniform(size=target_length-1) < teacher_forcing_ratio
                    i = 0
                    while i < target_length - 1:
                        contiguous_frames = 1
                        # Batch together consecutive "teacher forcings" to improve performance
                        if tf_mask[i]:
                            while i + contiguous_frames < target_length - 1 and tf_mask[i + contiguous_frames]:
                                contiguous_frames += 1
                            # Feed ground truth
                            predicted, hidden, term = self.model(inputs[:, self.prefix_length+i:self.prefix_length+i+contiguous_frames],
                                                                 hidden, True, True)
                        else:
                            # Feed own output
                            if self.controls_size > 0:
                                predicted = torch.cat((predicted,
                                               inputs[:, self.prefix_length+i:self.prefix_length+i+1, -self.controls_size:]), dim=2)
                            predicted, hidden, term = self.model(predicted, hidden, True)
                        terms.append(term)
                        predictions.append(predicted)
                        if contiguous_frames > 1:
                            predicted = predicted[:, -1:]
                        i += contiguous_frames

                    terms = torch.cat(terms, dim=1)
                    terms = terms.view(terms.shape[0], terms.shape[1], -1, 4)
                    penalty_loss = rot_reg * torch.mean((torch.sum(terms**2, dim=3) - 1)**2)

                    predictions = torch.cat(predictions, dim=1)
                    loss = self._loss_impl(predictions, outputs)

                    loss_total = penalty_loss + loss
                    loss_total.backward()
                    nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clip)
                    optimizer.step()

                    # Compute statistics
                    batch_loss += loss.item() * inputs.shape[0]
                    N += inputs.shape[0]

                batch_loss = batch_loss / N
                losses.append(batch_loss)

                # Run validation
                if len(sequences_valid) > 0:
                    with torch.no_grad():
                        predictions = []
                        predicted, hidden = self.model(inputs_valid[:, :self.prefix_length])
                        predictions.append(predicted)
                        for i in range(target_length - 1):
                            # Feed own output
                            if self.controls_size > 0:
                                predicted = torch.cat((predicted,
                                        inputs_valid[:, self.prefix_length+i:self.prefix_length+i+1, -self.controls_size:]), dim=2)
                            predicted, hidden = self.model(predicted, hidden)
                            predictions.append(predicted)
                        predictions = torch.cat(predictions, dim=1)
                        loss = self._loss_impl(predictions, outputs_valid)
                        valid_loss = loss.item()
                        valid_losses.append(valid_loss)
                        print('[%d] loss: %.5f valid_loss %.5f lr %f tf_ratio %f' % (epoch + 1, batch_loss, valid_loss,
                                                                  lr, teacher_forcing_ratio))
                else:
                    print('[%d] loss: %.5f lr %f tf_ratio %f' % (epoch + 1, batch_loss,
                                                              lr, teacher_forcing_ratio))
                teacher_forcing_ratio *= tf_decay
                lr *= lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= lr_decay

                if epoch > 0 and (epoch+1) % 10 == 0:
                    next_time = time()
                    time_per_epoch = (next_time - start_time)/(epoch - start_epoch)
                    print('Benchmark:', time_per_epoch, 's per epoch')
                    start_time = next_time
                    start_epoch = epoch
        except KeyboardInterrupt:
            print('Training aborted.')
            
        print('Done.')
        #print('gradient_norms =', gradient_norms)
        #print('losses =', losses)
        #print('valid_losses =', valid_losses)
        return losses, valid_losses, gradient_norms

        
    def save_weights(self, model_file):
        print('Saving weights to', model_file)
        torch.save(self.model.state_dict(), model_file)
        
    def load_weights(self, model_file):
        print('Loading weights from', model_file)
        self.model.load_state_dict(torch.load(model_file, map_location=lambda storage, loc: storage))