Beispiel #1
0
def get_model(args, test=False):
    """
    Create computation graph and variables.

    """

    image = nn.Variable(
        [args.batch_size, 3, args.image_height, args.image_width])
    label = nn.Variable(
        [args.batch_size, 1, args.image_height, args.image_width])
    mask = nn.Variable(
        [args.batch_size, 1, args.image_height, args.image_width])

    pred = model.deeplabv3plus_model(image,
                                     args.output_stride,
                                     args.num_class,
                                     test=test,
                                     fix_params=False)

    if pred.shape != label.shape:
        pred = F.interpolate(pred,
                             output_size=(label.shape[2], label.shape[3]),
                             mode='linear')

    loss = F.sum(
        F.softmax_cross_entropy(pred, label, axis=1) * mask) / F.sum(mask)
    Model = namedtuple('Model', ['image', 'label', 'mask', 'pred', 'loss'])
    return Model(image, label, mask, pred, loss)
Beispiel #2
0
def get_model(args, test=False):
    """
    Create computation graph and variables.

    """
    nn_in_size = 513

    image = nn.Variable([args.batch_size, 3, nn_in_size, nn_in_size])
    label = nn.Variable([args.batch_size, 1, nn_in_size, nn_in_size])
    mask = nn.Variable([args.batch_size, 1, nn_in_size, nn_in_size])

    pred = model.deeplabv3plus_model(
        image, args.output_stride, args.num_class, test=test, fix_params=False)

    # Initializing moving variance by 1
    params = nn.get_parameters()
    for key, val in params.items():
        if 'bn/var' in key:
            val.d.fill(1)

    loss = F.sum(F.softmax_cross_entropy(
        pred, label, axis=1) * mask) / F.sum(mask)
    Model = namedtuple('Model', ['image', 'label', 'mask', 'pred', 'loss'])
    return Model(image, label, mask, pred, loss)
Beispiel #3
0
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context, import_extension_module
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = import_extension_module(args.context)

    # read label file
    f = open(args.label_file_path, "r")
    labels_dict = f.readlines()

    # Load parameters
    _ = nn.load_parameters(args.model_load_path)

    # Build a Deeplab v3+ network
    x = nn.Variable((1, 3, args.image_height, args.image_width),
                    need_grad=False)
    y = net.deeplabv3plus_model(x,
                                args.output_stride,
                                args.num_class,
                                test=True)

    # preprocess image
    image = imageio.imread(args.test_image_file, as_gray=False, pilmode="RGB")
    #image = imread(args.test_image_file).astype('float32')
    orig_h, orig_w, orig_c = image.shape
    old_size = (orig_h, orig_w)

    input_array = image_preprocess.preprocess_image_and_label(
        image,
        label=None,
        target_width=args.image_width,
        target_height=args.image_height,
        train=False)
    print('Input', input_array.shape)
    input_array = np.transpose(input_array, (2, 0, 1))
    input_array = np.reshape(
        input_array,
        (1, input_array.shape[0], input_array.shape[1], input_array.shape[2]))

    # Compute inference and inference time
    t = time.time()

    x.d = input_array
    y.forward(clear_buffer=True)
    print("done")
    available_devices = ext.get_devices()
    ext.device_synchronize(available_devices[0])
    ext.clear_memory_cache()

    elapsed = time.time() - t
    print('Inference time : %s seconds' % (elapsed))

    output = np.argmax(y.d, axis=1)  # (batch,h,w)

    # Apply post processing
    post_processed = post_process(output[0], old_size,
                                  (args.image_height, args.image_width))

    # Get the classes predicted
    predicted_classes = np.unique(post_processed)
    for i in range(predicted_classes.shape[0]):
        print('Classes Segmented: ', labels_dict[predicted_classes[i]])

    # Visualize inference result
    visualize(post_processed)