Beispiel #1
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

    def setup(self):
        args = self.args
        sub_dir = '{}_input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format(
            args.dataset, args.crop_size, args.wot, args.wtv, args.reg,
            args.num_of_iter_in_ot, args.norm_cood)

        self.save_dir = os.path.join(args.out_path, 'ckpts', sub_dir)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S')
        self.logger = log_utils.get_logger(
            os.path.join(self.save_dir, 'train-{:s}.log'.format(time_str)))
        log_utils.print_config(vars(args), self.logger)

        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            assert self.device_count == 1
            self.logger.info('Using {} gpus'.format(self.device_count))
        else:
            raise Exception("Gpu is not available")

        dataset_name = args.dataset.lower()
        if dataset_name == 'qnrf':
            from datasets.crowd import Crowd_qnrf as Crowd
        elif dataset_name == 'nwpu':
            from datasets.crowd import Crowd_nwpu as Crowd
        elif dataset_name == 'sha':
            from datasets.crowd import Crowd_sh as Crowd
        elif dataset_name == 'shb':
            from datasets.crowd import Crowd_sh as Crowd
        else:
            raise NotImplementedError

        downsample_ratio = 8
        self.datasets = {
            'train':
            Crowd(os.path.join(args.data_path,
                               DATASET_PATHS[dataset_name]["train_path"]),
                  crop_size=args.crop_size,
                  downsample_ratio=downsample_ratio,
                  method='train'),
            'val':
            Crowd(os.path.join(args.data_path,
                               DATASET_PATHS[dataset_name]["val_path"]),
                  crop_size=args.crop_size,
                  downsample_ratio=downsample_ratio,
                  method='val')
        }

        self.dataloaders = {
            x: DataLoader(self.datasets[x],
                          collate_fn=(train_collate
                                      if x == 'train' else default_collate),
                          batch_size=(args.batch_size if x == 'train' else 1),
                          shuffle=(True if x == 'train' else False),
                          num_workers=args.num_workers * self.device_count,
                          pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        self.model = vgg19()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            self.logger.info('loading pretrained model from ' + args.resume)
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume,
                                                      self.device))
        else:
            self.logger.info('random initialization')

        self.ot_loss = OT_Loss(args.crop_size, downsample_ratio,
                               args.norm_cood, self.device,
                               args.num_of_iter_in_ot, args.reg)
        self.tv_loss = nn.L1Loss(reduction='none').to(self.device)
        self.mse = nn.MSELoss().to(self.device)
        self.mae = nn.L1Loss().to(self.device)
        self.save_list = Save_Handle(max_num=1)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0

    def train(self):
        """training process"""
        args = self.args
        for epoch in range(self.start_epoch, args.max_epoch + 1):
            self.logger.info('-' * 5 +
                             'Epoch {}/{}'.format(epoch, args.max_epoch) +
                             '-' * 5)
            self.epoch = epoch
            self.train_eopch()
            if epoch % args.val_epoch == 0 and epoch >= args.val_start:
                self.val_epoch()

    def train_eopch(self):
        epoch_ot_loss = AverageMeter()
        epoch_ot_obj_value = AverageMeter()
        epoch_wd = AverageMeter()
        epoch_count_loss = AverageMeter()
        epoch_tv_loss = AverageMeter()
        epoch_loss = AverageMeter()
        epoch_mae = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        self.model.train()  # Set model to training mode

        for step, (inputs, points, st_sizes,
                   gt_discrete) in enumerate(self.dataloaders['train']):
            inputs = inputs.to(self.device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(self.device) for p in points]
            gt_discrete = gt_discrete.to(self.device)
            N = inputs.size(0)

            with torch.set_grad_enabled(True):
                outputs, outputs_normed = self.model(inputs)
                # Compute OT loss.
                ot_loss, wd, ot_obj_value = self.ot_loss(
                    outputs_normed, outputs, points)
                ot_loss = ot_loss * self.args.wot
                ot_obj_value = ot_obj_value * self.args.wot
                epoch_ot_loss.update(ot_loss.item(), N)
                epoch_ot_obj_value.update(ot_obj_value.item(), N)
                epoch_wd.update(wd, N)

                # Compute counting loss.
                count_loss = self.mae(
                    outputs.sum(1).sum(1).sum(1),
                    torch.from_numpy(gd_count).float().to(self.device))
                epoch_count_loss.update(count_loss.item(), N)

                # Compute TV loss.
                gd_count_tensor = torch.from_numpy(gd_count).float().to(
                    self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3)
                gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
                tv_loss = (self.tv_loss(
                    outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1) *
                           torch.from_numpy(gd_count).float().to(
                               self.device)).mean(0) * self.args.wtv
                epoch_tv_loss.update(tv_loss.item(), N)

                loss = ot_loss + count_loss + tv_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                pred_count = torch.sum(outputs.view(N, -1),
                                       dim=1).detach().cpu().numpy()
                pred_err = pred_count - gd_count
                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(pred_err * pred_err), N)
                epoch_mae.update(np.mean(abs(pred_err)), N)

        self.logger.info(
            'Epoch {} Train, Loss: {:.2f}, OT Loss: {:.2e}, Wass Distance: {:.2f}, OT obj value: {:.2f}, '
            'Count Loss: {:.2f}, TV Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
            .format(self.epoch, epoch_loss.get_avg(), epoch_ot_loss.get_avg(),
                    epoch_wd.get_avg(), epoch_ot_obj_value.get_avg(),
                    epoch_count_loss.get_avg(), epoch_tv_loss.get_avg(),
                    np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
                    time.time() - epoch_start))
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir,
                                 '{}_ckpt.tar'.format(self.epoch))
        torch.save(
            {
                'epoch': self.epoch,
                'optimizer_state_dict': self.optimizer.state_dict(),
                'model_state_dict': model_state_dic
            }, save_path)
        self.save_list.append(save_path)

    def val_epoch(self):
        args = self.args
        epoch_start = time.time()
        self.model.eval()  # Set model to evaluate mode
        epoch_res = []
        for inputs, count, name in self.dataloaders['val']:
            inputs = inputs.to(self.device)
            assert inputs.size(
                0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs, _ = self.model(inputs)
                res = count[0].item() - torch.sum(outputs).item()
                epoch_res.append(res)

        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))
        self.logger.info(
            'Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'.format(
                self.epoch, mse, mae,
                time.time() - epoch_start))

        model_state_dic = self.model.state_dict()
        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
            self.best_mse = mse
            self.best_mae = mae
            self.logger.info(
                "save best mse {:.2f} mae {:.2f} model epoch {}".format(
                    self.best_mse, self.best_mae, self.epoch))
            torch.save(
                model_state_dic,
                os.path.join(self.save_dir,
                             'best_model_{}.pth'.format(self.best_count)))
            self.best_count += 1
