示例#1
0
def eval(args, model, data_loader):
    device = get_device(args)
    data_loader = tqdm(data_loader)
    model.eval()
    model = model.to(device)
    fake_dir = osp.join(args.save, 'fake_result')
    real_dir = osp.join(args.save, 'real_result')
    create_dir(real_dir)
    create_dir(fake_dir)

    for i, sample in enumerate(data_loader):
        imgs = sample['image'].to(device)
        maps = sample['map'].to(device)
        im_name = sample['im_name']
        with torch.no_grad():
            fakes = model(imgs)

        batch_size = imgs.size(0)
        for b in range(batch_size):
            file_name = osp.split(im_name[b])[-1].split('.')[0]
            real_file = osp.join(real_dir, f'{file_name}.tif')
            fake_file = osp.join(fake_dir, f'{file_name}.tif')

            from_std_tensor_save_image(filename=real_file, data=maps[b].cpu())
            from_std_tensor_save_image(filename=fake_file, data=fakes[b].cpu())
        pass
    pass
    fid = fid_score(real_path=real_dir, fake_path=fake_dir, gpu=str(args.gpu))
    print(f'===> fid score:{fid:.4f}')
    return fid
def get_E(args):
    norm_layer = get_norm_layer(norm_type=args.norm)
    netE = Encoder(args.output_nc, args.feat_num, args.ngf, args.n_downsample_global, norm_layer)
    netE.apply(weights_init)
    print(netE)
    netE = nn.DataParallel(netE).to(get_device(args))
    return netE
示例#3
0
文件: networks.py 项目: GAIMJKP/GAN-1
def get_G(args, input_nc=None):
    if input_nc is None:
        input_nc = args.label_nc
        if args.use_instance:
            input_nc += 1
        if args.feat_num > 0:
            input_nc += args.feat_num

    norm_layer = get_norm_layer(norm_type=args.norm)
    if args.netG == 'global':
        netG = GlobalGenerator(input_nc, args.output_nc, args.ngf,
                               args.n_downsample_global, args.n_blocks_global,
                               norm_layer)
    elif args.netG == 'local':
        netG = LocalEnhancer(input_nc, args.output_nc, args.ngf,
                             args.n_downsample_global, args.n_blocks_global,
                             args.n_local_enhancers, args.n_blocks_local,
                             norm_layer)
    elif args.netG == 'encoder':
        netG = Encoder(input_nc, args.output_nc, args.ngf,
                       args.n_downsample_global, norm_layer)
    else:
        raise ('generator not implemented!')
    print(netG)
    netG.apply(weights_init)
    netG = nn.DataParallel(netG).to(get_device(args))
    return netG
    def __init__(self, args):
        super(VGGLoss, self).__init__()
        assert args.vgg_type in ('vgg16', 'vgg19')
        vgg = Vgg16 if args.vgg_type == 'vgg16' else Vgg19
        self.vgg = nn.DataParallel(vgg()).to(get_device(args))
        # self.vgg = vgg().to(get_device(args))
        self.vgg.eval()
        self.criterion = nn.DataParallel(nn.L1Loss())
        # self.criterion = nn.L1Loss()
        self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]

        print(
            f'===> {self.__class__.__name__} | vgg:{args.vgg_type} | loss:{self.criterion}'
        )
def get_D(args, input_nc=None):
    if input_nc is None:
        #input_nc = args.label_nc + args.output_nc
        input_nc = args.input_nc + args.output_nc
        if args.use_instance:
            input_nc += 1

    norm_layer = get_norm_layer(norm_type=args.norm)
    netD = MultiscaleDiscriminator(input_nc, args.ndf, args.n_layers_D, norm_layer, args.use_lsgan, args.num_D,
                                   args.use_ganFeat_loss)
    print(netD)
    netD.apply(weights_init)
    netD = nn.DataParallel(netD).to(get_device(args))
    return netD
示例#6
0
    def __init__(self, args):
        self.content_layer = args.content_layer
        device = get_device(args)
        self.vgg = nn.DataParallel(Vgg16()).to(device)
        self.vgg.eval()
        self.mse = nn.DataParallel(nn.MSELoss())
        self.mse_sum = nn.DataParallel(nn.MSELoss(reduction='sum'))
        style_image = Image.open(args.style_image).convert('RGB')
        _, transform = get_transform(args)
        style_image = transform(style_image).repeat(args.batch_size, 1, 1, 1).to(device)

        with torch.no_grad():
            self.style_features = self.vgg(style_image)
            self.style_gram = [gram(fmap) for fmap in self.style_features]
        pass
示例#7
0
    model_saver = ModelSaver(save_path=args.save,
                             name_list=[
                                 f'{style_name}',
                                 f'{style_name}_{args.optimizer}',
                                 f'{style_name}_{args.scheduler}'
                             ])
    criterion = PerceptualLoss(args)
    model = ImageTransformNet()
    model = model_accelerate(args, model)
    model_saver.load(f'{style_name}', model=model)

    optimizer = Adam(model.parameters(), lr=args.lr)
    model_saver.load(f'{style_name}_{args.optimizer}', model=optimizer)

    epoch_now = len(logger.get_data(key='loss'))
    device = get_device(args)

    ####
    data_loader = get_dataloader_from_dir(args)
    data_loader = tqdm(data_loader)
    model.eval()
    with torch.no_grad():
        # counter = 10
        for i, (imgs, path) in enumerate(data_loader):
            imgs = imgs.to(device)
            # counter -= 1
            # if counter < 0:
            #     break
            y_hat = model(imgs)
            for index in range(y_hat.size(0)):
                # data = torch.cat([y_hat[index], imgs[index]], dim=2)
