Exemplo n.º 1
0
def main():
    chainer.config.train = False

    model = FastStyleNet()
    serializers.load_npz(
        './chainer-fast-neuralstyle-models/models/starrynight.model', model)

    input = './test1.jpg'
    original = Image.open(input).convert('RGB')
    print(original.size)
    image = np.asarray(original, dtype=np.float32).transpose(2, 0, 1)
    image = image.reshape((1, ) + image.shape)
    padding = 0  #50
    if padding > 0:
        image = np.pad(
            image, [[0, 0], [0, 0], [padding, padding], [padding, padding]],
            'symmetric')
    x = image

    out = model(x)
    out = out.data[0]
    print(out.shape)
    print('model done.')

    postprocess(out)

    print('export onnx...')
    onnx_chainer.export(model, x, filename='FastStyleNet.onnx')
Exemplo n.º 2
0
    kanagawa = "../../resources/chainer-fast-neuralstyle-models/models/kanagawa.model"


sys.setrecursionlimit(10000)
parser = argparse.ArgumentParser()
parser.add_argument("--model", default=NSTModelPath.kanagawa.name, choices=[v.name for v in NSTModelPath])
parser.add_argument("--backend", default="webgpu", choices=["webgpu", "webassembly", "fallback"])
parser.add_argument("--encoding")
args = parser.parse_args()

print(f"model: {args.model}")
print(f"backend: {args.backend}")
print(f"encoding: {args.encoding}")

# Load chainer pre-trained model
model = FastStyleNet()

model_path = NSTModelPath[args.model].value
if not path.exists(model_path):
    raise FileNotFoundError(f"Model data ({model_path}) is not found. Please clone " +
                            "'https://github.com/gafr/chainer-fast-neuralstyle-models' under the resource directory. " +
                            "Clone command takes about a few minute, the repository size is about 200MB.")

chainer.serializers.load_npz(model_path, model)

# Execute forward propagation to construct computation graph
if chainer.__version__ >= "2.":
    with chainer.using_config("train", False):  # fixes batch normalization
        x = chainer.Variable(np.zeros((1, 3, 144, 192), dtype=np.float32))
        y = model(x)
else: