Пример #1
0
def load_darknet_model(filename,
                       model_type,
                       number_of_foreground_classes,
                       retry_with_update_if_possible=False,
                       darknet_library='libdarknet.so',
                       darknet_cfg_filename=''):
    # Extract information of anchors.
    anchor_info = None
    with open(darknet_cfg_filename, 'r') as fp:
        line = iter(fp.readline, '')
        while True:
            try:
                current = line.next()
            except StopIteration:
                break

            if current.lstrip(' ').startswith('anchors'):
                # Use only the line found first.
                # For a YOLOv3 model, it should be fixed in future.
                try:
                    anchor_info = [
                        float(s) for s in current.split('=')[1].split(',')
                    ]
                except ValueError:
                    print('Failed to extract numbers from the line [%s]' %
                          (current, ))
                    raise
                break
    #
    print('anchors = %s' % (anchor_info, ))

    if model_type == 'yolo_v2':
        overwritten_class_variables = None
        if anchor_info != None:
            assert (len(anchor_info) % 2 == 0)
            anchors = []
            for n in range(len(anchor_info) / 2):
                width = anchor_info[2 * n]
                height = anchor_info[2 * n + 1]
                anchors.append((height, width))
            anchors = tuple(anchors)
            overwritten_class_variables = {'_anchors': anchors}
            print('anchors: %s' % (anchors, ))

        model = make_serializable_object(
            chainercv.links.YOLOv2,
            constructor_args={
                'n_fg_class': number_of_foreground_classes,
            },
            overwritten_class_variables=overwritten_class_variables)
    elif model_type == 'yolo_v3':
        overwritten_class_variables = None
        if anchor_info != None:
            # Currently, the anchor information is generated from
            # only the first line that includes 'anchors = ' in the
            # cfg file.
            # It may have to be fixed in future.
            assert (len(anchor_info) % (2 * 3) == 0)
            anchors = []
            offset = 0
            number_of_pairs = len(anchor_info) / 2
            number_in_group = number_of_pairs / 3
            for l in range(3):
                anchors_in_group = []
                offset = l * number_in_group * 2
                for n in range(number_in_group):
                    width = int(anchor_info[offset + 2 * n])
                    height = int(anchor_info[offset + 2 * n + 1])
                    anchors_in_group.append((height, width))
                anchors.append(tuple(anchors_in_group))
            anchors = tuple(reversed(anchors))
            overwritten_class_variables = {'_anchors': anchors}
            print('anchors: %s' % (anchors, ))

        model = make_serializable_object(
            chainercv.links.YOLOv3,
            constructor_args={
                'n_fg_class': number_of_foreground_classes,
            },
            overwritten_class_variables=overwritten_class_variables)

    with chainer.using_config('train', False):
        model(np.empty((1, 3, model.insize, model.insize), dtype=np.float32))

    model_is_old_format = False
    with open(filename, mode='rb') as f:
        major = np.fromfile(f, dtype=np.int32, count=1)
        minor = np.fromfile(f, dtype=np.int32, count=1)
        np.fromfile(f, dtype=np.int32, count=1)  # revision
        if major == 0 and minor <= 1:
            model_is_old_format = True
            print(
                'The file "%s" is written in an old format, where (major, minor) == (%d, %d) '
                % (filename, major, minor))
            sys.stdout.flush()
            model = None
        else:
            print(
                'The file "%s" is written in a new format, where (major, minor) == (%d, %d) '
                % (filename, major, minor))
            assert (major * 10 + minor >= 2 and major < 1000 and minor < 1000)
            model_is_old_format = False
            np.fromfile(f, dtype=np.int64, count=1)  # seen
            if model_type == 'yolo_v2':
                load_yolo_v2(f, model)
            elif model_type == 'yolo_v3':
                load_yolo_v3(f, model)

    if model_is_old_format and retry_with_update_if_possible:
        import tempfile
        import os
        print('Try to convert "%s" to the new format and load it.' %
              (filename, ))
        sys.stdout.flush()
        fileno, tmp_path = tempfile.mkstemp()
        try:
            succeeded = update_format_of_darknet_model(tmp_path, filename,
                                                       darknet_library,
                                                       darknet_cfg_filename)
            if succeeded:
                print(
                    'Succeeded to update "%s" and save the updated version into "%s".'
                    % (filename, tmp_path))
                sys.stdout.flush()
                model = load_darknet_model(
                    tmp_path,
                    model_type,
                    number_of_foreground_classes,
                    retry_with_update_if_possible=False,
                    darknet_cfg_filename=darknet_cfg_filename)
            else:
                print('Failed to convert "%s".' % (filename, ))
                sys.stdout.flush()
        finally:
            # remove the temporary file
            os.remove(tmp_path)
    return model
    import rospy
    parsed_args = rospy.myargv(argv=sys.argv)
except ImportError:
    pass

args = parser.parse_args(parsed_args[1:])

print("Preparing the pretrained model...")
sys.stdout.flush()
if args.model == 'faster_rcnn':
    model = make_serializable_object(
        chainercv.links.FasterRCNNVGG16,
        constructor_args={
            'n_fg_class': len(voc_bbox_label_names),
            'pretrained_model': args.pretrained_model,
        },
        template_args={
            # Do not retrieve the pre-trained model again on generating a
            # template object for loading weights in a file.
            'n_fg_class': len(voc_bbox_label_names),
            'pretrained_model': None,
        })
elif args.model == 'yolo_v2':
    model = make_serializable_object(
        chainercv.links.YOLOv2,
        constructor_args={
            'n_fg_class': len(voc_bbox_label_names),
            'pretrained_model': args.pretrained_model,
        },
        template_args={
            # Do not retrieve the pre-trained model again on generating a
            # template object for loading weights in a file.