示例#8
0
def train(args, get_dataloader_func=get_pix2pix_maps_dataloader):
    logger = Logger(save_path=args.save, json_name='img2map')

    model_saver = ModelSaver(save_path=args.save,
                             name_list=[
                                 'G', 'D', 'E', 'G_optimizer', 'D_optimizer',
                                 'E_optimizer', 'G_scheduler', 'D_scheduler',
                                 'E_scheduler'
                             ])
    visualizer = Visualizer(
        keys=['image', 'encode_feature', 'fake', 'label', 'instance'])
    sw = SummaryWriter(args.tensorboard_path)
    G = get_G(args)
    D = get_D(args)
    model_saver.load('G', G)
    model_saver.load('D', D)

    # fid = get_fid(args)
    # logger.log(key='FID', data=fid)
    # logger.save_log()
    # logger.visualize()

    G_optimizer = Adam(G.parameters(), lr=args.G_lr, betas=(args.beta1, 0.999))
    D_optimizer = Adam(D.parameters(), lr=args.D_lr, betas=(args.beta1, 0.999))

    model_saver.load('G_optimizer', G_optimizer)
    model_saver.load('D_optimizer', D_optimizer)

    G_scheduler = get_hinge_scheduler(args, G_optimizer)
    D_scheduler = get_hinge_scheduler(args, D_optimizer)

    model_saver.load('G_scheduler', G_scheduler)
    model_saver.load('D_scheduler', D_scheduler)

    device = get_device(args)

    GANLoss = get_GANLoss(args)

    if args.use_ganFeat_loss:
        DFLoss = get_DFLoss(args)
    if args.use_vgg_loss:
        VGGLoss = get_VGGLoss(args)
    if args.use_low_level_loss:
        LLLoss = get_low_level_loss(args)

    epoch_now = len(logger.get_data('G_loss'))
    for epoch in range(epoch_now, args.epochs):
        G_loss_list = []
        D_loss_list = []

        data_loader = get_dataloader_func(args, train=True)
        data_loader = tqdm(data_loader)

        for step, sample in enumerate(data_loader):
            imgs = sample['image'].to(device)
            maps = sample['map'].to(device)
            # print(smasks.shape)

            # train the Discriminator
            D_optimizer.zero_grad()
            reals_maps = torch.cat([imgs.float(), maps.float()], dim=1)
            fakes = G(imgs).detach()
            fakes_maps = torch.cat([imgs.float(), fakes.float()], dim=1)

            D_real_outs = D(reals_maps)
            D_real_loss = GANLoss(D_real_outs, True)

            D_fake_outs = D(fakes_maps)
            D_fake_loss = GANLoss(D_fake_outs, False)

            D_loss = 0.5 * (D_real_loss + D_fake_loss)
            D_loss = D_loss.mean()
            D_loss.backward()
            D_loss = D_loss.item()
            D_optimizer.step()

            # train generator and encoder
            G_optimizer.zero_grad()
            fakes = G(imgs)
            fakes_maps = torch.cat([imgs.float(), fakes.float()], dim=1)
            D_fake_outs = D(fakes_maps)

            gan_loss = GANLoss(D_fake_outs, True)

            G_loss = 0
            G_loss += gan_loss
            gan_loss = gan_loss.mean().item()

            if args.use_vgg_loss:
                vgg_loss = VGGLoss(fakes, imgs)
                G_loss += args.lambda_feat * vgg_loss
                vgg_loss = vgg_loss.mean().item()
            else:
                vgg_loss = 0.

            if args.use_ganFeat_loss:
                df_loss = DFLoss(D_fake_outs, D_real_outs)
                G_loss += args.lambda_feat * df_loss
                df_loss = df_loss.mean().item()
            else:
                df_loss = 0.

            if args.use_low_level_loss:
                ll_loss = LLLoss(fakes, maps)
                G_loss += args.lambda_feat * ll_loss
                ll_loss = ll_loss.mean().item()
            else:
                ll_loss = 0.

            G_loss = G_loss.mean()
            G_loss.backward()
            G_loss = G_loss.item()

            G_optimizer.step()

            data_loader.write(
                f'Epochs:{epoch} | Dloss:{D_loss:.6f} | Gloss:{G_loss:.6f}'
                f'| GANloss:{gan_loss:.6f} | VGGloss:{vgg_loss:.6f} | DFloss:{df_loss:.6f} '
                f'| LLloss:{ll_loss:.6f} | lr:{get_lr(G_optimizer):.8f}')

            G_loss_list.append(G_loss)
            D_loss_list.append(D_loss)

            # display
            if args.display and step % args.display == 0:
                visualizer.display(transforms.ToPILImage()(imgs[0].cpu()),
                                   'image')
                visualizer.display(transforms.ToPILImage()(fakes[0].cpu()),
                                   'fake')
                visualizer.display(transforms.ToPILImage()(maps[0].cpu()),
                                   'label')

            # tensorboard log
            if args.tensorboard_log and step % args.tensorboard_log == 0:
                total_steps = epoch * len(data_loader) + step
                sw.add_scalar('Loss/G', G_loss, total_steps)
                sw.add_scalar('Loss/D', D_loss, total_steps)
                sw.add_scalar('Loss/gan', gan_loss, total_steps)
                sw.add_scalar('Loss/vgg', vgg_loss, total_steps)
                sw.add_scalar('Loss/df', df_loss, total_steps)
                sw.add_scalar('Loss/ll', ll_loss, total_steps)

                sw.add_scalar('LR/G', get_lr(G_optimizer), total_steps)
                sw.add_scalar('LR/D', get_lr(D_optimizer), total_steps)

                sw.add_image('img/real', imgs[0].cpu(), step)
                sw.add_image('img/fake', fakes[0].cpu(), step)
                sw.add_image('visual/label', maps[0].cpu(), step)

        D_scheduler.step(epoch)
        G_scheduler.step(epoch)
        if epoch % 10 == 0 or epoch == args.epochs:
            fid = eval(args,
                       model=G,
                       data_loader=get_dataloader_func(args, train=False))
            logger.log(key='FID', data=fid)
            if fid > logger.get_max(key='FID'):
                model_saver.save(f'G_{fid:.4f}', G)
                model_saver.save(f'D_{fid:.4f}', D)

        logger.log(key='D_loss',
                   data=sum(D_loss_list) / float(len(D_loss_list)))
        logger.log(key='G_loss',
                   data=sum(G_loss_list) / float(len(G_loss_list)))
        logger.save_log()
        logger.visualize()

        model_saver.save('G', G)
        model_saver.save('D', D)

        model_saver.save('G_optimizer', G_optimizer)
        model_saver.save('D_optimizer', D_optimizer)

        model_saver.save('G_scheduler', G_scheduler)
        model_saver.save('D_scheduler', D_scheduler)
