Ejemplo n.º 1
0
def args(tmpdir):
    return argparse.Namespace(
        image_source_dir=tmpdir,
        image_results_dir=tmpdir,
        device=misc.get_device(),
        impl_params=True,
    )
Ejemplo n.º 2
0
def parse_input():
    # TODO: write CLI
    image_source_dir = None
    image_guides_dir = None
    image_results_dir = None
    impl_params = True

    def process_dir(dir):
        dir = path.abspath(path.expanduser(dir))
        os.makedirs(dir, exist_ok=True)
        return dir

    here = path.dirname(__file__)

    if image_source_dir is None:
        image_source_dir = path.join(here, "images", "source")
    image_source_dir = process_dir(image_source_dir)

    if image_guides_dir is None:
        image_guides_dir = path.join(here, "images", "guides")
    image_guides_dir = process_dir(image_guides_dir)

    if image_results_dir is None:
        image_results_dir = path.join(here, "images", "results")
    image_results_dir = process_dir(image_results_dir)

    device = get_device()

    return Namespace(
        image_source_dir=image_source_dir,
        image_guides_dir=image_guides_dir,
        image_results_dir=image_results_dir,
        device=device,
        impl_params=impl_params,
    )
Ejemplo n.º 3
0
def args(tmpdir):
    return argparse.Namespace(
        image_source_dir=tmpdir,
        image_results_dir=tmpdir,
        dataset_dir=tmpdir,
        model_dir=tmpdir,
        device=misc.get_device(),
        impl_params=bool,
        instance_norm=bool,
    )
Ejemplo n.º 4
0
def stylization_args(tmpdir, patch_argv):
    style = "candy"
    patch_argv("stylization.py", style)
    return argparse.Namespace(
        style=style,
        content=["chicago", "hoovertowernight"],
        images_source_dir=tmpdir,
        images_results_dir=tmpdir,
        models_dir=tmpdir,
        impl_params=True,
        instance_norm=True,
        device=misc.get_device(),
    )
Ejemplo n.º 5
0
def parse_input():
    # TODO: write CLI
    image_source_dir = None
    image_results_dir = None
    dataset_dir = None
    model_dir = None
    device = None
    impl_params = True
    instance_norm = False

    def process_dir(dir):
        dir = path.abspath(path.expanduser(dir))
        os.makedirs(dir, exist_ok=True)
        return dir

    here = path.dirname(__file__)

    if image_source_dir is None:
        image_source_dir = path.join(here, "data", "images", "source")
    image_source_dir = process_dir(image_source_dir)

    if image_results_dir is None:
        image_results_dir = path.join(here, "data", "images", "results")
    image_results_dir = process_dir(image_results_dir)

    if dataset_dir is None:
        dataset_dir = path.join(here, "data", "images", "dataset", "coco",
                                "train2014")
    dataset_dir = process_dir(dataset_dir)

    if model_dir is None:
        model_dir = path.join(here, "data", "models")
    model_dir = process_dir(model_dir)

    device = misc.get_device(device=device)

    return Namespace(
        image_source_dir=image_source_dir,
        image_results_dir=image_results_dir,
        dataset_dir=dataset_dir,
        model_dir=model_dir,
        device=device,
        impl_params=impl_params,
        instance_norm=instance_norm,
    )
Ejemplo n.º 6
0
def training_args(tmpdir, patch_argv):
    patch_argv("training.py")
    return argparse.Namespace(
        style=[
            "starry_night",
            "la_muse",
            "composition_vii",
            "the_wave",
            "candy",
            "udnie",
            "the_scream",
            "mosaic",
            "feathers",
        ],
        images_source_dir=tmpdir,
        models_dir=tmpdir,
        dataset_dir=tmpdir,
        impl_params=bool,
        instance_norm=bool,
        device=misc.get_device(),
    )
from pystiche.demo import demo_images, demo_logger
from pystiche.enc import vgg19_multi_layer_encoder
from pystiche.image import show_image
from pystiche.loss import PerceptualLoss
from pystiche.misc import get_device, get_input_image
from pystiche.ops import (
    FeatureReconstructionOperator,
    MRFOperator,
    MultiLayerEncodingOperator,
)
from pystiche.optim import default_image_pyramid_optim_loop
from pystiche.pyramid import ImagePyramid

print(f"I'm working with pystiche=={pystiche.__version__}")

device = get_device()
print(f"I'm working with {device}")

images = demo_images()
images.download()

########################################################################################
# At first we define a :class:`~pystiche.loss.perceptual.PerceptualLoss` that is used
# as optimization ``criterion``.

multi_layer_encoder = vgg19_multi_layer_encoder()

content_layer = "relu4_2"
content_encoder = multi_layer_encoder.extract_single_layer_encoder(
    content_layer)
content_weight = 1e0
Ejemplo n.º 8
0
 def _process_device(self, args):
     if "device" in args:
         args.device = misc.get_device(args.device)
