Пример #1
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(
        root=args.source_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=args.resize_ratio,
                                scale=(0.5, 1.)),
            T.ColorJitter(brightness=0.3, contrast=0.3),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(
        root=args.target_root,
        transforms=T.Compose([
            T.RandomResizedCrop(size=args.train_size,
                                ratio=(2., 2.),
                                scale=(0.5, 1.)),
            T.RandomHorizontalFlip(),
            T.NormalizeAndTranspose(),
        ]),
    )
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(
        root=args.target_root,
        split='val',
        transforms=T.Compose([
            T.Resize(image_size=args.test_input_size,
                     label_size=args.test_output_size),
            T.NormalizeAndTranspose(),
        ]),
    )
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=1,
                                   shuffle=False,
                                   pin_memory=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    num_classes = train_source_dataset.num_classes
    model = models.__dict__[args.arch](num_classes=num_classes).to(device)
    discriminator = Discriminator(num_classes=num_classes).to(device)

    # define optimizer and lr scheduler
    optimizer = SGD(model.get_parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    optimizer_d = Adam(discriminator.parameters(),
                       lr=args.lr_d,
                       betas=(0.9, 0.99))
    lr_scheduler = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))
    lr_scheduler_d = LambdaLR(
        optimizer_d, lambda x:
        (1. - float(x) / args.epochs / args.iters_per_epoch)**(args.lr_power))

    # optionally resume from a checkpoint
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        discriminator.load_state_dict(checkpoint['discriminator'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        optimizer_d.load_state_dict(checkpoint['optimizer_d'])
        lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d'])
        args.start_epoch = checkpoint['epoch'] + 1

    # define loss function (criterion)
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=args.ignore_label).to(device)
    dann = DomainAdversarialEntropyLoss(discriminator)
    interp_train = nn.Upsample(size=args.train_size[::-1],
                               mode='bilinear',
                               align_corners=True)
    interp_val = nn.Upsample(size=args.test_output_size[::-1],
                             mode='bilinear',
                             align_corners=True)

    # define visualization function
    decode = train_source_dataset.decode_target

    def visualize(image, pred, label, prefix):
        """
        Args:
            image (tensor): 3 x H x W
            pred (tensor): C x H x W
            label (tensor): H x W
            prefix: prefix of the saving image
        """
        image = image.detach().cpu().numpy()
        pred = pred.detach().max(dim=0)[1].cpu().numpy()
        label = label.cpu().numpy()
        for tensor, name in [
            (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))),
             "image"), (decode(label), "label"), (decode(pred), "pred")
        ]:
            tensor.save(logger.get_image_path("{}_{}.png".format(prefix,
                                                                 name)))

    if args.phase == 'test':
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           visualize, args)
        print(confmat)
        return

    # start training
    best_iou = 0.
    for epoch in range(args.start_epoch, args.epochs):
        logger.set_epoch(epoch)
        print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr())
        # train for one epoch
        train(train_source_iter, train_target_iter, model, interp_train,
              criterion, dann, optimizer, lr_scheduler, optimizer_d,
              lr_scheduler_d, epoch, visualize if args.debug else None, args)

        # evaluate on validation set
        confmat = validate(val_target_loader, model, interp_val, criterion,
                           None, args)
        print(confmat.format(train_source_dataset.classes))
        acc_global, acc, iu = confmat.compute()

        # calculate the mean iou over partial classes
        indexes = [
            train_source_dataset.classes.index(name)
            for name in train_source_dataset.evaluate_classes
        ]
        iu = iu[indexes]
        mean_iou = iu.mean()

        # remember best acc@1 and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'discriminator': discriminator.state_dict(),
                'optimizer': optimizer.state_dict(),
                'optimizer_d': optimizer_d.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'lr_scheduler_d': lr_scheduler_d.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if mean_iou > best_iou:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
        best_iou = max(best_iou, mean_iou)
        print("Target: {} Best: {}".format(mean_iou, best_iou))

    logger.close()
def main(args):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    train_transform = T.Compose([
        T.RandomResizedCrop(size=args.train_size,
                            ratio=args.resize_ratio,
                            scale=(0.5, 1.)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # define networks (both generators and discriminators)
    netG_S2T = cyclegan.generator.__dict__[args.netG](
        ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
    netG_T2S = cyclegan.generator.__dict__[args.netG](
        ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
    netD_S = cyclegan.discriminator.__dict__[args.netD](
        ndf=args.ndf, norm=args.norm).to(device)
    netD_T = cyclegan.discriminator.__dict__[args.netD](
        ndf=args.ndf, norm=args.norm).to(device)

    # create image buffer to store previously generated images
    fake_S_pool = ImagePool(args.pool_size)
    fake_T_pool = ImagePool(args.pool_size)

    # define optimizer and lr scheduler
    optimizer_G = Adam(itertools.chain(netG_S2T.parameters(),
                                       netG_T2S.parameters()),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    optimizer_D = Adam(itertools.chain(netD_S.parameters(),
                                       netD_T.parameters()),
                       lr=args.lr,
                       betas=(args.beta1, 0.999))
    lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs
                                                ) / float(args.epochs_decay)
    lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)
    lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)

    # optionally resume from a checkpoint
    if args.resume:
        print("Resume from", args.resume)
        checkpoint = torch.load(args.resume, map_location='cpu')
        netG_S2T.load_state_dict(checkpoint['netG_S2T'])
        netG_T2S.load_state_dict(checkpoint['netG_T2S'])
        netD_S.load_state_dict(checkpoint['netD_S'])
        netD_T.load_state_dict(checkpoint['netD_T'])
        optimizer_G.load_state_dict(checkpoint['optimizer_G'])
        optimizer_D.load_state_dict(checkpoint['optimizer_D'])
        lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
        lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])
        args.start_epoch = checkpoint['epoch'] + 1

    if args.phase == 'test':
        transform = T.Compose([
            T.Resize(image_size=args.test_input_size),
            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
        ])
        train_source_dataset.translate(transform, args.translated_root)
        return

    # define loss function
    criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()

    # define visualization function
    tensor_to_image = Compose(
        [Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
         ToPILImage()])

    def visualize(image, name):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            name: name of the saving image
        """
        tensor_to_image(image).save(
            logger.get_image_path("{}.png".format(name)))

    # start training
    for epoch in range(args.start_epoch, args.epochs + args.epochs_decay):
        logger.set_epoch(epoch)
        print(lr_scheduler_G.get_lr())

        # train for one epoch
        train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S,
              netD_T, criterion_gan, criterion_cycle, criterion_identity,
              optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch,
              visualize, args)

        # update learning rates
        lr_scheduler_G.step()
        lr_scheduler_D.step()

        # save checkpoint
        torch.save(
            {
                'netG_S2T': netG_S2T.state_dict(),
                'netG_T2S': netG_T2S.state_dict(),
                'netD_S': netD_S.state_dict(),
                'netD_T': netD_T.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
                'lr_scheduler_G': lr_scheduler_G.state_dict(),
                'lr_scheduler_D': lr_scheduler_D.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))

    if args.translated_root is not None:
        transform = T.Compose([
            T.Resize(image_size=args.test_input_size),
            T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
        ])
        train_source_dataset.translate(transform, args.translated_root)

    logger.close()