def __init__(self, cfg_data, pwd): self.cfg_data = cfg_data self.pwd = pwd self.exp_path = cfg.EXP_PATH self.exp_name = cfg.EXP_NAME self.exp_path = osp.join(self.exp_path, 'fine_tune') if not osp.exists(self.exp_path): os.mkdir(self.exp_path) self.sou_query_loader, self.tar_shot_loader, self.tar_val_loader, self.tar_test_loader,self.restore_transform = loading_data(cfg) self.sou_model = NLT_Counter( mode='fine_tune', backbone=cfg.model_type) self.sou_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.sou_model.parameters()), lr = cfg.fine_lr, weight_decay=cfg.fine_weight_decay) self.sou_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.sou_optimizer, step_size=cfg.fine_step_size, gamma=cfg.fine_gamma) if cfg.GCC_pre_train_model is not None: print('load GCC pre_trained model') self.pretrained_dict = torch.load(cfg.GCC_pre_train_model) self.sou_model.load_state_dict(self.pretrained_dict) self.sou_model = torch.nn.DataParallel(self.sou_model).cuda() self.sou_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0, "temp_test_mae": 1e20, "temp_test_mse": 1e20} self.epoch = 0 self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, ["exp"])
def __init__(self, model_path, tar_list, tarRoot, cfg_data, img_transform): self.cfg_data = cfg_data self.img_transform = img_transform self.tarRoot = tarRoot self.net = NLT_Counter(mode='nlt', backbone=cfg.model_type) self.net.load_state_dict(torch.load(model_path)) self.net = torch.nn.DataParallel(self.net).cuda() self.net.eval() with open(tar_list) as f: lines = f.readlines() self.tar_list = [] for line in lines: line = line.strip('\n') self.tar_list.append(line)
def __init__(self, cfg, pwd): # Set the folder to save the records and checkpoints # Set cfg to be shareable in the class self.cfg_data = cfg_data self.pwd = pwd self.exp_path = cfg.EXP_PATH self.exp_name = cfg.EXP_NAME self.exp_path = osp.join(self.exp_path, 'pre') self.train_loader, self.val_loader, self.restore_transform = loading_data( cfg) self.model = NLT_Counter(mode='pre', backbone=cfg.model_type) if cfg.init_weights is not None: self.pretrained_dict = torch.load(cfg.init_weights) # ['params'] self.model.load_state_dict(self.pretrained_dict) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.pre_lr, weight_decay=cfg.pre_weight_decay) self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=cfg.pre_step_size, gamma=cfg.pre_gamma) if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id self.model = torch.nn.DataParallel(self.model).cuda() self.record = {} self.record['train_loss'] = [] self.record['train_mae'] = [] self.record['train_mse'] = [] self.record['val_loss'] = [] self.record['val_mae'] = [] self.record['val_mse'] = [] self.record['best_mae'] = 1e10 self.record['best_mse'] = 1e10 self.record['best_model_name'] = '' self.record['update_flag'] = 0 self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, ["exp"])
def __init__(self, cfg_data, pwd): self.cfg_data = cfg_data self.pwd = pwd self.exp_path = cfg.EXP_PATH self.exp_name = cfg.EXP_NAME if not osp.exists(self.exp_path): os.makedirs(self.exp_path) self.sou_loader, self.tar_shot_loader, self.tar_val_loader, self.tar_test_loader, self.restore_transform = loading_data(cfg) self.sou_model = NLT_Counter( backbone=cfg.model_type) self.tar_model = NLT_Counter( mode='nlt', backbone=cfg.model_type) self.sou_optimizer = torch.optim.Adam(self.sou_model.parameters(), lr = cfg.nlt_lr, weight_decay=cfg.nlt_lr_decay) self.tar_optimizer = torch.optim.Adam( [{'params': filter(lambda p: p.requires_grad, self.tar_model.encoder.parameters()), 'lr': cfg.nlt_lr}, \ {'params': filter(lambda p: p.requires_grad, self.tar_model.decoder.parameters()), 'lr':cfg.nlt_lr}]) self.sou_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.sou_optimizer, step_size=cfg.step_size, gamma=cfg.gamma) self.tar_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.tar_optimizer, step_size=cfg.step_size, gamma=cfg.gamma) # if cfg.init_weights is not None: self.pretrained_dict = torch.load(cfg.init_weights) # ['params'] self.sou_model.load_state_dict(self.pretrained_dict,strict=False) self.tar_model.load_state_dict(self.pretrained_dict,strict=False) os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id self.sou_model = torch.nn.DataParallel(self.sou_model).cuda() self.tar_model = torch.nn.DataParallel(self.tar_model).cuda() self.tar_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0, "temp_test_mae": 1e20, "temp_test_mse": 1e20} self.sou_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0, "temp_test_mae": 1e20, "temp_test_mse": 1e20 } self.epoch = 0 self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, ["exp"])
class NLT_Trainer(object): def __init__(self, cfg_data, pwd): self.cfg_data = cfg_data self.pwd = pwd self.exp_path = cfg.EXP_PATH self.exp_name = cfg.EXP_NAME if not osp.exists(self.exp_path): os.makedirs(self.exp_path) self.sou_loader, self.tar_shot_loader, self.tar_val_loader, self.tar_test_loader, self.restore_transform = loading_data(cfg) self.sou_model = NLT_Counter( backbone=cfg.model_type) self.tar_model = NLT_Counter( mode='nlt', backbone=cfg.model_type) self.sou_optimizer = torch.optim.Adam(self.sou_model.parameters(), lr = cfg.nlt_lr, weight_decay=cfg.nlt_lr_decay) self.tar_optimizer = torch.optim.Adam( [{'params': filter(lambda p: p.requires_grad, self.tar_model.encoder.parameters()), 'lr': cfg.nlt_lr}, \ {'params': filter(lambda p: p.requires_grad, self.tar_model.decoder.parameters()), 'lr':cfg.nlt_lr}]) self.sou_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.sou_optimizer, step_size=cfg.step_size, gamma=cfg.gamma) self.tar_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.tar_optimizer, step_size=cfg.step_size, gamma=cfg.gamma) # if cfg.init_weights is not None: self.pretrained_dict = torch.load(cfg.init_weights) # ['params'] self.sou_model.load_state_dict(self.pretrained_dict,strict=False) self.tar_model.load_state_dict(self.pretrained_dict,strict=False) os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id self.sou_model = torch.nn.DataParallel(self.sou_model).cuda() self.tar_model = torch.nn.DataParallel(self.tar_model).cuda() self.tar_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0, "temp_test_mae": 1e20, "temp_test_mse": 1e20} self.sou_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0, "temp_test_mae": 1e20, "temp_test_mse": 1e20 } self.epoch = 0 self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, ["exp"]) def forward(self): timer = Timer() self.global_count = 0 for epoch in range(1, cfg.max_epoch + 1): self.epoch = epoch self.train() self.fine_tune() if self.epoch % cfg.val_freq == 0: if cfg.target_dataset is "WE": self.tar_model_V2(self.tar_val_loader, "val") if cfg.target_dataset in ["VENICE", "QNRF", "SHHA", "SHHB", "MALL", "UCSD"]: self.tar_model_V1(self.tar_val_loader, "val") print('=' * 50) print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(self.epoch / cfg.max_epoch))) self.sou_lr_scheduler.step() self.tar_lr_scheduler.step() self.writer.close() def train(self): self.sou_model.train() self.tar_model.train() train_loss = AverageMeter() train_mae = AverageMeter() train_mse = AverageMeter() shot_loss = AverageMeter() shot_mae = AverageMeter() shot_mse = AverageMeter() for i, (a, b) in enumerate(zip(self.sou_loader, self.tar_shot_loader), 1): self.global_count = self.global_count + 1 sou_img,sou_label = a[0].cuda(),a[1].cuda() shot_img, shot_label = b[0].cuda(), b[1].cuda() if self.epoch <cfg.DA_stop_epoch: # ==================change sou_model parameters=============== sou_pred = self.sou_model(sou_img) loss = F.mse_loss(sou_pred.squeeze(), sou_label.squeeze()) self.sou_optimizer.zero_grad() loss.backward() self.sou_optimizer.step() self.writer.add_scalar('data/sou_loss', loss.item(), self.global_count) train_loss.update(loss.item()) sou_pred_cnt, sou_label_cnt = self.mae_mse_update(sou_pred, sou_label, train_mae, train_mse) self.tar_model.load_state_dict(self.sou_model.state_dict(), strict=False) # ================================================================ else: sou_label_cnt=0 sou_pred_cnt=0 #=====================change tar_model parameters================ shot_pred = self.tar_model(shot_img) loss_mse = F.mse_loss(shot_pred.squeeze(), shot_label.squeeze()) loss = self.weight_decay_loss(self.tar_model, 1e-4) + loss_mse self.tar_optimizer.zero_grad() loss.backward() self.tar_optimizer.step() self.writer.add_scalar('data/shot_loss', loss.item(), self.global_count) shot_loss.update(loss.item()) pred_cnt, label_cnt = self.mae_mse_update(shot_pred, shot_label, shot_mae, shot_mse) # =============================================================== if i % cfg.print_freq == 0: print('Epoch {}, Loss={:.4f} s_gt={:.1f} s_pre={:.1f},t_gt={:.1f} t_pre={:.1f} lr={:.4f}'.format( self.epoch, loss.item(), sou_label_cnt,sou_pred_cnt,label_cnt, pred_cnt, self.sou_optimizer.param_groups[0]['lr']*10000)) self.writer.add_scalar('data/train_loss_tar', float(shot_loss.avg), self.epoch) self.writer.add_scalar('data/train_mae_tar', float(shot_mae.avg), self.epoch) self.writer.add_scalar('data/train_mse_tar', float( np.sqrt(shot_mse.avg)), self.epoch) self.writer.add_scalar('data/train_loss_sou', float(train_loss.avg), self.epoch) self.writer.add_scalar('data/train_mae_sou', float(train_mae.avg), self.epoch) self.writer.add_scalar('data/train_mse_sou', float(np.sqrt(train_mse.avg)), self.epoch) # Start validation for this epoch, set model to eval mode def fine_tune(self): for i, (shot_img, shot_label) in enumerate(self.tar_shot_loader, 1): if i <= 50: shot_img = shot_img.cuda() shot_label = shot_label.cuda() shot_pred = self.tar_model(shot_img) loss_mse = F.mse_loss(shot_pred.squeeze(), shot_label.squeeze()) loss = self.weight_decay_loss(self.tar_model, 1e-4) + loss_mse self.tar_optimizer.zero_grad() loss.backward() self.tar_optimizer.step() else: break def tar_model_V2(self, dataset, mode=None):# Run meta-validatio self.tar_model.eval() losses = AverageCategoryMeter(5) maes = AverageCategoryMeter(5) val_losses = AverageMeter() val_maes = AverageMeter() if mode =='val' : for i, batch in enumerate(dataset, 1): with torch.no_grad(): img = batch[0].cuda() label = batch[1].cuda() pred = self.tar_model(img) self.mae_mse_update(pred, label, val_maes, losses=val_losses) mae = np.average(val_maes.avg) loss = np.average(val_losses.avg) self.writer.add_scalar('data/val_mae', mae, self.epoch) self.writer.add_scalar('data/val_loss',loss, self.epoch) self.tar_model_record = update_model( self.tar_model.module, self.epoch, self.exp_path, self.exp_name, [mae, 0, loss], self.tar_model_record, self.log_txt) print_summary(self.exp_name, [mae, 0, loss], self.tar_model_record) else: for i_sub, i_loader in enumerate(dataset, 0): for i, batch in enumerate(i_loader, 1): with torch.no_grad(): img = batch[0].cuda() label = batch[1].cuda() pred = self.tar_model(img) self.mae_mse_update(pred,label,maes=maes,losses=losses,cls_id=i_sub) if i == 1 and self.epoch%10==0: vis_results(self.epoch, self.writer, self.restore_transform, img, pred.data.cpu().numpy(), label.data.cpu().numpy(), self.exp_name) mae = np.average(maes.avg) loss = np.average(losses.avg) self.writer.add_scalar("data/mae_s1", maes.avg[0], self.epoch) self.writer.add_scalar("data/mae_s2", maes.avg[1], self.epoch) self.writer.add_scalar("data/mae_s3", maes.avg[2], self.epoch) self.writer.add_scalar("data/mae_s4", maes.avg[3], self.epoch) self.writer.add_scalar("data/mae_s5", maes.avg[4], self.epoch) self.writer.add_scalar("data/test_mae", float(mae), self.epoch) self.writer.add_scalar('data/test_loss', float(loss), self.epoch) logger_txt(self.log_txt, self.epoch, [mae, 0, loss]) self.tar_model_record['temp_test_mae'] = mae self.tar_model_record['temp_test_mse'] = 0 # Print loss and maeuracy for this epoch def tar_model_V1(self, dataset, mode=None): self.tar_model.eval() losses = AverageMeter() maes = AverageMeter() mses = AverageMeter() ssims = AverageMeter() psnrs = AverageMeter() # tqdm_gen = tqdm.tqdm(dataset) for i, batch in enumerate(dataset, 1): with torch.no_grad(): img = batch[0].cuda() label = batch[1].cuda() pred = self.tar_model(img) if mode == 'test': self.mae_mse_update(pred, label, maes, mses, ssims,psnrs,losses) else: self.mae_mse_update(pred, label, maes, mses, losses=losses) if i == 1 and self.epoch%10==0: vis_results(self.epoch, self.writer, self.restore_transform, img, pred.data.cpu().numpy(), label.cpu().detach().numpy(), self.exp_name) mae = maes.avg mse = np.sqrt(mses.avg) loss = losses.avg if mode == "val": self.writer.add_scalar('data/val_mae', mae, self.epoch) self.writer.add_scalar('data/val_mse', mse, self.epoch) self.writer.add_scalar('data/val_loss',loss, self.epoch) self.tar_model_record = update_model( self.tar_model.module, self.epoch, self.exp_path, self.exp_name, [mae, mse, loss], self.tar_model_record, self.log_txt) print_summary(self.exp_name, [mae, mse, loss], self.tar_model_record) elif mode == "test": self.writer.add_scalar('data/test_mae', mae, self.epoch) self.writer.add_scalar('data/test_mse', mse, self.epoch) self.writer.add_scalar('data/test_loss',loss, self.epoch) self.writer.add_scalar("data/test_ssim", ssims.avg, self.epoch) self.writer.add_scalar("data/test_psnr", psnrs.avg, self.epoch) self.tar_model_record['temp_test_mae'] = mae self.tar_model_record['temp_test_mse'] = mse logger_txt(self.log_txt, self.epoch, [mae, mse, loss]) def weight_decay_loss(self,model, lamda): loss_weight = 0 loss_bias = 0 for name, param in model.named_parameters(): if 'nlt_weight' in name: loss_weight += 0.5 * torch.sum(torch.pow(param - 1, 2)) elif 'nlt_bias' in name: loss_bias += 0.5 * torch.sum(torch.pow(param, 2)) return lamda*loss_weight + lamda*10*loss_bias def mae_mse_update(self,pred,label,maes,mses=None,ssims=None,psnrs=None,losses=None,cls_id=None): for num in range(pred.size()[0]): sub_pred = pred[num].data.cpu().squeeze().numpy()/ self.cfg_data.LOG_PARA sub_label = label[num].data.cpu().squeeze().numpy() / self.cfg_data.LOG_PARA pred_cnt = np.sum(sub_pred) gt_cnt = np.sum(sub_label) mae = abs(pred_cnt - gt_cnt) mse = (pred_cnt - gt_cnt)*(pred_cnt - gt_cnt) if ssims and psnrs is not None: ssims.update(get_ssim(sub_label,sub_pred)) psnrs.update(get_psnr(sub_label,sub_pred)) if cls_id is not None: maes.update(mae,cls_id) if losses is not None: loss = F.mse_loss(pred.detach().squeeze(), label.detach().squeeze()) losses.update(loss.item(),cls_id) if mses is not None: mses.update(mse,cls_id) else: maes.update(mae) if losses is not None: loss = F.mse_loss(pred.detach().squeeze(), label.detach().squeeze()) losses.update(loss.item()) if mses is not None: mses.update(mse) return pred_cnt,gt_cnt
class PreTrainer(object): """The class that contains the code for the pretrain phase.""" def __init__(self, cfg, pwd): # Set the folder to save the records and checkpoints # Set cfg to be shareable in the class self.cfg_data = cfg_data self.pwd = pwd self.exp_path = cfg.EXP_PATH self.exp_name = cfg.EXP_NAME self.exp_path = osp.join(self.exp_path, 'pre') self.train_loader, self.val_loader, self.restore_transform = loading_data( cfg) self.model = NLT_Counter(mode='pre', backbone=cfg.model_type) if cfg.init_weights is not None: self.pretrained_dict = torch.load(cfg.init_weights) # ['params'] self.model.load_state_dict(self.pretrained_dict) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.pre_lr, weight_decay=cfg.pre_weight_decay) self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=cfg.pre_step_size, gamma=cfg.pre_gamma) if torch.cuda.is_available(): torch.backends.cudnn.benchmark = True os.environ["CUDA_VISIBLE_DEVICES"] = cfg.gpu_id self.model = torch.nn.DataParallel(self.model).cuda() self.record = {} self.record['train_loss'] = [] self.record['train_mae'] = [] self.record['train_mse'] = [] self.record['val_loss'] = [] self.record['val_mae'] = [] self.record['val_mse'] = [] self.record['best_mae'] = 1e10 self.record['best_mse'] = 1e10 self.record['best_model_name'] = '' self.record['update_flag'] = 0 self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, ["exp"]) def save_model(self, name): torch.save(dict(params=self.model.module.state_dict()), osp.join(self.exp_path, self.exp_name, name + '.pth')) def train(self): """The function for the pre_train on GCC dataset.""" # Set the timer timer = Timer() # Set global count to zero global_count = 0 for epoch in range(1, cfg.pre_max_epoch + 1): self.model.train() train_loss_avg = Averager() train_mae_avg = Averager() train_mse_avg = Averager() # Using tqdm to read samples from train loader tqdm_gen = tqdm.tqdm(self.train_loader) for i, batch in enumerate(tqdm_gen, 1): global_count = global_count + 1 img = batch[0].cuda() label = batch[1].cuda() pred = self.model(img) loss = F.mse_loss(pred.squeeze(), label) # Print loss and maeuracy for this step label_cnt = label.sum().data / self.cfg_data.LOG_PARA pred_cnt = pred.sum().data / self.cfg_data.LOG_PARA mae = torch.abs(label_cnt - pred_cnt).item() mse = (label_cnt - pred_cnt).pow(2).item() tqdm_gen.set_description( 'Epoch {}, Loss={:.4f} gt={:.1f} pred={:.1f} lr={:.4f}'. format(epoch, loss.item(), label_cnt, pred_cnt, self.optimizer.param_groups[0]['lr'] * 10000)) # # Add loss and maeuracy for the averagers train_loss_avg.add(loss.item()) train_mae_avg.add(mae) train_mse_avg.add(mse) # Loss backwards and optimizer updates self.optimizer.zero_grad() loss.backward() self.optimizer.step() # Update the averagers train_loss_avg = train_loss_avg.item() train_mae_avg = train_mae_avg.item() train_mse_avg = np.sqrt(train_mse_avg.item()) self.writer.add_scalar('data/loss', train_loss_avg, global_count) self.writer.add_scalar('data/mae', train_mae_avg, global_count) self.writer.add_scalar('data/mse', train_mse_avg, global_count) # Start validation for this epoch, set model to eval mode self.model.eval() val_loss_avg = Averager() val_mae_avg = Averager() val_mse_avg = Averager() # Print previous information if epoch % 10 == 0: print('Best Epoch {}, Best Val mae={:.2f} mae={:.2f}'.format( self.record['best_model_name'], self.record['best_mae'], self.record['best_mse'])) # Run validation for i, batch in enumerate(self.val_loader, 1): # print(i) with torch.no_grad(): data = batch[0].cuda() label = batch[1].cuda() pred = self.model(inp=data) loss = F.mse_loss(pred.squeeze(), label) val_loss_avg.add(loss.item()) for img in range(pred.size()[0]): pred_cnt = (pred[img] / self.cfg_data.LOG_PARA).sum().data gt_cnt = (label[img] / self.cfg_data.LOG_PARA).sum().data mae = torch.abs(pred_cnt - gt_cnt).item() mse = (pred_cnt - gt_cnt).pow(2).item() val_mae_avg.add(mae) val_mse_avg.add(mse) # Update validation averagers val_loss_avg = val_loss_avg.item() val_mae_avg = val_mae_avg.item() val_mse_avg = np.sqrt(val_mse_avg.item()) self.writer.add_scalar('data/val_loss', float(val_loss_avg), epoch) self.writer.add_scalar('data/val_mae', float(val_mae_avg), epoch) self.writer.add_scalar('data/val_mse', float(val_mse_avg), epoch) # Print loss and maeuracy for this epoch print('Epoch {}, Val, Loss={:.4f} mae={:.4f} mse={:.4f}'.format( epoch, val_loss_avg, val_mae_avg, val_mse_avg)) # Save model every 10 epochs if epoch % 10 == 0: self.save_model('epoch' + str(epoch) + '_' + str(val_mae_avg)) # Update the logs self.record['train_loss'].append(train_loss_avg) self.record['train_mae'].append(train_mae_avg) self.record['train_mse'].append(train_mse_avg) self.record['val_loss'].append(val_loss_avg) self.record['val_mae'].append(val_mae_avg) self.record = update_model( self.model.module, epoch, self.exp_path, self.exp_name, [val_mae_avg, val_mse_avg, val_loss_avg], self.record, self.log_txt) if epoch % 10 == 0: print('Running Time: {}, Estimated Time: {}'.format( timer.measure(), timer.measure(epoch / cfg.max_epoch))) self.lr_scheduler.step() self.writer.close()
class Fine_tune_Trainer(object): def __init__(self, cfg_data, pwd): self.cfg_data = cfg_data self.pwd = pwd self.exp_path = cfg.EXP_PATH self.exp_name = cfg.EXP_NAME self.exp_path = osp.join(self.exp_path, 'fine_tune') if not osp.exists(self.exp_path): os.mkdir(self.exp_path) self.sou_query_loader, self.tar_shot_loader, self.tar_val_loader, self.tar_test_loader,self.restore_transform = loading_data(cfg) self.sou_model = NLT_Counter( mode='fine_tune', backbone=cfg.model_type) self.sou_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.sou_model.parameters()), lr = cfg.fine_lr, weight_decay=cfg.fine_weight_decay) self.sou_lr_scheduler = torch.optim.lr_scheduler.StepLR(self.sou_optimizer, step_size=cfg.fine_step_size, gamma=cfg.fine_gamma) if cfg.GCC_pre_train_model is not None: print('load GCC pre_trained model') self.pretrained_dict = torch.load(cfg.GCC_pre_train_model) self.sou_model.load_state_dict(self.pretrained_dict) self.sou_model = torch.nn.DataParallel(self.sou_model).cuda() self.sou_model_record = {"best_mae": 1e20, "best_mse": 1e20, "best_model_name": "", "update_flag": 0, "temp_test_mae": 1e20, "temp_test_mse": 1e20} self.epoch = 0 self.writer, self.log_txt = logger(self.exp_path, self.exp_name, self.pwd, ["exp"]) def forward(self): timer = Timer() self.global_count = 0 for epoch in range(1, cfg.max_epoch + 1): self.train() self.epoch = epoch if self.epoch % cfg.val_freq == 0: self.sou_model_V1(self.tar_val_loader, "val") print('=' * 50) print('Running Time: {}, Estimated Time: {}'.format(timer.measure(), timer.measure(self.epoch / cfg.max_epoch))) self.sou_lr_scheduler.step() self.writer.close() def train(self): self.sou_model.train() train_loss = AverageMeter() train_mae = AverageMeter() train_mse = AverageMeter() for i, (img,gt_map) in enumerate( self.tar_shot_loader, 1): self.global_count = self.global_count + 1 shot_img, shot_label = img.cuda(), gt_map.cuda() # ==================change sou_model parameters=============== shot_pred = self.sou_model(shot_img) loss = F.mse_loss(shot_pred.squeeze(), shot_label.squeeze()) self.sou_optimizer.zero_grad() loss.backward() self.sou_optimizer.step() train_loss.update(loss.item()) self.writer.add_scalar('data/fine_tune_loss', float(loss), self.global_count) sou_pred_cnt, sou_label_cnt = self.mae_mse_update(shot_pred, shot_label, train_mae, train_mse) # =============================================================== if i % 50 == 0: print('Epoch {}, Loss={:.4f} s_gt={:.1f} s_pre={:.1f}'.format( self.epoch, loss.item(), sou_label_cnt,sou_pred_cnt)) self.writer.add_scalar('data/train_loss_tar', float(train_loss.avg), self.epoch) self.writer.add_scalar('data/train_mae_tar', float(train_mae.avg), self.epoch) self.writer.add_scalar('data/train_mse_tar', float(np.sqrt(train_mse.avg)), self.epoch) # Start validation for this epoch, set model to eval mode def validation(self):# Run meta-validation self.sou_model.eval() if cfg.target_dataset in ["WE", "SHFD"]: val_loss =AverageCategoryMeter(5) val_mae = AverageCategoryMeter(5) # self.tar_model.eval() for i_sub, i_loader in enumerate(self.tar_val_loader, 0): tqdm_gen = tqdm.tqdm(i_loader) for i, batch in enumerate(tqdm_gen, 1): img = batch[0].cuda() gt_map = batch[1].cuda() with torch.no_grad(): pred = self.sou_model(inp=img) self.mae_mse_update(pred, gt_map, val_mae,losses=val_loss,cls_id=i_sub) if i == 1 : vis_results(self.epoch, self.writer, self.restore_transform, img, pred.data.cpu().numpy(), gt_map.data.cpu().numpy(), 'temp_val/sou') mae = np.average(val_mae.avg) loss = np.average(val_loss.avg) self.writer.add_scalar("data/mae_s1", val_mae.avg[0], self.epoch) self.writer.add_scalar("data/mae_s2", val_mae.avg[1], self.epoch) self.writer.add_scalar("data/mae_s3", val_mae.avg[2], self.epoch) self.writer.add_scalar("data/mae_s4", val_mae.avg[3], self.epoch) self.writer.add_scalar("data/mae_s5", val_mae.avg[4], self.epoch) self.writer.add_scalar("data/tar_val_mae", float(mae), self.epoch) self.writer.add_scalar('data/tar_val_loss', float(loss), self.epoch) # Print loss and maeuracy for this epoch self.record = update_model( self.sou_model.module, self.epoch, self.exp_path, self.exp_name, [mae, 0, loss], self.record, self.log_txt) print('Epoch {}, Val, mae={:.2f} mse={:.2f}'.format(self.epoch, mae, 0)) self.record['val_loss'].append(loss) self.record['val_mae'].append(mae) def sou_model_V1(self, dataset, mode=None): self.sou_model.eval() losses = AverageMeter() maes = AverageMeter() mses = AverageMeter() ssims = AverageMeter() psnrs = AverageMeter() # tqdm_gen = tqdm.tqdm(dataset) for i, batch in enumerate(dataset, 1): with torch.no_grad(): img = batch[0].cuda() label = batch[1].cuda() pred = self.sou_model(img) if mode == 'test': self.mae_mse_update(pred, label, maes, mses, ssims,psnrs,losses) else: self.mae_mse_update(pred, label, maes, mses, losses=losses) if i == 1 and self.epoch%10==0: vis_results(self.epoch, self.writer, self.restore_transform, img, pred.data.cpu().numpy(), label.cpu().detach().numpy(), self.exp_name) mae = maes.avg mse = np.sqrt(mses.avg) loss = losses.avg if mode == "val": self.writer.add_scalar('data/val_mae', mae, self.epoch) self.writer.add_scalar('data/val_mse', mse, self.epoch) self.writer.add_scalar('data/val_loss',loss, self.epoch) self.tar_model_record = update_model( self.sou_model.module, self.epoch, self.exp_path, self.exp_name, [mae, mse, loss], self.sou_model_record, self.log_txt) print_summary(self.exp_name, [mae, mse, loss], self.sou_model_record) elif mode == "test": self.writer.add_scalar('data/test_mae', mae, self.epoch) self.writer.add_scalar('data/test_mse', mse, self.epoch) self.writer.add_scalar('data/test_loss',loss, self.epoch) self.writer.add_scalar("data/test_ssim", ssims.avg, self.epoch) self.writer.add_scalar("data/test_psnr", psnrs.avg, self.epoch) self.tar_model_record['temp_test_mae'] = mae self.tar_model_record['temp_test_mse'] = mse logger_txt(self.log_txt, self.epoch, [mae, mse, loss]) def weight_decay_loss(self,model, lamda): loss_weight = 0 loss_bias = 0 for name, param in model.named_parameters(): if 'mtl_weight' in name: loss_weight += 0.5 * torch.sum(torch.pow(param - 1, 2)) elif 'mtl_bias' in name: loss_bias += 0.5 * torch.sum(torch.pow(param,2)) return lamda*loss_weight + lamda*loss_bias def mae_mse_update(self,pred,label,maes,mses=None,ssims=None,psnrs=None,losses=None,cls_id=None): for num in range(pred.size()[0]): sub_pred = pred[num].data.cpu().squeeze().numpy()/ self.cfg_data.LOG_PARA sub_label = label[num].data.cpu().squeeze().numpy() / self.cfg_data.LOG_PARA pred_cnt = np.sum(sub_pred) gt_cnt = np.sum(sub_label) mae = abs(pred_cnt - gt_cnt) mse = (pred_cnt - gt_cnt)*(pred_cnt - gt_cnt) if ssims and psnrs is not None: ssims.update(get_ssim(sub_label,sub_pred)) psnrs.update(get_psnr(sub_label,sub_pred)) if cls_id is not None: maes.update(mae,cls_id) if losses is not None: loss = F.mse_loss(pred.detach().squeeze(), label.detach().squeeze()) losses.update(loss.item(),cls_id) if mses is not None: mses.update(mse,cls_id) else: maes.update(mae) if losses is not None: loss = F.mse_loss(pred.detach().squeeze(), label.detach().squeeze()) losses.update(loss.item()) if mses is not None: mses.update(mse) return pred_cnt,gt_cnt
class den_test: def __init__(self, model_path, tar_list, tarRoot, cfg_data, img_transform): self.cfg_data = cfg_data self.img_transform = img_transform self.tarRoot = tarRoot self.net = NLT_Counter(mode='nlt', backbone=cfg.model_type) self.net.load_state_dict(torch.load(model_path)) self.net = torch.nn.DataParallel(self.net).cuda() self.net.eval() with open(tar_list) as f: lines = f.readlines() self.tar_list = [] for line in lines: line = line.strip('\n') self.tar_list.append(line) def forward(self): score = {'MAE': 0, 'MSE': 0, 'PSNR': 0, 'SSIM': 0} count = 0 tar_list = tqdm.tqdm(self.tar_list) for fname in tar_list: count += 1 imgname = os.path.join(self.tarRoot + "/train/img/" + fname + '.jpg') # filename_no_ext = filename.split('.')[0] denname = imgname.replace('img', 'den').replace('jpg', 'csv') # denname = os.path.join(self.tarRoot + "/test/den/" + fname + ".csv") den = pd.read_csv(denname, sep=',', header=None).values den = den.astype(np.float32, copy=False) img = Image.open(imgname) if img.mode == 'L': img = img.convert('RGB') img = self.img_transform(img) gt = np.sum(den) img = img[None, :, :, :].cuda() pred_map = self.net(img) pred_map = pred_map.cpu().data.numpy()[0, 0, :, :] pred = np.sum(pred_map) / self.cfg_data.LOG_PARA score['MAE'] += np.abs(gt - pred) score['MSE'] += (gt - pred) * (gt - pred) score['SSIM'] += get_ssim(den, pred_map) score['PSNR'] += get_psnr(den, pred_map) pred_map = pred_map / np.max(pred_map + 1e-20) den = den / np.max(den + 1e-20) den_frame = plt.gca() plt.imshow(den, cmap='jet') den_frame.axes.get_yaxis().set_visible(False) den_frame.axes.get_xaxis().set_visible(False) den_frame.spines['top'].set_visible(False) den_frame.spines['bottom'].set_visible(False) den_frame.spines['left'].set_visible(False) den_frame.spines['right'].set_visible(False) plt.savefig(den_path+'/'+fname+'_gt_'+str(int(gt))+'.png',\ bbox_inches='tight',pad_inches=0,dpi=600) plt.close() # sio.savemat(exp_name+'/'+filename_no_ext+'_gt_'+str(int(gt))+'.mat',{'data':den}) pred_frame = plt.gca() plt.imshow(pred_map, cmap='jet') pred_frame.axes.get_yaxis().set_visible(False) pred_frame.axes.get_xaxis().set_visible(False) pred_frame.spines['top'].set_visible(False) pred_frame.spines['bottom'].set_visible(False) pred_frame.spines['left'].set_visible(False) pred_frame.spines['right'].set_visible(False) plt.savefig(den_path+'/'+fname+'_DA_'+str(float(pred))+'.png',\ bbox_inches='tight',pad_inches=0,dpi=600) plt.close() score['MAE'], score['MSE'] = score['MAE'] / count, np.sqrt( score['MSE'] / count) score['SSIM'], score[ 'PSNR'] = score['SSIM'] / count, score['PSNR'] / count print("processed MAE_in: %.2f MSE_in: %.2f" % (score['MAE'], score['MSE'])) print("processed PSNR: %.2f SSIM: %.2f" % (score['PSNR'], score['SSIM']))