Ejemplo n.º 9
0
def training(
    content_image_loader: DataLoader,
    style_image: Union[str, torch.Tensor],
    impl_params: bool = True,
    instance_norm: Optional[bool] = None,
    hyper_parameters: Optional[HyperParameters] = None,
    quiet: bool = False,
) -> nn.Module:
    r"""Training a transformer for the NST.

    Args:
        content_image_loader: Content images used as input for the ``transformer``.
        style_image: Style image on which the ``transformer`` should be trained. If
            ``str``, the image is read from
            :func:`~pystiche_papers.johnson_alahi_li_2016.images`.
        impl_params: If ``True``, uses the parameters used in the reference
            implementation of the original authors rather than what is described in
            the paper. For details see below.
        instance_norm: If ``True``, use :class:`~torch.nn.InstanceNorm2d` rather than
            :class:`~torch.nn.BatchNorm2d` as described in the paper. If omitted,
            defaults to ``impl_params``.
        hyper_parameters: If omitted,
            :func:`~pystiche_papers.johnson_alahi_li_2016.hyper_parameters` is used.
        quiet: If ``True``, not information is logged during the optimization. Defaults
            to ``False``.

    If ``impl_params is True`` , an external preprocessing of the images is used.

    """
    if isinstance(style_image, torch.Tensor):
        device = style_image.device
    else:
        device = misc.get_device()
        images = _images()
        style_image = images[style_image].read(device=device)

    if instance_norm is None:
        instance_norm = impl_params

    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters()

    transformer = _transformer(impl_params=impl_params,
                               instance_norm=instance_norm)
    transformer = transformer.train().to(device)

    criterion = perceptual_loss(impl_params=impl_params,
                                hyper_parameters=hyper_parameters)
    criterion = criterion.eval().to(device)

    optimizer = _optimizer(transformer)

    style_transform = _style_transform(hyper_parameters=hyper_parameters)
    style_transform = style_transform.to(device)
    style_image = style_transform(style_image)
    style_image = batch_up_image(style_image, loader=content_image_loader)

    preprocessor = _preprocessor()
    preprocessor = preprocessor.to(device)
    style_image = preprocessor(style_image)

    criterion.set_style_image(style_image)

    def criterion_update_fn(input_image: torch.Tensor,
                            criterion: nn.Module) -> None:
        cast(loss.PerceptualLoss, criterion).set_content_image(input_image)

    return optim.model_optimization(
        content_image_loader,
        transformer,
        criterion,
        criterion_update_fn,
        optimizer=optimizer,
        quiet=quiet,
    )
Ejemplo n.º 10
0
def training(
    content_image_loader: DataLoader,
    style: Union[str, torch.Tensor],
    impl_params: bool = True,
    instance_norm: bool = True,
    hyper_parameters: Optional[HyperParameters] = None,
    quiet: bool = False,
) -> nn.Module:
    r"""Training a transformer for the NST.

    Args:
        content_image_loader: Content images used as input for the ``transformer``.
        style: Style image on which the ``transformer`` should be trained. If the
            input is :class:`str`, the style image is read from
            :func:`~pystiche_papers.ulyanov_et_al_2016.images`.
        impl_params: Switch the behavior and hyper-parameters between the reference
            implementation of the original authors and what is described in the paper.
            For details see :ref:`here <li_wand_2016-impl_params>`.
        instance_norm: Switch the behavior and hyper-parameters between both
            publications of the original authors. For details see
            :ref:`here <ulyanov_et_al_2016-instance_norm>`.
        hyper_parameters: Hyper parameters. If omitted,
            :func:`~pystiche_papers.ulyanov_et_al_2016.hyper_parameters` is used.
        quiet: If ``True``, not information is logged during the optimization. Defaults
            to ``False``.

    """
    if hyper_parameters is None:
        hyper_parameters = _hyper_parameters(impl_params=impl_params,
                                             instance_norm=instance_norm)

    if isinstance(style, str):
        device = misc.get_device()
        images = _images()
        style_image = images[style].read(device=device)
    else:
        style_image = style
        device = style_image.device

    transformer = _transformer(
        impl_params=impl_params,
        instance_norm=instance_norm,
    )
    transformer = transformer.train()
    transformer = transformer.to(device)

    criterion = perceptual_loss(
        impl_params=impl_params,
        instance_norm=instance_norm,
        hyper_parameters=hyper_parameters,
    )
    criterion = criterion.eval()
    criterion = criterion.to(device)

    optimizer = _optimizer(transformer,
                           impl_params=impl_params,
                           instance_norm=instance_norm)
    lr_scheduler = _lr_scheduler(
        optimizer,
        impl_params=impl_params,
        instance_norm=instance_norm,
        hyper_parameters=hyper_parameters,
    )

    style_transform = _style_transform(
        impl_params=impl_params,
        instance_norm=instance_norm,
        hyper_parameters=hyper_parameters,
    )
    style_transform = style_transform.to(device)
    style_image = style_transform(style_image)
    style_image = batch_up_image(style_image, loader=content_image_loader)

    preprocessor = _preprocessor()
    preprocessor = preprocessor.to(device)
    style_image = preprocessor(style_image)

    criterion.set_style_image(style_image)

    def criterion_update_fn(input_image: torch.Tensor,
                            criterion: nn.Module) -> None:
        cast(loss.PerceptualLoss,
             criterion).set_content_image(preprocessor(input_image))

    return optim.multi_epoch_model_optimization(
        content_image_loader,
        transformer,
        criterion,
        criterion_update_fn,
        hyper_parameters.num_epochs,
        lr_scheduler=lr_scheduler,
        quiet=quiet,
    )
Ejemplo n.º 11
0
def test_get_device_str():
    device_name = "mkldnn"
    actual = misc.get_device(device_name)
    desired = torch.device(device_name)
    assert actual == desired
Ejemplo n.º 12
0
def test_get_device():
    actual = misc.get_device()
    desired = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    assert actual == desired