Exemple #1
0
    def generate_kernel_test_case(cls,
                                  description: str,
                                  graph: Graph,
                                  inputs: Dict[Variable, np.array],
                                  expected: Dict[Variable, np.array],
                                  backend=None,
                                  raise_skip: bool = True,
                                  EPS: float = 1.0e-3,
                                  ABS_EPS: float = 0.0):
        """Generate test data for generated kernel codes
    
        Generated data are saved in JSON format, and BrowserTestRunner executes it.
        """

        if backend is None:
            backend = ["webgpu", "webassembly", "fallback"]

        if not cls.flag_initialized:
            cls.setup()

        if not isinstance(backend, str):
            for b in backend:
                generate_kernel_test_case(description=description,
                                          graph=graph,
                                          inputs=inputs,
                                          expected=expected,
                                          backend=b,
                                          raise_skip=False,
                                          EPS=EPS,
                                          ABS_EPS=ABS_EPS)

            if raise_skip:
                raise SkipTest(f"[BrowserTest|{backend}] {description}")

            return

        graph_descriptor = generate_descriptor(backend, graph)

        testcase_dirname = f"testcase-{str(cls.counter)}"
        cls.counter += 1

        graph_descriptor.save(path.join(cls.OUTPUT_ROOT, testcase_dirname))

        cls.cases.append({
            "description":
            description,
            "inputs": [list(inputs[v].flatten()) for v in graph.inputs],
            "expected": [list(expected[v].flatten()) for v in graph.outputs],
            "dirname":
            testcase_dirname,
            "backend":
            backend,
            "EPS":
            EPS,
            "ABS_EPS":
            ABS_EPS
        })

        if raise_skip:
            raise SkipTest(f"[BrowserTest|{backend}] {description}")
def main():
    sys.setrecursionlimit(10000)  # workaround for deep copying large graph

    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        default="resnet50",
                        choices=["vgg16", "resnet50"])
    parser.add_argument("--backend", default="webgpu,webassembly,fallback")
    parser.add_argument("--encoding")
    parser.add_argument('--out',
                        '-o',
                        default='output_chainer',
                        help='Directory to output the graph descriptor')

    args = parser.parse_args()

    os.makedirs(args.out, exist_ok=True)

    sample_image = np.zeros((224, 224, 3),
                            dtype=np.uint8)  # PIL.Image.open("")
    if args.model == "vgg16":
        link = chainer.links.model.vision.vgg.VGG16Layers()
        prepared_image = chainer.links.model.vision.vgg.prepare(
            sample_image)  # BGR, CHW
        out_layer_name = "fc8"

    elif args.model == "resnet50":
        link = chainer.links.model.vision.resnet.ResNet50Layers()
        prepared_image = chainer.links.model.vision.resnet.prepare(
            sample_image)
        out_layer_name = "fc6"

    nn_input = chainer.Variable(np.array([prepared_image], dtype=np.float32))
    nn_output = link(nn_input, layers=[
        out_layer_name
    ])[out_layer_name]  # 'prob' is also possible (uses softmax)
    chainer_cg = chainer.computational_graph.build_computational_graph(
        [nn_output])
    converter = ChainerConverter()
    graph = converter.convert(chainer_cg, [nn_input],
                              [nn_output])  # type: Graph

    any_backend_failed = False
    last_backend_exception = None
    for backend in args.backend.split(","):
        try:
            graph_exec_data = generate_descriptor(
                backend, graph, constant_encoder_name=args.encoding)
            graph_exec_data.save(args.out)
        except Exception as ex:
            any_backend_failed = True
            last_backend_exception = ex
            console.error(
                f"Failed generating descriptor for backend {backend}: {str(ex)}\n"
            )

    if any_backend_failed:
        raise last_backend_exception
