def train(self): """Train the model for one epoch, evaluate on validation set and save the best model """ start_time = datetime.now() self._train(self.train_loader) utils.save_state_dict(self.state_dict(), os.path.join(self.save_path, 'ckpt.pth')) dev_class_loss, dev_domain_loss, dev_class_output, _,_, dev_domain_accuarcy = \ self._test(self.dev_loader) dev_predict_prob = self.inference(dev_class_output) dev_per_class_AP = utils.compute_weighted_AP(self.dev_target, dev_predict_prob, self.dev_class_weight) dev_mAP = utils.compute_mAP(dev_per_class_AP, self.subclass_idx) self.log_result('Dev epoch', {'class_loss': dev_class_loss/len(self.dev_loader), 'domain_loss': dev_domain_loss/len(self.dev_loader), 'mAP': dev_mAP, 'domain_accuracy': dev_domain_accuarcy}, self.epoch) if (self.epoch) > 1 and (dev_mAP > self.best_dev_mAP): self.best_dev_mAP = dev_mAP utils.save_state_dict(self.state_dict(), os.path.join(self.save_path, 'best.pth')) duration = datetime.now() - start_time print('Finish training epoch {}, dev class loss: {}, dev doamin loss: {}, dev mAP: {},'\ 'domain_accuracy: {}, time used: {}' .format(self.epoch, dev_class_loss/len(self.dev_loader), dev_domain_loss/len(self.dev_loader), dev_mAP, dev_domain_accuarcy, duration))
def _compute_result(self, model_name, data_loader, target, class_weight, inference_fn, save_name, conditional=False): """Load model and compute performance with given inference method""" state_dict = torch.load(os.path.join(self.save_path, model_name)) self.network.load_state_dict(state_dict['model']) loss, output, feature = self._test(data_loader) if conditional: predict = inference_fn(output, target) else: predict = inference_fn(output) per_class_AP = utils.compute_weighted_AP(target, predict, class_weight) mAP = utils.compute_mAP(per_class_AP, self.subclass_idx) result = { 'output': output.cpu().numpy(), 'feature': feature.cpu().numpy(), 'per_class_AP': per_class_AP, 'mAP': mAP } utils.save_pkl(result, os.path.join(self.save_path, save_name)) return mAP
def test(self): # Test and save the result state_dict = torch.load(os.path.join(self.save_path, 'best.pth')) self.load_state_dict(state_dict) dev_class_loss, dev_domain_loss, dev_class_output, dev_domain_output, \ dev_feature, dev_domain_accuracy = self._test(self.dev_loader) dev_predict_prob = self.inference(dev_class_output) dev_per_class_AP = utils.compute_weighted_AP(self.dev_target, dev_predict_prob, self.dev_class_weight) dev_mAP = utils.compute_mAP(dev_per_class_AP, self.subclass_idx) dev_result = { 'output': dev_class_output.cpu().numpy(), 'feature': dev_feature.cpu().numpy(), 'per_class_AP': dev_per_class_AP, 'mAP': dev_mAP, 'domain_output': dev_domain_output.cpu().numpy(), 'domain_accuracy': dev_domain_accuracy } utils.save_pkl(dev_result, os.path.join(self.save_path, 'dev_result.pkl')) test_class_loss, test_domain_loss, test_class_output, test_domain_output, \ test_feature, test_domain_accuracy = self._test(self.test_loader) test_predict_prob = self.inference(test_class_output) test_per_class_AP = utils.compute_weighted_AP(self.test_target, test_predict_prob, self.test_class_weight) test_mAP = utils.compute_mAP(test_per_class_AP, self.subclass_idx) test_result = { 'output': test_class_output.cpu().numpy(), 'feature': test_feature.cpu().numpy(), 'per_class_AP': test_per_class_AP, 'mAP': test_mAP, 'domain_output': test_domain_output.cpu().numpy(), 'domain_accuracy': test_domain_accuracy } utils.save_pkl(test_result, os.path.join(self.save_path, 'test_result.pkl')) # Output the mean AP for the best model on dev and test set info = ('Dev mAP: {}\n' 'Test mAP: {}'.format(dev_mAP, test_mAP)) utils.write_info(os.path.join(self.save_path, 'result.txt'), info)
def train(self): """Train the model for one epoch, evaluate on validation set and save the best model for each inference method """ start_time = datetime.now() self._train(self.train_loader) utils.save_state_dict(self.state_dict(), os.path.join(self.save_path, 'ckpt.pth')) dev_loss, dev_output, _ = self._test(self.dev_loader) dev_predict_conditional = self.inference_conditional( dev_output, self.dev_target) dev_per_class_AP_conditional = utils.compute_weighted_AP( self.dev_target, dev_predict_conditional, self.dev_class_weight) dev_mAP_conditional = utils.compute_mAP(dev_per_class_AP_conditional, self.subclass_idx) if dev_mAP_conditional > self.best_dev_mAP_conditional: self.best_dev_mAP_conditional = dev_mAP_conditional utils.save_state_dict( self.state_dict(), os.path.join(self.save_path, 'best-conditional.pth')) dev_predict_max = self.inference_max(dev_output) dev_per_class_AP_max = utils.compute_weighted_AP( self.dev_target, dev_predict_max, self.dev_class_weight) dev_mAP_max = utils.compute_mAP(dev_per_class_AP_max, self.subclass_idx) if dev_mAP_max > self.best_dev_mAP_max: self.best_dev_mAP_max = dev_mAP_max utils.save_state_dict(self.state_dict(), os.path.join(self.save_path, 'best-max.pth')) dev_predict_sum_prob = self.inference_sum_prob(dev_output) dev_per_class_AP_sum_prob = utils.compute_weighted_AP( self.dev_target, dev_predict_sum_prob, self.dev_class_weight) dev_mAP_sum_prob = utils.compute_mAP(dev_per_class_AP_sum_prob, self.subclass_idx) if dev_mAP_sum_prob > self.best_dev_mAP_sum_prob: self.best_dev_mAP_sum_prob = dev_mAP_sum_prob utils.save_state_dict( self.state_dict(), os.path.join(self.save_path, 'best-sum_prob.pth')) dev_predict_sum_out = self.inference_sum_out(dev_output) dev_per_class_AP_sum_out = utils.compute_weighted_AP( self.dev_target, dev_predict_sum_out, self.dev_class_weight) dev_mAP_sum_out = utils.compute_mAP(dev_per_class_AP_sum_out, self.subclass_idx) if dev_mAP_sum_out > self.best_dev_mAP_sum_out: self.best_dev_mAP_sum_out = dev_mAP_sum_out utils.save_state_dict( self.state_dict(), os.path.join(self.save_path, 'best-sum_out.pth')) self.log_result( 'Dev epoch', { 'loss': dev_loss / len(self.dev_loader), 'mAP_conditional': dev_mAP_conditional, 'mAP_max': dev_mAP_max, 'mAP_sum_prob': dev_mAP_sum_prob, 'mAP_sum_out': dev_mAP_sum_out, }, self.epoch) duration = datetime.now() - start_time print(('Finish training epoch {}, dev mAP conditional: {}' 'dev mAP max: {}, dev mAP sum prob: {}, ' 'dev mAP sum out: {}, time used: {}').format( self.epoch, dev_mAP_conditional, dev_mAP_max, dev_mAP_sum_prob, dev_mAP_sum_out, duration))