Beispiel #2
0
class Trainer(object):
    def __init__(self, args):
        self.args = args

    def setup(self):
        args = self.args
        sub_dir = 'input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format(
            args.crop_size, args.wot, args.wtv, args.reg,
            args.num_of_iter_in_ot, args.norm_cood)

        self.save_dir = os.path.join('ckpts', sub_dir)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S')
        self.logger = log_utils.get_logger(
            os.path.join(self.save_dir, 'train-{:s}.log'.format(time_str)))
        log_utils.print_config(vars(args), self.logger)

        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            assert self.device_count == 1
            self.logger.info('using {} gpus'.format(self.device_count))
        else:
            raise Exception("gpu is not available")

        downsample_ratio = 8
        if args.dataset.lower() == 'qnrf':
            self.datasets = {
                x: Crowd_qnrf(os.path.join(args.data_dir, x), args.crop_size,
                              downsample_ratio, x)
                for x in ['train', 'val']
            }
        elif args.dataset.lower() == 'nwpu':
            self.datasets = {
                x: Crowd_nwpu(os.path.join(args.data_dir, x), args.crop_size,
                              downsample_ratio, x)
                for x in ['train', 'val']
            }
        elif args.dataset.lower() == 'sha' or args.dataset.lower() == 'shb':
            self.datasets = {
                'train':
                Crowd_sh(os.path.join(args.data_dir, 'train_data'),
                         args.crop_size, downsample_ratio, 'train'),
                'val':
                Crowd_sh(os.path.join(args.data_dir, 'test_data'),
                         args.crop_size, downsample_ratio, 'val'),
            }
        else:
            raise NotImplementedError

        self.dataloaders = {
            x: DataLoader(self.datasets[x],
                          collate_fn=(train_collate
                                      if x == 'train' else default_collate),
                          batch_size=(args.batch_size if x == 'train' else 1),
                          shuffle=(True if x == 'train' else False),
                          num_workers=args.num_workers * self.device_count,
                          pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        #self.model = vgg19()
        self.model = TR_CC()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=args.lr,
                                    weight_decay=args.weight_decay)

        self.start_epoch = 0
        if args.resume:
            self.logger.info('loading pretrained model from ' + args.resume)
            suf = args.resume.rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(args.resume, self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
            elif suf == 'pth':
                self.model.load_state_dict(torch.load(args.resume,
                                                      self.device))
        else:
            self.logger.info('random initialization')

        self.ot_loss = OT_Loss(args.crop_size, downsample_ratio,
                               args.norm_cood, self.device,
                               args.num_of_iter_in_ot, args.reg)
        self.tv_loss = nn.L1Loss(reduction='none').to(self.device)
        self.mse = nn.MSELoss().to(self.device)
        self.mae = nn.L1Loss().to(self.device)
        self.save_list = Save_Handle(max_num=1)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0

    def train(self):
        """training process"""
        args = self.args
        for epoch in range(self.start_epoch, args.max_epoch + 1):
            self.logger.info('-' * 5 +
                             'Epoch {}/{}'.format(epoch, args.max_epoch) +
                             '-' * 5)
            self.epoch = epoch
            self.train_eopch()
            if epoch % args.val_epoch == 0 and epoch >= args.val_start:
                self.val_epoch()

    def train_eopch(self):
        epoch_ot_loss = AverageMeter()
        epoch_ot_obj_value = AverageMeter()
        epoch_wd = AverageMeter()
        epoch_count_loss = AverageMeter()
        epoch_tv_loss = AverageMeter()
        epoch_loss = AverageMeter()
        epoch_mae = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        self.model.train()  # Set model to training mode

        for step, (inputs, points, st_sizes,
                   gt_discrete) in enumerate(self.dataloaders['train']):
            inputs = inputs.to(self.device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(self.device) for p in points]
            gt_discrete = gt_discrete.to(self.device)
            N = inputs.size(0)

            with torch.set_grad_enabled(True):
                outputs, outputs_normed = self.model(inputs)
                # Compute OT loss.
                ot_loss, wd, ot_obj_value = self.ot_loss(
                    outputs_normed, outputs, points)
                ot_loss = ot_loss * self.args.wot
                ot_obj_value = ot_obj_value * self.args.wot
                epoch_ot_loss.update(ot_loss.item(), N)
                epoch_ot_obj_value.update(ot_obj_value.item(), N)
                epoch_wd.update(wd, N)

                # Compute counting loss.
                count_loss = self.mae(
                    outputs.sum(1).sum(1).sum(1),
                    torch.from_numpy(gd_count).float().to(self.device))
                epoch_count_loss.update(count_loss.item(), N)

                # Compute TV loss.
                gd_count_tensor = torch.from_numpy(gd_count).float().to(
                    self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3)
                gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
                tv_loss = (self.tv_loss(
                    outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1) *
                           torch.from_numpy(gd_count).float().to(
                               self.device)).mean(0) * self.args.wtv
                epoch_tv_loss.update(tv_loss.item(), N)

                loss = ot_loss + count_loss + tv_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                pred_count = torch.sum(outputs.view(N, -1),
                                       dim=1).detach().cpu().numpy()
                pred_err = pred_count - gd_count
                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(pred_err * pred_err), N)
                epoch_mae.update(np.mean(abs(pred_err)), N)

        self.logger.info(
            'Epoch {} Train, Loss: {:.2f}, OT Loss: {:.2e}, Wass Distance: {:.2f}, OT obj value: {:.2f}, '
            'Count Loss: {:.2f}, TV Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
            .format(self.epoch, epoch_loss.get_avg(), epoch_ot_loss.get_avg(),
                    epoch_wd.get_avg(), epoch_ot_obj_value.get_avg(),
                    epoch_count_loss.get_avg(), epoch_tv_loss.get_avg(),
                    np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
                    time.time() - epoch_start))
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir,
                                 '{}_ckpt.tar'.format(self.epoch))
        torch.save(
            {
                'epoch': self.epoch,
                'optimizer_state_dict': self.optimizer.state_dict(),
                'model_state_dict': model_state_dic
            }, save_path)
        self.save_list.append(save_path)

    def val_epoch(self):
        args = self.args
        counter_dir = os.path.join(
            args.data_dir, 'test_data',
            'base_dir_metric_{}'.format(args.counter_type))
        if not os.path.exists(counter_dir):
            os.makedirs(counter_dir)
        epoch_start = time.time()
        self.model.eval()  # Set model to evaluate mode
        epoch_res = []
        for inputs, count, name in self.dataloaders['val']:
            inputs = inputs.to(self.device)
            assert inputs.size(
                0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs, _ = self.model(inputs)
                res = count[0].item() - torch.sum(outputs).item()
                epoch_res.append(res)

        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))
        self.logger.info(
            'Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'.format(
                self.epoch, mse, mae,
                time.time() - epoch_start))

        model_state_dic = self.model.state_dict()
        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
            self.best_mse = mse
            self.best_mae = mae
            self.logger.info(
                "save best mse {:.2f} mae {:.2f} model epoch {}".format(
                    self.best_mse, self.best_mae, self.epoch))
            torch.save(
                model_state_dic,
                os.path.join(self.save_dir,
                             'best_model_{}.pth'.format(self.best_count)))
            self.best_count += 1
            os.environ['CUDA_VISIBLE_DEVICES'] = '0'
            device = torch.device('cuda')
            part_B_train = os.path.join(args.data_dir, 'train_data', 'images')
            part_B_train = part_B_train.replace(
                '{}/train_data'.format(args.counter_type), 'train_data')
            part_B_test = os.path.join(args.data_dir, 'test_data', 'images')
            part_B_test = part_B_test.replace(
                '{}/test_data'.format(args.counter_type), 'test_data')
            model_path = os.path.join(
                self.save_dir, 'best_model_{}.pth'.format(self.best_count - 1))
            model = vgg19()
            model.to(device)
            model.load_state_dict(torch.load(model_path, device))
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])
            path_sets_B = [part_B_test]
            img_paths_B = []
            for path in path_sets_B:
                for img_path in glob.glob(os.path.join(path, '*.png')):
                    img_paths_B.append(img_path)
            number = 0
            image_errs_temp = []
            for img_path in tqdm(img_paths_B):
                #for k in xrange(len(img_paths_B)):
                for i in range(0, 3):
                    for j in range(0, 3):
                        image_path = img_path.replace(
                            'test_data',
                            '{}/test_data'.format(args.counter_type)).replace(
                                '.png', '_{}_{}.png'.format(i, j))
                        name = os.path.basename(image_path).split('.')[0]
                        mat_path = image_path.replace('.png', '.mat').replace(
                            'images',
                            'ground-truth').replace(name, 'GT_{}'.format(name))
                        mat = io.loadmat(mat_path)
                        #          dataloader = torch.utils.data.DataLoader('sha', 1, shuffle=False,num_workers=1, pin_memory=True)
                        image_errs = []
                        img = transform(
                            Image.open(image_path).convert('RGB')).cuda()
                        inputs = img.unsqueeze(0)

                        #assert inputs.size(0) == 1, 'the batch size should equal to 1'
                        with torch.set_grad_enabled(False):
                            outputs, _ = model(inputs)
                        img_err = abs(mat["image_info"][0, 0][0, 0][1] -
                                      torch.sum(outputs).item())
                        img_err = np.squeeze(img_err)
                        print(image_path, img_err)
                        image_errs_temp.append(img_err)

                image_errs = np.reshape(image_errs_temp, (3, 3))

                with open(
                        img_path.replace(
                            'test_data/images',
                            '{}/test_data/base_dir_metric_{}'.format(
                                args.counter_type,
                                args.counter_type)).replace('.png', '.npy'),
                        'wb') as f:
                    np.save(f, image_errs)
                image_errs_temp.clear()