Exemple #3
0
def main():
    sys.setrecursionlimit(10000)  # workaround for deep copying large graph
    parser = argparse.ArgumentParser()
    parser.add_argument("kerasmodel")
    parser.add_argument("--backend",
                        default="webgpu,webassembly,fallback",
                        help="comma-separated list of backends")
    parser.add_argument(
        "--input_shape",
        required=True,
        help="shape of blobs for inputs (example: '(1,3,224,224)')")
    # parser.add_argument("--input_data_format", choices=["channels_first", "channels_last"])
    parser.add_argument(
        "--out",
        help="output directory (default: <model>/webdnn_graph_descriptor)")
    parser.add_argument("--encoding", help="name of weight encoder")
    args = parser.parse_args()

    sys.stderr.write("Generating feedforward graph\n")
    input_shape = ast.literal_eval(args.input_shape)
    input_shapes = [input_shape]
    model = h5py.File(args.kerasmodel, "r")
    converter = KerasGraphConverter()
    graph = converter.convert(model, input_shapes)

    if args.out:
        output_dir = args.out
    else:
        output_dir = path.join(path.dirname(args.kerasmodel),
                               "webdnn_graph_descriptor")
    os.makedirs(output_dir, exist_ok=True)

    sys.stderr.write("Generating descriptors\n")
    any_backend_failed = False
    last_backend_exception = None
    for backend in args.backend.split(","):
        try:
            graph_exec_data = generate_descriptor(
                backend, graph, constant_encoder_name=args.encoding)
            graph_exec_data.save(output_dir)
        except Exception as ex:
            any_backend_failed = True
            last_backend_exception = ex
            sys.stderr.write(
                f"Failed generating descriptor for backend {backend}: {str(ex)}\n"
            )

    if any_backend_failed:
        raise last_backend_exception
Exemple #4
0
    def generate_kernel_test_case(cls,
                                  description: str,
                                  backend: Union[str, Iterable[str]],
                                  graph: Graph,
                                  inputs: Dict[Variable, np.array],
                                  expected: Dict[Variable, np.array],
                                  raise_skip: bool = True):
        """Generate test data for generated kernel codes
    
        Generated data are saved in JSON format, and BrowserTestRunner executes it.
        """

        if not cls.flag_initialized:
            cls.setup()

        if not isinstance(backend, str):
            for b in backend:
                generate_kernel_test_case(description, b, graph, inputs,
                                          expected, False)

            if raise_skip:
                raise SkipTest(f"[BrowserTest|{backend}] {description}")

            return

        graph_descriptor = generate_descriptor(backend, graph)

        testcase_dirname = f"testcase-{str(cls.counter)}"
        cls.counter += 1

        graph_descriptor.save(path.join(cls.OUTPUT_ROOT, testcase_dirname))

        cls.cases.append({
            "description":
            description,
            "inputs": [list(inputs[v].flatten()) for v in graph.inputs],
            "expected": [list(expected[v].flatten()) for v in graph.outputs],
            "dirname":
            testcase_dirname,
            "backend":
            backend
        })

        if raise_skip:
            raise SkipTest(f"[BrowserTest|{backend}] {description}")
Exemple #5
0
def main():
    sys.setrecursionlimit(10000)  # workaround for deep copying large graph

    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        default="resnet50",
                        choices=["vgg16", "resnet50"])
    parser.add_argument("--backend",
                        default="webgpu",
                        choices=["webgpu", "webassembly", "fallback"])
    parser.add_argument("--encoding")
    args = parser.parse_args()

    sample_image = np.zeros((224, 224, 3), dtype=np.uint8)  #PIL.Image.open("")
    if args.model == "vgg16":
        link = chainer.links.model.vision.vgg.VGG16Layers()
        prepared_image = chainer.links.model.vision.vgg.prepare(
            sample_image)  # BGR, CHW
        out_layer_name = "fc8"

    elif args.model == "resnet50":
        link = chainer.links.model.vision.resnet.ResNet50Layers()
        prepared_image = chainer.links.model.vision.resnet.prepare(
            sample_image)
        out_layer_name = "fc6"

    nn_input = chainer.Variable(np.array([prepared_image], dtype=np.float32))
    nn_output = link(nn_input, layers=[
        out_layer_name
    ])[out_layer_name]  # 'prob' is also possible (uses softmax)
    chainer_cg = chainer.computational_graph.build_computational_graph(
        [nn_output])
    converter = ChainerGraphConverter()
    graph = converter.convert(chainer_cg, [nn_input],
                              [nn_output])  # type: Graph

    graph_exec_data = generate_descriptor(args.backend,
                                          graph,
                                          constant_encoder_name=args.encoding)

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    graph_exec_data.save(OUTPUT_DIR)
Exemple #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", default="resnet50", choices=["resnet50"])
    parser.add_argument('--out',
                        '-o',
                        default='output_keras',
                        help='Directory to output the graph descriptor')
    parser.add_argument("--encoding", help="name of weight encoder")
    args = parser.parse_args()

    model = resnet50.ResNet50(include_top=True, weights='imagenet')

    sys.setrecursionlimit(10000)
    graph = KerasConverter(batch_size=1).convert(model)
    for backend in ["webgpu", "webassembly", "fallback"]:
        graph_exec_data = generate_descriptor(
            backend, graph, constant_encoder_name=args.encoding)
        graph_exec_data.save(args.out)

    console.stderr("Done.")
