def main():
    config = configparser.ConfigParser()
    config.read('config.ini', 'UTF-8')
    dataset_type = config.get('dataset', 'type')
    logger.info('loading {}'.format(dataset_type))
    if dataset_type == 'mpii':
        _, test_set = get_mpii_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'images'),
            annotations=config.get(dataset_type, 'annotations'),
            train_size=config.getfloat(dataset_type, 'train_size'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            seed=config.getint('training_param', 'seed'),
        )
    elif dataset_type == 'coco':
        # 已经将原来的图片换成固定大小
        test_set = get_coco_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'val_images'),
            annotations=config.get(dataset_type, 'val_annotations'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
        )
    else:
        raise Exception('Unknown dataset {}'.format(dataset_type))

    model = create_model(config)

    ## 生成用于计算mAP的gt_KPs、pred_KPs, gt_bbox
    mAP = [[], [], []]
    # 测试多张图片
    for i in range(30):
        #  pdb()
        idx = random.choice(range(len(test_set)))
        image = test_set.get_example(idx)['image']
        gt_kps = test_set.get_example(idx)['keypoints']
        # coco person   mpii head
        gt_bboxs = test_set.get_example(idx)['bbox']
        humans = estimate(model,
                        image.astype(np.float32))
        mAP[0].append(gt_kps)
        mAP[1].append(humans)
        mAP[2].append(gt_bboxs)
        pil_image = Image.fromarray(image.transpose(1, 2, 0).astype(np.uint8))
        pil_image = draw_humans(
            keypoint_names=model.keypoint_names,
            edges=model.edges,
            pil_image=pil_image,
            humans=humans
        )

        pil_image.save('results/result{}.png'.format(i), 'PNG')

    gene_json(mAP)
Example #2
0
def create_model(args, config):
    global DIRECTED_GRAPHS, COLOR_MAP

    dataset_type = config.get('dataset', 'type')

    if dataset_type == 'mpii':
        import mpii_dataset as x_dataset
    elif dataset_type == 'coco':
        import coco_dataset as x_dataset
    else:
        raise Exception('Unknown dataset {}'.format(dataset_type))

    KEYPOINT_NAMES = x_dataset.KEYPOINT_NAMES
    EDGES = x_dataset.EDGES
    DIRECTED_GRAPHS = x_dataset.DIRECTED_GRAPHS
    COLOR_MAP = x_dataset.COLOR_MAP

    model = PoseProposalNet(
        model_name=config.get('model_param', 'model_name'),
        insize=parse_size(config.get('model_param', 'insize')),
        keypoint_names=KEYPOINT_NAMES,
        edges=np.array(EDGES),
        local_grid_size=parse_size(config.get('model_param',
                                              'local_grid_size')),
        parts_scale=parse_size(config.get(dataset_type, 'parts_scale')),
        instance_scale=parse_size(config.get(dataset_type, 'instance_scale')),
        width_multiplier=config.getfloat('model_param', 'width_multiplier'),
    )

    logger.info('input size = {}'.format(model.insize))
    logger.info('output size = {}'.format(model.outsize))

    try:
        result_dir = args.model
    except:
        result_dir = args

    chainer.serializers.load_npz(os.path.join(result_dir, 'bestmodel.npz'),
                                 model)

    logger.info('cuda enable {}'.format(chainer.backends.cuda.available))
    logger.info('ideep enable {}'.format(
        chainer.backends.intel64.is_ideep_available()))
    if chainer.backends.cuda.available:
        logger.info('gpu mode')
        model.to_gpu()
    elif chainer.backends.intel64.is_ideep_available():
        logger.info('Indel64 mode')
        model.to_intel64()
    return model