示例#9
0
def eval_fidiou(args, model_G, model_seg, data_loader):
    device = get_device(args)
    data_loader = tqdm(data_loader)
    model_G.eval()
    model_seg.eval()
    model_G = model_G.to(device)
    model_seg = model_seg.to(device)

    label_preds = []
    label_targets = []

    real_seg_dir = osp.join(args.save, 'real_seg')
    real_dir = osp.join(args.save, 'real_result')
    A_dir = osp.join(args.save, 'real_source')
    seg_dir = osp.join(args.save, 'seg_result')
    fake_dir = osp.join(args.save, 'fake_result')
    create_dir(real_dir)
    create_dir(real_seg_dir)
    create_dir(A_dir)
    create_dir(seg_dir)
    create_dir(fake_dir)

    for i, sample in enumerate(data_loader):
        inputs, labels = sample['A_seg'], sample['seg'].squeeze(dim=1)
        inputs = inputs.cuda() if args.gpu else inputs
        labels = labels.cuda() if args.gpu else labels
        imgs = sample['A'].to(device)
        maps = sample['B'].to(device)

        outputs, feature_map = model_seg(inputs)
        bs, n_class, h, w = outputs.shape
        outs = outputs.data.cpu().numpy()
        pred = outs.transpose(0, 2, 3,
                              1).reshape(-1, n_class).argmax(axis=1).reshape(
                                  bs, h, w)
        target = labels.cpu().numpy().reshape(bs, h, w)
        label_preds.append(pred)
        label_targets.append(target)

        # seg_ret = pred2gray(outputs).unsqueeze(1).type(torch.FloatTensor).to(device)  # bs*1*h*w
        feature_map = feature_map.detach()
        imgs_plus = torch.cat((imgs, feature_map), 1)
        fakes = model_G(imgs_plus).detach()

        batch_size = inputs.size(0)
        im_name = sample['A_paths']
        for b in range(batch_size):
            file_name = osp.split(im_name[b])[0].split(
                os.sep)[-2] + '_' + osp.split(im_name[b])[0].split(
                    os.sep)[-1] + '_' + osp.split(im_name[b])[-1].split('.')[0]
            real_file = osp.join(real_dir, f'{file_name}.tif')
            real_seg_file = osp.join(real_seg_dir, f'{file_name}.tif')
            A_file = osp.join(A_dir, f'{file_name}.tif')
            seg_file = osp.join(seg_dir, f'{file_name}.tif')
            fake_file = osp.join(fake_dir, f'{file_name}.tif')

            from_std_tensor_save_image(filename=real_file,
                                       data=sample['B'][b].cpu())
            from_std_tensor_save_image(filename=A_file,
                                       data=sample['A'][b].cpu())
            from_std_tensor_save_image(filename=fake_file, data=fakes[b].cpu())
            tmpimg = sample['seg'][b].data.cpu().numpy()
            tmpimg = gray2rgb(tmpimg)
            tmpimg = Image.fromarray(tmpimg)
            tmpimg.save(fp=real_seg_file)

            tmpimg = gray2rgb(pred[b])
            tmpimg = Image.fromarray(tmpimg)
            tmpimg.save(fp=seg_file)

    fid = fid_score(real_path=real_dir, fake_path=fake_dir, gpu=str(args.gpu))
    print(f'===> fid score:{fid:.4f}')
    iou = None
    from src.pix2pixHD.eval_iou import label_accuracy_score
    _, _, iou, _, _ = label_accuracy_score(label_targets, label_preds, n_class)
    print(f'===> iou score:{iou:.4f}')

    model_seg.train()
    model_G.train()
    return fid, iou