Exemple #7
0
    def generate_kernel_test_case(cls,
                                  description: str,
                                  graph: Graph,
                                  inputs: Dict[Variable, np.array],
                                  expected: Dict[Variable, np.array],
                                  backend=None,
                                  raise_skip: bool = True,
                                  EPS: float = 1.0e-3,
                                  ABS_EPS: float = 0.0):
        """Generate test data for generated kernel codes
    
        Generated data are saved in JSON format, and BrowserTestRunner executes it.
        """

        if not cls.flag_initialized:
            cls.setup()

        if backend is None:
            backend = ["webgpu", "webassembly", "fallback"]

        if not isinstance(backend, str):
            for b in backend:
                generate_kernel_test_case(description=description,
                                          graph=graph,
                                          inputs=inputs,
                                          expected=expected,
                                          backend=b,
                                          raise_skip=False,
                                          EPS=EPS,
                                          ABS_EPS=ABS_EPS)

            if raise_skip:
                raise SkipTest(f"[BrowserTest|{backend}] {description}")

            return

        backend_flag_map = {
            "webgpu": flags.test.TEST_WEBGPU,
            "webassembly": flags.test.TEST_WEBASSEMBLY,
            "fallback": flags.test.TEST_FALLBACK
        }

        if not backend_flag_map[backend]:
            return

        graph_descriptor = generate_descriptor(backend, graph)

        cls.counter += 1
        testcase_dirname = f"testcase-{str(cls.counter)}"

        output_root = path.join(cls.OUTPUT_ROOT, testcase_dirname)
        graph_descriptor.save(output_root)
        with open(path.join(output_root, "./cg.dot"), "w") as f:
            f.write(traverse.dump_dot(graph_descriptor.graph))

        cls.cases.append({
            "description":
            description,
            "inputs": [list(inputs[v].flatten()) for v in graph.inputs],
            "expected": [list(expected[v].flatten()) for v in graph.outputs],
            "dirname":
            testcase_dirname,
            "backend":
            backend,
            "EPS":
            EPS,
            "ABS_EPS":
            ABS_EPS
        })

        if raise_skip:
            raise SkipTest(f"[BrowserTest|{backend}] {description}")
