def test():
    parser = argparse.ArgumentParser(
        description='PyTorch Photo-Realistic Style Transfer Library')

    parser.add_argument('--config-file',
                        type=str,
                        default='',
                        help='path to configuration file')
    parser.add_argument('--outputDir',
                        type=str,
                        default='Demo',
                        help='name of output folder')
    parser.add_argument('--saveOrig', default=False, action='store_true')
    parser.add_argument('--contentDir',
                        type=str,
                        default='',
                        help='path to directory of content images')
    parser.add_argument('--styleDir',
                        type=str,
                        default='',
                        help='path to directory of style images')
    parser.add_argument('--content',
                        type=str,
                        default='',
                        help='path to content image')
    parser.add_argument('--style',
                        type=str,
                        default='',
                        help='path to style image')
    parser.add_argument(
        '--mode',
        type=int,
        default=0,
        help=
        'Inference mode: 0 - Single Content; 1 - Multiple Content (Stored in a directory)'
    )

    # advanced options
    parser.add_argument('--content-seg',
                        default='',
                        type=str,
                        help='path to content mask image')
    parser.add_argument('--style-seg',
                        default='',
                        type=str,
                        help='path to style mask image')
    parser.add_argument('--resize',
                        default=False,
                        action='store_true',
                        help='resize original image to accelerate computing')
    args = parser.parse_args()

    # update configuration
    cfg.merge_from_file(args.config_file)

    cfg.freeze()

    test_transform = build_transform(cfg,
                                     train=False,
                                     interpolation=Image.BICUBIC,
                                     normalize=True)
    test_seg_transform = build_transform(cfg,
                                         train=False,
                                         interpolation=Image.NEAREST,
                                         normalize=False)

    if args.content_seg or args.style_seg:
        mask_on = True
    else:
        mask_on = False

    # create output dir
    if cfg.OUTPUT_DIR:
        os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

    # create logger
    logger = setup_logger(cfg.MODEL.NAME,
                          save_dir=cfg.OUTPUT_DIR,
                          filename=cfg.MODEL.NAME + '.txt')

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    logger.info("Using {} GPUs".format(num_gpus))

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + get_pretty_env_info())

    logger.info('Loaded configuration file {}'.format(args.config_file))
    logger.info("Running with config:\n{}".format(cfg))

    # create output dir
    output_dir = os.path.join(cfg.OUTPUT_DIR, args.outputDir)
    os.makedirs(output_dir, exist_ok=True)
    logger.info('Output Dir Created: {}'.format(output_dir))

    # create model
    model = model_factory[cfg.MODEL.NAME](cfg)
    logger.info(model)

    # inference
    if args.mode == 0:
        if mask_on:
            # 1-content | N-style | 1-mask, process single content image
            assert args.content, 'Path to the content image should be non-empty'
            assert args.style, 'Paths to the style images should be non-empty'
            assert args.content_seg, 'Path to the content segment image should be non-empty'
            assert args.style_seg, 'Path to the style segment image should be non-empty'

            content_img_path = os.path.join(cfg.INPUT_DIR, args.content)
            style_img_path = os.path.join(cfg.INPUT_DIR, args.style)
            content_seg_path = os.path.join(
                cfg.INPUT_DIR,
                args.content_seg) if args.content_seg else args.content_seg
            style_seg_path = os.path.join(
                cfg.INPUT_DIR,
                args.style_seg) if args.style_seg else args.style_seg

            name = content_img_path.split('/')[-1]
            name = name[:name.rindex('.')]

            # load image
            content_img = default_loader(content_img_path)
            style_img = default_loader(style_img_path)

            content_copy = content_img.copy()
            cw, ch = content_copy.width, content_copy.height
            sw, sh = style_img.width, style_img.height

            if args.resize:
                # new size after resizing content image
                new_cw, new_ch = memory_limit_image_size(content_img,
                                                         cfg.INPUT.MIN_SIZE,
                                                         cfg.INPUT.MAX_SIZE,
                                                         logger=logger)
                # new size after resizing style image
                new_sw, new_sh = memory_limit_image_size(style_img,
                                                         cfg.INPUT.MIN_SIZE,
                                                         cfg.INPUT.MAX_SIZE,
                                                         logger=logger)
            else:
                new_cw, new_ch = cw, ch
                new_sw, new_sh = sw, sh

            content_img = test_transform(content_img).unsqueeze(0)
            style_img = test_transform(style_img).unsqueeze(0)

            cont_seg = Image.open(content_seg_path)
            styl_seg = Image.open(style_seg_path)

            # resize segmentation image the same size as corresponding images
            cont_seg = cont_seg.resize((new_cw, new_ch), Image.NEAREST)
            styl_seg = styl_seg.resize((new_sw, new_sh), Image.NEAREST)
            cont_seg = test_seg_transform(cont_seg)
            styl_seg = test_seg_transform(styl_seg)

            with torch.no_grad():
                infer_image(cfg,
                            name,
                            model,
                            content_img,
                            style_img,
                            logger,
                            output_dir,
                            ch,
                            cw,
                            save_orig=args.saveOrig,
                            content_seg_img=cont_seg,
                            style_seg_img=styl_seg,
                            orig_content=content_copy,
                            test_transform=test_transform)

        elif args.content and args.style:
            # 1-content | 1-style, process single pair of images
            content_img_path = os.path.join(cfg.INPUT_DIR, args.content)
            style_img_path = os.path.join(cfg.INPUT_DIR, args.style)
            name = content_img_path.split('/')[-1]
            name = name[:name.rindex('.')]

            content_img = default_loader(content_img_path)
            style_img = default_loader(style_img_path)
            ch, cw = content_img.width, content_img.height
            content_copy = content_img.copy()

            if args.resize:
                # new size after resizing content image
                new_cw, new_ch = memory_limit_image_size(content_img,
                                                         cfg.INPUT.MIN_SIZE,
                                                         cfg.INPUT.MAX_SIZE,
                                                         logger=logger)
                # new size after resizing style image
                new_sw, new_sh = memory_limit_image_size(style_img,
                                                         cfg.INPUT.MIN_SIZE,
                                                         cfg.INPUT.MAX_SIZE,
                                                         logger=logger)
            else:
                new_cw, new_ch = cw, ch

            content_img = test_transform(content_img).unsqueeze(0)
            style_img = test_transform(style_img).unsqueeze(0)

            with torch.no_grad():
                infer_image(cfg,
                            name,
                            model,
                            content_img,
                            style_img,
                            logger,
                            output_dir,
                            ch,
                            cw,
                            save_orig=args.saveOrig,
                            orig_content=content_copy,
                            test_transform=test_transform)
        else:
            raise RuntimeError('Invalid Argument Setting')

    else:
        if args.contentDir and args.styleDir:
            # 1-vs-1, but process multiple images in the directory
            content_img, style_img, names = prepare_loading(
                cfg,
                os.path.join(cfg.INPUT_DIR, args.contentDir),
                os.path.join(cfg.INPUT_DIR, args.styleDir),
            )
            iterator = tqdm(range(len(content_img)))
            for i in iterator:
                c_img, s_img = content_img[i], style_img[i]
                cw, ch = c_img.width, c_img.height
                c_copy = c_img.copy()

                if args.resize:
                    # new size after resizing content image
                    new_cw, new_ch = memory_limit_image_size(
                        c_img,
                        cfg.INPUT.MIN_SIZE,
                        cfg.INPUT.MAX_SIZE,
                        logger=logger)
                    # new size after resizing style image
                    new_sw, new_sh = memory_limit_image_size(
                        s_img,
                        cfg.INPUT.MIN_SIZE,
                        cfg.INPUT.MAX_SIZE,
                        logger=logger)
                else:
                    new_cw, new_ch = cw, ch

                c_img = test_transform(c_img).unsqueeze(0)
                s_img = test_transform(s_img).unsqueeze(0)

                name = names[i]

                with torch.no_grad():
                    infer_image(cfg,
                                name,
                                model,
                                c_img,
                                s_img,
                                logger,
                                output_dir,
                                ch,
                                cw,
                                save_orig=args.saveOrig,
                                orig_content=c_copy,
                                test_transform=test_transform)

                iterator.set_description(desc='Test Case {}'.format(i))
        else:
            raise RuntimeError('Invalid Argument Setting')

    logger.info('Done!')