Example #3
0
def predict(args):
    config = load_config(args)
    detection_thresh = config.getfloat('predict', 'detection_thresh')
    min_num_keypoints = config.getint('predict', 'min_num_keypoints')
    dataset_type = config.get('dataset', 'type')
    logger.info('loading {}'.format(dataset_type))
    if dataset_type == 'mpii':
        _, test_set = get_mpii_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'images'),
            annotations=config.get(dataset_type, 'annotations'),
            train_size=config.getfloat(dataset_type, 'train_size'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            seed=config.getint('training_param', 'seed'),
        )
    elif dataset_type == 'coco':
        test_set = get_coco_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'val_images'),
            annotations=config.get(dataset_type, 'val_annotations'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
        )
    else:
        raise Exception('Unknown dataset {}'.format(dataset_type))

    model = create_model(args, config)

    # choose specific image
    idx = random.choice(range(len(test_set)))
    idx = 50
    image = test_set.get_example(idx)['image']
    humans = estimate(
        model,
        image.astype(np.float32),
        detection_thresh,
        min_num_keypoints,
    )
    pil_image = Image.fromarray(image.transpose(1, 2, 0).astype(np.uint8))
    pil_image = draw_humans(keypoint_names=model.keypoint_names,
                            edges=model.edges,
                            pil_image=pil_image,
                            humans=humans,
                            visbbox=config.getboolean('predict', 'visbbox'))

    #pil_image.save('result.png', 'PNG')
    pil_image.save(
        'result_' + 'X'.join((str(_.insize[0]), str(_.insize[1]))) + '_idx_' +
        str(idx) + '_time_' + str(round(inference_time, 3)) + 's.png', 'PNG')
Example #4
0
def process_proxy_options(parser, options):
    if options.cert:
        options.cert = os.path.expanduser(options.cert)
        if not os.path.exists(options.cert):
            return parser.error(
                "Manually created certificate does not exist: %s" %
                options.cert)

    cacert = os.path.join(options.confdir, "mitmproxy-ca.pem")
    cacert = os.path.expanduser(cacert)
    if not os.path.exists(cacert):
        certutils.dummy_ca(cacert)
    body_size_limit = utils.parse_size(options.body_size_limit)
    if options.reverse_proxy and options.transparent_proxy:
        return parser.error(
            "Can't set both reverse proxy and transparent proxy.")

    if options.transparent_proxy:
        if not platform.resolver:
            return parser.error(
                "Transparent mode not supported on this platform.")
        trans = dict(resolver=platform.resolver(),
                     sslports=TRANSPARENT_SSL_PORTS)
    else:
        trans = None

    if options.reverse_proxy:
        rp = utils.parse_proxy_spec(options.reverse_proxy)
        if not rp:
            return parser.error("Invalid reverse proxy specification: %s" %
                                options.reverse_proxy)
    else:
        rp = None

    if options.clientcerts:
        options.clientcerts = os.path.expanduser(options.clientcerts)
        if not os.path.exists(options.clientcerts) or not os.path.isdir(
                options.clientcerts):
            return parser.error(
                "Client certificate directory does not exist or is not a directory: %s"
                % options.clientcerts)

    if (options.auth_nonanonymous or options.auth_singleuser
            or options.auth_htpasswd):
        if options.auth_singleuser:
            if len(options.auth_singleuser.split(':')) != 2:
                return parser.error(
                    "Invalid single-user specification. Please use the format username:password"
                )
            username, password = options.auth_singleuser.split(':')
            password_manager = http_auth.PassManSingleUser(username, password)
        elif options.auth_nonanonymous:
            password_manager = http_auth.PassManNonAnon()
        elif options.auth_htpasswd:
            try:
                password_manager = http_auth.PassManHtpasswd(
                    options.auth_htpasswd)
            except ValueError, v:
                return parser.error(v.message)
        authenticator = http_auth.BasicProxyAuth(password_manager, "mitmproxy")
Example #5
0
def put_files(dist: Distributor, f_type, count, size):
    """Example:
    depth: 4
    width: 4
    layers:
    - layer1:
        - size: 10KB
            type: regular
            count: 2000
        - size: 12MB
            type: regular
            count: 10
        - size: 90MB
            type: regular
            count: 1
        - type: symlink
            count: 100

    """

    logging.info("putting %s, count %d", f_type, count)
    if f_type == "regular":
        size_in_bytes = utils.parse_size(size)
        dist.put_multiple_files(count, Size(size_in_bytes))
    elif f_type == "dir":
        dist.put_directories(count)
    elif f_type == "symlink":
        dist.put_symlinks(count)
    elif f_type == "hardlink":
        dist.put_hardlinks(count)
