Example #1
0
class CC:
    # class for crowd counting with SANet
    def __init__(self, args):
        # setup the net
        torch.device("cuda")
        torch.cuda.set_device(0)
        torch.backends.cudnn.enabled = True
        self.net = CrowdCounter(cfg.GPU_ID, cfg.NET, torch.nn.MSELoss(),
                                pytorch_ssim.SSIM(window_size=11))
        self.net.load_state_dict(torch.load(args.model))
        self.net.cuda()
        self.net.eval()
        # IMPORTANT: when changing model, make sure you change config for the dataset it was trained in to make sure mean, std, h and w are correct
        self.mean, self.std = cfg.MEAN_STD
        self.h, self.w = cfg.STD_SIZE
        print(self.net)  # print net structure
        # video capture object
        self.vid = VideoCapture(args.video)

    def run(self, thread_queue):
        # thread function to run
        global go_thread
        while go_thread:
            ret, frame = self.vid.get_frame()
            if ret:
                # resize, convert to pytorch tensor, normalize
                if frame.shape[0] != self.h or frame.shape[1] != self.w:
                    frame = cv2.resize(frame, (self.w, self.h),
                                       interpolation=cv2.INTER_CUBIC)
                tensor = torchvision.transforms.ToTensor()(frame)
                tensor = torchvision.transforms.functional.normalize(
                    tensor, mean=self.mean, std=self.std)
                # forward propagation
                with torch.no_grad():
                    tensor = torch.autograd.Variable(
                        tensor[None, :, :, :]).cuda()
                    pred_map = self.net.test_forward(tensor)
                pred_map = pred_map.cpu().data.numpy()[0, 0, :, :]
                # generate grayscale image for overlap and put results in thread queue
                gray = np.repeat(cv2.cvtColor(frame,
                                              cv2.COLOR_RGB2GRAY)[:, :,
                                                                  np.newaxis],
                                 3,
                                 axis=2)
                thread_queue.put((gray, pred_map))
Example #2
0
class CC:
    """ Class for crowd counting with SANet (threaded) """
    def __init__(self, model):
        # Setup the net
        torch.device("cuda")
        torch.cuda.set_device(0)
        torch.backends.cudnn.enabled = True
        self.net = CrowdCounter(cfg.GPU_ID, cfg.NET, torch.nn.MSELoss(),
                                pytorch_ssim.SSIM(window_size=11))
        self.net.load_state_dict(torch.load(model))
        self.net.cuda()
        self.net.eval()
        # IMPORTANT: when changing model, make sure you change config for the dataset it was trained in to make sure mean and std are correct
        self.mean, self.std = cfg.MEAN_STD
        # Print net structure
        # print(self.net)

    def run(self, stop_event, cam_queue, net_queue):
        while not stop_event.is_set():
            try:
                frame = cam_queue.get(0)
                # convert to pytorch tensor, normalize
                tensor = torchvision.transforms.ToTensor()(frame)
                tensor = torchvision.transforms.functional.normalize(
                    tensor, mean=self.mean, std=self.std)
                # padding to multiples of 8
                xr = (8 - tensor.shape[2] % 8) % 8
                yr = (8 - tensor.shape[1] % 8) % 8
                tensor = torch.nn.functional.pad(tensor, (xr, xr, yr, yr),
                                                 'constant', 0)
                # forward propagation
                with torch.no_grad():
                    tensor = torch.autograd.Variable(
                        tensor[None, :, :, :]).cuda()
                    pred_map = self.net.test_forward(tensor)
                pred_map = pred_map.cpu().data.numpy()[0, 0, :, :]
                gray = np.repeat(np.array(frame.convert("L"))[:, :,
                                                              np.newaxis],
                                 3,
                                 axis=2)
                gray = gray.astype(np.float32) / 255
                net_queue.put((gray, pred_map))
            except queue.Empty:
                pass
