def test_styletransfer_gatys():
    """Style transfer works without error for the Gatys algorithm"""
    tmpdir = TemporaryDirectory()
    styletransfer([CONTENTS + "dockersmall.png"], [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  alg="gatys")
    assert len(glob(tmpdir.name + "/dockersmall*cubism*")) == 1
def test_styletransfer_chenschmidtinverse():
    """Style transfer method works without error for the Chend-Schmidt Inverse algorithm"""
    tmpdir = TemporaryDirectory()
    styletransfer([CONTENTS + "dockersmall.png"], [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  alg="chen-schmidt-inverse")
    assert len(glob(tmpdir.name + "/dockersmall*cubism*")) == 1
def test_styletransfer_gatysmultiresolution():
    """Style transfer works without error for the Gatys algorithm with multiresolution"""
    tmpdir = TemporaryDirectory()
    styletransfer([CONTENTS + "docker.png"], [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  alg="gatys-multiresolution",
                  size=600)
    assert len(glob(tmpdir.name + "/docker*cubism*")) == 1
def test_styletransfer_gatys_parameters():
    """Algorithm parameters can be passed to the Gatys method"""
    tmpdir = TemporaryDirectory()
    algparams = ("-num_iterations", "50")
    styletransfer([CONTENTS + "dockersmall.png"], [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  algparams=algparams)
    assert len(glob(tmpdir.name + "/dockersmall*cubism*")) == 1
def test_formatpsd():
    """PSD format images can be processed correctly"""
    contents = [CONTENTS + f for f in ["oldtelephone.psd"]]
    tmpdir = TemporaryDirectory()
    styletransfer(contents, [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  alg="chen-schmidt-inverse")
    assert len(glob(tmpdir.name + "/*cubism*")) == 1
def test_formattga():
    """TGA format images can be processed correctly"""
    contents = [CONTENTS + f for f in ["tgasample.tga", "marbles.tga"]]
    tmpdir = TemporaryDirectory()
    styletransfer(contents, [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  alg="chen-schmidt-inverse")
    assert len(glob(tmpdir.name + "/*cubism*")) == 2
def test_styletransfer_keepsize():
    """Style transfer keeps the original image size if no size paramenter is given"""
    for alg in ALGORITHMS.keys():
        tmpdir = TemporaryDirectory()
        img = CONTENTS + "dockersmall.png"
        styletransfer([img], [STYLES + "cubism.jpg"], tmpdir.name, alg=alg)
        files = glob(tmpdir.name + "/" + filename(img) + "*cubism*")
        print("Expected size", shape(img))
        print("Actual shape", shape(files[0]))
        assert len(files) == 1
        assert shape(files[0]) == shape(img)
def test_styletransfer_sw():
    """Style transfer works for varying style weights"""
    styleweights = [1, 5, 10]
    alg = "gatys"
    img = "docker.png"
    tmpdir = TemporaryDirectory()
    styletransfer([CONTENTS + img], [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  alg=alg,
                  size=100,
                  weights=styleweights)
    assertalldifferent(tmpdir.name + "/" + filename(img) + "*cubism*",
                       len(styleweights))
def test_styletransfer_ss():
    """Style transfer works for varying style scales"""
    stylescales = [0.75, 1, 1.25]
    for alg in ALGORITHMS.keys():
        img = "docker.png"
        tmpdir = TemporaryDirectory()
        styletransfer([CONTENTS + img], [STYLES + "cubism.jpg"],
                      tmpdir.name,
                      alg=alg,
                      size=100,
                      stylescales=stylescales)
        assertalldifferent(tmpdir.name + "/" + filename(img) + "*cubism*",
                           len(stylescales))
def test_alpha():
    """Transformation of images with an alpha channel preserve transparency"""
    tmpdir = TemporaryDirectory()
    # Transform image with alpha
    styletransfer([CONTENTS + "dockersmallalpha.png"], [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  alg="chen-schmidt-inverse")
    assert len(glob(tmpdir.name + "/*dockersmallalpha_cubism*")) == 1
    # Transform image without alpha
    styletransfer([CONTENTS + "dockersmall.png"], [STYLES + "cubism.jpg"],
                  tmpdir.name,
                  alg="chen-schmidt-inverse")
    assert len(glob(tmpdir.name + "/*dockersmall_cubism*")) == 1
    # Check correct that generated image are different
    assertalldifferent(tmpdir.name + "/*cubism*")
def test_styletransfer_size():
    """Style transfer works for varying image sizes, producing correctly scaled images"""
    for alg in ALGORITHMS.keys():
        for size in [50, 100, 200]:
            for img in ["docker.png", "obama.jpg"]:
                originalshape = shape(CONTENTS + img)
                tmpdir = TemporaryDirectory()
                styletransfer([CONTENTS + img], [STYLES + "cubism.jpg"],
                              tmpdir.name,
                              alg=alg,
                              size=size)
                files = glob(tmpdir.name + "/" + filename(img) + "*cubism*")
                resultshape = shape(files[0])
                rescalefactor = size / originalshape[0]
                expectedshape = [size, int(rescalefactor * originalshape[1])]
                print("Expected shape", expectedshape)
                print("Actual shape", resultshape)
                assert len(files) == 1
                assert expectedshape == resultshape
def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        # Default parameters
        contents = []
        styles = []
        savefolder = "/images"
        size = None
        alg = "gatys"
        weights = None
        stylescales = None
        tileoverlap = None
        otherparams = []

        # Gather parameters
        i = 1
        while i < len(argv):
            # References to inputs/outputs are re-referenced to the mounted /images directory
            if argv[i] == "--content":
                contents = ["/images/" + x for x in sublist(argv[i+1:], stopper="-")]
                i += len(contents) + 1
            elif argv[i] == "--style":
                styles = ["/images/" + x for x in sublist(argv[i+1:], stopper="-")]
                i += len(styles) + 1
            # Other general parameters
            elif argv[i] == "--output":
                savefolder = "/images/" + argv[i+1]
                i += 2
            elif argv[i] == "--alg":
                alg = argv[i+1]
                i += 2
            elif argv[i] == "--size":
                size = int(argv[i+1])
                i += 2
            elif argv[i] == "--sw":
                weights = [float(x) for x in sublist(argv[i+1:], stopper="-")]
                i += len(weights) + 1
            elif argv[i] == "--ss":
                stylescales = [float(x) for x in sublist(argv[i+1:], stopper="-")]
                i += len(stylescales) + 1
            elif argv[i] == "--tileoverlap":
                tileoverlap = int(argv[i+1])
                i += 2
            # Help
            elif argv[i] == "--help":
                print(HELP)
                return 0
            # Additional parameters will be passed on to the specific algorithms
            else:
                otherparams.append(argv[i])
                i += 1

        # Check parameters
        if len(contents) == 0:
            raise ValueError("At least one content image must be provided")
        if len(styles) == 0:
            raise ValueError("At least one style image must be provided")

        LOGGER.info("Running neural style transfer with")
        LOGGER.info("\tContents = %s" % str(contents))
        LOGGER.info("\tStyle = %s" % str(styles))
        LOGGER.info("\tAlgorithm = %s" % alg)
        LOGGER.info("\tStyle weights = %s" % str(weights))
        LOGGER.info("\tStyle scales = %s" % str(stylescales))
        LOGGER.info("\tSize = %s" % str(size))
        LOGGER.info("\tTile overlap = %s" % str(tileoverlap))
        styletransfer(contents, styles, savefolder, size, alg, weights, stylescales, tileoverlap, algparams=otherparams)
        return 1

    except Exception:
        print(HELP)
        traceback.print_exc()
        return 0