Example #6
0
    def __init__(self, config):
        super(MyModel, self).__init__()

        dataset_type = config.get('dataset', 'type')
        if dataset_type == 'mpii':
            import mpii_dataset as x_dataset
        elif dataset_type == 'coco':
            import coco_dataset as x_dataset
        else:
            raise Exception('Unknown dataset {}'.format(dataset_type))

        with self.init_scope():
            dtype = np.float32
            self.feature_layer = get_network(config.get(
                'model_param', 'model_name'),
                                             dtype=dtype,
                                             width_multiplier=1.0)
            ksize = self.feature_layer.last_ksize
            self.local_grid_size = parse_size(
                config.get('model_param', 'local_grid_size'))
            self.keypoint_names = x_dataset.KEYPOINT_NAMES
            self.edges = x_dataset.EDGES
            self.lastconv = L.Convolution2D(
                None,
                6 * len(self.keypoint_names) + self.local_grid_size[0] *
                self.local_grid_size[1] * len(self.edges),
                ksize=ksize,
                stride=1,
                pad=ksize // 2,
                initialW=initializers.HeNormal(1 / np.sqrt(2), dtype))
Example #7
0
def process_proxy_options(parser, options):
    if options.cert:
        options.cert = os.path.expanduser(options.cert)
        if not os.path.exists(options.cert):
            return parser.error("Manually created certificate does not exist: %s"%options.cert)

    cacert = os.path.join(options.confdir, "mitmproxy-ca.pem")
    cacert = os.path.expanduser(cacert)
    if not os.path.exists(cacert):
        certutils.dummy_ca(cacert)
    body_size_limit = utils.parse_size(options.body_size_limit)
    if options.reverse_proxy and options.transparent_proxy:
        return parser.error("Can't set both reverse proxy and transparent proxy.")

    if options.transparent_proxy:
        if not platform.resolver:
            return parser.error("Transparent mode not supported on this platform.")
        trans = dict(
            resolver = platform.resolver(),
            sslports = TRANSPARENT_SSL_PORTS
        )
    else:
        trans = None

    if options.reverse_proxy:
        rp = utils.parse_proxy_spec(options.reverse_proxy)
        if not rp:
            return parser.error("Invalid reverse proxy specification: %s"%options.reverse_proxy)
    else:
        rp = None

    if options.forward_proxy:
        fp = utils.parse_proxy_spec(options.forward_proxy)
        if not fp:
            return parser.error("Invalid forward proxy specification: %s"%options.forward_proxy)
    else:
        fp = None

    if options.clientcerts:
        options.clientcerts = os.path.expanduser(options.clientcerts)
        if not os.path.exists(options.clientcerts) or not os.path.isdir(options.clientcerts):
            return parser.error(
                    "Client certificate directory does not exist or is not a directory: %s"%options.clientcerts
                )

    if (options.auth_nonanonymous or options.auth_singleuser or options.auth_htpasswd):
        if options.auth_singleuser:
            if len(options.auth_singleuser.split(':')) != 2:
                return parser.error("Invalid single-user specification. Please use the format username:password")
            username, password = options.auth_singleuser.split(':')
            password_manager = http_auth.PassManSingleUser(username, password)
        elif options.auth_nonanonymous:
            password_manager = http_auth.PassManNonAnon()
        elif options.auth_htpasswd:
            try:
                password_manager = http_auth.PassManHtpasswd(options.auth_htpasswd)
            except ValueError, v:
                return parser.error(v.message)
        authenticator = http_auth.BasicProxyAuth(password_manager, "mitmproxy")
Example #8
0
def create_model(config, dataset):
    dataset_type = config.get('dataset', 'type')
    return PoseProposalNet(
        model_name=config.get('model_param', 'model_name'),
        insize=parse_size(config.get('model_param', 'insize')),
        keypoint_names=dataset.keypoint_names,
        edges=dataset.edges,
        local_grid_size=parse_size(config.get('model_param', 'local_grid_size')),
        parts_scale=parse_size(config.get(dataset_type, 'parts_scale')),
        instance_scale=parse_size(config.get(dataset_type, 'instance_scale')),
        width_multiplier=config.getfloat('model_param', 'width_multiplier'),
        lambda_resp=config.getfloat('model_param', 'lambda_resp'),
        lambda_iou=config.getfloat('model_param', 'lambda_iou'),
        lambda_coor=config.getfloat('model_param', 'lambda_coor'),
        lambda_size=config.getfloat('model_param', 'lambda_size'),
        lambda_limb=config.getfloat('model_param', 'lambda_limb'),
    )
Example #9
0
def process_proxy_options(parser, options):
    if options.cert:
        options.cert = os.path.expanduser(options.cert)
        if not os.path.exists(options.cert):
            parser.error("Manually created certificate does not exist: %s"%options.cert)

    cacert = os.path.join(options.confdir, "mitmproxy-ca.pem")
    cacert = os.path.expanduser(cacert)
    if not os.path.exists(cacert):
        certutils.dummy_ca(cacert)
    if getattr(options, "cache", None) is not None:
        options.cache = os.path.expanduser(options.cache)
    body_size_limit = utils.parse_size(options.body_size_limit)

    if options.reverse_proxy and options.transparent_proxy:
        parser.errror("Can't set both reverse proxy and transparent proxy.")

    if options.transparent_proxy:
        if not platform.resolver:
            parser.error("Transparent mode not supported on this platform.")
        trans = dict(
            resolver = platform.resolver,
            sslports = TRANSPARENT_SSL_PORTS
        )
    else:
        trans = None

    if options.reverse_proxy:
        rp = utils.parse_proxy_spec(options.reverse_proxy)
        if not rp:
            parser.error("Invalid reverse proxy specification: %s"%options.reverse_proxy)
    else:
        rp = None

    if options.clientcerts:
        options.clientcerts = os.path.expanduser(options.clientcerts)
        if not os.path.exists(options.clientcerts) or not os.path.isdir(options.clientcerts):
            parser.error("Client certificate directory does not exist or is not a directory: %s"%options.clientcerts)

    if options.certdir:
        options.certdir = os.path.expanduser(options.certdir)
        if not os.path.exists(options.certdir) or not os.path.isdir(options.certdir):
            parser.error("Dummy cert directory does not exist or is not a directory: %s"%options.certdir)

    return ProxyConfig(
        certfile = options.cert,
        cacert = cacert,
        clientcerts = options.clientcerts,
        cert_wait_time = options.cert_wait_time,
        body_size_limit = body_size_limit,
        no_upstream_cert = options.no_upstream_cert,
        reverse_proxy = rp,
        transparent_proxy = trans,
        certdir = options.certdir
    )
Example #10
0
def process_proxy_options(parser, options):
    if options.cert:
        options.cert = os.path.expanduser(options.cert)
        if not os.path.exists(options.cert):
            parser.error("Manually created certificate does not exist: %s"%options.cert)

    cacert = os.path.join(options.confdir, "mitmproxy-ca.pem")
    cacert = os.path.expanduser(cacert)
    if not os.path.exists(cacert):
        certutils.dummy_ca(cacert)
    if getattr(options, "cache", None) is not None:
        options.cache = os.path.expanduser(options.cache)
    body_size_limit = utils.parse_size(options.body_size_limit)

    if options.reverse_proxy and options.transparent_proxy:
        parser.errror("Can't set both reverse proxy and transparent proxy.")

    if options.transparent_proxy:
        if not platform.resolver:
            parser.error("Transparent mode not supported on this platform.")
        trans = dict(
            resolver = platform.resolver(),
            sslports = TRANSPARENT_SSL_PORTS
        )
    else:
        trans = None

    if options.reverse_proxy:
        rp = utils.parse_proxy_spec(options.reverse_proxy)
        if not rp:
            parser.error("Invalid reverse proxy specification: %s"%options.reverse_proxy)
    else:
        rp = None

    if options.clientcerts:
        options.clientcerts = os.path.expanduser(options.clientcerts)
        if not os.path.exists(options.clientcerts) or not os.path.isdir(options.clientcerts):
            parser.error("Client certificate directory does not exist or is not a directory: %s"%options.clientcerts)

    if options.certdir:
        options.certdir = os.path.expanduser(options.certdir)
        if not os.path.exists(options.certdir) or not os.path.isdir(options.certdir):
            parser.error("Dummy cert directory does not exist or is not a directory: %s"%options.certdir)

    return ProxyConfig(
        certfile = options.cert,
        cacert = cacert,
        clientcerts = options.clientcerts,
        cert_wait_time = options.cert_wait_time,
        body_size_limit = body_size_limit,
        no_upstream_cert = options.no_upstream_cert,
        reverse_proxy = rp,
        transparent_proxy = trans,
        certdir = options.certdir
    )