Exemple #8
0
def main():
    sys.setrecursionlimit(10000)  # workaround for deep copying large graph
    parser = argparse.ArgumentParser()
    parser.add_argument("kerasmodel")
    parser.add_argument("--backend", default="webgpu,webassembly,fallback",
                        help="comma-separated list of backends")
    parser.add_argument("--input_shape", required=True,
                        help="shape of blobs for inputs (example: '(1,3,224,224)')")
    # parser.add_argument("--input_data_format", choices=["channels_first", "channels_last"])
    parser.add_argument("--out",
                        help="output directory (default: <model>/webdnn_graph_descriptor)")
    parser.add_argument("--encoding", help="name of weight encoder")
    parser.add_argument("--visualize_ir", action="store_true")
    parser.add_argument("--plugin", action="append", help="plugin python files which are imported before transpiling")
    args = parser.parse_args()

    console.stderr(f"[{path.basename(__file__)}] Generating feedforward graph")
    class_list = []
    if args.plugin:
        for plugin_path in args.plugin:
            class_list += _load_plugin(plugin_path)
    if len(class_list) > 0:
        # custom_objects is a dictionary for load_model to load user-defined custom layers
        custom_objects = {}
        for k, v in class_list:
            custom_objects[k] = v

    input_shape, _ = Shape.parse(args.input_shape)
    input_shapes = [input_shape]

    model = keras.models.load_model(args.kerasmodel, custom_objects=custom_objects)
    model.build()
    converter = KerasConverter()
    graph = converter.convert(model)

    for graph_input, input_shape in zip(graph.inputs, input_shapes):
        for p1, p2 in zip(graph_input.shape, input_shape):
            if not Placeholder.check_resolved(p1) and Placeholder.check_resolved(p2):
                p1.value = Placeholder.force_int(p2)

            elif Placeholder.check_resolved(p1) and not Placeholder.check_resolved(p2):
                raise ValueError(f'Shape mismatch: {p1} != {p2}')

            elif Placeholder.check_resolved(p1) and Placeholder.check_resolved(p2):
                assert p1 == p2, f'Shape mismatch: {p1} != {p2}'

    if args.out:
        output_dir = args.out
    else:
        output_dir = path.join(path.dirname(args.kerasmodel), "webdnn_graph_descriptor")
    os.makedirs(output_dir, exist_ok=True)

    if args.visualize_ir:
        ir_dot_path = path.join(output_dir, "ir.dot")
        with open(ir_dot_path, "w") as f:
            f.write(dump_dot(graph))
        console.stderr(f"IR graph can be visualized with graphviz command: 'dot {ir_dot_path} -T png -o output.png'")

    console.stderr(f"[{path.basename(__file__)}] Generating graph descriptor")

    any_backend_failed = False
    backends = args.backend.split(",")
    for i, backend in enumerate(backends):
        console.stderr(f"[{path.basename(__file__)}] Backend: {console.colorize(backend, console.Color.Cyan)}")
        try:
            graph_exec_data = generate_descriptor(backend, graph, constant_encoder_name=args.encoding)
            graph_exec_data.save(output_dir)
        except Exception as ex:
            if flags.DEBUG:
                raise ex

            any_backend_failed = True
            console.error(f"[{path.basename(__file__)}] Failed generating descriptor for {backend} backend")
            console.stderr(traceback.format_exc())
            continue

    if any_backend_failed:
        exit(1)
Exemple #9
0
def main():
    sys.setrecursionlimit(10000)  # workaround for deep copying large graph
    parser = argparse.ArgumentParser()
    # default is Caffenet of Caffe example
    parser.add_argument("caffemodel")
    parser.add_argument("--backend",
                        default="webgpu,webassembly,fallback",
                        help="comma-separated list of backends")
    parser.add_argument("--input_name", help="blob name for input (mandatory)")
    parser.add_argument(
        "--input_shape",
        help="shape of blobs for inputs (example: '(1,3,224,224)')")
    parser.add_argument("--input_npy",
                        help="npy file containing sample inputs")
    parser.add_argument(
        "--output_names",
        required=True,
        help="comma-separated blob name for output (mandatory)")
    parser.add_argument(
        "--out",
        help="output directory (default: <model>/webdnn_graph_descriptor)")
    parser.add_argument("--encoding", help="name of weight encoder")
    args = parser.parse_args()

    # multiple blob input can be easily implemented, but command-line arguments becomes complicated.
    input_blob, input_filled = parse_input_blob(args)
    output_names = args.output_names.split(",")

    sys.stderr.write(
        "Loading caffe model... (usually takes several minutes)\n")
    link = chainer.links.caffe.CaffeFunction(args.caffemodel)

    sys.stderr.write("Generating feedforward graph\n")
    output_blobs = list(
        link(inputs={args.input_name: input_blob},
             outputs=output_names,
             train=False))  # list of Variable
    chainer_cg = chainer.computational_graph.build_computational_graph(
        output_blobs)
    converter = ChainerGraphConverter()
    graph = converter.convert(chainer_cg, [input_blob],
                              output_blobs)  # type: Graph

    if args.out:
        output_dir = args.out
    else:
        output_dir = path.join(path.dirname(args.caffemodel),
                               "webdnn_graph_descriptor")
    os.makedirs(output_dir, exist_ok=True)

    if input_filled:
        # save output of Caffe Network (not required for inference)
        output_arrays = {
            output_name: output_blob.data
            for output_name, output_blob in zip(output_names, output_blobs)
        }
        np.savez(path.join(output_dir, "example_output.npz"), **output_arrays)

    sys.stderr.write("Generating descriptors\n")
    any_backend_failed = False
    for backend in args.backend.split(","):
        try:
            graph_exec_data = generate_descriptor(
                backend, graph, constant_encoder_name=args.encoding)
            graph_exec_data.save(output_dir)
        except Exception as ex:
            any_backend_failed = True
            sys.stderr.write(
                f"Failed generating descriptor for backend {backend}: {str(ex)}\n"
            )

    if any_backend_failed:
        sys.exit(1)