示例#10
0
def train(args, get_dataloader_func=get_pix2pix_maps_dataloader):
    with open(os.path.join(args.save, 'args.json'), 'w') as f:
        json.dump(vars(args), f)
    logger = Logger(save_path=args.save, json_name='img2map_seg')
    epoch_now = len(logger.get_data('FOCAL_loss'))

    model_saver = ModelSaver(
        save_path=args.save,
        name_list=[
            'G', 'D', 'G_optimizer', 'D_optimizer', 'G_scheduler',
            'D_scheduler', 'DLV3P', "DLV3P_global_optimizer",
            "DLV3P_backbone_optimizer", "DLV3P_global_scheduler",
            "DLV3P_backbone_scheduler", 'best_G', 'best_D', 'best_G_optimizer',
            'best_D_optimizer', 'best_G_scheduler', 'best_D_scheduler',
            'best_DLV3P', "best_DLV3P_global_optimizer",
            "best_DLV3P_backbone_optimizer", "best_DLV3P_global_scheduler",
            "best_DLV3P_backbone_scheduler"
        ])

    sw = SummaryWriter(args.tensorboard_path)

    G = get_G(args, input_nc=3 + 256)  # 3+256,256为分割网络输出featuremap的通道数
    D = get_D(args)
    model_saver.load('G', G)
    model_saver.load('D', D)

    cfg = Configuration()
    cfg.MODEL_NUM_CLASSES = args.label_nc
    DLV3P = deeplabv3plus(cfg)
    if args.gpu:
        # DLV3P=nn.DataParallel(DLV3P)
        DLV3P = DLV3P.cuda()
    model_saver.load('DLV3P', DLV3P)

    G_optimizer = Adam(G.parameters(), lr=args.G_lr, betas=(args.beta1, 0.999))
    D_optimizer = Adam(D.parameters(), lr=args.D_lr, betas=(args.beta1, 0.999))

    seg_global_params, seg_backbone_params = DLV3P.get_paras()
    DLV3P_global_optimizer = torch.optim.Adam([{
        'params': seg_global_params,
        'initial_lr': args.seg_lr_global
    }],
                                              lr=args.seg_lr_global,
                                              betas=(args.beta1, 0.999))
    DLV3P_backbone_optimizer = torch.optim.Adam(
        [{
            'params': seg_backbone_params,
            'initial_lr': args.seg_lr_backbone
        }],
        lr=args.seg_lr_backbone,
        betas=(args.beta1, 0.999))

    model_saver.load('G_optimizer', G_optimizer)
    model_saver.load('D_optimizer', D_optimizer)
    model_saver.load('DLV3P_global_optimizer', DLV3P_global_optimizer)
    model_saver.load('DLV3P_backbone_optimizer', DLV3P_backbone_optimizer)

    G_scheduler = get_hinge_scheduler(args, G_optimizer)
    D_scheduler = get_hinge_scheduler(args, D_optimizer)
    DLV3P_global_scheduler = torch.optim.lr_scheduler.LambdaLR(
        DLV3P_global_optimizer,
        lr_lambda=lambda epoch: (1 - epoch / args.epochs)**0.9,
        last_epoch=epoch_now)
    DLV3P_backbone_scheduler = torch.optim.lr_scheduler.LambdaLR(
        DLV3P_backbone_optimizer,
        lr_lambda=lambda epoch: (1 - epoch / args.epochs)**0.9,
        last_epoch=epoch_now)

    model_saver.load('G_scheduler', G_scheduler)
    model_saver.load('D_scheduler', D_scheduler)
    model_saver.load('DLV3P_global_scheduler', DLV3P_global_scheduler)
    model_saver.load('DLV3P_backbone_scheduler', DLV3P_backbone_scheduler)

    D_scheduler.step(epoch_now)  # 调整lr便于finetrain
    G_scheduler.step(epoch_now)

    device = get_device(args)

    GANLoss = get_GANLoss(args)
    if args.use_ganFeat_loss:
        DFLoss = get_DFLoss(args)
    if args.use_vgg_loss:
        VGGLoss = get_VGGLoss(args)
    if args.use_low_level_loss:
        LLLoss = get_low_level_loss(args)

    # CE_loss=nn.CrossEntropyLoss(ignore_index=255)
    LVS_loss = lovasz_softmax
    data_loader_focal = get_dataloader_func(args, train=True)
    data_loader_focal = tqdm(data_loader_focal)
    alpha = label_nums(data_loader_focal, label_num=args.label_nc)
    # alpha = [1,1,1,1,1]
    tmp_min = min(alpha)
    assert tmp_min > 0
    for i in range(len(alpha)):
        alpha[i] = tmp_min / alpha[i]
    if args.focal_alpha_revise:
        assert len(args.focal_alpha_revise) == len(alpha)
        for i in range(len(alpha)):
            alpha[i] = alpha[i] * args.focal_alpha_revise[i]
    print(alpha)
    FOCAL_loss = FocalLoss(gamma=2, alpha=alpha)

    if epoch_now == args.epochs:
        print('get final models')
        iou = eval_fidiou(args,
                          model_G=G,
                          model_seg=DLV3P,
                          data_loader=get_pix2pix_maps_dataloader(args,
                                                                  train=False))
        logger.log(key='iou', data=iou)
        # if iou < logger.get_max(key='FID'):
        #     model_saver.save(f'DLV3P_{iou:.4f}', DLV3P)
        sw.add_scalar('eval/iou', iou, epoch_now)

    for epoch in range(epoch_now, args.epochs):
        G_loss_list = []
        D_loss_list = []
        # CE_loss_list = []
        LVS_loss_list = []
        FOCAL_loss_list = []

        data_loader = get_dataloader_func(
            args, train=True,
            flag=(2 if args._usefakelen else 0))  # flag=2:使用虚假的数据长度
        data_loader = tqdm(data_loader)

        for step, sample in enumerate(data_loader):
            # 先训练deeplabv3+
            imgs_seg = sample['A_seg'].to(
                device)  # (shape: (batch_size, 3, img_h, img_w))
            label_imgs = sample['seg'].type(torch.LongTensor).to(
                device)  # (shape: (batch_size, img_h, img_w))
            # print(label_imgs)
            # print(label_imgs.max())
            # print(label_imgs.min())

            # imgs_show=sample['A'].to(device) # (shape: (batch_size, 3, img_h, img_w))
            # maps_show= sample['B'].to(device)  # (shape: (batch_size, 3, img_h, img_w))

            outputs, feature_map = DLV3P(
                imgs_seg)  # (shape: (batch_size, num_classes, img_h, img_w))
            # feature_map=feature_map.detach()

            # compute the loss:
            # ce_loss = CE_loss(outputs, label_imgs)
            # ce_loss_value = ce_loss.data.cpu().numpy()
            soft_outputs = torch.nn.functional.softmax(outputs, dim=1)
            lvs_loss = LVS_loss(soft_outputs, label_imgs, ignore=255)
            lvs_loss_value = lvs_loss.data.cpu().numpy()
            focal_loss = FOCAL_loss(outputs, label_imgs)
            focal_loss_value = focal_loss.data.cpu().numpy()

            seg_loss = (focal_loss + lvs_loss) * 0.5

            # optimization step:
            # DLV3P_global_optimizer.zero_grad()  # (reset gradients)
            # DLV3P_backbone_optimizer.zero_grad()
            # seg_loss.backward()  # (compute gradients)
            # DLV3P_global_optimizer.step()  # (perform optimization step)
            # DLV3P_backbone_optimizer.step()

            # 然后训练GAN
            imgs = sample['A'].to(device)
            maps = sample['B'].to(device)
            # feature_map=feature_map.detach()
            imgs_plus = torch.cat((imgs, feature_map), 1)  # bs*(3+256)*h*w

            # train the Discriminator
            D_optimizer.zero_grad()
            reals_maps = torch.cat([imgs.float(), maps.float()], dim=1)

            fakes = G(imgs_plus).detach()
            fakes_maps = torch.cat([imgs.float(), fakes.float()], dim=1)

            D_real_outs = D(reals_maps)
            D_real_loss = GANLoss(D_real_outs, True)

            D_fake_outs = D(fakes_maps)
            D_fake_loss = GANLoss(D_fake_outs, False)

            D_loss = 0.5 * (D_real_loss + D_fake_loss)
            D_loss = D_loss.mean()
            D_loss.backward()
            D_loss = D_loss.item()
            D_optimizer.step()

            # train generator and encoder
            # G_optimizer.zero_grad()
            fakes = G(imgs_plus)
            fakes_maps = torch.cat([imgs.float(), fakes.float()], dim=1)
            D_fake_outs = D(fakes_maps)

            gan_loss = GANLoss(D_fake_outs, True)

            G_loss = 0
            G_loss += gan_loss
            gan_loss = gan_loss.mean().item()

            if args.use_vgg_loss:
                vgg_loss = VGGLoss(fakes, imgs)
                G_loss += args.lambda_feat * vgg_loss
                vgg_loss = vgg_loss.mean().item()
            else:
                vgg_loss = 0.

            if args.use_ganFeat_loss:
                df_loss = DFLoss(D_fake_outs, D_real_outs)
                G_loss += args.lambda_feat * df_loss
                df_loss = df_loss.mean().item()
            else:
                df_loss = 0.

            if args.use_low_level_loss:
                ll_loss = LLLoss(fakes, maps)
                G_loss += args.lambda_feat * ll_loss
                ll_loss = ll_loss.mean().item()
            else:
                ll_loss = 0.

            G_loss = G_loss.mean()
            G_seg_loss = args._1002arg_GANloss_alpha * G_loss + args._1002arg_segloss_alpha * seg_loss
            G_loss = G_loss.item()
            seg_loss = seg_loss.item()

            G_optimizer.zero_grad()
            DLV3P_global_optimizer.zero_grad()  # (reset gradients)
            DLV3P_backbone_optimizer.zero_grad()
            G_seg_loss.backward()
            G_optimizer.step()
            DLV3P_global_optimizer.step()  # (perform optimization step)
            DLV3P_backbone_optimizer.step()

            data_loader.write(
                f'Epochs:{epoch}  | Dloss:{D_loss:.6f} | Gloss:{G_loss:.6f}'
                f'| GANloss:{gan_loss:.6f} | VGGloss:{vgg_loss:.6f} | DFloss:{df_loss:.6f} '
                f'| LLloss:{ll_loss:.6f} | lr_gan:{get_lr(G_optimizer):.8f}'
                f'| FOCAL_loss:{focal_loss_value:.6f}|LVS_loss:{lvs_loss_value:.6f} '
                f'| lr_global:{get_lr(DLV3P_global_optimizer):.8f}| lr_backbone:{get_lr(DLV3P_backbone_optimizer):.8f}'
            )

            G_loss_list.append(G_loss)
            D_loss_list.append(D_loss)
            # CE_loss_list.append(ce_loss_value)
            LVS_loss_list.append(lvs_loss_value)
            FOCAL_loss_list.append(focal_loss_value)

            # tensorboard log
            if args.tensorboard_log and step % args.tensorboard_log == 0:  # defalut is 5
                total_steps = epoch * len(data_loader) + step
                sw.add_scalar('Loss1/G', G_loss, total_steps)
                sw.add_scalar('Loss1/seg', seg_loss, total_steps)
                sw.add_scalar('Loss1/G_seg', G_seg_loss, total_steps)
                sw.add_scalar('Loss/G', G_loss, total_steps)
                sw.add_scalar('Loss/D', D_loss, total_steps)
                sw.add_scalar('Loss/gan', gan_loss, total_steps)
                sw.add_scalar('Loss/vgg', vgg_loss, total_steps)
                sw.add_scalar('Loss/df', df_loss, total_steps)
                sw.add_scalar('Loss/ll', ll_loss, total_steps)
                # sw.add_scalar('Loss/CE', ce_loss_value, total_steps)
                sw.add_scalar('Loss/LVS', lvs_loss_value, total_steps)
                sw.add_scalar('Loss/focal', focal_loss_value, total_steps)
                sw.add_scalar('LR/G', get_lr(G_optimizer), total_steps)
                sw.add_scalar('LR/D', get_lr(D_optimizer), total_steps)
                sw.add_scalar('LR/global_seg', get_lr(DLV3P_global_optimizer),
                              total_steps)
                sw.add_scalar('LR/backbone_seg',
                              get_lr(DLV3P_backbone_optimizer), total_steps)

                sw.add_image('img2/realA',
                             tensor2im(imgs.data),
                             total_steps,
                             dataformats='HWC')
                sw.add_image('img2/fakeB',
                             tensor2im(fakes.data),
                             total_steps,
                             dataformats='HWC')
                sw.add_image('img2/realB',
                             tensor2im(maps.data),
                             total_steps,
                             dataformats='HWC')
                tmpsegmap = pred2gray(outputs)
                tmpsegmap = tmpsegmap[0].data.numpy()
                tmpsegmap = gray2rgb(tmpsegmap)
                sw.add_image('img2/fake_segB',
                             tmpsegmap,
                             total_steps,
                             dataformats='HWC')
                tmpsegmap = label_imgs[0].data.cpu().numpy()
                tmpsegmap = gray2rgb(tmpsegmap)
                sw.add_image('img2/real_segB',
                             tmpsegmap,
                             total_steps,
                             dataformats='HWC')

        D_scheduler.step(epoch)
        G_scheduler.step(epoch)
        DLV3P_global_scheduler.step()
        DLV3P_backbone_scheduler.step()
        if epoch % args._val_frequency == 0 or epoch == (args.epochs - 1):
            import copy
            args2 = copy.deepcopy(args)
            args2.batch_size = args.batch_size_eval
            fid, iou = eval_fidiou(args,
                                   model_G=G,
                                   model_seg=DLV3P,
                                   data_loader=get_pix2pix_maps_dataloader(
                                       args2, train=False))
            if fid < logger.get_min(key='FID') == 0:
                # if epoch >= 80 :
                # model_saver.save('best_G', G)
                # model_saver.save('best_D', D)
                # model_saver.save('best_DLV3P', DLV3P)
                # model_saver.save('best_G_optimizer', G_optimizer)
                # model_saver.save('best_D_optimizer', D_optimizer)
                # model_saver.save('best_DLV3P_global_optimizer', DLV3P_global_optimizer)
                # model_saver.save('best_DLV3P_backbone_optimizer', DLV3P_backbone_optimizer)
                # model_saver.save('best_G_scheduler', G_scheduler)
                # model_saver.save('best_D_scheduler', D_scheduler)
                # model_saver.save('best_DLV3P_global_scheduler', DLV3P_global_scheduler)
                # model_saver.save('best_DLV3P_backbone_scheduler', DLV3P_backbone_scheduler)
                model_saver.save(f'best_G_{epoch}', G)
                model_saver.save(f'best_D_{epoch}', D)
                model_saver.save(f'best_DLV3P_{epoch}', DLV3P)
                model_saver.save(f'best_G_optimizer_{epoch}', G_optimizer)
                model_saver.save(f'best_D_optimizer_{epoch}', D_optimizer)
                model_saver.save(f'best_DLV3P_global_optimizer_{epoch}',
                                 DLV3P_global_optimizer)
                model_saver.save(f'best_DLV3P_backbone_optimizer_{epoch}',
                                 DLV3P_backbone_optimizer)
                model_saver.save(f'best_G_scheduler_{epoch}', G_scheduler)
                model_saver.save(f'best_D_scheduler_{epoch}', D_scheduler)
                model_saver.save(f'best_DLV3P_global_scheduler_{epoch}',
                                 DLV3P_global_scheduler)
                model_saver.save(f'best_DLV3P_backbone_scheduler_{epoch}',
                                 DLV3P_backbone_scheduler)
            logger.log(key='FID', data=fid)
            logger.log(key='iou', data=iou)
            sw.add_scalar('eval/fid', fid, epoch)
            sw.add_scalar('eval/iou', iou, epoch)

        logger.log(key='D_loss',
                   data=sum(D_loss_list) / float(len(D_loss_list)))
        logger.log(key='G_loss',
                   data=sum(G_loss_list) / float(len(G_loss_list)))
        # logger.log(key='CE_loss', data=sum(CE_loss_list) / float(len(CE_loss_list)))
        logger.log(key='LVS_loss',
                   data=sum(LVS_loss_list) / float(len(LVS_loss_list)))
        logger.log(key='FOCAL_loss',
                   data=sum(FOCAL_loss_list) / float(len(FOCAL_loss_list)))
        logger.save_log()
        # logger.visualize()

        model_saver.save('G', G)
        model_saver.save('D', D)
        model_saver.save('DLV3P', DLV3P)

        model_saver.save('G_optimizer', G_optimizer)
        model_saver.save('D_optimizer', D_optimizer)
        model_saver.save('DLV3P_global_optimizer', DLV3P_global_optimizer)
        model_saver.save('DLV3P_backbone_optimizer', DLV3P_backbone_optimizer)

        model_saver.save('G_scheduler', G_scheduler)
        model_saver.save('D_scheduler', D_scheduler)
        model_saver.save('DLV3P_global_scheduler', DLV3P_global_scheduler)
        model_saver.save('DLV3P_backbone_scheduler', DLV3P_backbone_scheduler)
