Exemple #1
0
class NLT_Trainer(object):
    def __init__(self, cfg_data, pwd):
        self.cfg_data = cfg_data
        self.pwd = pwd
        self.exp_path = cfg.EXP_PATH
        self.exp_name = cfg.EXP_NAME
        if not osp.exists(self.exp_path):
            os.makedirs(self.exp_path)

        self.sou_loader, self.tar_shot_loader, self.tar_val_loader, self.tar_test_loader, self.restore_transform = loading_data(cfg)

        self.sou_model = NLT_Counter( backbone=cfg.model_type)
        self.tar_model = NLT_Counter( mode='nlt', backbone=cfg.model_type)

        self.sou_optimizer = torch.optim.Adam(self.sou_model.parameters(), lr = cfg.nlt_lr, weight_decay=cfg.nlt_lr_decay)

        self.tar_optimizer = torch.optim.Adam(
            [{'params': filter(lambda p: p.requires_grad, self.tar_model.encoder.parameters()), 'lr': cfg.nlt_lr}, \
            {'params': filter(lambda p: p.requires_grad, self.tar_model.decoder.parameters()), 'lr':cfg.nlt_lr}])

        self.sou_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.sou_optimizer, step_size=cfg.step_size, gamma=cfg.gamma)
        self.tar_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.tar_optimizer, step_size=cfg.step_size, gamma=cfg.gamma)
        #
        if cfg.init_weights is not None:
            self.pretrained_dict = torch.load(cfg.init_weights)  # ['params']
            self.sou_model.load_state_dict(self.pretrained_dict,strict=False)
            self.tar_model.load_state_dict(self.pretrained_dict,strict=False)
        os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id
        self.sou_model = torch.nn.DataParallel(self.sou_model).cuda()
        self.tar_model = torch.nn.DataParallel(self.tar_model).cuda()

        self.tar_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0,
                                 "temp_test_mae": 1e20, "temp_test_mse": 1e20}

        self.sou_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0,
                                 "temp_test_mae": 1e20, "temp_test_mse": 1e20 }
        self.epoch = 0
        self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, ["exp"])

    def forward(self):
        timer = Timer()
        self.global_count = 0
        for epoch in range(1, cfg.max_epoch + 1):
            self.epoch = epoch
            self.train()
            self.fine_tune()
            if self.epoch % cfg.val_freq == 0:
                if cfg.target_dataset is "WE":
                    self.tar_model_V2(self.tar_val_loader, "val")
                if cfg.target_dataset in ["VENICE", "QNRF", "SHHA", "SHHB",  "MALL", "UCSD"]:
                    self.tar_model_V1(self.tar_val_loader, "val")

                print('=' * 50)
                print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(self.epoch / cfg.max_epoch)))
            self.sou_lr_scheduler.step()
            self.tar_lr_scheduler.step()
        self.writer.close()

    def train(self):
        self.sou_model.train()
        self.tar_model.train()
        train_loss = AverageMeter()
        train_mae = AverageMeter()
        train_mse = AverageMeter()

        shot_loss = AverageMeter()
        shot_mae = AverageMeter()
        shot_mse = AverageMeter()

        for i, (a, b) in enumerate(zip(self.sou_loader, self.tar_shot_loader), 1):
            self.global_count = self.global_count + 1
            sou_img,sou_label = a[0].cuda(),a[1].cuda()
            shot_img, shot_label = b[0].cuda(), b[1].cuda()
            if self.epoch <cfg.DA_stop_epoch:

                # ==================change sou_model parameters===============
                sou_pred = self.sou_model(sou_img)
                loss = F.mse_loss(sou_pred.squeeze(), sou_label.squeeze())

                self.sou_optimizer.zero_grad()
                loss.backward()
                self.sou_optimizer.step()
                self.writer.add_scalar('data/sou_loss', loss.item(), self.global_count)
                train_loss.update(loss.item())
                sou_pred_cnt, sou_label_cnt = self.mae_mse_update(sou_pred, sou_label, train_mae, train_mse)
                self.tar_model.load_state_dict(self.sou_model.state_dict(), strict=False)
                # ================================================================
            else:
                sou_label_cnt=0
                sou_pred_cnt=0
                #=====================change tar_model parameters================

            shot_pred = self.tar_model(shot_img)
            loss_mse = F.mse_loss(shot_pred.squeeze(), shot_label.squeeze())
            loss = self.weight_decay_loss(self.tar_model, 1e-4) + loss_mse
            self.tar_optimizer.zero_grad()
            loss.backward()
            self.tar_optimizer.step()
            self.writer.add_scalar('data/shot_loss', loss.item(), self.global_count)
            shot_loss.update(loss.item())
            pred_cnt, label_cnt = self.mae_mse_update(shot_pred, shot_label, shot_mae, shot_mse)

            # ===============================================================
            if i % cfg.print_freq == 0:
                print('Epoch {}, Loss={:.4f} s_gt={:.1f} s_pre={:.1f},t_gt={:.1f} t_pre={:.1f} lr={:.4f}'.format(
                    self.epoch, loss.item(), sou_label_cnt,sou_pred_cnt,label_cnt, pred_cnt, self.sou_optimizer.param_groups[0]['lr']*10000))

        self.writer.add_scalar('data/train_loss_tar', float(shot_loss.avg), self.epoch)
        self.writer.add_scalar('data/train_mae_tar', float(shot_mae.avg), self.epoch)
        self.writer.add_scalar('data/train_mse_tar', float( np.sqrt(shot_mse.avg)), self.epoch)

        self.writer.add_scalar('data/train_loss_sou', float(train_loss.avg), self.epoch)
        self.writer.add_scalar('data/train_mae_sou', float(train_mae.avg), self.epoch)
        self.writer.add_scalar('data/train_mse_sou', float(np.sqrt(train_mse.avg)), self.epoch)

        # Start validation for this epoch, set model to eval mode
    def fine_tune(self):
        for i, (shot_img, shot_label) in enumerate(self.tar_shot_loader, 1):
            if i <= 50:
                shot_img = shot_img.cuda()
                shot_label = shot_label.cuda()
                shot_pred = self.tar_model(shot_img)

                loss_mse = F.mse_loss(shot_pred.squeeze(), shot_label.squeeze())
                loss = self.weight_decay_loss(self.tar_model, 1e-4) + loss_mse
                self.tar_optimizer.zero_grad()
                loss.backward()
                self.tar_optimizer.step()
            else:
                break

    def tar_model_V2(self, dataset, mode=None):# Run meta-validatio
        self.tar_model.eval()
        losses = AverageCategoryMeter(5)
        maes = AverageCategoryMeter(5)
        val_losses = AverageMeter()
        val_maes = AverageMeter()
        if mode =='val' :
            for i, batch in enumerate(dataset, 1):
                with torch.no_grad():
                    img = batch[0].cuda()
                    label = batch[1].cuda()
                    pred = self.tar_model(img)
                    self.mae_mse_update(pred, label, val_maes, losses=val_losses)
            mae = np.average(val_maes.avg)
            loss = np.average(val_losses.avg)

            self.writer.add_scalar('data/val_mae', mae, self.epoch)
            self.writer.add_scalar('data/val_loss',loss, self.epoch)
            self.tar_model_record = update_model(
                self.tar_model.module, self.epoch, self.exp_path, self.exp_name, [mae, 0, loss], self.tar_model_record,
                self.log_txt)
            print_summary(self.exp_name, [mae, 0, loss], self.tar_model_record)

        else:
            for i_sub, i_loader in enumerate(dataset, 0):
                for i, batch in enumerate(i_loader, 1):
                    with torch.no_grad():
                        img = batch[0].cuda()
                        label = batch[1].cuda()
                        pred = self.tar_model(img)
                        self.mae_mse_update(pred,label,maes=maes,losses=losses,cls_id=i_sub)
                        if i == 1 and self.epoch%10==0:
                            vis_results(self.epoch, self.writer, self.restore_transform,
                                        img, pred.data.cpu().numpy(), label.data.cpu().numpy(), self.exp_name)

            mae = np.average(maes.avg)
            loss = np.average(losses.avg)

            self.writer.add_scalar("data/mae_s1", maes.avg[0], self.epoch)
            self.writer.add_scalar("data/mae_s2", maes.avg[1], self.epoch)
            self.writer.add_scalar("data/mae_s3", maes.avg[2], self.epoch)
            self.writer.add_scalar("data/mae_s4", maes.avg[3], self.epoch)
            self.writer.add_scalar("data/mae_s5", maes.avg[4], self.epoch)

            self.writer.add_scalar("data/test_mae", float(mae), self.epoch)
            self.writer.add_scalar('data/test_loss', float(loss), self.epoch)
            logger_txt(self.log_txt, self.epoch, [mae, 0, loss])
            self.tar_model_record['temp_test_mae'] = mae
            self.tar_model_record['temp_test_mse'] = 0
        # Print loss and maeuracy for this epoch


    def tar_model_V1(self, dataset, mode=None):
        self.tar_model.eval()
        losses = AverageMeter()
        maes  = AverageMeter()
        mses = AverageMeter()
        ssims = AverageMeter()
        psnrs = AverageMeter()

        # tqdm_gen = tqdm.tqdm(dataset)
        for i, batch in enumerate(dataset, 1):
            with torch.no_grad():
                img = batch[0].cuda()
                label = batch[1].cuda()
                pred = self.tar_model(img)
                if mode == 'test':
                    self.mae_mse_update(pred, label, maes, mses, ssims,psnrs,losses)
                else:
                    self.mae_mse_update(pred, label, maes, mses, losses=losses)
                if i == 1 and self.epoch%10==0:
                    vis_results(self.epoch, self.writer, self.restore_transform,
                                img, pred.data.cpu().numpy(), label.cpu().detach().numpy(), self.exp_name)
        mae = maes.avg
        mse = np.sqrt(mses.avg)
        loss = losses.avg

        if mode == "val":
            self.writer.add_scalar('data/val_mae', mae, self.epoch)
            self.writer.add_scalar('data/val_mse', mse, self.epoch)
            self.writer.add_scalar('data/val_loss',loss, self.epoch)
            self.tar_model_record = update_model(
                self.tar_model.module, self.epoch, self.exp_path, self.exp_name, [mae, mse, loss], self.tar_model_record,
                self.log_txt)
            print_summary(self.exp_name, [mae, mse, loss], self.tar_model_record)

        elif mode == "test":
            self.writer.add_scalar('data/test_mae', mae, self.epoch)
            self.writer.add_scalar('data/test_mse', mse, self.epoch)
            self.writer.add_scalar('data/test_loss',loss, self.epoch)
            self.writer.add_scalar("data/test_ssim", ssims.avg, self.epoch)
            self.writer.add_scalar("data/test_psnr", psnrs.avg, self.epoch)

            self.tar_model_record['temp_test_mae'] = mae
            self.tar_model_record['temp_test_mse'] = mse
            logger_txt(self.log_txt, self.epoch, [mae, mse, loss])



    def weight_decay_loss(self,model, lamda):
        loss_weight = 0
        loss_bias = 0
        for name, param in model.named_parameters():
            if  'nlt_weight' in name:
                loss_weight += 0.5 * torch.sum(torch.pow(param - 1, 2))
            elif  'nlt_bias' in name:
                loss_bias   += 0.5 * torch.sum(torch.pow(param, 2))
            return lamda*loss_weight + lamda*10*loss_bias

    def mae_mse_update(self,pred,label,maes,mses=None,ssims=None,psnrs=None,losses=None,cls_id=None):
        for num in range(pred.size()[0]):
            sub_pred = pred[num].data.cpu().squeeze().numpy()/ self.cfg_data.LOG_PARA
            sub_label = label[num].data.cpu().squeeze().numpy() / self.cfg_data.LOG_PARA
            pred_cnt = np.sum(sub_pred)
            gt_cnt =   np.sum(sub_label)
            mae = abs(pred_cnt - gt_cnt)
            mse = (pred_cnt - gt_cnt)*(pred_cnt - gt_cnt)

            if ssims and psnrs is not None:
                ssims.update(get_ssim(sub_label,sub_pred))
                psnrs.update(get_psnr(sub_label,sub_pred))

            if cls_id is not None:
                maes.update(mae,cls_id)
                if losses is not None:
                    loss = F.mse_loss(pred.detach().squeeze(), label.detach().squeeze())
                    losses.update(loss.item(),cls_id)
                if mses is not None:
                    mses.update(mse,cls_id)
            else:
                maes.update(mae)
                if losses is not None:
                    loss = F.mse_loss(pred.detach().squeeze(), label.detach().squeeze())
                    losses.update(loss.item())
                if mses is not None:
                    mses.update(mse)

        return pred_cnt,gt_cnt