Exemple #10
0
    #    generator_class = DCGANGenerator64
    #elif args.arch == 'resnet128':
    #    generator_class = ResNetGenerator128
    #elif args.arch == 'resnet256':
    #else:
    #    raise ValueError('Unknown -arch %s' % FLAGS.arch)
    generator_class = ResNetGenerator256
    gen = generator_class()
    serializers.load_npz(args.chainer_model_path, gen)
    print("Generator model loaded")

    x = chainer.Variable(np.empty((1, args.latent_len), dtype=np.float32))
    with chainer.using_config('train', False):
        y = gen(x)
    print("Start Convert")
    graph = ChainerConverter().convert([x], [y])
    #exec_info = generate_descriptor("webgpu", graph)
    #exec_info.save(args.out)
    #exec_info = generate_descriptor("webgl", graph)
    #exec_info.save(args.out)
    exec_info = generate_descriptor("webgl", graph)
    exec_info.save(args.out)
    exec_info = generate_descriptor("webgl",
                                    graph,
                                    constant_encoder_name="eightbit")
    exec_info.save(args.out + "_8bit")
    #exec_info = generate_descriptor("webassembly", graph, constant_encoder_name="eightbit")
    #exec_info.save(args.out+"_8bit")
    #exec_info = generate_descriptor("webgpu", graph, constant_encoder_name="eightbit")
    #exec_info.save(args.out+"_8bit")