Example #11
0
File: api.py Project: larsks/vmm
    def create_volume(self, disk):
        size, unit = utils.parse_size(disk['size'])
        if 'backing_store' in disk:
            self.resolve_volume(disk['backing_store'])

        xml = env.get_template('volume.xml').render(
                vol=disk)
        pool = self.find_pool(disk['pool'])
        self.log.debug('creating vol %s in pool %s',
                disk['name'], pool.name())
        vol = pool.createXML(xml, 0)
        return vol
Example #12
0
def main():
    config = configparser.ConfigParser()
    config.read('config.ini', 'UTF-8')
    dataset_type = config.get('dataset', 'type')
    logger.info('loading {}'.format(dataset_type))
    if dataset_type == 'mpii':
        _, test_set = get_mpii_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'images'),
            annotations=config.get(dataset_type, 'annotations'),
            train_size=config.getfloat(dataset_type, 'train_size'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            seed=config.getint('training_param', 'seed'),
        )
    elif dataset_type == 'coco':
        test_set = get_coco_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'val_images'),
            annotations=config.get(dataset_type, 'val_annotations'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
        )
    else:
        raise Exception('Unknown dataset {}'.format(dataset_type))

    model = create_model(config)

    idx = random.choice(range(len(test_set)))
    image = test_set.get_example(idx)['image']
    humans = estimate(model,
                      image.astype(np.float32))
    pil_image = Image.fromarray(image.transpose(1, 2, 0).astype(np.uint8))
    pil_image = draw_humans(
        keypoint_names=model.keypoint_names,
        edges=model.edges,
        pil_image=pil_image,
        humans=humans
    )

    pil_image.save('result.png', 'PNG')
def export_onnx(args):
    config = load_config(args)
    model = MyModel(config)
    chainer.serializers.load_npz(os.path.join(args.model, 'bestmodel.npz'),
                                 model)
    w, h = parse_size(config.get('model_param', 'insize'))
    x = np.zeros((1, 3, h, w), dtype=np.float32)
    logger.info('begin export')
    output = os.path.join(args.model, 'bestmodel.onnx')
    with chainer.using_config('train', False):
        onnx_chainer.export(model, x, filename=output)
    logger.info('end export')
    logger.info('run onnx.check')
    onnx_model = onnx.load(output)
    onnx.checker.check_model(onnx_model)
    logger.info('done')
Example #14
0
File: disk.py Project: larsks/vmm
 def size(self):
     disk_size, disk_unit = utils.parse_size(self['size'])
     return utils.adjust_size(int(disk_size), disk_unit)
Example #15
0
def process_proxy_options(parser, options):
    if options.cert:
        options.cert = os.path.expanduser(options.cert)
        if not os.path.exists(options.cert):
            parser.error("Manually created certificate does not exist: %s"%options.cert)

    cacert = os.path.join(options.confdir, "mitmproxy-ca.pem")
    cacert = os.path.expanduser(cacert)
    if not os.path.exists(cacert):
        certutils.dummy_ca(cacert)
    if getattr(options, "cache", None) is not None:
        options.cache = os.path.expanduser(options.cache)
    body_size_limit = utils.parse_size(options.body_size_limit)

    if options.reverse_proxy and options.transparent_proxy:
        parser.errror("Can't set both reverse proxy and transparent proxy.")

    if options.transparent_proxy:
        if not platform.resolver:
            parser.error("Transparent mode not supported on this platform.")
        trans = dict(
            resolver = platform.resolver(),
            sslports = TRANSPARENT_SSL_PORTS
        )
    else:
        trans = None

    if options.reverse_proxy:
        rp = utils.parse_proxy_spec(options.reverse_proxy)
        if not rp:
            parser.error("Invalid reverse proxy specification: %s"%options.reverse_proxy)
    else:
        rp = None

    if options.clientcerts:
        options.clientcerts = os.path.expanduser(options.clientcerts)
        if not os.path.exists(options.clientcerts) or not os.path.isdir(options.clientcerts):
            parser.error("Client certificate directory does not exist or is not a directory: %s"%options.clientcerts)

    if options.certdir:
        options.certdir = os.path.expanduser(options.certdir)
        if not os.path.exists(options.certdir) or not os.path.isdir(options.certdir):
            parser.error("Dummy cert directory does not exist or is not a directory: %s"%options.certdir)

    if (options.auth_nonanonymous or options.auth_singleuser or options.auth_htpasswd):
        if options.auth_singleuser:
            if len(options.auth_singleuser.split(':')) != 2:
                parser.error("Please specify user in the format username:password")
            username, password = options.auth_singleuser.split(':')
            password_manager = authentication.SingleUserPasswordManager(username, password)
        elif options.auth_nonanonymous:
            password_manager = authentication.PermissivePasswordManager()
        elif options.auth_htpasswd:
            password_manager = authentication.HtpasswdPasswordManager(options.auth_htpasswd)
        authenticator = authentication.BasicProxyAuth(password_manager, "mitmproxy")
    else:
        authenticator = authentication.NullProxyAuth(None)

    return ProxyConfig(
        certfile = options.cert,
        cacert = cacert,
        clientcerts = options.clientcerts,
        body_size_limit = body_size_limit,
        no_upstream_cert = options.no_upstream_cert,
        reverse_proxy = rp,
        transparent_proxy = trans,
        certdir = options.certdir,
        authenticator = authenticator
    )
