示例#1
0
class Trainer():
    def __init__(self, dataloader, cfg_data, pwd):

        self.cfg_data = cfg_data

        self.data_mode = cfg.DATASET
        self.exp_name = cfg.EXP_NAME
        self.exp_path = cfg.EXP_PATH
        self.pwd = pwd

        self.net_name = cfg.NET
        self.net = CrowdCounter(cfg.GPU_ID, self.net_name).cuda()
        self.optimizer = optim.Adam(self.net.CCN.parameters(),
                                    lr=cfg.LR,
                                    weight_decay=1e-4)
        # self.optimizer = optim.SGD(self.net.parameters(), cfg.LR, momentum=0.95,weight_decay=5e-4)
        self.scheduler = StepLR(self.optimizer,
                                step_size=cfg.NUM_EPOCH_LR_DECAY,
                                gamma=cfg.LR_DECAY)

        self.train_record = {
            'best_mae': 1e20,
            'best_mse': 1e20,
            'best_model_name': ''
        }
        self.timer = {
            'iter time': Timer(),
            'train time': Timer(),
            'val time': Timer()
        }

        self.epoch = 0
        self.i_tb = 0

        self.mae = 1e5
        self.mse = 1e5
        self.ep = 0  # record which epoch gets the better performance

        if cfg.PRE_GCC:
            self.net.load_state_dict(torch.load(cfg.PRE_GCC_MODEL))

        self.train_loader, self.val_loader, self.restore_transform = dataloader

        if cfg.RESUME:
            latest_state = torch.load(cfg.RESUME_PATH)
            self.net.load_state_dict(latest_state['net'])
            self.optimizer.load_state_dict(latest_state['optimizer'])
            self.scheduler.load_state_dict(latest_state['scheduler'])
            self.epoch = latest_state['epoch'] + 1
            self.i_tb = latest_state['i_tb']
            self.train_record = latest_state['train_record']
            self.exp_path = latest_state['exp_path']
            self.exp_name = latest_state['exp_name']

#        self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, 'exp', resume=cfg.RESUME)

    def forward(self):

        # self.validate_V3()
        for epoch in range(self.epoch, cfg.MAX_EPOCH):
            self.epoch = epoch
            if epoch > cfg.LR_DECAY_START:
                self.scheduler.step()

            # training
            self.timer['train time'].tic()
            self.train()
            self.timer['train time'].toc(average=False)

            print('train time: {:.2f}s'.format(self.timer['train time'].diff))
            print('=' * 20)

            # validation
            if (epoch % cfg.VAL_FREQ == 0
                    and epoch > 0) or epoch > cfg.VAL_DENSE_START:
                self.timer['val time'].tic()
                if self.data_mode in [
                        'SHHA', 'SHHB', 'QNRF', 'UCF50', 'Mall', 'FDST'
                ]:
                    self.validate_V1(epoch)
                elif self.data_mode is 'WE':
                    self.validate_V2()
                elif self.data_mode is 'GCC':
                    self.validate_V3()
                self.timer['val time'].toc(average=False)
                print('val time: {:.2f}s'.format(self.timer['val time'].diff))
            torch.save(self.net.cpu().state_dict(),
                       "./weights/Pre_model_{}.pth".format(epoch + 1))
            self.net.cuda()
        print('Best model:', self.ep, 'MAE:', self.mae, 'MSE:', self.mse)

    def train(self):  # training for all datasets
        self.net.train()
        for i, data in enumerate(self.train_loader, 0):
            self.timer['iter time'].tic()
            img, gt_map, img_p, gt_map_p = data
            img = Variable(img).cuda()
            gt_map = Variable(gt_map).cuda()
            img_p = Variable(img_p).cuda()
            gt_map_p = Variable(gt_map_p).cuda()

            self.optimizer.zero_grad()
            pred_map = self.net(img, gt_map, img_p)
            loss = self.net.loss
            loss.backward()
            self.optimizer.step()

            if (i + 1) % cfg.PRINT_FREQ == 0:
                self.i_tb += 1
                #                self.writer.add_scalar('train_loss', loss.item(), self.i_tb)
                self.timer['iter time'].toc(average=False)
                print( '[ep %d][it %d][loss %.4f][lr %.4f][%.2fs]' % \
                        (self.epoch + 1, i + 1, loss.item(), self.optimizer.param_groups[0]['lr']*10000, self.timer['iter time'].diff) )
                print('        [cnt: gt: %.1f pred: %.2f]' %
                      (gt_map[0].sum().data / self.cfg_data.LOG_PARA,
                       pred_map[0].sum().data / self.cfg_data.LOG_PARA))

    def validate_V1(self,
                    epoch):  # validate_V1 for SHHA, SHHB, UCF-QNRF, UCF50

        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map, img_p, gt_map_p = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()
                img_p = Variable(img_p).cuda()
                gt_map_p = Variable(gt_map_p).cuda()

                pred_map = self.net.forward(img, gt_map, img_p)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):

                    pred_cnt = np.sum(pred_map[i_img]) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img]) / self.cfg_data.LOG_PARA

                    losses.update(self.net.loss.item())
                    maes.update(abs(gt_count - pred_cnt))
                    mses.update((gt_count - pred_cnt) * (gt_count - pred_cnt))
