class Session:
    def __init__(self):
        self.log_dir = settings.log_dir
        self.model_dir = settings.model_dir
        ensure_dir(settings.log_dir)
        ensure_dir(settings.model_dir)
        logger.info('set log dir as %s' % settings.log_dir)
        logger.info('set model dir as %s' % settings.model_dir)
        if torch.cuda.is_available():
            self.net = Net().cuda()
            self.dis_rain_img = Discriminator_rain_img().cuda()
            self.dis_img = Discriminator_img().cuda()
        if len(device_ids) > 1:
            self.net = nn.DataParallel(Net()).cuda()
            self.dis_rain_img = nn.DataParallel(
                Discriminator_rain_img()).cuda()
            self.dis_img = nn.DataParallel(Discriminator_img()).cuda()
        self.opt_net = Adam(self.net.parameters(), lr=settings.lr)
        self.sche_net = MultiStepLR(self.opt_net,
                                    milestones=[settings.l1, settings.l2],
                                    gamma=0.1)

        self.opt_dis_rain_img = Adam(self.dis_rain_img.parameters(),
                                     lr=settings.lr)
        self.sche_dis_rain_img = MultiStepLR(
            self.opt_dis_rain_img,
            milestones=[settings.l1, settings.l2],
            gamma=0.1)
        self.opt_dis_img = Adam(self.dis_rain_img.parameters(), lr=settings.lr)
        self.sche_dis_img = MultiStepLR(self.opt_dis_rain_img,
                                        milestones=[settings.l1, settings.l2],
                                        gamma=0.1)

        self.l2 = MSELoss().cuda()
        self.l1 = nn.L1Loss().cuda()
        self.ssim = SSIM().cuda()
        self.vgg = VGG().cuda()
        self.dataloaders = {}

    def get_dataloader(self, dataset_name):
        dataset = TestDataset(dataset_name)
        if not dataset_name in self.dataloaders:
            self.dataloaders[dataset_name] = \
                    DataLoader(dataset, batch_size=1,
                            shuffle=False, num_workers=1, drop_last=False)
        return self.dataloaders[dataset_name]

    def load_checkpoints_net(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            logger.info('Load checkpoint %s' % ckp_path)
            obj = torch.load(ckp_path)
        except FileNotFoundError:
            logger.info('No checkpoint %s!!' % ckp_path)
            return
        self.net.load_state_dict(obj['net'])
        self.opt_net.load_state_dict(obj['opt_net'])
        self.step = obj['clock_net']
        self.sche_net.last_epoch = self.step

    def load_checkpoints_dis_rain_img(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            logger.info('Load checkpoint %s' % ckp_path)
            obj = torch.load(ckp_path,
                             map_location={
                                 'cuda:1': 'cuda:0',
                                 'cuda:2': 'cuda:0',
                                 'cuda:3': 'cuda:0',
                                 'cuda:4': 'cuda:0',
                                 'cuda:5': 'cuda:0',
                                 'cuda:6': 'cuda:0',
                                 'cuda:7': 'cuda:0'
                             })
        except FileNotFoundError:
            logger.info('No checkpoint %s!!' % ckp_path)
            return
        self.dis_rain_img.load_state_dict(obj['dis_rain_img'])
        self.opt_dis_rain_img.load_state_dict(obj['opt_dis_rain_img'])
        self.step = obj['clock_dis_rain_img']
        self.sche_dis_rain_img.last_epoch = self.step

    def load_checkpoints_dis_img(self, name):
        ckp_path = os.path.join(self.model_dir, name)
        try:
            logger.info('Load checkpoint %s' % ckp_path)
            obj = torch.load(ckp_path,
                             map_location={
                                 'cuda:1': 'cuda:0',
                                 'cuda:2': 'cuda:0',
                                 'cuda:3': 'cuda:0',
                                 'cuda:4': 'cuda:0',
                                 'cuda:5': 'cuda:0',
                                 'cuda:6': 'cuda:0',
                                 'cuda:7': 'cuda:0'
                             })
        except FileNotFoundError:
            logger.info('No checkpoint %s!!' % ckp_path)
            return
        self.dis_img.load_state_dict(obj['dis_img'])
        self.opt_dis_img.load_state_dict(obj['opt_dis_img'])
        self.step = obj['clock_dis_img']
        self.sche_dis_img.last_epoch = self.step

    def loss_vgg(self, input, groundtruth):
        vgg_gt = self.vgg.forward(groundtruth)
        eval = self.vgg.forward(input)
        loss_vgg = [self.l1(eval[m], vgg_gt[m]) for m in range(len(vgg_gt))]
        loss = sum(loss_vgg)
        return loss

    def inf_batch(self, name, batch):
        with torch.no_grad():
            O, B = batch['O'].cuda(), batch['B'].cuda()
            R = O - B
            O, B, R = Variable(O, requires_grad=False), Variable(
                B, requires_grad=False), Variable(R, requires_grad=False)
            img, derain, rain = self.net(O)
            if settings.network_style == 'only_rain':
                img, derain, rain = self.net(O)
                img = O - rain
            if settings.network_style == 'only_derain':
                img, derain, rain = self.net(O)
                img = derain
            if settings.network_style == 'rain_derain_no_guide':
                img, derain, rain = self.net(O)
            if settings.network_style == 'rain_derain_with_guide':
                img, derain, rain = self.net(O)
        loss_list_img = [self.l1(img, B) for img in [img]]
        ssim_list = [self.ssim(img, B) for img in [img]]
        psnr = PSNR(img.data.cpu().numpy() * 255, B.data.cpu().numpy() * 255)
        losses = {
            'loss%d' % i: loss.item()
            for i, loss in enumerate(loss_list_img)
        }
        ssimes = {
            'ssim%d' % i: ssim.item()
            for i, ssim in enumerate(ssim_list)
        }
        losses.update(ssimes)

        return losses, psnr
예제 #2
0
# anchor_locations: 所有anchor框转为目标实体框的系数,无效anchor系数全部为0,有效anchor有有效系数
anchor_locations = np.empty((len(anchors), ) + anchors.shape[1:],
                            dtype=anchor_locs.dtype)
anchor_locations.fill(0)
anchor_locations[valid_anchor_index, :] = anchor_locs
print(anchor_locations.shape
      )  # 所有anchor对应的平移缩放系数(feature_size*feature_size*9,4)=》(22500, 4)

# 这里通过候选anchor与目标实体框计算得到anchor框的置信度(anchor_conf)和平移缩放系数(anchor_locations)
# ----------------------

# --------------------step_2: VGG 和 RPN 模型: RPN 预测的是anchor转为目标框的平移缩放系数
vgg = VGG()
# out_map 特征图, # pred_anchor_locs 预测anchor框到目标框转化的系数, pred_anchor_conf 预测anchor框的分数
out_map, pred_anchor_locs, pred_anchor_conf = vgg.forward(img_var)
print(out_map.data.shape
      )  # (batch_size, num, feature_size, feature_size) => (1, 512, 50, 50)

# 1. pred_anchor_locs 预测每个anchor框到目标框转化的系数(平移缩放),与 anchor_locations对应
pred_anchor_locs = pred_anchor_locs.permute(0, 2, 3,
                                            1).contiguous().view(1, -1, 4)
print(pred_anchor_locs.shape)  # Out: torch.Size([1, 22500, 4])

# 2. 预测anchor框的置信度,每个anchor框都会对应一个置信度,与 anchor_conf对应
pred_anchor_conf = pred_anchor_conf.permute(0, 2, 3, 1).contiguous()
print(pred_anchor_conf.shape)  # Out torch.Size([1, 50, 50, 18])
objectness_score = pred_anchor_conf.view(1, 50, 50, 9,
                                         2)[:, :, :, :,
                                            1].contiguous().view(1, -1)
print(objectness_score.shape)  # Out torch.Size([1, 22500])