Example #16
0
def main():

    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument("-m",
                        "--modelname",
                        help="model full name",
                        default='',
                        dest='modelName')
    parser.add_argument("-n",
                        "--testnum",
                        help="the number of test image",
                        type=int,
                        default=1000,
                        dest='test_num')
    args = parser.parse_args()
    modelName = args.modelName

    config = configparser.ConfigParser()
    config.read('config.ini', 'UTF-8')
    dataset_type = config.get('dataset', 'type')
    logger.info('loading {}'.format(dataset_type))
    if dataset_type == 'mpii':
        _, test_set = get_mpii_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'images'),
            annotations=config.get(dataset_type, 'annotations'),
            train_size=config.getfloat(dataset_type, 'train_size'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            seed=config.getint('training_param', 'seed'),
        )
    elif dataset_type == 'coco':
        test_set = get_coco_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'val_images'),
            annotations=config.get(dataset_type, 'val_annotations'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
        )
    else:
        raise Exception('Unknown dataset {}'.format(dataset_type))

    model = create_model(config, modelName)

    ## 生成用于计算pck_object的gt_KPs、 gt_bboxs, human(pred_KPs, pred_bboxs) is _visible
    pck_object = [[], [], [], []]

    modelName = modelName if modelName else 'trained/bestmodel.npz'
    test_num = args.test_num
    print('model name: {}\t test image number: {}'.format(modelName, test_num))
    # 测试多张图片
    for i in tqdm(range(test_num)):
        idx = random.choice(range(len(test_set)))
        image = test_set.get_example(idx)['image']
        gt_kps = test_set.get_example(idx)['keypoints']
        gt_bboxs = test_set.get_example(idx)['bbox']  # (left down point, w, h)
        is_visible = test_set.get_example(idx)['is_visible']  #

        # include pred_KPs, pred_bbox
        humans = estimate(model, image.astype(np.float32), 0.15)
        pck_object[0].append(gt_kps)
        pck_object[1].append(humans)
        pck_object[2].append(gt_bboxs)
        pck_object[3].append(is_visible)
    mylog.info('model name: {}\t test image number: {}'.format(
        modelName, test_num))
    evaluation(config, pck_object)
