def train(self):
     """Train the model for one epoch, evaluate on validation set and 
     save the best model
     """
     
     start_time = datetime.now()
     self._train(self.train_loader)
     utils.save_state_dict(self.state_dict(), os.path.join(self.save_path, 'ckpt.pth'))
     
     dev_class_loss, dev_domain_loss, dev_class_output, _,_, dev_domain_accuarcy = \
         self._test(self.dev_loader)
     dev_predict_prob = self.inference(dev_class_output)
     dev_per_class_AP = utils.compute_weighted_AP(self.dev_target, dev_predict_prob, 
                                                  self.dev_class_weight)
     dev_mAP = utils.compute_mAP(dev_per_class_AP, self.subclass_idx)
     
     self.log_result('Dev epoch', 
                     {'class_loss': dev_class_loss/len(self.dev_loader), 
                      'domain_loss': dev_domain_loss/len(self.dev_loader),
                      'mAP': dev_mAP,
                      'domain_accuracy': dev_domain_accuarcy},
                     self.epoch)
     if (self.epoch) > 1 and (dev_mAP > self.best_dev_mAP):
         self.best_dev_mAP = dev_mAP
         utils.save_state_dict(self.state_dict(), os.path.join(self.save_path, 'best.pth'))
     
     duration = datetime.now() - start_time
     print('Finish training epoch {}, dev class loss: {}, dev doamin loss: {}, dev mAP: {},'\
           'domain_accuracy: {}, time used: {}'
           .format(self.epoch, dev_class_loss/len(self.dev_loader), 
                   dev_domain_loss/len(self.dev_loader), dev_mAP, dev_domain_accuarcy,
                   duration))
    def _compute_result(self,
                        model_name,
                        data_loader,
                        target,
                        class_weight,
                        inference_fn,
                        save_name,
                        conditional=False):
        """Load model and compute performance with given inference method"""

        state_dict = torch.load(os.path.join(self.save_path, model_name))
        self.network.load_state_dict(state_dict['model'])
        loss, output, feature = self._test(data_loader)
        if conditional:
            predict = inference_fn(output, target)
        else:
            predict = inference_fn(output)
        per_class_AP = utils.compute_weighted_AP(target, predict, class_weight)
        mAP = utils.compute_mAP(per_class_AP, self.subclass_idx)
        result = {
            'output': output.cpu().numpy(),
            'feature': feature.cpu().numpy(),
            'per_class_AP': per_class_AP,
            'mAP': mAP
        }
        utils.save_pkl(result, os.path.join(self.save_path, save_name))
        return mAP
    def test(self):
        # Test and save the result
        state_dict = torch.load(os.path.join(self.save_path, 'best.pth'))
        self.load_state_dict(state_dict)

        dev_class_loss, dev_domain_loss, dev_class_output, dev_domain_output, \
            dev_feature, dev_domain_accuracy = self._test(self.dev_loader)
        dev_predict_prob = self.inference(dev_class_output)
        dev_per_class_AP = utils.compute_weighted_AP(self.dev_target,
                                                     dev_predict_prob,
                                                     self.dev_class_weight)
        dev_mAP = utils.compute_mAP(dev_per_class_AP, self.subclass_idx)
        dev_result = {
            'output': dev_class_output.cpu().numpy(),
            'feature': dev_feature.cpu().numpy(),
            'per_class_AP': dev_per_class_AP,
            'mAP': dev_mAP,
            'domain_output': dev_domain_output.cpu().numpy(),
            'domain_accuracy': dev_domain_accuracy
        }
        utils.save_pkl(dev_result,
                       os.path.join(self.save_path, 'dev_result.pkl'))

        test_class_loss, test_domain_loss, test_class_output, test_domain_output, \
            test_feature, test_domain_accuracy = self._test(self.test_loader)
        test_predict_prob = self.inference(test_class_output)
        test_per_class_AP = utils.compute_weighted_AP(self.test_target,
                                                      test_predict_prob,
                                                      self.test_class_weight)
        test_mAP = utils.compute_mAP(test_per_class_AP, self.subclass_idx)
        test_result = {
            'output': test_class_output.cpu().numpy(),
            'feature': test_feature.cpu().numpy(),
            'per_class_AP': test_per_class_AP,
            'mAP': test_mAP,
            'domain_output': test_domain_output.cpu().numpy(),
            'domain_accuracy': test_domain_accuracy
        }
        utils.save_pkl(test_result,
                       os.path.join(self.save_path, 'test_result.pkl'))

        # Output the mean AP for the best model on dev and test set
        info = ('Dev mAP: {}\n' 'Test mAP: {}'.format(dev_mAP, test_mAP))
        utils.write_info(os.path.join(self.save_path, 'result.txt'), info)
    def train(self):
        """Train the model for one epoch, evaluate on validation set and 
        save the best model for each inference method
        """

        start_time = datetime.now()
        self._train(self.train_loader)
        utils.save_state_dict(self.state_dict(),
                              os.path.join(self.save_path, 'ckpt.pth'))
        dev_loss, dev_output, _ = self._test(self.dev_loader)

        dev_predict_conditional = self.inference_conditional(
            dev_output, self.dev_target)
        dev_per_class_AP_conditional = utils.compute_weighted_AP(
            self.dev_target, dev_predict_conditional, self.dev_class_weight)
        dev_mAP_conditional = utils.compute_mAP(dev_per_class_AP_conditional,
                                                self.subclass_idx)
        if dev_mAP_conditional > self.best_dev_mAP_conditional:
            self.best_dev_mAP_conditional = dev_mAP_conditional
            utils.save_state_dict(
                self.state_dict(),
                os.path.join(self.save_path, 'best-conditional.pth'))

        dev_predict_max = self.inference_max(dev_output)
        dev_per_class_AP_max = utils.compute_weighted_AP(
            self.dev_target, dev_predict_max, self.dev_class_weight)
        dev_mAP_max = utils.compute_mAP(dev_per_class_AP_max,
                                        self.subclass_idx)
        if dev_mAP_max > self.best_dev_mAP_max:
            self.best_dev_mAP_max = dev_mAP_max
            utils.save_state_dict(self.state_dict(),
                                  os.path.join(self.save_path, 'best-max.pth'))

        dev_predict_sum_prob = self.inference_sum_prob(dev_output)
        dev_per_class_AP_sum_prob = utils.compute_weighted_AP(
            self.dev_target, dev_predict_sum_prob, self.dev_class_weight)
        dev_mAP_sum_prob = utils.compute_mAP(dev_per_class_AP_sum_prob,
                                             self.subclass_idx)
        if dev_mAP_sum_prob > self.best_dev_mAP_sum_prob:
            self.best_dev_mAP_sum_prob = dev_mAP_sum_prob
            utils.save_state_dict(
                self.state_dict(),
                os.path.join(self.save_path, 'best-sum_prob.pth'))

        dev_predict_sum_out = self.inference_sum_out(dev_output)
        dev_per_class_AP_sum_out = utils.compute_weighted_AP(
            self.dev_target, dev_predict_sum_out, self.dev_class_weight)
        dev_mAP_sum_out = utils.compute_mAP(dev_per_class_AP_sum_out,
                                            self.subclass_idx)
        if dev_mAP_sum_out > self.best_dev_mAP_sum_out:
            self.best_dev_mAP_sum_out = dev_mAP_sum_out
            utils.save_state_dict(
                self.state_dict(),
                os.path.join(self.save_path, 'best-sum_out.pth'))

        self.log_result(
            'Dev epoch', {
                'loss': dev_loss / len(self.dev_loader),
                'mAP_conditional': dev_mAP_conditional,
                'mAP_max': dev_mAP_max,
                'mAP_sum_prob': dev_mAP_sum_prob,
                'mAP_sum_out': dev_mAP_sum_out,
            }, self.epoch)

        duration = datetime.now() - start_time
        print(('Finish training epoch {}, dev mAP conditional: {}'
               'dev mAP max: {}, dev mAP sum prob: {}, '
               'dev mAP sum out: {}, time used: {}').format(
                   self.epoch, dev_mAP_conditional, dev_mAP_max,
                   dev_mAP_sum_prob, dev_mAP_sum_out, duration))