#                if vi==0:
#                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)

        mae = maes.avg
        mse = np.sqrt(mses.avg)
        if mae < self.mae:
            self.mae = mae
            self.ep = epoch
        if mse < self.mse:
            self.mse = mse
        loss = losses.avg

        print('[ep %d][loss %.4f][MAE %.4f][MSE %.4f][lr %.4f]' % \
                        (self.epoch + 1, loss, mae, mse, self.optimizer.param_groups[0]['lr']*10000))

#        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
#        self.writer.add_scalar('mae', mae, self.epoch + 1)
#        self.writer.add_scalar('mse', mse, self.epoch + 1)

#        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
#            [mae, mse, loss],self.train_record)
#        print_summary(self.exp_name,[mae, mse, loss],self.train_record)

    def validate_V2(self):  # validate_V2 for WE

        self.net.eval()

        losses = AverageCategoryMeter(5)
        maes = AverageCategoryMeter(5)

        roi_mask = []
        from datasets.WE.setting import cfg_data
        from scipy import io as sio
        for val_folder in cfg_data.VAL_FOLDER:

            roi_mask.append(
                sio.loadmat(
                    os.path.join(cfg_data.DATA_PATH, 'test',
                                 val_folder + '_roi.mat'))['BW'])

        for i_sub, i_loader in enumerate(self.val_loader, 0):

            mask = roi_mask[i_sub]
            for vi, data in enumerate(i_loader, 0):
                img, gt_map = data

                with torch.no_grad():
                    img = Variable(img).cuda()
                    gt_map = Variable(gt_map).cuda()

                    pred_map = self.net.forward(img, gt_map)

                    pred_map = pred_map.data.cpu().numpy()
                    gt_map = gt_map.data.cpu().numpy()

                    for i_img in range(pred_map.shape[0]):

                        pred_cnt = np.sum(
                            pred_map[i_img]) / self.cfg_data.LOG_PARA
                        gt_count = np.sum(
                            gt_map[i_img]) / self.cfg_data.LOG_PARA

                        losses.update(self.net.loss.item(), i_sub)
                        maes.update(abs(gt_count - pred_cnt), i_sub)
#                    if vi==0:
#                        vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)

        mae = np.average(maes.avg)
        loss = np.average(losses.avg)

        #        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        #        self.writer.add_scalar('mae', mae, self.epoch + 1)
        #        self.writer.add_scalar('mae_s1', maes.avg[0], self.epoch + 1)
        #        self.writer.add_scalar('mae_s2', maes.avg[1], self.epoch + 1)
        #        self.writer.add_scalar('mae_s3', maes.avg[2], self.epoch + 1)
        #        self.writer.add_scalar('mae_s4', maes.avg[3], self.epoch + 1)
        #        self.writer.add_scalar('mae_s5', maes.avg[4], self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, 0, loss],self.train_record)
