Exemplo n.º 1
0
def _get_module(args, margs, dargs, net=None):
    if net is None:
        # the following lines show how to create symbols for our networks
        if model_specs['net_type'] == 'rna':
            from util.symbol.symbol import cfg as symcfg
            symcfg['lr_type'] = 'alex'
            symcfg['workspace'] = dargs.mx_workspace
            symcfg['bn_use_global_stats'] = True
            if model_specs['net_name'] == 'a1':
                from util.symbol.resnet_v2 import fcrna_model_a1
                net = fcrna_model_a1(margs.classes,
                                     margs.feat_stride,
                                     bootstrapping=False)
            if model_specs['net_name'] == 'd':
                # load network
                from importlib import import_module
                sym = import_module('util.symbol.resnet')
                net = sym.get_symbol(19, 101, '3,512,512', conv_workspace=1650)

        if net is None:
            raise NotImplementedError('Unknown network: {}'.format(
                vars(margs)))
    contexts = [mx.gpu(int(_)) for _ in args.gpus.split(',')]
    mod = mx.mod.Module(net, context=contexts)
    return mod
Exemplo n.º 2
0
def _get_module(args, margs, dargs, net=None):
    if net is None:
        # the following lines show how to create symbols for our networks
        if model_specs['net_type'] == 'rna':
            from util.symbol.symbol import cfg as symcfg
            symcfg['lr_type'] = 'alex'
            symcfg['workspace'] = dargs.mx_workspace
            symcfg['bn_use_global_stats'] = True
            if model_specs['net_name'] == 'a1':
                from util.symbol.resnet_v2 import fcrna_model_a1
                net = fcrna_model_a1(margs.classes, margs.feat_stride, bootstrapping=True)
        if net is None:
            raise NotImplementedError('Unknown network: {}'.format(vars(margs)))
    contexts = [mx.gpu(int(_)) for _ in args.gpus.split(',')]
    mod = mx.mod.Module(net, context=contexts)
    return mod
Exemplo n.º 3
0
def _get_module(margs, dargs, net=None):
    if net is None:
        # the following lines show how to create symbols for our networks
        if model_specs['net_type'] == 'rna':
            from util.symbol.symbol import cfg as symcfg
            if model_specs['net_name'] == 'a1':
                symcfg['use_global_stats'] = True
                symcfg['workspace'] = dargs.mx_workspace
                from util.symbol.resnet_v2 import fcrna_model_a1
                net = fcrna_model_a1(margs.classes, margs.feat_stride)
        if net is None:
            raise NotImplementedError('Unknown network: {}'.format(
                vars(margs)))
    contexts = [mx.gpu(int(_)) for _ in args.gpus.split(',')]
    mod = mx.mod.Module(net, context=contexts)
    return mod
Exemplo n.º 4
0
def do(args, model_specs, logger):
    meta = get_dataset_specs(args, model_specs)
    id_2_label = meta['id_2_label']
    cmap = meta['cmap']
    input_h = 1024
    input_w = 2048
    classes = model_specs['classes']
    label_stride = model_specs['feat_stride']

    start_idx = args.start
    end_idx = args.end
    print('{}-{}\n'.format(start_idx, end_idx))

    image_list = []
    idx = 0
    with open(args.file_list) as f:
        for item in f.readlines():
            if idx < start_idx:
                idx += 1
                continue

            if idx > end_idx:
                break

            item = item.strip()
            image_list.append(os.path.join(args.data_root, item))
            idx += 1

    net_args, net_auxs = mxutil.load_params_from_file(args.weights)
    net = fcrna_model_a1(classes, label_stride, bootstrapping=True)
    if net is None:
        raise NotImplementedError('Unknown network')
    contexts = [mx.gpu(int(_)) for _ in args.gpus.split(',')]
    mod = mx.mod.Module(net, context=contexts)

    crop_size = 2048
    save_dir = args.output

    x_num = len(image_list)

    transformers = [ts.Scale(crop_size, Image.CUBIC, False)]
    transformers += _get_transformer_image()
    transformer = ts.Compose(transformers)

    start = time.time()

    for i in range(x_num):
        time1 = time.time()

        sample_name = osp.splitext(osp.basename(image_list[i]))[0]
        out_path = osp.join(save_dir, '{}.png'.format(sample_name))
        if os.path.exists(out_path):
            continue

        im_path = osp.join(args.data_root, image_list[i])
        rim = np.array(Image.open(im_path).convert('RGB'), np.uint8)

        h, w = rim.shape[:2]
        need_resize = False
        if h != input_h or w != input_w:
            need_resize = True
            im = np.array(
                Image.fromarray(rim.astype(np.uint8, copy=False)).resize(
                    (input_w, input_h), Image.NEAREST))
        else:
            im = rim
        im = transformer(im)
        imh, imw = im.shape[:2]

        # init
        label_h, label_w = input_h / label_stride, input_w / label_stride
        test_steps = 1
        pred_stride = label_stride / test_steps
        pred_h, pred_w = label_h * test_steps, label_w * test_steps

        input_data = np.zeros((1, 3, input_h, input_w), np.single)
        input_label = 255 * np.ones((1, label_h * label_w), np.single)
        dataiter = mx.io.NDArrayIter(input_data, input_label)
        batch = dataiter.next()
        mod.bind(dataiter.provide_data,
                 dataiter.provide_label,
                 for_training=False,
                 force_rebind=True)
        if not mod.params_initialized:
            mod.init_params(arg_params=net_args, aux_params=net_auxs)

        nim = np.zeros((3, imh + label_stride, imw + label_stride), np.single)
        sy = sx = label_stride / 2
        nim[:, sy:sy + imh, sx:sx + imw] = im.transpose(2, 0, 1)

        net_preds = np.zeros((classes, pred_h, pred_w), np.single)
        # sy = sx = pred_stride // 2 + np.arange(test_steps) * pred_stride
        # sy = sx = sy[0]
        input_data = np.zeros((1, 3, input_h, input_w), np.single)
        input_data[0, :, :imh, :imw] = nim[:, sy:sy + imh, sx:sx + imw]
        batch.data[0] = mx.nd.array(input_data)
        mod.forward(batch, is_train=False)
        this_call_preds = mod.get_outputs()[0].asnumpy()[0]
        if args.test_flipping:
            batch.data[0] = mx.nd.array(input_data[:, :, :, ::-1])
            mod.forward(batch, is_train=False)
            this_call_preds = 0.5 * (
                this_call_preds +
                mod.get_outputs()[0].asnumpy()[0][:, :, ::-1])
        net_preds[:, 0:0 + pred_h:test_steps,
                  0:0 + pred_w:test_steps] = this_call_preds

        # compute pixel-wise predictions
        interp_preds = interp_preds_as(rim.shape[:2], net_preds, pred_stride,
                                       imh, imw)
        pred_label = interp_preds.argmax(0)
        if id_2_label is not None:
            pred_label = id_2_label[pred_label]

        # save predicted labels into an image
        im_to_save = Image.fromarray(pred_label.astype(np.uint8))
        if cmap is not None:
            im_to_save.putpalette(cmap.ravel())

        if need_resize:
            im_to_save = im_to_save.resize((w, h), Image.NEAREST)

        im_to_save.save(out_path)

        time2 = time.time()
        print("{}/{} {} finish in {} s\n".format(i, x_num, out_path,
                                                 time2 - time1))

    logger.info('Done in %.2f s.', time.time() - start)