Beispiel #1
0
def init_worker():

    import nnabla as nn
    import nnabla.functions as F
    import nnabla.parametric_functions as PF
    #import nnabla.solvers as S
    from nnabla.utils.nnp_graph import NnpLoader

    nnp = NnpLoader("MyChain.nnp")
    net = nnp.get_network("MyChain", batch_size)

    global xx
    global yy

    xx = net.inputs['xx']
    yy = net.outputs['yy']
Beispiel #2
0
def load_nnp_model(path, batch_size, output_num):
    from nnabla.utils.nnp_graph import NnpLoader

    nnp = NnpLoader(path)
    network_names = nnp.get_network_names()
    assert (len(network_names) > 0)
    graph = nnp.get_network(network_names[0], batch_size=batch_size)
    inputs = list(graph.inputs.keys())[0]
    outputs = list(graph.outputs.keys())
    x = graph.inputs[inputs]

    output_list = list()
    for i in range(output_num):
        output_list.append(graph.outputs[outputs[i]])

    return (x, output_list)
Beispiel #3
0
def evaluate(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn

    # Model (Inception model) from nnp file
    nnp = NnpLoader(args.nnp_inception_model_load_path)
    x, y = get_input_and_output(nnp, args.batch_size, name=args.variable_name)

    if args.evaluation_metric == "IS":
        is_model = None
        compute_metric = compute_inception_score
    if args.evaluation_metric == "FID":
        di = data_iterator_imagenet(args.valid_dir,
                                    args.dirname_to_label_path,
                                    batch_size=args.batch_size,
                                    ih=args.image_size,
                                    iw=args.image_size,
                                    shuffle=True,
                                    train=False,
                                    noise=False)
        compute_metric = functools.partial(compute_frechet_inception_distance,
                                           di=di)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_metric = MonitorSeries("{}".format(args.evaluation_metric),
                                   monitor,
                                   interval=1)

    # Compute the evaluation metric for all models
    def cmp_func(path):
        return int(path.split("/")[-1].strip("params_").rstrip(".h5"))
    model_load_path = sorted(glob.glob("{}/*.h5".format(args.model_load_path)), key=cmp_func) \
        if os.path.isdir(args.model_load_path) else \
        [args.model_load_path]

    for path in model_load_path:
        # Model (SAGAN)
        nn.load_parameters(path)
        z = nn.Variable([batch_size, latent])
        y_fake = nn.Variable([batch_size])
        x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
            .apply(persistent=True)
        # Compute the evaluation metric
        score = compute_metric(z, y_fake, x_fake, x, y, args)
        itr = cmp_func(path)
        monitor_metric.add(itr, score)
Beispiel #4
0
    def _load_nnp(self, rel_name, rel_url):
        '''
            Args:
                rel_name: relative path to where downloaded nnp is saved.
                rel_url: relative url path to where nnp is downloaded from.

            '''
        from nnabla.utils.download import download
        path_nnp = os.path.join(get_model_home(),
                                'imagenet/{}'.format(rel_name))
        url = os.path.join(get_model_url_base(), 'imagenet/{}'.format(rel_url))
        logger.info('Downloading {} from {}'.format(rel_name, url))
        dir_nnp = os.path.dirname(path_nnp)
        if not os.path.isdir(dir_nnp):
            os.makedirs(dir_nnp)
        download(url, path_nnp, open_file=False, allow_overwrite=False)
        print('Loading {}.'.format(path_nnp))
        self.nnp = NnpLoader(path_nnp)
Beispiel #5
0
def match(args):
    # Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = 1
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn
    threshold = args.truncation_threshold

    # Model (SAGAN)
    nn.load_parameters(args.model_load_path)
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, test=True, sn=not_sn)\
        .apply(persistent=True)

    # Model (Inception model) from nnp file
    nnp = NnpLoader(args.nnp_inception_model_load_path)
    x, h = get_input_and_output(nnp, batch_size, args.variable_name)

    # DataIterator for a given class_id
    di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path,
                                batch_size=batch_size, n_classes=args.n_classes,
                                noise=False,
                                class_id=args.class_id)

    # Monitor
    monitor = Monitor(args.monitor_path)
    name = "Matched Image {}".format(args.class_id)
    monitor_image = MonitorImage(name, monitor, interval=1,
                                 num_images=batch_size,
                                 normalize_method=lambda x: (x + 1.) / 2. * 255.)
    name = "Matched Image Tile {}".format(args.class_id)
    monitor_image_tile = MonitorImageTile(name, monitor, interval=1,
                                          num_images=batch_size + args.top_n,
                                          normalize_method=lambda x: (x + 1.) / 2. * 255.)

    # Generate and p(h|x).forward
    # generate
    z_data = resample(batch_size, latent, threshold)
    y_data = generate_one_class(args.class_id, batch_size)
    z.d = z_data
    y_fake.d = y_data
    x_fake.forward(clear_buffer=True)
    # p(h|x).forward
    x_fake_d = x_fake.d.copy()
    x_fake_d = preprocess(
        x_fake_d, (args.image_size, args.image_size), args.nnp_preprocess)
    x.d = x_fake_d
    h.forward(clear_buffer=True)
    h_fake_d = h.d.copy()

    # Feature matching
    norm2_list = []
    x_data_list = []
    x_data_list.append(x_fake.d)
    for i in range(di.size):
        # forward for real data
        x_d, _ = di.next()
        x_data_list.append(x_d)
        x_d = preprocess(
            x_d, (args.image_size, args.image_size), args.nnp_preprocess)
        x.d = x_d
        h.forward(clear_buffer=True)
        h_real_d = h.d.copy()
        # norm computation
        axis = tuple(np.arange(1, len(h.shape)).tolist())
        norm2 = np.sum((h_real_d - h_fake_d) ** 2.0, axis=axis)
        norm2_list.append(norm2)

    # Save top-n images
    argmins = np.argsort(norm2_list)
    for i in range(args.top_n):
        monitor_image.add(i, x_data_list[i])
    matched_images = np.concatenate(x_data_list)
    monitor_image_tile.add(0, matched_images)
def infer():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for inference.
    * Load parameter variables to infer.
    * Create monitor instances for saving and displaying infering stats.
    """
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.net == 'lenet':
        mnist_cnn_prediction = mnist_lenet_prediction
    elif args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction
    else:
        raise ValueError("Unknown network type {}".format(args.net))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(1, False)

    from nnabla.utils.nnp_graph import NnpLoader

    # Read a .nnp file.
    nnp = NnpLoader(args.pretrained)
    # Assume a graph `graph_a` is in the nnp file.
    net = nnp.get_network(nnp.get_network_names()[0], batch_size=1)
    # `x` is an input of the graph.
    x = net.inputs['x']
    # 'y' is an outputs of the graph.
    y = net.outputs['y']
    ve = 0.0

    for j in range(10000):
        x.d, vlabel.d = vdata.next()
        y.forward(clear_buffer=True)
        ve += categorical_error(y.d, vlabel.d)
    #monitor_verr.add(1, ve / args.val_iter)

    print("acc=", 1 - ve / 10000, ".")
    # append F.Softmax to the prediction graph so users see intuitive outputs
    runtime_contents = {
        'networks': [{
            'name': 'Validation',
            'batch_size': args.batch_size,
            'outputs': {
                'y': F.softmax(vpred)
            },
            'names': {
                'x': vimage
            }
        }],
        'executors': [{
            'name': 'Runtime',
            'network': 'Validation',
            'data': ['x'],
            'output': ['y']
        }]
    }