#        print_WE_summary(self.log_txt,self.epoch,[mae, 0, loss],self.train_record,maes)

    def validate_V3(self):  # validate_V3 for GCC

        self.net.eval()

        losses = AverageMeter()
        maes = AverageMeter()
        mses = AverageMeter()

        c_maes = {
            'level': AverageCategoryMeter(9),
            'time': AverageCategoryMeter(8),
            'weather': AverageCategoryMeter(7)
        }
        c_mses = {
            'level': AverageCategoryMeter(9),
            'time': AverageCategoryMeter(8),
            'weather': AverageCategoryMeter(7)
        }

        for vi, data in enumerate(self.val_loader, 0):
            img, gt_map, attributes_pt = data

            with torch.no_grad():
                img = Variable(img).cuda()
                gt_map = Variable(gt_map).cuda()

                pred_map = self.net.forward(img, gt_map)

                pred_map = pred_map.data.cpu().numpy()
                gt_map = gt_map.data.cpu().numpy()

                for i_img in range(pred_map.shape[0]):

                    pred_cnt = np.sum(pred_map[i_img]) / self.cfg_data.LOG_PARA
                    gt_count = np.sum(gt_map[i_img]) / self.cfg_data.LOG_PARA

                    s_mae = abs(gt_count - pred_cnt)
                    s_mse = (gt_count - pred_cnt) * (gt_count - pred_cnt)

                    losses.update(self.net.loss.item())
                    maes.update(s_mae)
                    mses.update(s_mse)
                    attributes_pt = attributes_pt.squeeze()
                    c_maes['level'].update(s_mae, attributes_pt[i_img][0])
                    c_mses['level'].update(s_mse, attributes_pt[i_img][0])
                    c_maes['time'].update(s_mae, attributes_pt[i_img][1] / 3)
                    c_mses['time'].update(s_mse, attributes_pt[i_img][1] / 3)
                    c_maes['weather'].update(s_mae, attributes_pt[i_img][2])
                    c_mses['weather'].update(s_mse, attributes_pt[i_img][2])

#                if vi==0:
#                    vis_results(self.exp_name, self.epoch, self.writer, self.restore_transform, img, pred_map, gt_map)

        loss = losses.avg
        mae = maes.avg
        mse = np.sqrt(mses.avg)

        #        self.writer.add_scalar('val_loss', loss, self.epoch + 1)
        #        self.writer.add_scalar('mae', mae, self.epoch + 1)
        #        self.writer.add_scalar('mse', mse, self.epoch + 1)

        self.train_record = update_model(self.net,self.optimizer,self.scheduler,self.epoch,self.i_tb,self.exp_path,self.exp_name, \
            [mae, mse, loss],self.train_record)

        print_GCC_summary(self.log_txt, self.epoch, [mae, mse, loss],
                          self.train_record, c_maes, c_mses)