def main():
    sys.setrecursionlimit(10000)  # workaround for deep copying large graph

    parser = argparse.ArgumentParser()
    parser.add_argument("--backend", default="webgpu,webassembly")
    parser.add_argument("--encoding", default="eightbit")
    parser.add_argument('--out',
                        '-o',
                        default='webdnn/image-caption-model',
                        help='Directory to output the graph descriptor')
    parser.add_argument('--sentence',
                        '-s',
                        required=True,
                        type=str,
                        help='sentence dataset file path')
    parser.add_argument('--model',
                        '-m',
                        required=True,
                        type=str,
                        help='input model file path')
    parser.add_argument("--example_image",
                        help="example image for comparing output")
    parser.add_argument("--visualize_ir", action="store_true")

    args = parser.parse_args()

    os.makedirs(args.out, exist_ok=True)
    out_dir_graph1 = os.path.join(args.out, "image-feature")
    out_dir_graph2 = os.path.join(args.out, "caption-generation")

    hidden_num = 512
    with open(args.sentence, 'rb') as f:
        sentence_dataset = pickle.load(f)
    word_ids = sentence_dataset['word_ids']
    word_num = len(word_ids)
    id_to_word = [""] * word_num
    for k, v in word_ids.items():
        id_to_word[v] = k

    with open(os.path.join(args.out, "word_data.json"), "w") as f:
        json.dump(
            {
                "id_to_word": id_to_word,
                "bos_id": word_ids["<S>"],
                "eos_id": word_ids["</S>"],
                "word_num": word_num,
                "hidden_num": hidden_num
            }, f)

    caption_net = ImageCaption(word_num=word_num,
                               feature_num=2048,
                               hidden_num=hidden_num)
    chainer.serializers.load_hdf5(args.model, caption_net)
    graph1 = generate_graph_model1(caption_net)
    graph2 = generate_graph_model2(caption_net, hidden_num=hidden_num)

    if args.example_image:
        example_io = generate_example_io(caption_net, word_ids,
                                         args.example_image)
        with open(os.path.join(args.out, "example_io.json"), "w") as f:
            json.dump(example_io, f)

    if args.visualize_ir:
        ir_dot_path = os.path.join(args.out, "ir.dot")
        with open(ir_dot_path, "w") as f:
            f.write(dump_dot(graph2))
        console.stderr(
            f"IR graph can be visualized with graphviz command: 'dot {ir_dot_path} -T png -o output.png'"
        )

    any_backend_failed = False
    last_backend_exception = None
    for backend in args.backend.split(","):
        try:
            graph_exec_data = generate_descriptor(
                backend, graph1, constant_encoder_name=args.encoding)
            graph_exec_data.save(out_dir_graph1)
            graph_exec_data = generate_descriptor(
                backend, graph2, constant_encoder_name=args.encoding)
            graph_exec_data.save(out_dir_graph2)
        except Exception as ex:
            any_backend_failed = True
            last_backend_exception = ex
            console.error(
                f"Failed generating descriptor for backend {backend}: {str(ex)}\n"
            )

    if any_backend_failed:
        raise last_backend_exception
Exemple #12
0
#os.environ["OPTIMIZE"] = "0"
from webdnn.frontend.chainer import ChainerConverter
from webdnn.backend.interface.generator import generate_descriptor

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='discriminator testing script')
    parser.add_argument('--gen_class',  default='', help='Default generator class')
    parser.add_argument("--load_gen_model", '-l', default='', help='load generator model')
    parser.add_argument('--out', '-o', default='test.jpg', help='output image path')
    parser.add_argument("--image_channels", type=int, default=3, help='number of image channels')
    parser.add_argument("--image_size", type=int, default=64, help='image size')
    parser.add_argument("--latent_len", type=int, default=128, help='latent vector length')
    parser.add_argument("--attr_len", type=int, default=38, help='attribute vector length')
    args = parser.parse_args()
    print(args)

    if args.gen_class != '':
        gen = eval(args.gen_class)
    else:
        gen = DCGANGenerator(latent=args.latent_len+args.attr_len, out_ch=args.image_channels)

    if args.load_gen_model != '':
        serializers.load_npz(args.load_gen_model, gen)
        print("Generator model loaded")
    np.random.seed(0)
    x =  chainer.Variable(np.empty((1, args.latent_len+args.attr_len), dtype=np.float32))
    y = gen(x, test=True)
    graph = ChainerConverter().convert_from_inout_vars([x], [y])
    exec_info = generate_descriptor("webassembly", graph)
    exec_info.save("./output_model")
Exemple #13
0
print(f"model: {args.model}")
print(f"backend: {args.backend}")
print(f"encoding: {args.encoding}")

# Load chainer pre-trained model
model = FastStyleNet()

model_path = NSTModelPath[args.model].value
if not path.exists(model_path):
    raise FileNotFoundError(
        f"Model data ({model_path}) is not found. Please clone " +
        "'https://github.com/gafr/chainer-fast-neuralstyle-models' under the resource directory. "
        +
        "Clone command takes about a few minute, the repository size is about 200MB."
    )

chainer.serializers.load_npz(model_path, model)

# Execute forward propagation to construct computation graph
x = chainer.Variable(np.zeros((1, 3, 144, 192), dtype=np.float32))
y = model(x, test=True)

# Convert chainer computation graph into IR
graph = ChainerGraphConverter().convert_from_inout_vars([x], [y])

