class Session: def __init__(self): self.log_dir = settings.log_dir self.model_dir = settings.model_dir ensure_dir(settings.log_dir) ensure_dir(settings.model_dir) logger.info('set log dir as %s' % settings.log_dir) logger.info('set model dir as %s' % settings.model_dir) if torch.cuda.is_available(): self.net = Net().cuda() self.dis_rain_img = Discriminator_rain_img().cuda() self.dis_img = Discriminator_img().cuda() if len(device_ids) > 1: self.net = nn.DataParallel(Net()).cuda() self.dis_rain_img = nn.DataParallel( Discriminator_rain_img()).cuda() self.dis_img = nn.DataParallel(Discriminator_img()).cuda() self.opt_net = Adam(self.net.parameters(), lr=settings.lr) self.sche_net = MultiStepLR(self.opt_net, milestones=[settings.l1, settings.l2], gamma=0.1) self.opt_dis_rain_img = Adam(self.dis_rain_img.parameters(), lr=settings.lr) self.sche_dis_rain_img = MultiStepLR( self.opt_dis_rain_img, milestones=[settings.l1, settings.l2], gamma=0.1) self.opt_dis_img = Adam(self.dis_rain_img.parameters(), lr=settings.lr) self.sche_dis_img = MultiStepLR(self.opt_dis_rain_img, milestones=[settings.l1, settings.l2], gamma=0.1) self.l2 = MSELoss().cuda() self.l1 = nn.L1Loss().cuda() self.ssim = SSIM().cuda() self.vgg = VGG().cuda() self.dataloaders = {} def get_dataloader(self, dataset_name): dataset = TestDataset(dataset_name) if not dataset_name in self.dataloaders: self.dataloaders[dataset_name] = \ DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, drop_last=False) return self.dataloaders[dataset_name] def load_checkpoints_net(self, name): ckp_path = os.path.join(self.model_dir, name) try: logger.info('Load checkpoint %s' % ckp_path) obj = torch.load(ckp_path) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.net.load_state_dict(obj['net']) self.opt_net.load_state_dict(obj['opt_net']) self.step = obj['clock_net'] self.sche_net.last_epoch = self.step def load_checkpoints_dis_rain_img(self, name): ckp_path = os.path.join(self.model_dir, name) try: logger.info('Load checkpoint %s' % ckp_path) obj = torch.load(ckp_path, map_location={ 'cuda:1': 'cuda:0', 'cuda:2': 'cuda:0', 'cuda:3': 'cuda:0', 'cuda:4': 'cuda:0', 'cuda:5': 'cuda:0', 'cuda:6': 'cuda:0', 'cuda:7': 'cuda:0' }) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.dis_rain_img.load_state_dict(obj['dis_rain_img']) self.opt_dis_rain_img.load_state_dict(obj['opt_dis_rain_img']) self.step = obj['clock_dis_rain_img'] self.sche_dis_rain_img.last_epoch = self.step def load_checkpoints_dis_img(self, name): ckp_path = os.path.join(self.model_dir, name) try: logger.info('Load checkpoint %s' % ckp_path) obj = torch.load(ckp_path, map_location={ 'cuda:1': 'cuda:0', 'cuda:2': 'cuda:0', 'cuda:3': 'cuda:0', 'cuda:4': 'cuda:0', 'cuda:5': 'cuda:0', 'cuda:6': 'cuda:0', 'cuda:7': 'cuda:0' }) except FileNotFoundError: logger.info('No checkpoint %s!!' % ckp_path) return self.dis_img.load_state_dict(obj['dis_img']) self.opt_dis_img.load_state_dict(obj['opt_dis_img']) self.step = obj['clock_dis_img'] self.sche_dis_img.last_epoch = self.step def loss_vgg(self, input, groundtruth): vgg_gt = self.vgg.forward(groundtruth) eval = self.vgg.forward(input) loss_vgg = [self.l1(eval[m], vgg_gt[m]) for m in range(len(vgg_gt))] loss = sum(loss_vgg) return loss def inf_batch(self, name, batch): with torch.no_grad(): O, B = batch['O'].cuda(), batch['B'].cuda() R = O - B O, B, R = Variable(O, requires_grad=False), Variable( B, requires_grad=False), Variable(R, requires_grad=False) img, derain, rain = self.net(O) if settings.network_style == 'only_rain': img, derain, rain = self.net(O) img = O - rain if settings.network_style == 'only_derain': img, derain, rain = self.net(O) img = derain if settings.network_style == 'rain_derain_no_guide': img, derain, rain = self.net(O) if settings.network_style == 'rain_derain_with_guide': img, derain, rain = self.net(O) loss_list_img = [self.l1(img, B) for img in [img]] ssim_list = [self.ssim(img, B) for img in [img]] psnr = PSNR(img.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255) losses = { 'loss%d' % i: loss.item() for i, loss in enumerate(loss_list_img) } ssimes = { 'ssim%d' % i: ssim.item() for i, ssim in enumerate(ssim_list) } losses.update(ssimes) return losses, psnr
# anchor_locations: 所有anchor框转为目标实体框的系数,无效anchor系数全部为0,有效anchor有有效系数 anchor_locations = np.empty((len(anchors), ) + anchors.shape[1:], dtype=anchor_locs.dtype) anchor_locations.fill(0) anchor_locations[valid_anchor_index, :] = anchor_locs print(anchor_locations.shape ) # 所有anchor对应的平移缩放系数(feature_size*feature_size*9,4)=》(22500, 4) # 这里通过候选anchor与目标实体框计算得到anchor框的置信度(anchor_conf)和平移缩放系数(anchor_locations) # ---------------------- # --------------------step_2: VGG 和 RPN 模型: RPN 预测的是anchor转为目标框的平移缩放系数 vgg = VGG() # out_map 特征图, # pred_anchor_locs 预测anchor框到目标框转化的系数, pred_anchor_conf 预测anchor框的分数 out_map, pred_anchor_locs, pred_anchor_conf = vgg.forward(img_var) print(out_map.data.shape ) # (batch_size, num, feature_size, feature_size) => (1, 512, 50, 50) # 1. pred_anchor_locs 预测每个anchor框到目标框转化的系数(平移缩放),与 anchor_locations对应 pred_anchor_locs = pred_anchor_locs.permute(0, 2, 3, 1).contiguous().view(1, -1, 4) print(pred_anchor_locs.shape) # Out: torch.Size([1, 22500, 4]) # 2. 预测anchor框的置信度,每个anchor框都会对应一个置信度,与 anchor_conf对应 pred_anchor_conf = pred_anchor_conf.permute(0, 2, 3, 1).contiguous() print(pred_anchor_conf.shape) # Out torch.Size([1, 50, 50, 18]) objectness_score = pred_anchor_conf.view(1, 50, 50, 9, 2)[:, :, :, :, 1].contiguous().view(1, -1) print(objectness_score.shape) # Out torch.Size([1, 22500])