Beispiel #2
0
def train_lst():
    parser = argparse.ArgumentParser(
        description='PyTorch Style Transfer -- LinearStyleTransfer')

    parser.add_argument('--config-file',
                        type=str,
                        default='',
                        help='path to configuration file')

    args = parser.parse_args()

    cfg.merge_from_file(args.config_file)

    cfg.freeze()

    # create output dir
    if cfg.OUTPUT_DIR:
        os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

    # create logger
    logger = setup_logger(cfg.MODEL.NAME,
                          save_dir=cfg.OUTPUT_DIR,
                          filename=cfg.MODEL.NAME + '.txt')

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    logger.info("Using {} GPUs".format(num_gpus))

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + get_pretty_env_info())

    logger.info('Loaded configuration file {}'.format(args.config_file))
    logger.info("Running with config:\n{}".format(cfg))

    # create model
    model = get_model(cfg.MODEL.NAME, cfg)

    # push model to device
    model.to(cfg.DEVICE)

    logger.info(model)

    # create dataloader
    train_path_content, train_path_style = get_data(cfg, dtype='train')
    content_dataset = DatasetNoSeg(cfg, train_path_content, train=True)
    style_dataset = DatasetNoSeg(cfg, train_path_style, train=True)

    # content loader
    sampler = torch.utils.data.sampler.RandomSampler(content_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler, cfg.DATALOADER.BATCH_SIZE, drop_last=False)
    content_loader = DataLoader(content_dataset,
                                batch_sampler=IterationBasedBatchSampler(
                                    batch_sampler,
                                    cfg.OPTIMIZER.MAX_ITER,
                                    start_iter=0),
                                num_workers=cfg.DATALOADER.NUM_WORKERS)
    logger.info('Content Loader Created!')

    # style loader
    sampler = torch.utils.data.sampler.RandomSampler(style_dataset)
    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler, cfg.DATALOADER.BATCH_SIZE, drop_last=False)
    style_loader = DataLoader(style_dataset,
                              batch_sampler=IterationBasedBatchSampler(
                                  batch_sampler,
                                  cfg.OPTIMIZER.MAX_ITER,
                                  start_iter=0),
                              num_workers=cfg.DATALOADER.NUM_WORKERS)
    logger.info('Style Loader Created!')
    content_loader = iter(content_loader)
    style_loader = iter(style_loader)

    optimizer = build_optimizer(cfg, model.trans_layer)
    lr_scheduler = build_lr_scheduler(cfg, optimizer)
    logger.info("Using Optimizer: ")
    logger.info(optimizer)
    logger.info("Using LR Scheduler: {}".format(
        cfg.OPTIMIZER.LR_SCHEDULER.NAME))

    iterator = tqdm(range(cfg.OPTIMIZER.MAX_ITER))

    writer = SummaryWriter(log_dir=cfg.OUTPUT_DIR)
    # start training
    for i in iterator:
        content_img = next(content_loader).to(cfg.DEVICE)
        style_img = next(style_loader).to(cfg.DEVICE)
        if content_img.shape[0] != style_img.shape[0]:
            continue

        g_t = model.forward_with_trans(content_img, style_img)

        loss, style_loss, content_loss = model.cal_trans_loss(
            g_t, content_img, style_img)

        # update info
        iterator.set_description(
            desc=
            'Iteration: {} -- Loss: {:.3f} -- Content Loss: {:.3f} -- Style Loss: {:.3f}'
            .format(i +
                    1, loss.item(), content_loss.item(), style_loss.item()))
        writer.add_scalar('loss_content', content_loss.item(), i + 1)
        writer.add_scalar('loss_style', style_loss.item(), i + 1)

        # update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # update lr
        lr_scheduler.step()

        # save image
        if i % 1000 == 0:
            n = content_img.shape[0]
            all_imgs = torch.cat((content_img, style_img, g_t), dim=0)
            save_image(all_imgs,
                       os.path.join(cfg.OUTPUT_DIR, '{}.jpg'.format(i)),
                       nrow=n)

        if i % 10000 == 0:
            torch.save(model.trans_layer.state_dict(),
                       os.path.join(cfg.OUTPUT_DIR, '{}_lst.pth'.format(i)))

    torch.save(model.trans_layer.state_dict(),
               os.path.join(cfg.OUTPUT_DIR, 'final_lst.pth'))
    writer.close()