示例#2
0
def test(file_list, model_path):

    net = CrowdCounter(cfg.GPU_ID, cfg.NET)
    net.load_state_dict(
        torch.load(model_path, map_location=torch.device("cpu")))
    net.to("cpu")
    #net.cuda()
    net.cpu()
    net.eval()

    f1 = plt.figure(1)

    difftotal = 0
    difftotalsqr = 0
    gts = []
    preds = []

    counter = 0
    for filename in file_list:
        print(filename)
        counter = counter + 1
        imgname = dataRoot + '/img/' + filename
        filename_no_ext = filename.split('.')[0]

        denname = dataRoot + '/den/' + filename_no_ext + '.csv'

        den = pd.read_csv(denname, sep=',', header=None).values
        den = den.astype(np.float32, copy=False)

        img = Image.open(imgname)

        if img.mode == 'L':
            img = img.convert('RGB')

        #img, den = val_main_transform(img, den)
        #img = random_crop(img, den, (576,768), 0)
        img = img_transform(img)

        gt = np.sum(den)
        with torch.no_grad():
            img = Variable(img[None, :, :, :]).cpu()
            pred_map = net.test_forward(img)
        #print(pred_map.size())

        sio.savemat(exp_name + '/pred/' + filename_no_ext + '.mat',
                    {'data': pred_map.squeeze().cpu().numpy() / 100.})
        sio.savemat(exp_name + '/gt/' + filename_no_ext + '.mat',
                    {'data': den})

        pred_map = pred_map.cpu().data.numpy()[0, 0, :, :]

        pred = np.sum(pred_map) / 100.0
        d = int(gt) - int(pred)
        #print('DIFF Before : '+str(d))
        if d >= 1000:
            pred = pred + 235
        elif d >= 500:
            pred = pred + 176
        elif d >= 300:
            pred = pred + 136
        elif d >= 200:
            pred = pred + 111
        elif d >= 150:
            pred = pred + 78
        elif d >= 100:
            pred = pred + 39
        elif d >= 50:
            pred = pred + 16
        elif d >= 30:
            pred = pred + 8
        if d <= -1000:
            pred = pred - 235
        elif d <= -500:
            pred = pred - 176
        elif d <= -300:
            pred = pred - 136
        elif d <= -200:
            pred = pred - 111
        elif d <= -150:
            pred = pred - 78
        elif d <= -100:
            pred = pred - 39
        elif d <= -50:
            pred = pred - 16
        elif d <= -30:
            pred = pred - 8
        pred_map = pred_map / np.max(pred_map + 1e-20)

        d = int(gt) - int(pred)
        #print('DIFF After : '+str(d))

        den = den / np.max(den + 1e-20)

        den_frame = plt.gca()
        plt.imshow(den, 'jet')
        den_frame.axes.get_yaxis().set_visible(False)
        den_frame.axes.get_xaxis().set_visible(False)
        den_frame.spines['top'].set_visible(False)
        den_frame.spines['bottom'].set_visible(False)
        den_frame.spines['left'].set_visible(False)
        den_frame.spines['right'].set_visible(False)
        plt.savefig(exp_name+'/'+filename_no_ext+'_gt_'+str(int(gt))+'.png',\
            bbox_inches='tight',pad_inches=0,dpi=150)

        plt.close()

        # sio.savemat(exp_name+'/'+filename_no_ext+'_gt_'+str(int(gt))+'.mat',{'data':den})

        pred_frame = plt.gca()
        plt.imshow(pred_map, 'jet')
        pred_frame.axes.get_yaxis().set_visible(False)
        pred_frame.axes.get_xaxis().set_visible(False)
        pred_frame.spines['top'].set_visible(False)
        pred_frame.spines['bottom'].set_visible(False)
        pred_frame.spines['left'].set_visible(False)
        pred_frame.spines['right'].set_visible(False)
        plt.savefig(exp_name+'/'+filename_no_ext+'_pred_'+str(float(pred))+'.png',\
            bbox_inches='tight',pad_inches=0,dpi=150)

        plt.close()

        # sio.savemat(exp_name+'/'+filename_no_ext+'_pred_'+str(float(pred))+'.mat',{'data':pred_map})

        if den.shape[0] < pred_map.shape[0]:
            temp = np.zeros((pred_map.shape[0] - den.shape[0], den.shape[1]))
            den = np.concatenate((den, temp), axis=0)
        elif den.shape[0] > pred_map.shape[0]:
            temp = np.zeros(
                (den.shape[0] - pred_map.shape[0], pred_map.shape[1]))
            pred_map = np.concatenate((pred_map, temp), axis=0)

        if den.shape[1] < pred_map.shape[1]:
            temp = np.zeros((den.shape[0], pred_map.shape[1] - den.shape[1]))
            den = np.concatenate((den, temp), axis=1)
        elif den.shape[1] > pred_map.shape[1]:
            temp = np.zeros(
                (pred_map.shape[0], den.shape[1] - pred_map.shape[1]))
            pred_map = np.concatenate((pred_map, temp), axis=1)

        diff = den - pred_map

        diff_frame = plt.gca()
        plt.imshow(diff, 'jet')
        plt.colorbar()
        diff_frame.axes.get_yaxis().set_visible(False)
        diff_frame.axes.get_xaxis().set_visible(False)
        diff_frame.spines['top'].set_visible(False)
        diff_frame.spines['bottom'].set_visible(False)
        diff_frame.spines['left'].set_visible(False)
        diff_frame.spines['right'].set_visible(False)
        plt.savefig(exp_name+'/'+filename_no_ext+'_diff.png',\
            bbox_inches='tight',pad_inches=0,dpi=150)

        plt.close()

        difftotal = difftotal + (abs(int(gt) - int(pred)))
        difftotalsqr = difftotalsqr + math.pow(int(gt) - int(pred), 2)

        MAE = float(difftotal) / counter
        MSE = math.sqrt(difftotalsqr / counter)