Пример #1
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    loader = data.DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform,
        target_transform=target_transform)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear',
                           align_corners=True)

    net = resnet101_ibn_a_deeplab(args.model_path_prefix,
                                  n_classes=args.n_classes)
    # optimizer = get_seg_optimizer(net, args)
    optimizer = torch.optim.SGD(net.parameters(), args.learning_rate,
                                args.momentum)
    net = torch.nn.DataParallel(net)
    criterion = torch.nn.CrossEntropyLoss(size_average=False,
                                          ignore_index=args.ignore_index)

    num_batches = len(loader)
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            show_fig = (batch_index + 1) % args.show_img_freq == 0
            iteration = batch_index + 1 + epoch * num_batches

            # poly_lr_scheduler(
            #     optimizer=optimizer,
            #     init_lr=args.learning_rate,
            #     iter=iteration - 1,
            #     lr_decay_iter=args.lr_decay,
            #     max_iter=args.num_epoch*num_batches,
            #     power=args.poly_power,
            # )

            net.train()
            # net.module.freeze_bn()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time() - tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}')
            if show_fig:
                base_lr = optimizer.param_groups[0]["lr"]
                output = torch.argmax(output, dim=1).detach()[0, ...].cpu()
                fig, axes = plt.subplots(2, 1, figsize=(12, 14))
                axes = axes.flat
                axes[0].imshow(colorize_mask(output.numpy()))
                axes[0].set_title(name[0])
                axes[1].imshow(colorize_mask(label[0, ...].numpy()))
                axes[1].set_title(f'seg_true_{base_lr:.6f}')
                writer.add_figure('A_seg', fig, iteration)

        mean_iu = test_miou(net, val_loader, upsample,
                            './ae_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix,
                         f'{epoch:d}_{mean_iu*100:.0f}.pth'))

    writer.close()
Пример #2
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    student_net = FCN8s(args.num_classes, args.model_path_prefix)
    student_net = torch.nn.DataParallel(student_net)

    student_net = student_net.cuda()

    mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    optimizer = optim.SGD(student_net.parameters(),
                          lr=args.learning_rate,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     student_net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    student_params = list(student_net.parameters())

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # src_criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    src_criterion = torch.nn.CrossEntropyLoss(ignore_index=255,
                                              size_average=False)

    num_batches = len(src_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            student_net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = student_net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = src_criterion(src_output, src_label)
            cls_loss_value /= src_images.shape[0]

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')

        miu = test_miou(student_net, tgt_val_loader, upsample,
                        './dataset/info.json')
        if miu > highest:
            torch.save(student_net.module.state_dict(),
                       osp.join(args.snapshot_dir, f'final_fcn.pth'))
            highest = miu
            print('>' * 50 + f'save highest with {miu:.2%}')
Пример #3
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    loader = data.DataLoader(
        dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True
    )
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform, target_transform=target_transform
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size,
        shuffle=False, num_workers=args.num_workers, pin_memory=True
    )


    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear', align_corners=True)

    net = PSP(
        nclass = args.n_classes, backbone='resnet101', 
        root=args.model_path_prefix, norm_layer=BatchNorm2d,
    )

    params_list = [
        {'params': net.pretrained.parameters(), 'lr': args.learning_rate},
        {'params': net.head.parameters(), 'lr': args.learning_rate*10},
        {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10},
    ]
    optimizer = torch.optim.SGD(params_list,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255)
    # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255)

    net = DataParallelModel(net).cuda()
    criterion = DataParallelCriterion(criterion).cuda()

    logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train')
    scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate,
                                   args.num_epoch, len(loader), logger=logger,
                                   lr_step=args.lr_step)

    net_eval = Eval(net)

    num_batches = len(loader)
    best_pred = 0.0
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            scheduler(optimizer, batch_index, epoch, best_pred)
            show_fig = (batch_index+1) % args.show_img_freq == 0
            iteration = batch_index+1+epoch*num_batches

            net.train()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time()-tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time()-tem_time)
            tem_time = time.time()

            if (batch_index+1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}'
                )
            # if show_fig:
            #     # base_lr = optimizer.param_groups[0]["lr"]
            #     output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu()
            #     # fig, axes = plt.subplots(2, 1, figsize=(12, 14))
            #     # axes = axes.flat
            #     # axes[0].imshow(colorize_mask(output.numpy()))
            #     # axes[0].set_title(name[0])
            #     # axes[1].imshow(colorize_mask(label[0, ...].numpy()))
            #     # axes[1].set_title(f'seg_true_{base_lr:.6f}')
            #     # writer.add_figure('A_seg', fig, iteration)
            #     output_mask = np.asarray(colorize_mask(output.numpy()))
            #     label = np.asarray(colorize_mask(label[0,...].numpy()))
            #     image_out = np.concatenate([output_mask, label])
            #     writer.add_image('A_seg', image_out, iteration)

        mean_iu = test_miou(net_eval, val_loader, upsample,
                            './style_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth')
        )

    writer.close()