示例#11
0
def train(args, get_dataloader_func=get_cityscapes_dataloader):
    logger = Logger(save_path=args.save, json_name='seg2img')

    model_saver = ModelSaver(save_path=args.save,
                             name_list=['G', 'D', 'E', 'G_optimizer', 'D_optimizer', 'E_optimizer',
                                        'G_scheduler', 'D_scheduler', 'E_scheduler'])
    visualizer = Visualizer(keys=['image', 'encode_feature', 'fake', 'label', 'instance'])
    sw = SummaryWriter(args.tensorboard_path)
    G = get_G(args)
    D = get_D(args)
    E = get_E(args)
    model_saver.load('G', G)
    model_saver.load('D', D)
    model_saver.load('E', E)

    G_optimizer = Adam(G.parameters(), lr=args.G_lr, betas=(args.beta1, 0.999))
    D_optimizer = Adam(D.parameters(), lr=args.D_lr, betas=(args.beta1, 0.999))
    E_optimizer = Adam(E.parameters(), lr=args.E_lr, betas=(args.beta1, 0.999))

    model_saver.load('G_optimizer', G_optimizer)
    model_saver.load('D_optimizer', D_optimizer)
    model_saver.load('E_optimizer', E_optimizer)

    G_scheduler = get_hinge_scheduler(args, G_optimizer)
    D_scheduler = get_hinge_scheduler(args, D_optimizer)
    E_scheduler = get_hinge_scheduler(args, E_optimizer)

    model_saver.load('G_scheduler', G_scheduler)
    model_saver.load('D_scheduler', D_scheduler)
    model_saver.load('E_scheduler', E_scheduler)

    device = get_device(args)

    GANLoss = get_GANLoss(args)
    if args.use_ganFeat_loss:
        DFLoss = get_DFLoss(args)
    if args.use_vgg_loss:
        VGGLoss = get_VGGLoss(args)

    epoch_now = len(logger.get_data('G_loss'))
    for epoch in range(epoch_now, args.epochs):
        G_loss_list = []
        D_loss_list = []

        data_loader = get_dataloader_func(args, train=True)
        data_loader = tqdm(data_loader)

        for step, sample in enumerate(data_loader):
            imgs = sample['image'].to(device)
            instances = sample['instance'].to(device)
            labels = sample['label'].to(device)
            smasks = sample['smask'].to(device)
            # print(smasks.shape)

            instances_edge = get_edges(instances)
            one_hot_labels = label_to_one_hot(smasks.long(), n_class=args.label_nc)

            # Encoder out
            encode_features = E(imgs, instances)

            # train the Discriminator
            D_optimizer.zero_grad()
            labels_instE_encodeF = torch.cat([one_hot_labels.float(), instances_edge.float(), encode_features.float()],
                                             dim=1)
            fakes = G(labels_instE_encodeF).detach()

            labels_instE_realimgs = torch.cat([one_hot_labels.float(), instances_edge.float(), imgs.float()], dim=1)
            D_real_outs = D(labels_instE_realimgs)
            D_real_loss = GANLoss(D_real_outs, True)

            labels_instE_fakeimgs = torch.cat([one_hot_labels.float(), instances_edge.float(), fakes.float()], dim=1)
            D_fake_outs = D(labels_instE_fakeimgs)
            D_fake_loss = GANLoss(D_fake_outs, False)

            D_loss = 0.5 * (D_real_loss + D_fake_loss)
            D_loss = D_loss.mean()
            D_loss.backward()
            D_loss = D_loss.item()
            D_optimizer.step()

            # train generator and encoder
            G_optimizer.zero_grad()
            E_optimizer.zero_grad()
            fakes = G(labels_instE_encodeF)
            labels_instE_fakeimgs = torch.cat([one_hot_labels.float(), instances_edge.float(), fakes.float()], dim=1)
            D_fake_outs = D(labels_instE_fakeimgs)

            gan_loss = GANLoss(D_fake_outs, True)

            G_loss = 0
            G_loss += gan_loss
            gan_loss = gan_loss.mean().item()

            if args.use_vgg_loss:
                vgg_loss = VGGLoss(fakes, imgs)
                G_loss += args.lambda_feat * vgg_loss
                vgg_loss = vgg_loss.mean().item()
            else:
                vgg_loss = 0.

            if args.use_ganFeat_loss:
                df_loss = DFLoss(D_fake_outs, D_real_outs)
                G_loss += args.lambda_feat * df_loss
                df_loss = df_loss.mean().item()
            else:
                df_loss = 0.

            G_loss = G_loss.mean()
            G_loss.backward()
            G_loss = G_loss.item()

            G_optimizer.step()
            E_optimizer.step()

            data_loader.write(f'Epochs:{epoch} | Dloss:{D_loss:.6f} | Gloss:{G_loss:.6f}'
                              f'| GANloss:{gan_loss:.6f} | VGGloss:{vgg_loss:.6f} '
                              f'| DFloss:{df_loss:.6f} | lr:{get_lr(G_optimizer):.8f}')

            G_loss_list.append(G_loss)
            D_loss_list.append(D_loss)

            # display
            if args.display and step % args.display == 0:
                visualizer.display(transforms.ToPILImage()(encode_features[0].cpu()), 'encode_feature')
                visualizer.display(transforms.ToPILImage()(imgs[0].cpu()), 'image')
                visualizer.display(transforms.ToPILImage()(fakes[0].cpu()), 'fake')
                visualizer.display(transforms.ToPILImage()(labels[0].cpu() * 15), 'label')
                visualizer.display(transforms.ToPILImage()(instances[0].cpu() * 15), 'instance')

            # tensorboard log
            if args.tensorboard_log and step % args.tensorboard_log == 0:
                total_steps = epoch * len(data_loader) + step
                sw.add_scalar('Loss/G', G_loss, total_steps)
                sw.add_scalar('Loss/D', D_loss, total_steps)
                sw.add_scalar('Loss/gan', gan_loss, total_steps)
                sw.add_scalar('Loss/vgg', vgg_loss, total_steps)
                sw.add_scalar('Loss/df', df_loss, total_steps)

                sw.add_scalar('LR/G', get_lr(G_optimizer), total_steps)
                sw.add_scalar('LR/D', get_lr(D_optimizer), total_steps)
                sw.add_scalar('LR/E', get_lr(E_optimizer), total_steps)

                sw.add_image('img/real', imgs[0].cpu(), step)
                sw.add_image('img/fake', fakes[0].cpu(), step)
                sw.add_image('visual/encode_feature', encode_features[0].cpu(), step)
                sw.add_image('visual/instance', instances[0].cpu(), step)
                sw.add_image('visual/label', labels[0].cpu(), step)

        D_scheduler.step(epoch)
        G_scheduler.step(epoch)
        E_scheduler.step(epoch)

        logger.log(key='D_loss', data=sum(D_loss_list) / float(len(D_loss_list)))
        logger.log(key='G_loss', data=sum(G_loss_list) / float(len(G_loss_list)))
        logger.save_log()
        logger.visualize()

        model_saver.save('G', G)
        model_saver.save('D', D)
        model_saver.save('E', E)

        model_saver.save('G_optimizer', G_optimizer)
        model_saver.save('D_optimizer', D_optimizer)
        model_saver.save('E_optimizer', E_optimizer)

        model_saver.save('G_scheduler', G_scheduler)
        model_saver.save('D_scheduler', D_scheduler)
        model_saver.save('E_scheduler', E_scheduler)