# Generate graph descriptor
generate_descriptor(args.backend, graph,
                    constant_encoder_name=args.encoding).save(
                        path.join(path.dirname(__file__), "./output"))
Exemple #14
0
def main():
    parser = argparse.ArgumentParser(description='Chainer example: MNIST')
    parser.add_argument('--batchsize', '-b', type=int, default=100,
                        help='Number of images in each mini-batch')
    parser.add_argument('--epoch', '-e', type=int, default=5,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--frequency', '-f', type=int, default=-1,
                        help='Frequency of taking a snapshot')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='output',
                        help='Directory to output the graph descriptor and sample test data')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=100,
                        help='Number of units')
    args = parser.parse_args()

    print('GPU: {}'.format(args.gpu))
    print('# unit: {}'.format(args.unit))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')

    os.makedirs(args.out, exist_ok=True)

    # Set up a neural network to train
    # Classifier reports softmax cross entropy loss and accuracy at every
    # iteration, which will be used by the PrintReport extension below.
    model = L.Classifier(MLP(args.unit, 10))
    if args.gpu >= 0:
        # Make a specified GPU current
        chainer.cuda.get_device_from_id(args.gpu).use()
        model.to_gpu()  # Copy the model to the GPU

    # Setup an optimizer
    optimizer = chainer.optimizers.Adam()
    optimizer.setup(model)

    # Load the MNIST dataset
    train, test = chainer.datasets.get_mnist()

    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                                 repeat=False, shuffle=False)

    # Set up a trainer
    updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=os.path.join(args.out, 'chainer_model'))

    # Evaluate the model with the test dataset for each epoch
    trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

    # Dump a computational graph from 'loss' variable at the first iteration
    # The "main" refers to the target link of the "main" optimizer.
    trainer.extend(extensions.dump_graph('main/loss'))

    # Take a snapshot for each specified epoch
    frequency = args.epoch if args.frequency == -1 else max(1, args.frequency)
    trainer.extend(extensions.snapshot(), trigger=(frequency, 'epoch'))

    # Write a log of evaluation statistics for each epoch
    trainer.extend(extensions.LogReport())

    # Save two plot images to the result dir
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport(['main/loss', 'validation/main/loss'],
                                  'epoch', file_name='loss.png'))
        trainer.extend(
            extensions.PlotReport(
                ['main/accuracy', 'validation/main/accuracy'],
                'epoch', file_name='accuracy.png'))

    # Print selected entries of the log to stdout
    # Here "main" refers to the target link of the "main" optimizer again, and
    # "validation" refers to the default name of the Evaluator extension.
    # Entries other than 'epoch' are reported by the Classifier link, called by
    # either the updater or the evaluator.
    trainer.extend(extensions.PrintReport(
        ['epoch', 'main/loss', 'validation/main/loss',
         'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

    # Print a progress bar to stdout
    trainer.extend(extensions.ProgressBar())

    if args.resume:
        # Resume from a snapshot
        chainer.serializers.load_npz(args.resume, trainer)

    # Run the training
    trainer.run()

    # conversion

    print('Transpiling model to WebDNN graph descriptor')

    example_input = numpy.expand_dims(train[0][0], axis=0)  # example input (anything ok, (batch_size, 784))
    x = chainer.Variable(example_input)
    y = model.predictor(x)  # run model (without softmax)
    graph = ChainerGraphConverter().convert_from_inout_vars([x], [y])  # convert graph to intermediate representation
    for backend in ["webgpu", "webassembly", "fallback"]:
        try:
            exec_info = generate_descriptor(backend, graph)
            exec_info.save(args.out)
        except Exception as ex:
            print(f"Failed generating descriptor for backend {backend}: {str(ex)}\n")
        else:
            print(f"Backend {backend} ok\n")

    print('Exporting test samples (for demo purpose)')
    test_samples_json = []
    for i in range(10):
        image, label = test[i]
        test_samples_json.append({'x': image.tolist(), 'y': int(label)})
    with open(os.path.join(args.out, 'test_samples.json'), 'w') as f:
        json.dump(test_samples_json, f)