class Trainer(object):
    def __init__(self, args, datargs):
        self.train_args = args
        self.datargs = datargs

    def setup(self):
        train_args = self.train_args
        datargs = self.datargs
        sub_dir = 'input-{}_wot-{}_wtv-{}_reg-{}_nIter-{}_normCood-{}'.format(
            train_args['crop_size'], train_args['wot'], train_args['wtv'],
            train_args['reg'], train_args['num_of_iter_in_ot'],
            train_args['norm_cood'])

        time_str = datetime.strftime(datetime.now(), '%m%d-%H%M%S')
        self.save_dir = os.path.join(train_args['out_path'], 'ckpts',
                                     train_args['conf_name'],
                                     train_args['dataset'], sub_dir, time_str)
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        log_dir = os.path.join(train_args['out_path'], 'runs',
                               train_args['dataset'], train_args['conf_name'],
                               time_str)
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        # TODO: Verify args
        self.logger = SummaryWriter(log_dir)
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
            self.device_count = torch.cuda.device_count()
            assert self.device_count == 1
        else:
            raise Exception("Gpu is not available")

        dataset_name = train_args['dataset'].lower()
        if dataset_name == 'qnrf':
            from datasets.crowd import Crowd_qnrf as Crowd
        elif dataset_name == 'nwpu':
            from datasets.crowd import Crowd_nwpu as Crowd
        elif dataset_name == 'sha' or dataset_name == 'shb':
            from datasets.crowd import Crowd_sh as Crowd
        elif dataset_name[:3] == 'ucf':
            from datasets.crowd import Crowd_ucf as Crowd
        else:
            raise NotImplementedError
        if dataset_name == 'sha' or dataset_name == 'shb':
            downsample_ratio = train_args['downsample_ratio']
            train_val = Crowd(os.path.join(datargs['data_path'],
                                           datargs["train_path"]),
                              crop_size=train_args['crop_size'],
                              downsample_ratio=downsample_ratio,
                              method='train')
            if dataset_name == 'sha':
                train_set, val = random_split(
                    train_val, [280, 20],
                    generator=torch.Generator().manual_seed(42))
                val_set = ValSubset(val)
            else:
                train_set, val = random_split(
                    train_val, [380, 20],
                    generator=torch.Generator().manual_seed(42))
                val_set = ValSubset(val)
            self.datasets = {'train': train_set, 'val': val_set}
        else:
            downsample_ratio = train_args['downsample_ratio']
            self.datasets = {
                'train':
                Crowd(os.path.join(datargs['data_path'],
                                   datargs["train_path"]),
                      crop_size=train_args['crop_size'],
                      downsample_ratio=downsample_ratio,
                      method='train'),
                'val':
                Crowd(os.path.join(datargs['data_path'], datargs["val_path"]),
                      crop_size=train_args['crop_size'],
                      downsample_ratio=downsample_ratio,
                      method='val')
            }
        self.dataloaders = {
            x: DataLoader(
                self.datasets[x],
                collate_fn=(train_collate
                            if x == 'train' else default_collate),
                batch_size=(train_args['batch_size'] if x == 'train' else 1),
                shuffle=(True if x == 'train' else False),
                num_workers=train_args['num_workers'] * self.device_count,
                pin_memory=(True if x == 'train' else False))
            for x in ['train', 'val']
        }
        self.model = vgg16dres(map_location=self.device)
        self.model.to(self.device)
        # for p in self.model.features.parameters():
        #     p.requires_grad = True
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=train_args['lr'],
                                    weight_decay=train_args['weight_decay'],
                                    amsgrad=False)
        # for _, p in zip(range(10000), next(self.model.children()).children()):
        #     p.requires_grad = False
        #     print("freeze: ", p)
        # print(self.optimizer.param_groups[0])
        self.start_epoch = 0
        self.ot_loss = OT_Loss(train_args['crop_size'], downsample_ratio,
                               train_args['norm_cood'], self.device,
                               self.logger, train_args['num_of_iter_in_ot'],
                               train_args['reg'])
        self.tv_loss = nn.L1Loss(reduction='none').to(self.device)
        self.mse = nn.MSELoss().to(self.device)
        self.mae = nn.L1Loss().to(self.device)
        self.save_list = Save_Handle(max_num=1)
        self.best_mae = np.inf
        self.best_mse = np.inf
        self.best_count = 0
        if train_args['resume']:
            self.logger.add_text(
                'log/train',
                'loading pretrained model from ' + train_args['resume'], 0)
            suf = train_args['resume'].rsplit('.', 1)[-1]
            if suf == 'tar':
                checkpoint = torch.load(train_args['resume'], self.device)
                self.model.load_state_dict(checkpoint['model_state_dict'])
                self.optimizer.load_state_dict(
                    checkpoint['optimizer_state_dict'])
                self.start_epoch = checkpoint['epoch'] + 1
                self.best_count = checkpoint['best_count']
                self.best_mae = checkpoint['best_mae']
                self.best_mse = checkpoint['best_mse']
                print(self.best_mae, self.best_mse, self.best_count)
            elif suf == 'pth':
                self.model.load_state_dict(
                    torch.load(train_args['resume'], self.device))
        else:
            self.logger.add_text('log/train', 'random initialization', 0)
        img_cnts = {
            'val_image_count': len(self.dataloaders['val']),
            'train_image_count': len(self.dataloaders['train'])
        }
        self.logger.add_hparams({
            **self.train_args,
            **img_cnts
        }, {
            'best_mse': np.inf,
            'best_mae': np.inf,
            'best_count': 0
        },
                                run_name='hparams')

    def train(self):
        """training process"""
        train_args = self.train_args
        for epoch in range(self.start_epoch, train_args['max_epoch'] + 1):
            print(
                'log/train', '-' * 5 +
                'Epoch {}/{}'.format(epoch, train_args['max_epoch']) + '-' * 5)
            self.epoch = epoch
            self.train_eopch()
            if epoch % train_args['val_epoch'] == 0 and epoch >= train_args[
                    'val_start']:
                self.val_epoch()

    def train_eopch(self):
        epoch_ot_loss = AverageMeter()
        epoch_ot_obj_value = AverageMeter()
        epoch_wd = AverageMeter()
        epoch_count_loss = AverageMeter()
        epoch_tv_loss = AverageMeter()
        epoch_loss = AverageMeter()
        epoch_mae = AverageMeter()
        epoch_mse = AverageMeter()
        epoch_start = time.time()
        self.model.train()  # Set model to training mode

        for step, (inputs, points, st_sizes,
                   gt_discrete) in enumerate(self.dataloaders['train']):
            inputs = inputs.to(self.device)
            gd_count = np.array([len(p) for p in points], dtype=np.float32)
            points = [p.to(self.device) for p in points]
            gt_discrete = gt_discrete.to(self.device)
            N = inputs.size(0)
            wot = self.train_args['wot']
            wtv = self.train_args['wtv']
            drop = random() >= 0.5
            with torch.set_grad_enabled(True):
                if drop:
                    self.model.dl1.eval()
                    self.model.dl1a.eval()
                    self.model.dl2.eval()
                    self.model.dl2a.eval()
                    self.model.dl3.eval()
                    self.model.dl3a.eval()

                outputs, outputs_normed = self.model(inputs)
                # Compute OT loss.
                ot_loss, wd, ot_obj_value = self.ot_loss(
                    outputs_normed, outputs, points)
                ot_loss = ot_loss * wot
                ot_obj_value = ot_obj_value * wot
                epoch_ot_loss.update(ot_loss.item(), N)
                epoch_ot_obj_value.update(ot_obj_value.item(), N)
                epoch_wd.update(wd, N)

                # Compute counting loss.
                count_loss = self.mae(
                    outputs.sum(1).sum(1).sum(1),
                    torch.from_numpy(gd_count).float().to(self.device))
                epoch_count_loss.update(count_loss.item(), N)

                # Compute TV loss.
                gd_count_tensor = torch.from_numpy(gd_count).float().to(
                    self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3)
                gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
                tv_loss = (self.tv_loss(
                    outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1) *
                           torch.from_numpy(gd_count).float().to(
                               self.device)).mean(0) * wtv
                epoch_tv_loss.update(tv_loss.item(), N)

                loss = ot_loss + count_loss + tv_loss

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                pred_count = torch.sum(outputs.view(N, -1),
                                       dim=1).detach().cpu().numpy()
                pred_err = pred_count - gd_count
                epoch_loss.update(loss.item(), N)
                epoch_mse.update(np.mean(pred_err * pred_err), N)
                epoch_mae.update(np.mean(abs(pred_err)), N)
                if drop:
                    self.model.dl1.train()
                    self.model.dl1a.train()
                    self.model.dl2.train()
                    self.model.dl2a.train()
                    self.model.dl3.train()
                    self.model.dl3a.train()
        mae = epoch_mae.get_avg()
        mse = np.sqrt(epoch_mse.get_avg())
        self.logger.add_scalar('loss/train', epoch_loss.get_avg(), self.epoch)
        self.logger.add_scalar('mse/train', mse, self.epoch)
        self.logger.add_scalar('mae/train', mae, self.epoch)
        self.logger.add_scalar('ot_loss/train', epoch_ot_loss.get_avg(),
                               self.epoch)
        self.logger.add_scalar('wd/train', epoch_wd.get_avg(), self.epoch)
        self.logger.add_scalar('ot_obj_val/train',
                               epoch_ot_obj_value.get_avg(), self.epoch)
        self.logger.add_scalar('count_loss/train', epoch_count_loss.get_avg(),
                               self.epoch)
        self.logger.add_scalar('tv_loss/train', epoch_tv_loss.get_avg(),
                               self.epoch)
        self.logger.add_scalar('time_cost/train',
                               time.time() - epoch_start, self.epoch)
        print(
            'log/train',
            'Epoch {} Train, Loss: {:.2f}, OT Loss: {:.2e}, Wass Distance: {:.2f}, OT obj value: {:.2f}, '
            'Count Loss: {:.2f}, TV Loss: {:.2f}, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'
            .format(self.epoch, epoch_loss.get_avg(), epoch_ot_loss.get_avg(),
                    epoch_wd.get_avg(), epoch_ot_obj_value.get_avg(),
                    epoch_count_loss.get_avg(), epoch_tv_loss.get_avg(),
                    np.sqrt(epoch_mse.get_avg()), epoch_mae.get_avg(),
                    time.time() - epoch_start), self.epoch)
        model_state_dic = self.model.state_dict()
        save_path = os.path.join(self.save_dir, str(self.epoch) + '_ckpt.tar')
        # TODO: Reset best counts option

        torch.save(
            {
                'epoch': self.epoch,
                'best_mae': self.best_mae,
                'best_mse': self.best_mse,
                'best_count': self.best_count,
                'optimizer_state_dict': self.optimizer.state_dict(),
                'model_state_dict': model_state_dic
            }, save_path)
        self.save_list.append(save_path)

    def val_epoch(self):
        epoch_start = time.time()
        self.model.eval()  # Set model to evaluate mode
        epoch_res = []
        for inputs, count, name in self.dataloaders['val']:
            inputs = inputs.to(self.device)
            assert inputs.size(
                0) == 1, 'the batch size should equal to 1 in validation mode'
            with torch.set_grad_enabled(False):
                outputs, _ = self.model(inputs)
                res = count[0].item() - torch.sum(outputs).item()
                epoch_res.append(res)

        epoch_res = np.array(epoch_res)
        mse = np.sqrt(np.mean(np.square(epoch_res)))
        mae = np.mean(np.abs(epoch_res))
        self.logger.add_scalar('mse/val', mse, self.epoch)
        self.logger.add_scalar('mae/val', mae, self.epoch)
        self.logger.add_scalar('time_cost/val',
                               time.time() - epoch_start, self.epoch)
        print(
            'log/val',
            'Epoch {} Val, MSE: {:.2f} MAE: {:.2f}, Cost {:.1f} sec'.format(
                self.epoch, mse, mae,
                time.time() - epoch_start))

        model_state_dic = self.model.state_dict()
        if (2.0 * mse + mae) < (2.0 * self.best_mse + self.best_mae):
            self.best_mse = mse
            self.best_mae = mae
            filename = 'best_model_{:.2f}_{:.2f}_{}.pth'.format(
                self.best_mse, self.best_mae, self.best_count)
            txt = "save best mse {:.2f} mae {:.2f} model epoch {}".format(
                self.best_mse, self.best_mae, self.epoch)
            print(txt)
            self.logger.add_text('log/val', txt, self.best_count)
            best_metrics = {
                'best_mse': mse,
                'best_mae': mae,
                'best_count': self.best_count
            }
            for k, v in best_metrics.items():
                self.logger.add_scalar(k + '/val', v, self.epoch)
            self.logger.add_hparams({}, {
                'best_mse': mse,
                'best_mae': mae,
                'best_count': self.best_count
            },
                                    run_name='hparams')
            torch.save(model_state_dic, os.path.join(self.save_dir, filename))
            self.best_count += 1
        elif mse < self.best_mse or mae < self.best_mae:
            filename = 'best_model_{}_{}.pth'.format(mse, mae)
            txt = "save best mse {:.2f} mae {:.2f} model epoch {}**".format(
                mse, mae, self.epoch)
            print(txt)

            torch.save(model_state_dic, os.path.join(self.save_dir, filename))