def setup(self): """initial the datasets, model, loss and optimizer""" args = self.args if torch.cuda.is_available(): self.device = torch.device("cuda") self.device_count = torch.cuda.device_count() # for code conciseness, we release the single gpu version assert self.device_count == 1 logging.info('using {} gpus'.format(self.device_count)) else: raise Exception("gpu is not available") self.downsample_ratio = args.downsample_ratio self.datasets = { x: Crowd(os.path.join(args.data_dir, x), args.crop_size, args.downsample_ratio, args.is_gray, x) for x in ['train', 'val'] } self.dataloaders = { x: DataLoader(self.datasets[x], collate_fn=(train_collate if x == 'train' else default_collate), batch_size=(args.batch_size if x == 'train' else 1), shuffle=(True if x == 'train' else False), num_workers=args.num_workers * self.device_count, pin_memory=(True if x == 'train' else False)) for x in ['train', 'val'] } self.model = vgg19() self.model.to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay) self.start_epoch = 0 if args.resume: suf = args.resume.rsplit('.', 1)[-1] if suf == 'tar': checkpoint = torch.load(args.resume, self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict( checkpoint['optimizer_state_dict']) self.start_epoch = checkpoint['epoch'] + 1 elif suf == 'pth': self.model.load_state_dict(torch.load(args.resume, self.device)) self.post_prob = Post_Prob(args.sigma, args.crop_size, args.downsample_ratio, args.background_ratio, args.use_background, self.device) self.criterion = Bay_Loss(args.use_background, self.device) self.save_list = Save_Handle(max_num=args.max_model_num) self.best_mae = np.inf self.best_mse = np.inf self.best_mae_1 = np.inf self.best_mse_1 = np.inf self.best_count = 0 self.best_count_1 = 0
def main(args): # use gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu cur_device = torch.device('cuda:{}'.format(args.gpu)) if args.loss == 'bayes': root = '/home/datamining/Datasets/CrowdCounting/sha_bayes_512/' train_path = root + 'train/' test_path = root + 'test/' elif args.bn: root = '/home/datamining/Datasets/CrowdCounting/sha_512_a/' train_path = root + 'train/' test_path = root + 'test/' else: if args.dataset == 'sha': root = '/home/datamining/Datasets/CrowdCounting/shanghaitech/part_A_final/' train_path = root + 'train_data/images' test_path = root + 'test_data/images/' elif args.dataset == 'shb': root = '/home/datamining/Datasets/CrowdCounting/shb_1024_f15/' train_path = root + 'train/' test_path = root + 'test/' elif args.dataset == 'qnrf': root = '/home/datamining/Datasets/CrowdCounting/qnrf_1024_a/' train_path = root + 'train/' test_path = root + 'test/' downsample_ratio = args.downsample train_loader, test_loader, train_img_paths, test_img_paths = get_loader( train_path, test_path, downsample_ratio, args) model_dict = { 'VGG16_13': M_CSRNet, 'DefCcNet': DefCcNet, 'Res50_back3': Res50, 'InceptionV3': Inception3CC, 'CAN': CANNet } model_name = args.model dataset_name = args.dataset net = model_dict[model_name](downsample=args.downsample, bn=args.bn > 0, objective=args.objective, sp=(args.sp > 0), se=(args.se > 0), NL=args.nl) net.cuda() if args.bn > 0: save_name = '{}_{}_{}_bn{}_ps{}_{}'.format(model_name, dataset_name, str(int(args.bn)), str(args.crop_size), args.loss) else: save_name = '{}_d{}{}{}{}{}_{}_{}_cr{}_{}{}{}{}{}{}'.format( model_name, str(args.downsample), '_sp' if args.sp else '', '_se' if args.se else '', '_' + args.nl if args.nl != 'relu' else '', '_vp' if args.val_patch else '', dataset_name, args.crop_mode, str(args.crop_scale), args.loss, '_wu' if args.warm_up else '', '_cl' if args.curriculum == 'W' else '', '_v' + str(int(args.value_factor)) if args.value_factor != 1 else '', '_amp' + str(args.amp_k) if args.objective == 'dmp+amp' else '', '_bg' if args.use_bg else '') save_path = "/home/datamining/Models/CrowdCounting/" + save_name + ".pth" logger = get_logger('logs/' + save_name + '.txt') for k, v in args.__dict__.items(): # save args logger.info("{}: {}".format(k, v)) if os.path.exists(save_path) and args.resume: net.load_state_dict(torch.load(save_path)) print('{} loaded!'.format(save_path)) value_factor = args.value_factor freq = 100 if args.optimizer == 'Adam': optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.decay) elif args.optimizer == 'SGD': # not converage optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.95, weight_decay=args.decay) if args.loss == 'bayes': bayes_criterion = Bay_Loss(True, cur_device) post_prob = Post_Prob(sigma=8.0, c_size=args.crop_size, stride=1, background_ratio=0.15, use_background=True, device=cur_device) else: mse_criterion = nn.MSELoss().cuda() if args.scheduler == 'plt': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=10, verbose=True) elif args.scheduler == 'cos': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0) elif args.scheduler == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.8) elif args.scheduler == 'exp': scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99) elif args.scheduler == 'cyclic' and args.optimizer == 'SGD': scheduler = lr_scheduler.CyclicLR( optimizer, base_lr=args.lr * 0.01, max_lr=args.lr, step_size_up=25, ) elif args.scheduler == 'None': scheduler = None else: print('scheduler name error!') if args.val_patch: best_mae, best_rmse = val_patch(net, test_loader, value_factor) elif args.loss == 'bayes': best_mae, best_rmse = val_bayes(net, test_loader, value_factor) else: best_mae, best_rmse = val(net, test_loader, value_factor) if args.scheduler == 'plt': scheduler.step(best_mae) ssim_loss = pytorch_ssim.SSIM(window_size=11) for epoch in range(args.epochs): if args.crop_mode == 'curriculum': # every 20%, change the dataset if (epoch + 1) % (args.epochs // 5) == 0: print('change dataset') single_dataset = RawDataset( train_img_paths, transform, args.crop_mode, downsample_ratio, args.crop_scale, (epoch + 1.0 + args.epochs // 5) / args.epochs) train_loader = torch.utils.data.DataLoader(single_dataset, shuffle=True, batch_size=1, num_workers=8) train_loss = 0.0 if args.loss == 'bayes': epoch_mae = AverageMeter() epoch_mse = AverageMeter() net.train() if args.warm_up and epoch < args.warm_up_steps: linear_warm_up_lr(optimizer, epoch, args.warm_up_steps, args.lr) for it, data in enumerate(train_loader): if args.loss == 'bayes': inputs, points, targets, st_sizes = data img = inputs.to(cur_device) st_sizes = st_sizes.to(cur_device) gd_count = np.array([len(p) for p in points], dtype=np.float32) points = [p.to(cur_device) for p in points] targets = [t.to(cur_device) for t in targets] else: img, target, _, amp_gt = data img = img.cuda() target = value_factor * target.float().unsqueeze(1).cuda() amp_gt = amp_gt.cuda() #print(img.shape) optimizer.zero_grad() #print(target.shape) if args.objective == 'dmp+amp': output, amp = net(img) output = output * amp else: output = net(img) if args.curriculum == 'W': delta = (output - target)**2 k_w = 2e-3 * args.value_factor * args.downsample**2 b_w = 5e-3 * args.value_factor * args.downsample**2 T = torch.ones_like(target, dtype=torch.float32) * epoch * k_w + b_w W = T / torch.max(T, output) delta = delta * W mse_loss = torch.mean(delta) else: mse_loss = mse_criterion(output, target) if args.loss == 'mse+lc': loss = mse_loss + 1e2 * cal_lc_loss(output, target) * args.downsample elif args.loss == 'ssim': loss = 1 - ssim_loss(output, target) elif args.loss == 'mse+ssim': loss = 100 * mse_loss + 1e-2 * (1 - ssim_loss(output, target)) elif args.loss == 'mse+la': loss = mse_loss + cal_spatial_abstraction_loss(output, target) elif args.loss == 'la': loss = cal_spatial_abstraction_loss(output, target) elif args.loss == 'ms-ssim': #to do pass elif args.loss == 'adversial': # to do pass elif args.loss == 'bayes': prob_list = post_prob(points, st_sizes) loss = bayes_criterion(prob_list, targets, output) else: loss = mse_loss # add the cross entropy loss for attention map if args.objective == 'dmp+amp': cross_entropy = (amp_gt * torch.log(amp) + (1 - amp_gt) * torch.log(1 - amp)) * -1 cross_entropy_loss = torch.mean(cross_entropy) loss = loss + cross_entropy_loss * args.amp_k loss.backward() optimizer.step() data_loss = loss.item() train_loss += data_loss if args.loss == 'bayes': N = inputs.size(0) pre_count = torch.sum(output.view(N, -1), dim=1).detach().cpu().numpy() res = pre_count - gd_count epoch_mse.update(np.mean(res * res), N) epoch_mae.update(np.mean(abs(res)), N) if args.loss != 'bayes' and it % freq == 0: print( '[ep:{}], [it:{}], [loss:{:.8f}], [output:{:.2f}, target:{:.2f}]' .format(epoch + 1, it, data_loss, output[0].sum().item(), target[0].sum().item())) if args.val_patch: mae, rmse = val_patch(net, test_loader, value_factor) elif args.loss == 'bayes': mae, rmse = val_bayes(net, test_loader, value_factor) else: mae, rmse = val(net, test_loader, value_factor) if not (args.warm_up and epoch < args.warm_up_steps): if args.scheduler == 'plt': scheduler.step(best_mae) elif args.scheduler != 'None': scheduler.step() if mae + 0.1 * rmse < best_mae + 0.1 * best_rmse: best_mae, best_rmse = mae, rmse torch.save(net.state_dict(), save_path) if args.loss == 'bayes': logger.info( '{} Epoch {}/{} Loss:{:.8f},MAE:{:.2f},RMSE:{:.2f} lr:{:.8f}, [CUR]:{mae:.1f}, {rmse:.1f}, [Best]:{b_mae:.1f}, {b_rmse:.1f}' .format(model_name, epoch + 1, args.epochs, train_loss / len(train_loader), epoch_mae.get_avg(), np.sqrt(epoch_mse.get_avg()), optimizer.param_groups[0]['lr'], mae=mae, rmse=rmse, b_mae=best_mae, b_rmse=best_rmse)) else: logger.info( '{} Epoch {}/{} Loss:{:.8f}, lr:{:.8f}, [CUR]:{mae:.1f}, {rmse:.1f}, [Best]:{b_mae:.1f}, {b_rmse:.1f}' .format(model_name, epoch + 1, args.epochs, train_loss / len(train_loader), optimizer.param_groups[0]['lr'], mae=mae, rmse=rmse, b_mae=best_mae, b_rmse=best_rmse))