Example #3
0
class Trainer():
    def __init__(self, dataloader, cfg_data, pwd):

        self.cfg_data = cfg_data

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd

        self.net_name = cfg.NET

        if self.net_name in ['SANet']:
            loss_1_fn = nn.MSELoss()
            from misc import pytorch_ssim
            loss_2_fn = pytorch_ssim.SSIM(window_size=11)

        self.net = CrowdCounter(cfg.GPU_ID, self.net_name, loss_1_fn,
                                loss_2_fn).cuda()
        self.optimizer = optim.Adam(self.net.CCN.parameters(),
                                    lr=cfg.LR,
                                    weight_decay=1e-4)
        # self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
        self.scheduler = StepLR(self.optimizer,
                                step_size=cfg.NUM_EPOCH_LR_DECAY,
                                gamma=cfg.LR_DECAY)

        self.train_record = {
            'best_mae': 1e20,
            'best_mse': 1e20,
            'best_model_name': ''
        }
        self.timer = {
            'iter time': Timer(),
            'train time': Timer(),
            'val time': Timer()
        }

        self.epoch = 0
        self.i_tb = 0

        if cfg.PRE_GCC:
            self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))

        self.train_loader, self.val_loader, self.restore_transform = dataloader(
        )

        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

        self.writer, self.log_txt = logger(self.exp_path,
                                           self.exp_name,
                                           self.pwd,
                                           'exp',
                                           resume=cfg.RESUME)

    def forward(self):

        # self.validate_V3()
        for epoch in range(self.epoch, cfg.MAX_EPOCH):
            self.epoch = epoch
            if epoch > cfg.LR_DECAY_START:
                self.scheduler.step()

            # training
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            print 'train time: {:.2f}s'.format(self.timer['train time'].diff)
            print '=' * 20

            # validation
            if epoch % cfg.VAL_FREQ == 0 or epoch > cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                if self.data_mode in ['SHHA', 'SHHB', 'QNRF', 'UCF50']:
                    self.validate_V1()
                elif self.data_mode is 'WE':
                    self.validate_V2()
                elif self.data_mode is 'GCC':
                    self.validate_V3()
                self.timer['val time'].toc(average=False)
                print 'val time: {:.2f}s'.format(self.timer['val time'].diff)

    def train(self):  # training for all datasets
        self.net.train()
        for i, data in enumerate(self.train_loader, 0):
            self.timer['iter time'].tic()
            img, gt_map = data
            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()

            self.optimizer.zero_grad()
            pred_map = self.net(img, gt_map)
            loss1, loss2 = self.net.loss
            loss = loss1 + loss2
            loss.backward()
            self.optimizer.step()

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
                self.writer.add_scalar('train_loss1', loss1.item(), self.i_tb)
                self.writer.add_scalar('train_loss2', loss2.item(), self.i_tb)
                self.timer['iter time'].toc(average=False)
                print '[ep %d][it %d][loss %.4f][lr %.4f][%.2fs]' % \
                        (self.epoch + 1, i + 1, loss.item(), self.optimizer.param_groups[0]['lr']*10000, self.timer['iter time'].diff)
                print '        [cnt: gt: %.1f pred: %.2f]' % (
                    gt_map[0].sum().data / self.cfg_data.LOG_PARA,
                    pred_map[0].sum().data / self.cfg_data.LOG_PARA)

    def validate_V1(self):  # validate_V1 for SHHA, SHHB, UCF-QNRF, UCF50

        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for vi, data in enumerate(self.val_loader, 0):
            # print vi
            img, gt_map = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                    pred_cnt = np.sum(pred_map[i_img]) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img]) / self.cfg_data.LOG_PARA

                    loss1, loss2 = self.net.loss
                    loss = loss1.item() + loss2.item()
                    losses.update(loss)
                    maes.update(abs(gt_count - pred_cnt))
                    mses.update((gt_count - pred_cnt) * (gt_count - pred_cnt))
                if vi == 0:
                    vis_results(self.exp_name, self.epoch, self.writer,
                                self.restore_transform, img, pred_map, gt_map)

        mae = maes.avg
        mse = np.sqrt(mses.avg)
        loss = losses.avg

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)
        print_summary(self.exp_name, [mae, mse, loss], self.train_record)

    def validate_V2(self):  # validate_V2 for WE

        self.net.eval()

        losses = AverageCategoryMeter(5)
        maes = AverageCategoryMeter(5)

        for i_sub, i_loader in enumerate(self.val_loader, 0):

            for vi, data in enumerate(i_loader, 0):
                img, gt_map = data

                with torch.no_grad():
                    img = Variable(img).cuda()
                    gt_map = Variable(gt_map).cuda()

                    pred_map = self.net.forward(img, gt_map)

                    pred_map = pred_map.data.cpu().numpy()
                    gt_map = gt_map.data.cpu().numpy()

                    for i_img in range(pred_map.shape[0]):
                        pred_cnt = np.sum(
                            pred_map[i_img]) / self.cfg_data.LOG_PARA
                        gt_count = np.sum(
                            gt_map[i_img]) / self.cfg_data.LOG_PARA

                        losses.update(self.net.loss.item(), i_sub)
                        maes.update(abs(gt_count - pred_cnt), i_sub)
                    if vi == 0:
                        vis_results(self.exp_name, self.epoch, self.writer,
                                    self.restore_transform, img, pred_map,
                                    gt_map)

        mae = np.average(maes.avg)
        loss = np.average(losses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, 0, loss],self.train_record,self.log_txt)
        print_summary(self.exp_name, [mae, 0, loss], self.train_record)

    def validate_V3(self):  # validate_V3 for GCC

        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        c_maes = {
            'level': AverageCategoryMeter(9),
            'time': AverageCategoryMeter(8),
            'weather': AverageCategoryMeter(7)
        }
        c_mses = {
            'level': AverageCategoryMeter(9),
            'time': AverageCategoryMeter(8),
            'weather': AverageCategoryMeter(7)
        }

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map, attributes_pt = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                    pred_cnt = np.sum(pred_map) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map) / self.cfg_data.LOG_PARA

                    s_mae = abs(gt_count - pred_cnt)
                    s_mse = (gt_count - pred_cnt) * (gt_count - pred_cnt)

                    loss1, loss2 = self.net.loss
                    loss = loss1.item() + loss2.item()
                    losses.update(loss)
                    maes.update(s_mae)
                    mses.update(s_mse)
                    attributes_pt = attributes_pt.squeeze()
                    c_maes['level'].update(s_mae, attributes_pt[0])
                    c_mses['level'].update(s_mse, attributes_pt[0])
                    c_maes['time'].update(s_mae, attributes_pt[1] / 3)
                    c_mses['time'].update(s_mse, attributes_pt[1] / 3)
                    c_maes['weather'].update(s_mae, attributes_pt[2])
                    c_mses['weather'].update(s_mse, attributes_pt[2])

                if vi == 0:
                    vis_results(self.exp_name, self.epoch, self.writer,
                                self.restore_transform, img, pred_map, gt_map)

        loss = losses.avg
        mae = maes.avg
        mse = np.sqrt(mses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record,self.log_txt)

        print_GCC_summary(self.log_txt, self.epoch, [mae, mse, loss],
                          self.train_record, c_maes, c_mses)