Пример #4
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    tgt_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if args.seg_net == 'fcn':
        mean_std = ([103.939, 116.779, 123.68], [1.0, 1.0, 1.0])
        val_input_transform = standard_transforms.Compose([
            extended_transforms.FreeScale((h, w)),
            extended_transforms.FlipChannels(),
            standard_transforms.ToTensor(),
            standard_transforms.Lambda(lambda x: x.mul_(255)),
            standard_transforms.Normalize(*mean_std),
        ])
    else:
        normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        val_input_transform = standard_transforms.Compose([
            extended_transforms.FreeScale((h, w)),
            standard_transforms.ToTensor(),
            standard_transforms.Normalize(*normalize),
        ])

    tgt_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        joint_transform=joint_transform,
        transform=tgt_input_transform,
        target_transform=target_transform,
    )
    tgt_loader = data.DataLoader(tgt_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True,
                                 drop_last=True)
    val_dataset = Cityscapes16DataSetLMDB(
        args.data_dir_val,
        args.data_list_val,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    val_loader = data.DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear',
                           align_corners=True)

    if args.seg_net == 'fcn':
        net = FCN8s(args.n_classes, pretrained=False)
        net_static = FCN8s(args.n_classes, pretrained=False)
        file_name = os.path.join(args.resume, args.fcn_name)
        # for name, param in net.named_parameters():
        #     if 'feat' not in name:
        #         param.requires_grad = False
    elif args.seg_net == 'deeplab_ibn':
        deeplab = resnet101_ibn_a_deeplab()
        file_name = os.path.join(args.resume, 'deeplab_ibn.pth')
    net.load_state_dict(torch.load(file_name))
    net_static.load_state_dict(torch.load(file_name))
    for param in net_static.parameters():
        param.requires_grad = False

    optimizer = torch.optim.SGD(net.parameters(), args.learning_rate,
                                args.momentum)
    net = torch.nn.DataParallel(net.cuda())
    net_static = torch.nn.DataParallel(net_static.cuda())
    # criterion = torch.nn.MSELoss()
    # criterion = torch.nn.SmoothL1Loss()
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    gen_model = define_G()
    gen_model.load_state_dict(
        torch.load(os.path.join(args.resume, args.gen_name)))
    gen_model.eval()
    for param in gen_model.parameters():
        param.requires_grad = False
    gen_model = torch.nn.DataParallel(gen_model.cuda())

    # for seg net
    def normalize(x, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        if args.seg_net == 'fcn':
            mean = [103.939, 116.779, 123.68]
            flip_x = torch.cat(
                [x[:, 2 - i, :, :].unsqueeze(1) for i in range(3)],
                dim=1,
            )
            new_x = []
            for tem_x in flip_x:
                tem_new_x = []
                for c, m in zip(tem_x, mean):
                    tem_new_x.append(c.mul(255.0).sub(m).unsqueeze(0))
                new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0))
            new_x = torch.cat(new_x, dim=0)
            return new_x
        else:
            for tem_x in x:
                for c, m, s in zip(tem_x, mean, std):
                    c = c.sub(m).div(s)
            return x

    def de_normalize(x, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
        new_x = []
        for tem_x in x:
            tem_new_x = []
            for c, m, s in zip(tem_x, mean, std):
                tem_new_x.append(c.mul(s).add(s).unsqueeze(0))
            new_x.append(torch.cat(tem_new_x, dim=0).unsqueeze(0))
        new_x = torch.cat(new_x, dim=0)
        return new_x

    # ###################################################
    # direct test with gen
    # ###################################################
    print('Direct Test')
    mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json')
    direct_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        standard_transforms.ToTensor(),
        standard_transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
    val_dataset_direct = Cityscapes16DataSetLMDB(
        args.data_dir_val,
        args.data_list_val,
        transform=direct_input_transform,
        target_transform=target_transform,
    )
    val_loader_direct = data.DataLoader(val_dataset_direct,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        num_workers=args.num_workers,
                                        pin_memory=True,
                                        drop_last=False)

    class NewModel(object):
        def __init__(self, gen_net, val_net):
            self.gen_net = gen_net
            self.val_net = val_net

        def __call__(self, x):
            x = de_normalize(self.gen_net(x))
            new_x = normalize(x)
            out = self.val_net(new_x)
            return out

        def eval(self):
            self.gen_net.eval()
            self.val_net.eval()

    new_model = NewModel(gen_model, net)
    print('Test with Gen')
    mean_iu = test_miou(new_model, val_loader_direct, upsample,
                        './dataset/info.json')
    # return

    num_batches = len(tgt_loader)
    highest = 0

    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(tgt_loader):
            iteration = batch_index + 1 + epoch * num_batches

            net.train()
            net_static.eval()  # fine-tune use eval

            img, _, name = batch_data
            img = img.cuda()
            data_time_rec.update(time.time() - tem_time)

            with torch.no_grad():
                gen_output = gen_model(img)
                gen_seg_output_logits = net_static(
                    normalize(de_normalize(gen_output)))
            ori_seg_output_logits = net(normalize(de_normalize(img)))

            prob = torch.nn.Softmax(dim=1)
            max_value, label = torch.max(prob(gen_seg_output_logits), dim=1)
            label_mask = torch.zeros(label.shape, dtype=torch.uint8).cuda()
            for tem_label in range(19):
                tem_mask = label == tem_label
                if torch.sum(tem_mask) < 5:
                    continue
                value_vec = max_value[tem_mask]
                large_value = torch.topk(
                    value_vec, int(args.percent * value_vec.shape[0]))[0][0]
                large_mask = max_value > large_value
                label_mask = label_mask | (tem_mask & large_mask)
            label[label_mask] = 255

            # loss = criterion(ori_seg_output_logits, gen_seg_output_logits)
            loss = criterion(ori_seg_output_logits, label)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}')
            if iteration % args.checkpoint_freq == 0:
                mean_iu = test_miou(net,
                                    val_loader,
                                    upsample,
                                    './dataset/info.json',
                                    print_results=False)
                if mean_iu > highest:
                    torch.save(
                        net.module.state_dict(),
                        os.path.join(args.save_path_prefix,
                                     'cityscapes_best_fcn.pth'))
                    highest = mean_iu
                    print(f'save fcn model with {mean_iu:.2%}')

    print(('-' * 100 + '\n') * 3)
    print('>' * 50 + 'Final Model')
    net.module.load_state_dict(
        torch.load(
            os.path.join(args.save_path_prefix, 'cityscapes_best_fcn.pth')))
    mean_iu = test_miou(net, val_loader, upsample, './dataset/info.json')

    writer.close()