Exemple #2
0
class Fine_tune_Trainer(object):
    def __init__(self, cfg_data, pwd):

        self.cfg_data = cfg_data
        self.pwd = pwd
        self.exp_path = cfg.EXP_PATH
        self.exp_name = cfg.EXP_NAME
        self.exp_path = osp.join(self.exp_path, 'fine_tune')
        if not osp.exists(self.exp_path):
            os.mkdir(self.exp_path)

        self.sou_query_loader, self.tar_shot_loader, self.tar_val_loader, self.tar_test_loader,self.restore_transform = loading_data(cfg)

        self.sou_model = NLT_Counter( mode='fine_tune', backbone=cfg.model_type)

        self.sou_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.sou_model.parameters()), lr = cfg.fine_lr, weight_decay=cfg.fine_weight_decay)

        self.sou_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.sou_optimizer, step_size=cfg.fine_step_size, gamma=cfg.fine_gamma)

        if cfg.GCC_pre_train_model is not None:
            print('load GCC pre_trained model')
            self.pretrained_dict = torch.load(cfg.GCC_pre_train_model)
            self.sou_model.load_state_dict(self.pretrained_dict)

        self.sou_model = torch.nn.DataParallel(self.sou_model).cuda()

        self.sou_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0,
                                 "temp_test_mae": 1e20, "temp_test_mse": 1e20}
        self.epoch = 0
        self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, ["exp"])

    def forward(self):
        timer = Timer()
        self.global_count = 0
        for epoch in range(1, cfg.max_epoch + 1):
            self.train()
            self.epoch = epoch
            if self.epoch % cfg.val_freq == 0:
                self.sou_model_V1(self.tar_val_loader, "val")
                print('=' * 50)
                print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(self.epoch / cfg.max_epoch)))
            self.sou_lr_scheduler.step()
        self.writer.close()
    def train(self):
        self.sou_model.train()

        train_loss = AverageMeter()
        train_mae = AverageMeter()
        train_mse = AverageMeter()

        for i, (img,gt_map) in enumerate( self.tar_shot_loader, 1):
            self.global_count = self.global_count + 1
            shot_img, shot_label  = img.cuda(), gt_map.cuda()

            # ==================change sou_model parameters===============
            shot_pred = self.sou_model(shot_img)
            loss = F.mse_loss(shot_pred.squeeze(), shot_label.squeeze())
            self.sou_optimizer.zero_grad()
            loss.backward()
            self.sou_optimizer.step()
            train_loss.update(loss.item())
            self.writer.add_scalar('data/fine_tune_loss', float(loss), self.global_count)
            sou_pred_cnt, sou_label_cnt = self.mae_mse_update(shot_pred, shot_label, train_mae, train_mse)

            # ===============================================================
            if i % 50 == 0:
                print('Epoch {}, Loss={:.4f} s_gt={:.1f} s_pre={:.1f}'.format(
                    self.epoch, loss.item(), sou_label_cnt,sou_pred_cnt))

        self.writer.add_scalar('data/train_loss_tar', float(train_loss.avg), self.epoch)
        self.writer.add_scalar('data/train_mae_tar', float(train_mae.avg), self.epoch)
        self.writer.add_scalar('data/train_mse_tar', float(np.sqrt(train_mse.avg)), self.epoch)


        # Start validation for this epoch, set model to eval mode

    def validation(self):# Run meta-validation
        self.sou_model.eval()
        if cfg.target_dataset in ["WE", "SHFD"]:
            val_loss =AverageCategoryMeter(5)
            val_mae = AverageCategoryMeter(5)
            # self.tar_model.eval()
            for i_sub, i_loader in enumerate(self.tar_val_loader, 0):
                tqdm_gen = tqdm.tqdm(i_loader)
                for i, batch in enumerate(tqdm_gen, 1):
                    img = batch[0].cuda()
                    gt_map = batch[1].cuda()
                    with torch.no_grad():
                        pred = self.sou_model(inp=img)
                        self.mae_mse_update(pred, gt_map, val_mae,losses=val_loss,cls_id=i_sub)
                        if i == 1 :
                            vis_results(self.epoch, self.writer, self.restore_transform,
                                        img, pred.data.cpu().numpy(), gt_map.data.cpu().numpy(),  'temp_val/sou')

            mae = np.average(val_mae.avg)
            loss = np.average(val_loss.avg)

            self.writer.add_scalar("data/mae_s1", val_mae.avg[0], self.epoch)
            self.writer.add_scalar("data/mae_s2", val_mae.avg[1], self.epoch)
            self.writer.add_scalar("data/mae_s3", val_mae.avg[2], self.epoch)
            self.writer.add_scalar("data/mae_s4", val_mae.avg[3], self.epoch)
            self.writer.add_scalar("data/mae_s5", val_mae.avg[4], self.epoch)

            self.writer.add_scalar("data/tar_val_mae", float(mae), self.epoch)
            self.writer.add_scalar('data/tar_val_loss', float(loss), self.epoch)


            # Print loss and maeuracy for this epoch
            self.record = update_model(
                self.sou_model.module, self.epoch, self.exp_path, self.exp_name, [mae, 0, loss], self.record,
                self.log_txt)
            print('Epoch {}, Val, mae={:.2f} mse={:.2f}'.format(self.epoch, mae, 0))
            self.record['val_loss'].append(loss)
            self.record['val_mae'].append(mae)

    def sou_model_V1(self, dataset, mode=None):
        self.sou_model.eval()
        losses = AverageMeter()
        maes  = AverageMeter()
        mses = AverageMeter()
        ssims = AverageMeter()
        psnrs = AverageMeter()

        # tqdm_gen = tqdm.tqdm(dataset)
        for i, batch in enumerate(dataset, 1):
            with torch.no_grad():
                img = batch[0].cuda()
                label = batch[1].cuda()
                pred = self.sou_model(img)
                if mode == 'test':
                    self.mae_mse_update(pred, label, maes, mses, ssims,psnrs,losses)
                else:
                    self.mae_mse_update(pred, label, maes, mses, losses=losses)
                if i == 1 and self.epoch%10==0:
                    vis_results(self.epoch, self.writer, self.restore_transform,
                                img, pred.data.cpu().numpy(), label.cpu().detach().numpy(), self.exp_name)
        mae = maes.avg
        mse = np.sqrt(mses.avg)
        loss = losses.avg

        if mode == "val":
            self.writer.add_scalar('data/val_mae', mae, self.epoch)
            self.writer.add_scalar('data/val_mse', mse, self.epoch)
            self.writer.add_scalar('data/val_loss',loss, self.epoch)
            self.tar_model_record = update_model(
                self.sou_model.module, self.epoch, self.exp_path, self.exp_name, [mae, mse, loss], self.sou_model_record,
                self.log_txt)
            print_summary(self.exp_name, [mae, mse, loss], self.sou_model_record)

        elif mode == "test":
            self.writer.add_scalar('data/test_mae', mae, self.epoch)
            self.writer.add_scalar('data/test_mse', mse, self.epoch)
            self.writer.add_scalar('data/test_loss',loss, self.epoch)
            self.writer.add_scalar("data/test_ssim", ssims.avg, self.epoch)
            self.writer.add_scalar("data/test_psnr", psnrs.avg, self.epoch)

            self.tar_model_record['temp_test_mae'] = mae
            self.tar_model_record['temp_test_mse'] = mse
            logger_txt(self.log_txt, self.epoch, [mae, mse, loss])
    def weight_decay_loss(self,model, lamda):
        loss_weight = 0
        loss_bias = 0
        for name, param in model.named_parameters():
            if 'mtl_weight' in name:
                loss_weight += 0.5 * torch.sum(torch.pow(param - 1, 2))
            elif 'mtl_bias' in name:
                loss_bias   += 0.5 * torch.sum(torch.pow(param,2))
            return lamda*loss_weight + lamda*loss_bias

    def mae_mse_update(self,pred,label,maes,mses=None,ssims=None,psnrs=None,losses=None,cls_id=None):
        for num in range(pred.size()[0]):
            sub_pred = pred[num].data.cpu().squeeze().numpy()/ self.cfg_data.LOG_PARA
            sub_label = label[num].data.cpu().squeeze().numpy() / self.cfg_data.LOG_PARA
            pred_cnt = np.sum(sub_pred)
            gt_cnt =   np.sum(sub_label)
            mae = abs(pred_cnt - gt_cnt)
            mse = (pred_cnt - gt_cnt)*(pred_cnt - gt_cnt)

            if ssims and psnrs is not None:
                ssims.update(get_ssim(sub_label,sub_pred))
                psnrs.update(get_psnr(sub_label,sub_pred))

            if cls_id is not None:
                maes.update(mae,cls_id)
                if losses is not None:
                    loss = F.mse_loss(pred.detach().squeeze(), label.detach().squeeze())
                    losses.update(loss.item(),cls_id)
                if mses is not None:
                    mses.update(mse,cls_id)
            else:
                maes.update(mae)
                if losses is not None:
                    loss = F.mse_loss(pred.detach().squeeze(), label.detach().squeeze())
                    losses.update(loss.item())
                if mses is not None:
                    mses.update(mse)

        return pred_cnt,gt_cnt