Beispiel #3
0
def test():
    parser = argparse.ArgumentParser(
        description='PyTorch Style Transfer Library')

    parser.add_argument('--config-file',
                        type=str,
                        default='',
                        help='path to configuration file')
    parser.add_argument('--outputDir',
                        type=str,
                        default='Demo',
                        help='name of output folder')
    parser.add_argument('--saveOrig', default=False, action='store_true')
    parser.add_argument('--contentDir',
                        type=str,
                        default='',
                        help='path to directory of content images')
    parser.add_argument('--styleDir',
                        type=str,
                        default='',
                        help='path to directory of style images')
    parser.add_argument('--content',
                        type=str,
                        default='',
                        help='path to content image')
    parser.add_argument('--style',
                        type=str,
                        default='',
                        help='path to style image')
    parser.add_argument(
        '--mode',
        type=int,
        default=0,
        help=
        'Inference mode: 0 - Single Content; 1 - Multiple Content (Stored in a directory)'
    )

    # advanced options
    parser.add_argument(
        '--styleInterpWeights',
        default='',
        type=str,
        help='The weight for blending the style of multiple style images')
    parser.add_argument('--mask',
                        default='',
                        type=str,
                        help='path to mask image')
    parser.add_argument('--resize',
                        default=False,
                        action='store_true',
                        help='resize image to acclerate computing')
    args = parser.parse_args()
    style_weight_name = args.styleInterpWeights
    if args.styleInterpWeights:
        args.styleInterpWeights = [
            float(each.strip()) for each in args.styleInterpWeights.split(',')
        ]
        args.styleInterpWeights = [
            each / sum(args.styleInterpWeights)
            for each in args.styleInterpWeights
        ]  # normalize weights

    # update configuration
    cfg.merge_from_file(args.config_file)

    cfg.freeze()

    test_transform = build_transform(cfg,
                                     train=False,
                                     normalize=True,
                                     interpolation=Image.BICUBIC)
    test_seg_transform = build_transform(cfg,
                                         train=False,
                                         interpolation=Image.NEAREST,
                                         normalize=False)

    mask_on = args.mask != ''
    interpolate_on = args.styleInterpWeights != ''
    assert not (
        mask_on and interpolate_on
    ), 'Spatial control and Style Interpolation cannot be activated simultaneously.'

    # create output dir
    if cfg.OUTPUT_DIR:
        os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

    # create logger
    logger = setup_logger(cfg.MODEL.NAME,
                          save_dir=cfg.OUTPUT_DIR,
                          filename=cfg.MODEL.NAME + '.txt')

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    logger.info("Using {} GPUs".format(num_gpus))

    logger.info("Collecting env info (might take some time)")
    logger.info("\n" + get_pretty_env_info())

    logger.info('Loaded configuration file {}'.format(args.config_file))
    logger.info("Running with config:\n{}".format(cfg))

    # create output dir
    output_dir = os.path.join(cfg.OUTPUT_DIR, args.outputDir)
    os.makedirs(output_dir, exist_ok=True)
    logger.info('Output Dir Created: {}'.format(output_dir))

    # create model
    model = model_factory[cfg.MODEL.NAME](cfg)
    logger.info(model)

    # inference
    if args.mode == 0:
        if interpolate_on:
            # 1-content | N-style, process single content image
            assert args.content, 'Path to the content image should be non-empty'
            assert args.style, 'Paths to the style images should be non-empty'
            assert args.styleInterpWeights, 'Style interpolation weights must be provided'
            assert cfg.MODEL.NAME != 'LST', 'Interpolation of LinearStyleTransfer is currently not supported!'
            assert cfg.MODEL.NAME != 'FPS', 'Interpolation of FastPhotoTransfer is currently not supported, but should be similar to WCT!'

            style_paths = args.style.split(',')
            content_img_path = os.path.join(cfg.INPUT_DIR, args.content)
            style_img_paths = [
                os.path.join(cfg.INPUT_DIR, each) for each in style_paths
            ]
            name = content_img_path.split('/')[-1]
            name = name[:name.rindex('.')] + '_' + style_weight_name

            # load image
            content_img = default_loader(content_img_path)
            style_imgs = [default_loader(each) for each in style_img_paths]
            ch, cw = content_img.width, content_img.height

            if args.resize:
                # new size after resizing content image
                new_cw, new_ch = memory_limit_image_size(content_img,
                                                         cfg.INPUT.MIN_SIZE,
                                                         cfg.INPUT.MAX_SIZE,
                                                         logger=logger)
                # new size after resizing style image
                new_sw, new_sh = list(
                    zip(*[
                        memory_limit_image_size(each,
                                                cfg.INPUT.MIN_SIZE,
                                                cfg.INPUT.MAX_SIZE,
                                                logger=logger)
                        for each in style_imgs
                    ]))
            else:
                new_cw, new_ch = cw, ch

            content_img = test_transform(content_img).unsqueeze(0)
            style_imgs = [
                test_transform(each).unsqueeze(0) for each in style_imgs
            ]

            infer_image(cfg,
                        name,
                        model,
                        content_img,
                        style_imgs,
                        logger,
                        output_dir,
                        ch,
                        cw,
                        save_orig=args.saveOrig,
                        alpha=cfg.MODEL.ALPHA,
                        style_interp_weights=args.styleInterpWeights)

        elif mask_on:
            # 1-content | N-style | 1-mask, process single content image
            assert args.content, 'Path to the content image should be non-empty'
            assert args.style, 'Paths to the style images should be non-empty'
            assert args.mask, 'Path to the mask image should be non-empty'
            assert cfg.MODEL.NAME != 'LST', 'Spatial Control of LinearStyleTransfer is currently not supported!'
            assert cfg.MODEL.NAME != 'FPS', 'Spatial Control of FastPhotoTransfer is currently not supported, but should be similar to WCT!'

            style_paths = args.style.split(',')
            content_img_path = os.path.join(cfg.INPUT_DIR, args.content)
            style_img_paths = [
                os.path.join(cfg.INPUT_DIR, each) for each in style_paths
            ]
            name = content_img_path.split('/')[-1]
            name = name[:name.rindex('.')] + '_mask'

            # read image
            mask_img = default_loader(os.path.join(cfg.INPUT_DIR, args.mask))
            content_img = default_loader(content_img_path)
            style_imgs = [default_loader(each) for each in style_img_paths]

            cw, ch = content_img.width, content_img.height

            if args.resize:
                # new size after resizing content image
                new_cw, new_ch = memory_limit_image_size(content_img,
                                                         cfg.INPUT.MIN_SIZE,
                                                         cfg.INPUT.MAX_SIZE,
                                                         logger=logger)
                # new size after resizing style image
                new_sw, new_sh = list(
                    zip(*[
                        memory_limit_image_size(each,
                                                cfg.INPUT.MIN_SIZE,
                                                cfg.INPUT.MAX_SIZE,
                                                logger=logger)
                        for each in style_imgs
                    ]))
            else:
                new_cw, new_ch = cw, ch

            content_img = test_transform(content_img).unsqueeze(0)
            style_imgs = [
                test_transform(each).unsqueeze(0) for each in style_imgs
            ]
            mask_img = mask_img.resize((new_cw, new_ch), Image.NEAREST)
            mask_img = test_seg_transform(mask_img)

            # read content image
            content_img_path = os.path.join(cfg.INPUT_DIR, args.content)
            style_img_paths = [
                os.path.join(cfg.INPUT_DIR, each) for each in style_paths
            ]

            infer_image(cfg,
                        name,
                        model,
                        content_img,
                        style_imgs,
                        logger,
                        output_dir,
                        ch,
                        cw,
                        save_orig=args.saveOrig,
                        alpha=cfg.MODEL.ALPHA,
                        mask_img=mask_img)

        elif args.content and args.style:
            # 1-content | 1-style, process single pair of images
            content_img_path = os.path.join(cfg.INPUT_DIR, args.content)
            style_img_path = os.path.join(cfg.INPUT_DIR, args.style)
            name = content_img_path.split('/')[-1]
            name = name[:name.rindex('.')]

            # read images
            content_img = default_loader(content_img_path)
            style_img = default_loader(style_img_path)

            cw, ch = content_img.width, content_img.height

            if args.resize:
                # new size after resizing content image
                new_cw, new_ch = memory_limit_image_size(content_img,
                                                         cfg.INPUT.MIN_SIZE,
                                                         cfg.INPUT.MAX_SIZE,
                                                         logger=logger)
                # new size after resizing style image
                new_sw, new_sh = memory_limit_image_size(style_img,
                                                         cfg.INPUT.MIN_SIZE,
                                                         cfg.INPUT.MAX_SIZE,
                                                         logger=logger)
            else:
                new_cw, new_ch = cw, ch

            content_img = test_transform(content_img).unsqueeze(0)
            style_img = test_transform(style_img).unsqueeze(0)

            infer_image(cfg,
                        name,
                        model,
                        content_img,
                        style_img,
                        logger,
                        output_dir,
                        ch,
                        cw,
                        save_orig=args.saveOrig,
                        alpha=cfg.MODEL.ALPHA)
        else:
            raise RuntimeError('Invalid Argument Setting')

    else:
        if args.contentDir and args.styleDir:
            # 1-vs-1, but process multiple images in the directory
            content_img, style_img, names = prepare_loading(
                cfg,
                os.path.join(cfg.INPUT_DIR, args.contentDir),
                os.path.join(cfg.INPUT_DIR, args.styleDir),
            )
            iterator = tqdm(range(len(content_img)))
            for i in iterator:
                c_img, s_img = content_img[i], style_img[i]
                name = names[i]
                cw, ch = c_img.width, c_img.height

                if args.resize:
                    # new size after resizing content image
                    new_cw, new_ch = memory_limit_image_size(
                        c_img,
                        cfg.INPUT.MIN_SIZE,
                        cfg.INPUT.MAX_SIZE,
                        logger=logger)
                    # new size after resizing style image
                    new_sw, new_sh = memory_limit_image_size(
                        s_img,
                        cfg.INPUT.MIN_SIZE,
                        cfg.INPUT.MAX_SIZE,
                        logger=logger)
                else:
                    new_cw, new_ch = cw, ch

                c_img = test_transform(c_img).unsqueeze(0)
                s_img = test_transform(s_img).unsqueeze(0)

                infer_image(cfg,
                            name,
                            model,
                            c_img,
                            s_img,
                            logger,
                            output_dir,
                            ch,
                            cw,
                            save_orig=args.saveOrig,
                            alpha=cfg.MODEL.ALPHA)

                iterator.set_description(desc='Test Case {}'.format(i))
        else:
            raise RuntimeError('Invalid Argument Setting')

    logger.info('Done!')