Example #4
0
        choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
    parser.add_argument("--mode",
                        default='Add',
                        type=str,
                        help="Blend mode to use.",
                        choices=['Add', 'Lighten', 'Mix', 'Multiply'])
    args = parser.parse_args()
    print("Crowd Counter Demo running. Press Q to quit.")

    # setup the model
    device = torch.device("cuda")
    torch.cuda.set_device(0)
    torch.backends.cudnn.enabled = True  # use cudnn?
    net = CrowdCounter(cfg.GPU_ID, cfg.NET, torch.nn.MSELoss(),
                       pytorch_ssim.SSIM(window_size=11))
    net.load_state_dict(torch.load(args.model))
    net.cuda()
    net.eval()
    mean, std = cfg.MEAN_STD

    # open the video stream / file
    cap = cv2.VideoCapture(args.video)
    while (cap.isOpened()):
        _, frame = cap.read()
        if frame is None:
            break
        # convert to pytorch tensor and normalize
        tensor = torchvision.transforms.ToTensor()(cv2.cvtColor(
            frame, cv2.COLOR_BGR2RGB))
        tensor = torchvision.transforms.functional.normalize(tensor,
                                                             mean=mean,
Example #5
0
def test(file_list, model_path):
    loss_1_fn = torch.nn.MSELoss()
    loss_2_fn = pytorch_ssim.SSIM(window_size=11)
    net = CrowdCounter(cfg.GPU_ID, cfg.NET, loss_1_fn, loss_2_fn)
    net.load_state_dict(torch.load(model_path))
    net.cuda()
    net.eval()

    f1 = plt.figure(1)

    gts = []
    preds = []

    for filename in file_list:
        print(filename, end=', ')
        imgname = dataRoot + '/img/' + filename
        filename_no_ext = filename.split('.')[0]

        denname = dataRoot + '/den/' + filename_no_ext + '.csv'

        den = pd.read_csv(denname, sep=',', header=None).values
        den = den.astype(np.float32, copy=False)

        img = Image.open(imgname)

        if img.mode == 'L':
            img = img.convert('RGB')

        img = img_transform(img)

        if slicing:
            xr = (8 - img.shape[2] % 8) % 8
            yr = (8 - img.shape[1] % 8) % 8
            img = torch.nn.functional.pad(img, (xr, xr, yr, yr), 'constant', 0)
            pred_maps = []
            x4 = img.shape[2]  # full image
            x1 = x4 // 2  # half image
            x2 = x1 // 2  # quarter image
            x3 = x1 + x2
            y4 = img.shape[1]
            y1 = y4 // 2
            y2 = y1 // 2
            y3 = y1 + y2
            img_list = [
                img[:, 0:y1, 0:x1], img[:, 0:y1, x2:x3], img[:, 0:y1, x1:x4],
                img[:, y2:y3, 0:x1], img[:, y2:y3, x2:x3], img[:, y2:y3,
                                                               x1:x4],
                img[:, y1:y4, 0:x1], img[:, y1:y4, x2:x3], img[:, y1:y4, x1:x4]
            ]

            for inputs in img_list:
                with torch.no_grad():
                    img = torch.autograd.Variable(inputs[None, :, :, :]).cuda()
                    pred_maps.append(net.test_forward(img))

            x3, x5 = int(x4 * 3 / 8), int(x4 * 5 / 8)
            y3, y5 = int(y4 * 3 / 8), int(y4 * 5 / 8)
            x32, x52, x51, x41 = x3 - x2, x5 - x2, x5 - x1, x4 - x1
            y32, y52, y51, y41 = y3 - y2, y5 - y2, y5 - y1, y4 - y1

            slice0 = pred_maps[0].cpu().data.numpy()[0, 0, 0:y3, 0:x3]
            slice1 = pred_maps[1].cpu().data.numpy()[0, 0, 0:y3, x32:x52]
            slice2 = pred_maps[2].cpu().data.numpy()[0, 0, 0:y3, x51:x41]
            slice3 = pred_maps[3].cpu().data.numpy()[0, 0, y32:y52, 0:x3]
            slice4 = pred_maps[4].cpu().data.numpy()[0, 0, y32:y52, x32:x52]
            slice5 = pred_maps[5].cpu().data.numpy()[0, 0, y32:y52, x51:x41]
            slice6 = pred_maps[6].cpu().data.numpy()[0, 0, y51:y41, 0:x3]
            slice7 = pred_maps[7].cpu().data.numpy()[0, 0, y51:y41, x32:x52]
            slice8 = pred_maps[8].cpu().data.numpy()[0, 0, y51:y41, x51:x41]

            pred_map = np.vstack((np.hstack(
                (slice0, slice1, slice2)), np.hstack((slice3, slice4, slice5)),
                                  np.hstack((slice6, slice7, slice8))))
            sio.savemat(exp_name + '/pred/' + filename_no_ext + '.mat',
                        {'data': pred_map / 100.})

        else:
            with torch.no_grad():
                img = torch.autograd.Variable(img[None, :, :, :]).cuda()
                pred_map = net.test_forward(img)
            sio.savemat(exp_name + '/pred/' + filename_no_ext + '.mat',
                        {'data': pred_map.squeeze().cpu().numpy() / 100.})
            pred_map = pred_map.cpu().data.numpy()[0, 0, :, :]

        pred = np.sum(pred_map) / 100.0
        preds.append(pred)

        gt = np.sum(den)
        gts.append(gt)
        sio.savemat(exp_name + '/gt/' + filename_no_ext + '.mat',
                    {'data': den})
        pred_map = pred_map / np.max(pred_map + 1e-20)
        den = den / np.max(den + 1e-20)

        if save_graphs:
            den_frame = plt.gca()
            plt.imshow(den, 'jet')
            den_frame.axes.get_yaxis().set_visible(False)
            den_frame.axes.get_xaxis().set_visible(False)
            den_frame.spines['top'].set_visible(False)
            den_frame.spines['bottom'].set_visible(False)
            den_frame.spines['left'].set_visible(False)
            den_frame.spines['right'].set_visible(False)
            plt.savefig(exp_name + '/' + filename_no_ext + '_gt_' +
                        str(int(gt)) + '.png',
                        bbox_inches='tight',
                        pad_inches=0,
                        dpi=150)
            plt.close()
            # sio.savemat(exp_name+'/'+filename_no_ext+'_gt_'+str(int(gt))+'.mat',{'data':den})

            pred_frame = plt.gca()
            plt.imshow(pred_map, 'jet')
            pred_frame.axes.get_yaxis().set_visible(False)
            pred_frame.axes.get_xaxis().set_visible(False)
            pred_frame.spines['top'].set_visible(False)
            pred_frame.spines['bottom'].set_visible(False)
            pred_frame.spines['left'].set_visible(False)
            pred_frame.spines['right'].set_visible(False)
            plt.savefig(exp_name + '/' + filename_no_ext + '_pred_' +
                        str(float(pred)) + '.png',
                        bbox_inches='tight',
                        pad_inches=0,
                        dpi=150)
            plt.close()
            # sio.savemat(exp_name+'/'+filename_no_ext+'_pred_'+str(float(pred))+'.mat',{'data':pred_map})

            diff = den - pred_map
            diff_frame = plt.gca()
            plt.imshow(diff, 'jet')
            plt.colorbar()
            diff_frame.axes.get_yaxis().set_visible(False)
            diff_frame.axes.get_xaxis().set_visible(False)
            diff_frame.spines['top'].set_visible(False)
            diff_frame.spines['bottom'].set_visible(False)
            diff_frame.spines['left'].set_visible(False)
            diff_frame.spines['right'].set_visible(False)
            plt.savefig(exp_name + '/' + filename_no_ext + '_diff.png',
                        bbox_inches='tight',
                        pad_inches=0,
                        dpi=150)
            plt.close()
            # sio.savemat(exp_name+'/diff/'+filename_no_ext+'_diff.mat',{'data':diff})
    preds = np.asarray(preds)
    gts = np.asarray(gts)
    print('\nMAE= ' + str(np.mean(np.abs(gts - preds))))
    print('MSE= ' + str(np.sqrt(np.mean((gts - preds)**2))))
Example #6
0
class Trainer():
    def __init__(self, dataloader, cfg_data, pwd):

        self.cfg_data = cfg_data

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH

        self.pwd = pwd

        self.net_name = cfg.NET

        if self.net_name in ['SANet']:
            loss_1_fn = nn.MSELoss()
            from misc import pytorch_ssim
            loss_2_fn = pytorch_ssim.SSIM(window_size=11)

        if 'OAI' in self.net_name:
            loss_1_fn = nn.SmoothL1Loss()
            from misc import pytorch_ssim
            loss_2_fn = pytorch_ssim.SSIM(window_size=11)

        self.net = CrowdCounter(cfg.GPU_ID, self.net_name, loss_1_fn,
                                loss_2_fn, cfg.PRE).cuda()
        if not cfg.FINETUNE:
            self.optimizer = optim.Adam(self.net.CCN.parameters(),
                                        lr=cfg.LR,
                                        weight_decay=1e-4)
            print('using ADAM')
        else:
            self.optimizer = optim.SGD(self.net.parameters(),
                                       cfg.LR,
                                       momentum=0.95,
                                       weight_decay=5e-4)
            print('using SGD')

        if cfg.LR_CHANGER == 'step':
            self.scheduler = StepLR(self.optimizer,
                                    step_size=cfg.NUM_EPOCH_LR_DECAY,
                                    gamma=cfg.LR_DECAY)
        elif cfg.LR_CHANGER == 'cosann':
            self.scheduler = CosineAnnealingLR(self.optimizer,
                                               2 * cfg.EPOCH_DIS,
                                               eta_min=5e-9,
                                               last_epoch=-1)
        elif cfg.LR_CHANGER == 'expotential':
            self.scheduler = ExponentialLR(self.optimizer, 0.95, last_epoch=-1)
        elif cfg.LR_CHANGER == 'rop':
            self.scheduler = ReduceLROnPlateau(self.optimizer,
                                               mode='min',
                                               verbose=True,
                                               eps=1e-10)

        self.train_record = {
            'best_mae': 1e20,
            'best_mse': 1e20,
            'best_model_name': ''
        }
        self.timer = {
            'iter time': Timer(),
            'train time': Timer(),
            'val time': Timer()
        }

        self.epoch = 0
        self.i_tb = 0

        if cfg.PRE_GCC:
            self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))

        self.train_loader, self.val_loader, self.restore_transform = dataloader(
        )
        cfg.PRINT_FREQ = min(len(self.train_loader), cfg.ITER_DIS)
        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            if not cfg.FINETUNE:
                self.optimizer.load_state_dict(latest_state['optimizer'])
                self.scheduler.load_state_dict(latest_state['scheduler'])
                self.epoch = latest_state['epoch'] + 1
                self.i_tb = latest_state['i_tb']
                self.train_record = latest_state['train_record']
                self.exp_path = latest_state['exp_path']
                self.exp_name = latest_state['exp_name']

        self.writer, self.log_txt = logger(self.exp_path,
                                           self.exp_name,
                                           self.pwd,
                                           'exp',
                                           resume=cfg.RESUME)

    def forward(self):

        # self.validate_V3()
        for epoch in range(self.epoch, cfg.MAX_EPOCH):

            # training
            self.timer['train time'].tic()
            print('=' * 20 + 'EPOCH %d' % (epoch + 1) + '=' * 30)

            print('### start train ###')
            self.train(epoch)
            self.timer['train time'].toc(average=False)

            print('train time: {:.2f}s'.format(self.timer['train time'].diff))
            print('*' * 20 + '\n')

            print('### start val ###')

            # validation
            # if epoch % cfg.VAL_FREQ == 0 or epoch > cfg.VAL_DENSE_START:
            self.timer['val time'].tic()
            if self.data_mode in ['SHHA', 'SHHB', 'QNRF', 'UCF50']:
                self.validate_V1(epoch)
            elif self.data_mode is 'WE':
                self.validate_V2()
            elif self.data_mode is 'GCC':
                self.validate_V3()
            self.timer['val time'].toc(average=False)
            print('val time: {:.2f}s'.format(self.timer['val time'].diff))
            print('=' * 58)
            print('\n')

    def train(self, epoch):  # training for all datasets
        self.net.train()
        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()
        smoothL1_losses = AverageMeter()
        ssim_losses = AverageMeter()

        for i, data in tqdm(enumerate(self.train_loader, 0)):

            self.timer['iter time'].tic()
            img, gt_map = data

            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()

            self.optimizer.zero_grad()
            pred_map = self.net(img, gt_map)
            loss1, loss2 = self.net.loss
            b, c, h, w = img.shape

            loss = torch.mul((loss1 + loss2), h * w / (h + w))
            pred_cnt = np.sum(
                pred_map.data.cpu().numpy()) / self.cfg_data.LOG_PARA
            gt_count = np.sum(
                gt_map.data.cpu().numpy()) / self.cfg_data.LOG_PARA

            losses.update(loss.item(), b)
            smoothL1_losses.update(loss1.item(), b)
            ssim_losses.update(loss2.item(), b)
            maes.update(abs(gt_count - pred_cnt), b)
            mses.update((gt_count - pred_cnt) * (gt_count - pred_cnt), b)

            loss.backward()
            self.optimizer.step()

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                self.timer['iter time'].toc(average=False)
                print('[ep %d][it %d][loss %.4f][lr*10000 %.4f]' % \
                      (epoch + 1, i + 1, losses.avg, self.optimizer.param_groups[0]['lr'] * 10000))
                print('   [ gt: %.3f pre: %.3f diff: %.3f]' %
                      (gt_count, pred_cnt, abs(gt_count - pred_cnt)))

        print('[epoch %d] ,[mae %.2f mse %.2f], [train loss %.4f]' %
              (epoch + 1, maes.avg, mses.avg, losses.avg))
        logger_txt(self.log_txt,
                   epoch + 1, [maes.avg, mses.avg, losses.avg],
                   phase='train')
        self.writer.add_scalar('ssim', ssim_losses.avg, epoch + 1)
        self.writer.add_scalar('smoothL1', smoothL1_losses.avg, epoch + 1)
        self.writer.add_scalar('train_loss', losses.avg, epoch + 1)
        self.writer.add_scalar('train_mae', maes.avg, epoch + 1)
        self.writer.add_scalar('train_mse', mses.avg, epoch + 1)

    def validate_V1(self,
                    epoch):  # validate_V1 for SHHA, SHHB, UCF-QNRF, UCF50

        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for vi, data in tqdm(enumerate(self.val_loader, 0)):
            img, gt_map = data
            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                    pred_cnt = np.sum(pred_map[i_img]) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img]) / self.cfg_data.LOG_PARA

                    loss1, loss2 = self.net.loss
                    w = pred_map[i_img].shape[-1]
                    h = pred_map[i_img].shape[-2]
                    loss = torch.mul((loss1 + loss2), h * w / (h + w))
                    loss = loss.item()
                    losses.update(loss)
                    maes.update(abs(gt_count - pred_cnt))
                    mses.update((gt_count - pred_cnt) * (gt_count - pred_cnt))

                if epoch % cfg.VAL_FREQ == 0 or (epoch < 10
                                                 and epoch % 3 == 0):
                    if vi % cfg.ITER_DIS == 0:
                        vis_results(self.exp_name, epoch + 1, self.writer,
                                    self.restore_transform, img, pred_map,
                                    gt_map, vi)

        mae = maes.avg
        mse = np.sqrt(mses.avg)
        loss = losses.avg

        self.writer.add_scalar('val_loss', loss, epoch + 1)
        self.writer.add_scalar('mae', mae, epoch + 1)
        self.writer.add_scalar('mse', mse, epoch + 1)

        self.train_record = update_model(self.net, self.optimizer, self.scheduler, epoch + 1, self.i_tb, self.exp_path,
                                         self.exp_name, \
                                         [mae, mse, loss], self.train_record, self.log_txt)
        print_summary(self.exp_name, [mae, mse, loss], self.train_record,
                      epoch + 1)
        if epoch > cfg.LR_DECAY_START:
            cprint('start to change lr', color='yellow')
            if cfg.LR_CHANGER != 'rop':
                self.scheduler.step()
            if cfg.LR_CHANGER == 'rop':
                self.scheduler.step(mae)

    def validate_V2(self):  # validate_V2 for WE

        self.net.eval()

        losses = AverageCategoryMeter(5)
        maes = AverageCategoryMeter(5)

        for i_sub, i_loader in enumerate(self.val_loader, 0):

            for vi, data in enumerate(i_loader, 0):
                img, gt_map = data

                with torch.no_grad():
                    img = Variable(img).cuda()
                    gt_map = Variable(gt_map).cuda()

                    pred_map = self.net.forward(img, gt_map)

                    pred_map = pred_map.data.cpu().numpy()
                    gt_map = gt_map.data.cpu().numpy()

                    for i_img in range(pred_map.shape[0]):
                        pred_cnt = np.sum(
                            pred_map[i_img]) / self.cfg_data.LOG_PARA
                        gt_count = np.sum(
                            gt_map[i_img]) / self.cfg_data.LOG_PARA

                        losses.update(self.net.loss.item(), i_sub)
                        maes.update(abs(gt_count - pred_cnt), i_sub)
                    if vi == 0:
                        vis_results(self.exp_name, self.epoch, self.writer,
                                    self.restore_transform, img, pred_map,
                                    gt_map)

        mae = np.average(maes.avg)
        loss = np.average(losses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)

        self.train_record = update_model(self.net, self.optimizer, self.scheduler, self.epoch, self.i_tb, self.exp_path,
                                         self.exp_name, \
                                         [mae, 0, loss], self.train_record, self.log_txt)
        print_summary(self.exp_name, [mae, 0, loss], self.train_record)

    def validate_V3(self):  # validate_V3 for GCC

        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        c_maes = {
            'level': AverageCategoryMeter(9),
            'time': AverageCategoryMeter(8),
            'weather': AverageCategoryMeter(7)
        }
        c_mses = {
            'level': AverageCategoryMeter(9),
            'time': AverageCategoryMeter(8),
            'weather': AverageCategoryMeter(7)
        }

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map, attributes_pt = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):
                    pred_cnt = np.sum(pred_map) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map) / self.cfg_data.LOG_PARA

                    s_mae = abs(gt_count - pred_cnt)
                    s_mse = (gt_count - pred_cnt) * (gt_count - pred_cnt)

                    loss1, loss2 = self.net.loss
                    loss = loss1.item() + loss2.item()
                    losses.update(loss)
                    maes.update(s_mae)
                    mses.update(s_mse)
                    attributes_pt = attributes_pt.squeeze()
                    c_maes['level'].update(s_mae, attributes_pt[0])
                    c_mses['level'].update(s_mse, attributes_pt[0])
                    c_maes['time'].update(s_mae, attributes_pt[1] / 3)
                    c_mses['time'].update(s_mse, attributes_pt[1] / 3)
                    c_maes['weather'].update(s_mae, attributes_pt[2])
                    c_mses['weather'].update(s_mse, attributes_pt[2])

                if vi == 0:
                    vis_results(self.exp_name, self.epoch, self.writer,
                                self.restore_transform, img, pred_map, gt_map)

        loss = losses.avg
        mae = maes.avg
        mse = np.sqrt(mses.avg)

        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        self.writer.add_scalar('mae', mae, self.epoch + 1)
        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net, self.optimizer, self.scheduler, self.epoch, self.i_tb, self.exp_path,
                                         self.exp_name, \
                                         [mae, mse, loss], self.train_record, self.log_txt)

        print_GCC_summary(self.log_txt, self.epoch, [mae, mse, loss],
                          self.train_record, c_maes, c_mses)