Example #17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_path', type=str, default='config.ini')
    parser.add_argument('--resume')
    parser.add_argument('--plot_samples', type=int, default=0)
    args = parser.parse_args()

    config = configparser.ConfigParser()
    config.read(args.config_path, 'UTF-8')

    chainer.global_config.autotune = True
    chainer.cuda.set_max_workspace_size(11388608)

    # create result dir and copy file
    logger.info('> store file to result dir %s', config.get('result', 'dir'))
    save_files(config.get('result', 'dir'))

    logger.info('> set up devices')
    devices = setup_devices(config.get('training_param', 'gpus'))
    set_random_seed(devices, config.getint('training_param', 'seed'))

    logger.info('> get dataset')
    dataset_type = config.get('dataset', 'type')
    if dataset_type == 'coco':
        # force to set `use_cache = False`
        train_set = get_coco_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'train_images'),
            annotations=config.get(dataset_type, 'train_annotations'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            use_cache=False,
            do_augmentation=True,
        )
        test_set = get_coco_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'val_images'),
            annotations=config.get(dataset_type, 'val_annotations'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            use_cache=False,
        )
    elif dataset_type == 'mpii':
        train_set, test_set = get_mpii_dataset(
            insize=parse_size(config.get('model_param', 'insize')),
            image_root=config.get(dataset_type, 'images'),
            annotations=config.get(dataset_type, 'annotations'),
            train_size=config.getfloat(dataset_type, 'train_size'),
            min_num_keypoints=config.getint(dataset_type, 'min_num_keypoints'),
            use_cache=config.getboolean(dataset_type, 'use_cache'),
            seed=config.getint('training_param', 'seed'),
        )
    else:
        raise Exception('Unknown dataset {}'.format(dataset_type))
    logger.info('dataset type: %s', dataset_type)
    logger.info('training images: %d', len(train_set))
    logger.info('validation images: %d', len(test_set))

    if args.plot_samples > 0:
        for i in range(args.plot_samples):
            data = train_set[i]
            visualize.plot('train-{}.png'.format(i), data['image'],
                           data['keypoints'], data['bbox'], data['is_labeled'],
                           data['edges'])
            data = test_set[i]
            visualize.plot('val-{}.png'.format(i), data['image'],
                           data['keypoints'], data['bbox'], data['is_labeled'],
                           data['edges'])

    logger.info('> load model')
    model = create_model(config, train_set)

    logger.info('> transform dataset')
    train_set = TransformDataset(train_set, model.encode)
    test_set = TransformDataset(test_set, model.encode)

    logger.info('> create iterators')
    train_iter = chainer.iterators.MultiprocessIterator(
        train_set,
        config.getint('training_param', 'batchsize'),
        n_processes=config.getint('training_param', 'num_process'))
    test_iter = chainer.iterators.SerialIterator(test_set,
                                                 config.getint(
                                                     'training_param',
                                                     'batchsize'),
                                                 repeat=False,
                                                 shuffle=False)

    logger.info('> setup optimizer')
    optimizer = chainer.optimizers.MomentumSGD()
    optimizer.setup(model)
    optimizer.add_hook(chainer.optimizer.WeightDecay(0.0005))

    logger.info('> setup trainer')
    updater = training.updaters.ParallelUpdater(train_iter,
                                                optimizer,
                                                devices=devices)
    trainer = training.Trainer(
        updater, (config.getint('training_param', 'train_iter'), 'iteration'),
        config.get('result', 'dir'))

    logger.info('> setup extensions')
    trainer.extend(extensions.LinearShift(
        'lr',
        value_range=(config.getfloat('training_param', 'learning_rate'), 0),
        time_range=(0, config.getint('training_param', 'train_iter'))),
                   trigger=(1, 'iteration'))

    trainer.extend(
        extensions.Evaluator(test_iter, model, device=devices['main']))
    if extensions.PlotReport.available():
        trainer.extend(
            extensions.PlotReport([
                'main/loss',
                'validation/main/loss',
            ],
                                  'epoch',
                                  file_name='loss.png'))
    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.observe_lr())
    trainer.extend(
        extensions.PrintReport([
            'epoch',
            'elapsed_time',
            'lr',
            'main/loss',
            'validation/main/loss',
            'main/loss_resp',
            'validation/main/loss_resp',
            'main/loss_iou',
            'validation/main/loss_iou',
            'main/loss_coor',
            'validation/main/loss_coor',
            'main/loss_size',
            'validation/main/loss_size',
            'main/loss_limb',
            'validation/main/loss_limb',
        ]))
    trainer.extend(extensions.ProgressBar())

    trainer.extend(
        extensions.snapshot(filename='best_snapshot'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))
    trainer.extend(
        extensions.snapshot_object(model, filename='bestmodel.npz'),
        trigger=training.triggers.MinValueTrigger('validation/main/loss'))

    if args.resume:
        serializers.load_npz(args.resume, trainer)

    logger.info('> start training')
    trainer.run()
Example #18
0
 def memory(self):
     mem_size, mem_unit = utils.parse_size(self['memory'])
     return utils.adjust_size(int(mem_size), mem_unit)