Пример #5
0
    def train(self, src_loader, tgt_loader, val_loader, writer):
        print('>'*50+'Direct Test')
        direct_model = _AddNorm(self.seg, self.de_normalize, self.fcn_normalize)
        mean_iu = test_miou(
            direct_model, val_loader,
            direct_model.upsample, self.info_json,
        )

        highest_miu = 0
        tgt_domain_label = 1
        src_domain_label = 0
        num_batches = min(len(src_loader), len(tgt_loader))

        # self.G_optimizer.param_groups[0]['lr'] *= 10
        # self.D_optimizer.param_groups[0]['lr'] *= 10
        # self.D_optimizer.param_groups[1]['lr'] *= 10
        for epoch in range(self.opt.warm_up_epoch):
            for batch_index, batch_data in enumerate(zip(src_loader, tgt_loader)):
                self.G.train()
                src_batch, tgt_batch = batch_data
                src_img, _, src_name = src_batch
                tgt_img, _, tgt_name = tgt_batch
                src_img_cuda = src_img.cuda()
                tgt_img_cuda = tgt_img.cuda()

                rec_tgt = self.G(tgt_img_cuda)  # output [-1,1]
                rec_loss = self.mse_criterion(rec_tgt, tgt_img_cuda)
                self.G_optimizer.zero_grad()
                rec_loss.backward()
                self.G_optimizer.step()

                tgt_img_cuda = self.de_normalize(tgt_img_cuda).detach()
                tgt_D_loss = self.compute_discrim_loss(
                    tgt_img_cuda, tgt_domain_label
                )
                rec_D_loss = self.compute_discrim_loss(
                    src_img_cuda, src_domain_label
                )
                D_loss = tgt_D_loss + rec_D_loss
                self.D_optimizer.zero_grad()
                D_loss.backward()
                self.D_optimizer.step()

                if (batch_index+1) % self.opt.print_freq == 0:
                    print(
                        f'Warm Up Epoch [{epoch+1:d}/{self.opt.warm_up_epoch:d}]'
                        f'[{batch_index+1:d}/{num_batches:d}]\t'
                        f'G Loss: {rec_loss.item():.2f}   '
                        f'D Loss: {D_loss.item():.2f}'
                    )

        # self.G_optimizer.param_groups[0]['lr'] /= 10
        # self.D_optimizer.param_groups[0]['lr'] /= 10
        # self.D_optimizer.param_groups[1]['lr'] /= 10
        for epoch in range(self.opt.num_epoch):

            content_loss_rec = AverageMeter()
            style_loss1_rec = AverageMeter()
            style_loss2_rec = AverageMeter()
            data_time_rec = AverageMeter()
            batch_time_rec = AverageMeter()

            tem_time = time.time()
            for batch_index, batch_data in enumerate(zip(src_loader, tgt_loader)):
                iteration = batch_index+1+epoch*num_batches

                self.G.train()
                src_batch, tgt_batch = batch_data
                src_img, _, src_name = src_batch
                tgt_img, tgt_label, tgt_name = tgt_batch
                src_img_cuda = src_img.cuda()
                tgt_img_cuda = tgt_img.cuda()
                data_time_rec.update(time.time()-tem_time)

                rec_tgt = self.G(tgt_img_cuda) # output [-1,1]
                if (batch_index+1) % self.opt.show_img_freq == 0:
                    rec_results = rec_tgt.detach().clone().cpu()
                # return to [0,1], for VGG takes input [0,1]
                rec_tgt_de_norm = self.de_normalize(rec_tgt) 

                '''
                --------------------------------------------------
                NOTICE: 
                DO NOT ADD DE-NORM HERE 
                WITHOUT NORM WE CAN ADD MORE NOISE TO CONTENT LOSS 
                --------------------------------------------------
                '''
                content_loss = self.compute_content_loss(rec_tgt_de_norm, tgt_img_cuda)

                # style_loss1, style_loss2 =\
                #     self.compute_style_loss(rec_tgt, src_img_cuda)
                style_loss1 = torch.zeros(1).cuda()
                style_loss2 = torch.zeros(1).cuda()
                loss_style = content_loss * self.lambda_values[0] +\
                             style_loss1 * self.lambda_values[1] +\
                             style_loss2 * self.lambda_values[2]

                # adv train G
                for param in self.D.parameters():
                    param.requires_grad = False
                

                adv_tgt_rec_discrim_loss = self.compute_discrim_loss(
                    rec_tgt_de_norm, src_domain_label
                )
                G_loss = loss_style +\
                         adv_tgt_rec_discrim_loss * self.lambda_values[3]

                self.G_optimizer.zero_grad()
                G_loss.backward()
                self.G_optimizer.step()

                # train D
                for param in self.D.parameters():
                    param.requires_grad = True

                # add de norm here, since D do not need noise for training
                tgt_img_cuda_de_norm = self.de_normalize(tgt_img_cuda)

                rec_tgt_de_norm = rec_tgt_de_norm.detach()

                tgt_rec_discrim_loss = self.compute_discrim_loss(
                    rec_tgt_de_norm, tgt_domain_label
                )
                # tgt_rec_discrim_loss = 0
                tgt_discrim_loss = self.compute_discrim_loss(
                    tgt_img_cuda_de_norm, tgt_domain_label
                )
                src_discrim_loss = self.compute_discrim_loss(
                    src_img_cuda, src_domain_label
                )
                D_loss = 0.5 * (tgt_rec_discrim_loss + tgt_discrim_loss) +\
                         src_discrim_loss

                self.D_optimizer.zero_grad()
                D_loss.backward()
                self.D_optimizer.step()

                content_loss_rec.update(content_loss.item())
                style_loss1_rec.update(style_loss1.item())
                style_loss2_rec.update(style_loss2.item())
                writer.add_scalar(
                    'AA_content_loss', content_loss.item(), iteration
                )
                writer.add_scalar(
                    'AA_style_loss_1', style_loss1.item(), iteration
                )
                writer.add_scalar(
                    'AA_style_loss_2', style_loss2.item(), iteration
                )
                writer.add_scalar(
                    'AA_G_loss', G_loss.item(), iteration
                )
                writer.add_scalar(
                    'AA_D_loss', D_loss.item(), iteration
                )
                batch_time_rec.update(time.time()-tem_time)
                tem_time = time.time()

                if (batch_index+1) % self.opt.print_freq == 0:
                    print(
                        f'Epoch [{epoch+1:d}/{self.opt.num_epoch:d}]'
                        f'[{batch_index+1:d}/{num_batches:d}]\t'
                        f'Time: {batch_time_rec.avg:.2f}   '
                        f'Data: {data_time_rec.avg:.2f}   '
                        f'Loss: {content_loss_rec.avg:.2f}   '
                        f'Style1: {style_loss1_rec.avg:.2f}   '
                        f'Style2: {style_loss2_rec.avg:.2f}'
                    )
                if (batch_index+1) % self.opt.show_img_freq == 0:
                    fig, axes = plt.subplots(5, 1, figsize=(6, 20), dpi=120)
                    axes = axes.flat
                    axes[0].imshow(self.to_image(rec_results[0, ...]))
                    axes[0].set_title(f'rec')
                    axes[1].imshow(self.to_image(tgt_img[0, ...]))
                    axes[1].set_title(tgt_name[0])

                    rec_seg = self.compute_seg_map(rec_results).cpu().numpy() # already normed in to_image method
                    # tgt_img_cuda = self.de_normalize(tgt_img_cuda)
                    ori_seg = self.compute_seg_map(tgt_img_cuda_de_norm).cpu().numpy()

                    axes[2].imshow(colorize_mask(rec_seg[0, ...]))
                    # axes[2].set_title(f'rec_label_{rec_miu*100:.2f}')
                    axes[2].set_title(f'rec_label')

                    axes[3].imshow(colorize_mask(ori_seg[0, ...]))
                    # axes[3].set_title(f'ori_label_{ori_miu*100:.2f}')
                    axes[3].set_title(f'ori_label')

                    tgt_label = tgt_label.numpy()
                    gt_label = tgt_label[0, ...]
                    axes[4].imshow(colorize_mask(gt_label))
                    axes[4].set_title(f'gt_label')

                    writer.add_figure('A_rec', fig, iteration)
                if iteration % self.opt.checkpoint_freq == 0:

                    combine_model = _CombineModel(
                        self.G, self.seg, 
                        self.de_normalize, self.fcn_normalize,
                    )
                    mean_iu = test_miou(
                        combine_model, val_loader,
                        combine_model.upsample, self.info_json,
                        print_results=False
                    )

                    if mean_iu > highest_miu:
                        torch.save(
                            self.G.module.state_dict(), 
                            os.path.join(self.opt.save_path_prefix, 'gen.pth')
                        )
                        highest_miu = mean_iu
                        print('>'*50+f'save highest with {mean_iu:.2%}')