Exemple #3
0
class PreTrainer(object):
    """The class that contains the code for the pretrain phase."""
    def __init__(self, cfg, pwd):
        # Set the folder to save the records and checkpoints
        # Set cfg to be shareable in the class
        self.cfg_data = cfg_data
        self.pwd = pwd
        self.exp_path = cfg.EXP_PATH
        self.exp_name = cfg.EXP_NAME
        self.exp_path = osp.join(self.exp_path, 'pre')
        self.train_loader, self.val_loader, self.restore_transform = loading_data(
            cfg)

        self.model = NLT_Counter(mode='pre', backbone=cfg.model_type)
        if cfg.init_weights is not None:
            self.pretrained_dict = torch.load(cfg.init_weights)  # ['params']
            self.model.load_state_dict(self.pretrained_dict)

        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          lr=cfg.pre_lr,
                                          weight_decay=cfg.pre_weight_decay)

        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=cfg.pre_step_size, gamma=cfg.pre_gamma)

        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id
            self.model = torch.nn.DataParallel(self.model).cuda()

        self.record = {}
        self.record['train_loss'] = []
        self.record['train_mae'] = []
        self.record['train_mse'] = []

        self.record['val_loss'] = []
        self.record['val_mae'] = []
        self.record['val_mse'] = []

        self.record['best_mae'] = 1e10
        self.record['best_mse'] = 1e10
        self.record['best_model_name'] = ''

        self.record['update_flag'] = 0

        self.writer, self.log_txt = logger(self.exp_path, self.exp_name,
                                           self.pwd, ["exp"])

    def save_model(self, name):
        torch.save(dict(params=self.model.module.state_dict()),
                   osp.join(self.exp_path, self.exp_name, name + '.pth'))

    def train(self):
        """The function for the pre_train on GCC dataset."""
        # Set the timer
        timer = Timer()
        # Set global count to zero
        global_count = 0
        for epoch in range(1, cfg.pre_max_epoch + 1):
            self.model.train()
            train_loss_avg = Averager()
            train_mae_avg = Averager()
            train_mse_avg = Averager()

            # Using tqdm to read samples from train loader
            tqdm_gen = tqdm.tqdm(self.train_loader)
            for i, batch in enumerate(tqdm_gen, 1):
                global_count = global_count + 1

                img = batch[0].cuda()
                label = batch[1].cuda()

                pred = self.model(img)
                loss = F.mse_loss(pred.squeeze(), label)

                # Print loss and maeuracy for this step
                label_cnt = label.sum().data / self.cfg_data.LOG_PARA
                pred_cnt = pred.sum().data / self.cfg_data.LOG_PARA
                mae = torch.abs(label_cnt - pred_cnt).item()
                mse = (label_cnt - pred_cnt).pow(2).item()

                tqdm_gen.set_description(
                    'Epoch {}, Loss={:.4f} gt={:.1f} pred={:.1f} lr={:.4f}'.
                    format(epoch, loss.item(), label_cnt, pred_cnt,
                           self.optimizer.param_groups[0]['lr'] * 10000))
                #     # Add loss and maeuracy for the averagers
                train_loss_avg.add(loss.item())
                train_mae_avg.add(mae)
                train_mse_avg.add(mse)

                # Loss backwards and optimizer updates
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # Update the averagers
            train_loss_avg = train_loss_avg.item()
            train_mae_avg = train_mae_avg.item()
            train_mse_avg = np.sqrt(train_mse_avg.item())

            self.writer.add_scalar('data/loss', train_loss_avg, global_count)
            self.writer.add_scalar('data/mae', train_mae_avg, global_count)
            self.writer.add_scalar('data/mse', train_mse_avg, global_count)
            # Start validation for this epoch, set model to eval mode

            self.model.eval()
            val_loss_avg = Averager()
            val_mae_avg = Averager()
            val_mse_avg = Averager()

            # Print previous information
            if epoch % 10 == 0:
                print('Best Epoch {}, Best Val mae={:.2f} mae={:.2f}'.format(
                    self.record['best_model_name'], self.record['best_mae'],
                    self.record['best_mse']))
            # Run validation
            for i, batch in enumerate(self.val_loader, 1):
                # print(i)
                with torch.no_grad():
                    data = batch[0].cuda()
                    label = batch[1].cuda()
                    pred = self.model(inp=data)
                    loss = F.mse_loss(pred.squeeze(), label)
                    val_loss_avg.add(loss.item())

                    for img in range(pred.size()[0]):
                        pred_cnt = (pred[img] /
                                    self.cfg_data.LOG_PARA).sum().data
                        gt_cnt = (label[img] /
                                  self.cfg_data.LOG_PARA).sum().data
                        mae = torch.abs(pred_cnt - gt_cnt).item()
                        mse = (pred_cnt - gt_cnt).pow(2).item()
                        val_mae_avg.add(mae)
                        val_mse_avg.add(mse)

            # Update validation averagers
            val_loss_avg = val_loss_avg.item()
            val_mae_avg = val_mae_avg.item()
            val_mse_avg = np.sqrt(val_mse_avg.item())

            self.writer.add_scalar('data/val_loss', float(val_loss_avg), epoch)
            self.writer.add_scalar('data/val_mae', float(val_mae_avg), epoch)
            self.writer.add_scalar('data/val_mse', float(val_mse_avg), epoch)
            # Print loss and maeuracy for this epoch
            print('Epoch {}, Val, Loss={:.4f} mae={:.4f}  mse={:.4f}'.format(
                epoch, val_loss_avg, val_mae_avg, val_mse_avg))

            # Save model every 10 epochs
            if epoch % 10 == 0:
                self.save_model('epoch' + str(epoch) + '_' + str(val_mae_avg))

            # Update the logs
            self.record['train_loss'].append(train_loss_avg)
            self.record['train_mae'].append(train_mae_avg)
            self.record['train_mse'].append(train_mse_avg)

            self.record['val_loss'].append(val_loss_avg)
            self.record['val_mae'].append(val_mae_avg)

            self.record = update_model(
                self.model.module, epoch, self.exp_path, self.exp_name,
                [val_mae_avg, val_mse_avg, val_loss_avg], self.record,
                self.log_txt)

            if epoch % 10 == 0:
                print('Running Time: {}, Estimated Time: {}'.format(
                    timer.measure(), timer.measure(epoch / cfg.max_epoch)))
        self.lr_scheduler.step()
        self.writer.close()