示例#12
0
def eval(args, model, data_loader, model_seg=None):
    device = get_device(args)
    data_loader = tqdm(data_loader)
    model.eval()
    model = model.to(device)
    if not (model_seg is None):
        model_seg.eval()
        model_seg = model_seg.to(device)
        seg_dir = osp.join(args.save, 'seg_result')
        label_preds = []
        label_targets = []
    fake_dir = osp.join(args.save, 'fake_result')
    real_dir = osp.join(args.save, 'real_result')
    A_dir = osp.join(args.save, 'real_source')
    seg_dir = osp.join(args.save, 'seg_result')
    create_dir(real_dir)
    create_dir(fake_dir)
    create_dir(A_dir)
    create_dir(seg_dir)

    for i, sample in enumerate(data_loader):
        # imgs = sample['image'].to(device)
        # maps = sample['map'].to(device)
        # im_name = sample['im_name']
        imgs = sample['A'].to(device)
        maps = sample['B'].to(device)
        im_name = sample['A_paths']
        with torch.no_grad():
            if model_seg is None:
                fakes = model(imgs)
            else:
                outputs, feature_map = model_seg(imgs)
                input_2 = F.upsample(feature_map,
                                     size=(64, 64),
                                     mode="bilinear")  # BS*256*64*64
                input_3 = F.upsample(feature_map,
                                     size=(128, 128),
                                     mode="bilinear")  # BS*256*128*128
                fakes = model(imgs, input_2, input_3)
                # 以下为计算iou的准备
                bs, n_class, h, w = outputs.shape
                outs = outputs.data.cpu().numpy()
                pred = outs.transpose(0, 2, 3, 1).reshape(
                    -1, n_class).argmax(axis=1).reshape(bs, h, w)
                target = sample['seg'].cpu().numpy().reshape(bs, h, w)
                label_preds.append(pred)
                label_targets.append(target)

        batch_size = imgs.size(0)
        if not (model_seg is None):
            from src.pix2pixHD.myutils import pred2gray
            outputs = pred2gray(outputs)
        for b in range(batch_size):
            file_name = osp.split(im_name[b])[0].split(
                os.sep)[-2] + '_' + osp.split(im_name[b])[0].split(
                    os.sep)[-1] + '_' + osp.split(im_name[b])[-1].split('.')[0]
            real_file = osp.join(real_dir, f'{file_name}.tif')
            fake_file = osp.join(fake_dir, f'{file_name}.tif')
            A_file = osp.join(A_dir, f'{file_name}.tif')
            if not (model_seg is None):
                seg_file = osp.join(seg_dir, f'{file_name}.tif')
            # if not(model_seg is None):
            #     seg_file = osp.join(seg_dir, f'{file_name}.tif')
            # from_std_tensor_save_image(filename=seg_file, data=torch.unsqueeze(outputs[b],0).cpu())

            from_std_tensor_save_image(filename=real_file, data=maps[b].cpu())
            from_std_tensor_save_image(filename=fake_file, data=fakes[b].cpu())
            from_std_tensor_save_image(filename=A_file, data=imgs[b].cpu())
            if not (model_seg is None):
                from_std_tensor_save_image(filename=seg_file,
                                           data=outputs[b].cpu())
        pass
    pass
    fid = fid_score(real_path=real_dir, fake_path=fake_dir, gpu=str(args.gpu))
    print(f'===> fid score:{fid:.4f}')

    iou = None
    if not (model_seg is None):
        from src.pix2pixHD.eval_iou import label_accuracy_score
        _, _, iou, _, _ = label_accuracy_score(label_targets, label_preds,
                                               n_class)

    model.train()
    return fid, iou