Пример #6
0
def main():
    """Create the model and start the training."""
    args = get_arguments()

    w, h = map(int, args.input_size.split(','))

    w_target, h_target = map(int, args.input_size_target.split(','))

    # Create network
    if args.bn_sync:
        print('Using Sync BN')
        deeplabv3.BatchNorm2d = partial(InPlaceABNSync, activation='none')
    net = get_deeplabV3(args.num_classes, args.model_path_prefix)
    if not args.bn_sync:
        net.freeze_bn()
    net = torch.nn.DataParallel(net)

    net = net.cuda()

    mean_std = ([104.00698793, 116.66876762, 122.67891434], [1.0, 1.0, 1.0])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
    ])
    input_transform = standard_transforms.Compose([
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    val_input_transform = standard_transforms.Compose([
        extended_transforms.FreeScale((h, w)),
        extended_transforms.FlipChannels(),
        standard_transforms.ToTensor(),
        standard_transforms.Lambda(lambda x: x.mul_(255)),
        standard_transforms.Normalize(*mean_std),
    ])
    target_transform = extended_transforms.MaskToTensor()
    # show img
    restore_transform = standard_transforms.Compose([
        extended_transforms.DeNormalize(*mean_std),
        standard_transforms.Lambda(lambda x: x.div_(255)),
        standard_transforms.ToPILImage(),
        extended_transforms.FlipChannels(),
    ])
    visualize = standard_transforms.ToTensor()

    if '5' in args.data_dir:
        src_dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        src_dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=train_joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    src_loader = data.DataLoader(src_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    tgt_val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # no val resize
        # joint_transform=val_joint_transform,
        transform=val_input_transform,
        target_transform=target_transform,
    )
    tgt_val_loader = data.DataLoader(
        tgt_val_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    # freeze bn
    for module in net.module.modules():
        if isinstance(module, torch.nn.BatchNorm2d):
            for param in module.parameters():
                param.requires_grad = False
    optimizer = optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad,
                             net.module.parameters()),
            'lr': args.learning_rate
        }],
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    # optimizer = optim.Adam(
    #     net.parameters(), lr=args.learning_rate,
    #     weight_decay=args.weight_decay
    # )

    # interp = partial(
    #     nn.functional.interpolate,
    #     size=(input_size[1], input_size[0]), mode='bilinear', align_corners=True
    # )
    # interp_tgt = partial(
    #     nn.functional.interpolate,
    #     size=(h_target, w_target), mode='bilinear', align_corners=True
    # )
    upsample = nn.Upsample(size=(h_target, w_target), mode='bilinear')

    n_class = args.num_classes

    # criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, reduction='sum')
    # criterion = torch.nn.CrossEntropyLoss(
    #     ignore_index=255, size_average=True
    # )
    criterion = CriterionDSN(ignore_index=255,
                             # size_average=False
                             )

    num_batches = len(src_loader)
    max_iter = args.iterations
    i_iter = 0
    highest_miu = 0

    while True:

        cls_loss_rec = AverageMeter()
        aug_loss_rec = AverageMeter()
        mask_rec = AverageMeter()
        confidence_rec = AverageMeter()
        miu_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()
        # load_time_rec = AverageMeter()
        # trans_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, src_data in enumerate(src_loader):
            i_iter += 1
            lr = adjust_learning_rate(args, optimizer, i_iter, max_iter)
            net.train()
            optimizer.zero_grad()

            # train with source

            # src_images, src_label, src_img_name, (load_time, trans_time) = src_data
            src_images, src_label, src_img_name = src_data
            src_images = src_images.cuda()
            src_label = src_label.cuda()
            data_time_rec.update(time.time() - tem_time)

            src_output = net(src_images)
            # src_output = interp(src_output)

            # Segmentation Loss
            cls_loss_value = criterion(src_output, src_label)

            total_loss = cls_loss_value
            total_loss.backward()
            optimizer.step()

            src_output = torch.nn.functional.upsample(input=src_output[0],
                                                      size=(h, w),
                                                      mode='bilinear',
                                                      align_corners=True)

            _, predict_labels = torch.max(src_output, 1)
            lbl_pred = predict_labels.detach().cpu().numpy()
            lbl_true = src_label.detach().cpu().numpy()
            _, _, _, mean_iu, _ = _evaluate(lbl_pred, lbl_true, 19)

            cls_loss_rec.update(cls_loss_value.detach_().item())
            miu_rec.update(mean_iu)
            # load_time_rec.update(torch.mean(load_time).item())
            # trans_time_rec.update(torch.mean(trans_time).item())

            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if i_iter % args.print_freq == 0:
                print(
                    # f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Iter: [{i_iter}/{max_iter}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    # f'Load: {load_time_rec.avg:.2f}   '
                    # f'Trans: {trans_time_rec.avg:.2f}   '
                    f'Mean iu: {miu_rec.avg*100:.1f}   '
                    f'CLS: {cls_loss_rec.avg:.2f}')
            if i_iter % args.eval_freq == 0:
                miu = test_miou(net, tgt_val_loader, upsample,
                                './dataset/info.json')
                if miu > highest_miu:
                    torch.save(
                        net.module.state_dict(),
                        osp.join(args.snapshot_dir,
                                 f'{i_iter:d}_{miu*1000:.0f}.pth'))
                    highest_miu = miu
                print(f'>>>>>>>>>Learning Rate {lr}<<<<<<<<<')
            if i_iter == max_iter:
                return