Exemple #4
0
class den_test:
    def __init__(self, model_path, tar_list, tarRoot, cfg_data, img_transform):
        self.cfg_data = cfg_data
        self.img_transform = img_transform
        self.tarRoot = tarRoot
        self.net = NLT_Counter(mode='nlt', backbone=cfg.model_type)
        self.net.load_state_dict(torch.load(model_path))
        self.net = torch.nn.DataParallel(self.net).cuda()
        self.net.eval()
        with open(tar_list) as f:
            lines = f.readlines()
        self.tar_list = []
        for line in lines:
            line = line.strip('\n')
            self.tar_list.append(line)

    def forward(self):
        score = {'MAE': 0, 'MSE': 0, 'PSNR': 0, 'SSIM': 0}
        count = 0
        tar_list = tqdm.tqdm(self.tar_list)
        for fname in tar_list:
            count += 1
            imgname = os.path.join(self.tarRoot + "/train/img/" + fname +
                                   '.jpg')
            # filename_no_ext = filename.split('.')[0]
            denname = imgname.replace('img', 'den').replace('jpg', 'csv')
            # denname   = os.path.join(self.tarRoot + "/test/den/" + fname + ".csv")
            den = pd.read_csv(denname, sep=',', header=None).values
            den = den.astype(np.float32, copy=False)
            img = Image.open(imgname)

            if img.mode == 'L':
                img = img.convert('RGB')
            img = self.img_transform(img)
            gt = np.sum(den)
            img = img[None, :, :, :].cuda()

            pred_map = self.net(img)

            pred_map = pred_map.cpu().data.numpy()[0, 0, :, :]
            pred = np.sum(pred_map) / self.cfg_data.LOG_PARA

            score['MAE'] += np.abs(gt - pred)

            score['MSE'] += (gt - pred) * (gt - pred)
            score['SSIM'] += get_ssim(den, pred_map)
            score['PSNR'] += get_psnr(den, pred_map)

            pred_map = pred_map / np.max(pred_map + 1e-20)

            den = den / np.max(den + 1e-20)

            den_frame = plt.gca()
            plt.imshow(den, cmap='jet')
            den_frame.axes.get_yaxis().set_visible(False)
            den_frame.axes.get_xaxis().set_visible(False)
            den_frame.spines['top'].set_visible(False)
            den_frame.spines['bottom'].set_visible(False)
            den_frame.spines['left'].set_visible(False)
            den_frame.spines['right'].set_visible(False)
            plt.savefig(den_path+'/'+fname+'_gt_'+str(int(gt))+'.png',\
                bbox_inches='tight',pad_inches=0,dpi=600)

            plt.close()

            # sio.savemat(exp_name+'/'+filename_no_ext+'_gt_'+str(int(gt))+'.mat',{'data':den})

            pred_frame = plt.gca()
            plt.imshow(pred_map, cmap='jet')
            pred_frame.axes.get_yaxis().set_visible(False)
            pred_frame.axes.get_xaxis().set_visible(False)
            pred_frame.spines['top'].set_visible(False)
            pred_frame.spines['bottom'].set_visible(False)
            pred_frame.spines['left'].set_visible(False)
            pred_frame.spines['right'].set_visible(False)
            plt.savefig(den_path+'/'+fname+'_DA_'+str(float(pred))+'.png',\
                bbox_inches='tight',pad_inches=0,dpi=600)
            plt.close()

        score['MAE'], score['MSE'] = score['MAE'] / count, np.sqrt(
            score['MSE'] / count)
        score['SSIM'], score[
            'PSNR'] = score['SSIM'] / count, score['PSNR'] / count

        print("processed   MAE_in: %.2f  MSE_in: %.2f" %
              (score['MAE'], score['MSE']))
        print("processed   PSNR: %.2f  SSIM: %.2f" %
              (score['PSNR'], score['SSIM']))