Exemplo n.º 1
0
    def im_detect(self, im_list, root_dir=None, extension=None, show_timer=False):
        """
        wrapper for detecting multiple images

        Parameters:
        ----------
        im_list : list of str
            image path or list of image paths
        root_dir : str
            directory of input images, optional if image path already
            has full directory information
        extension : str
            image extension, eg. ".jpg", optional

        Returns:
        ----------
        list of detection results in format [det0, det1...], det is in
        format np.array([id, score, xmin, ymin, xmax, ymax]...)
        """
        test_db = TestDB(im_list, root_dir=root_dir, extension=extension)
        test_iter = DetIter(test_db, 1, self.data_shape, self.mean_pixels,
                            is_train=False)
        return self.detect(test_iter, show_timer)
Exemplo n.º 2
0
def train_net(net,
              dataset,
              image_set,
              devkit_path,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              from_scratch,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              optimizer_name='adam',
              learning_rate=1e-03,
              momentum=0.9,
              weight_decay=5e-04,
              lr_refactor_step=(3, 4, 5, 6),
              lr_refactor_ratio=0.1,
              val_image_set='',
              val_year='',
              use_plateau=True,
              year='',
              freeze_layer_pattern='',
              force_resize=True,
              min_obj_size=32.0,
              use_difficult=False,
              nms_thresh=0.45,
              force_suppress=False,
              ovp_thresh=0.5,
              voc07_metric=True,
              nms_topk=400,
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    dataset : str
        pascal_voc, imagenet...
    image_set : str
        train, trainval...
    devkit_path : str
        root directory of dataset
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    year : str
        2007, 2012 or combinations splitted by comma
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    assert len(data_shape) == 3 and data_shape[0] == 3

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    # load dataset
    val_imdb = None
    if dataset == 'pascal_voc':
        imdb = load_pascal(image_set, year, devkit_path, cfg.train['shuffle'])
        if val_image_set:
            assert val_year
            val_imdb = load_pascal(val_image_set, val_year, devkit_path, False)
            max_objects = max(imdb.max_objects, val_imdb.max_objects)
            imdb.pad_labels(max_objects)
            val_imdb.pad_labels(max_objects)
        force_resize = True
    elif dataset == 'wider':
        imdb = load_wider(image_set, devkit_path, cfg.train['shuffle'])
        force_resize = False
    elif dataset == 'mscoco':
        imdb = load_mscoco(image_set, devkit_path, cfg.train['shuffle'])
        force_resize = True
    else:
        raise NotImplementedError("Dataset " + dataset + " not supported")

    # init iterator
    patch_size = data_shape[1]
    min_gt_scale = min_obj_size / float(patch_size)
    rand_scaler = RandScaler(patch_size,
                             min_gt_scale=min_gt_scale,
                             force_resize=force_resize)
    rand_eraser = RandEraser() if force_resize else None
    train_iter = DetIter(imdb,
                         batch_size,
                         data_shape[1],
                         rand_scaler,
                         rand_eraser=rand_eraser,
                         mean_pixels=mean_pixels,
                         rand_mirror=cfg.train['rand_mirror_prob'] > 0,
                         shuffle=cfg.train['shuffle'],
                         rand_seed=cfg.train['seed'],
                         is_train=True)
    if val_imdb:
        rand_scaler = RandScaler(patch_size,
                                 no_random=True,
                                 force_resize=force_resize)
        val_iter = DetIter(val_imdb,
                           batch_size,
                           data_shape[1],
                           rand_scaler,
                           mean_pixels=mean_pixels,
                           is_train=True)
    else:
        val_iter = None

    train_net_common(net, train_iter, val_iter, batch_size, data_shape, resume,
                     finetune, pretrained, epoch, prefix, ctx, begin_epoch,
                     end_epoch, frequent, learning_rate, momentum,
                     weight_decay, use_plateau, lr_refactor_step,
                     lr_refactor_ratio, freeze_layer_pattern, imdb.num_images,
                     imdb.max_objects, nms_thresh, force_suppress, ovp_thresh,
                     use_difficult, imdb.classes, optimizer_name, voc07_metric,
                     nms_topk, iter_monitor, monitor_pattern, logger)
Exemplo n.º 3
0
def train_net(net, dataset, image_set, year, devkit_path, batch_size,
               data_shape, mean_pixels, resume, finetune, pretrained, epoch, prefix,
               ctx, begin_epoch, end_epoch, frequent, learning_rate,
               momentum, weight_decay, val_set, val_year,
               lr_refactor_epoch, lr_refactor_ratio,
               iter_monitor=0, log_file=None):
    """
    Wrapper for training module

    Parameters:
    ---------
    net : mx.Symbol
        training network
    dataset : str
        pascal, imagenet...
    image_set : str
        train, trainval...
    year : str
        2007, 2012 or combinations splitted by comma
    devkit_path : str
        root directory of dataset
    batch_size : int
        training batch size
    data_shape : int or (int, int)
        resize image size
    mean_pixels : tuple (float, float, float)
        mean pixel values in (R, G, B)
    resume : int
        if > 0, will load trained epoch with name given by prefix
    finetune : int
        if > 0, will load trained epoch with name given by prefix, in this mode
        all convolutional layers except the last(prediction layer) are fixed
    pretrained : str
        prefix of pretrained model name
    epoch : int
        epoch of pretrained model
    prefix : str
        prefix of new model
    ctx : mx.gpu(?) or list of mx.gpu(?)
        training context
    begin_epoch : int
        begin epoch, default should be 0
    end_epoch : int
        when to stop training
    frequent : int
        frequency to log out batch_end_callback
    learning_rate : float
        learning rate, will be divided by batch_size automatically
    momentum : float
        (0, 1), training momentum
    weight_decay : float
        decay weights regardless of gradient
    val_set : str
        similar to image_set, used for validation
    val_year : str
        similar to year, used for validation
    lr_refactor_epoch : int
        number of epoch to change learning rate
    lr_refactor_ratio : float
        new_lr = old_lr * lr_refactor_ratio
    iter_monitor : int
        if larger than 0, will print weights/gradients every iter_monitor iters
    log_file : str
        log to file if not None

    Returns:
    ---------
    None
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    if isinstance(data_shape, int):
        data_shape = (data_shape, data_shape)
    assert len(data_shape) == 2, "data_shape must be (h, w) tuple or list or int"
    prefix += '_' + str(data_shape[0])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    # load dataset
    if dataset == 'pascal':
        imdb = load_pascal(image_set, year, devkit_path, cfg.TRAIN.INIT_SHUFFLE)
        if val_set and val_year:
            val_imdb = load_pascal(val_set, val_year, devkit_path, False)
        else:
            val_imdb = None
    else:
        raise NotImplementedError, "Dataset " + dataset + " not supported"

    # init data iterator
    train_iter = DetIter(imdb, batch_size, data_shape, mean_pixels,
                         cfg.TRAIN.RAND_SAMPLERS, cfg.TRAIN.RAND_MIRROR,
                         cfg.TRAIN.EPOCH_SHUFFLE, cfg.TRAIN.RAND_SEED,
                         is_train=True)
    # save per N epoch, avoid saving too frequently
    resize_epoch = int(cfg.TRAIN.RESIZE_EPOCH)
    if resize_epoch > 1:
        batches_per_epoch = ((imdb.num_images - 1) / batch_size + 1) * resize_epoch
        train_iter = mx.io.ResizeIter(train_iter, batches_per_epoch)
    train_iter = mx.io.PrefetchingIter(train_iter)
    if val_imdb:
        val_iter = DetIter(val_imdb, batch_size, data_shape, mean_pixels,
                           cfg.VALID.RAND_SAMPLERS, cfg.VALID.RAND_MIRROR,
                           cfg.VALID.EPOCH_SHUFFLE, cfg.VALID.RAND_SEED,
                           is_train=True)
        val_iter = mx.io.PrefetchingIter(val_iter)
    else:
        val_iter = None

    # load symbol
    sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
    net = importlib.import_module("symbol_" + net).get_symbol_train(imdb.num_classes)

    # define layers with fixed weight/bias
    fixed_param_names = [name for name in net.list_arguments() \
        if name.startswith('conv1_') or name.startswith('conv2_')]

    # load pretrained or resume from previous state
    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}"
            .format(ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}"
            .format(ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # the prediction convolution layers name starts with relu, so it's fine
        fixed_param_names = [name for name in net.list_arguments() \
            if name.startswith('conv')]
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}"
            .format(ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        args = convert_pretrained(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}"
            .format(ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

    # init training module
    mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    iter_refactor = lr_refactor_epoch * imdb.num_images / train_iter.batch_size
    lr_scheduler = mx.lr_scheduler.FactorScheduler(iter_refactor, lr_refactor_ratio)
    optimizer_params={'learning_rate':learning_rate,
                      'momentum':momentum,
                      'wd':weight_decay,
                      'lr_scheduler':lr_scheduler,
                      'clip_gradient':None,
                      'rescale_grad': 1.0}
    monitor = mx.mon.Monitor(iter_monitor, pattern=".*") if iter_monitor > 0 else None

    mod.fit(train_iter,
            eval_data=val_iter,
            eval_metric=MultiBoxMetric(),
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=CustomInitializer(factor_type="in", magnitude=1),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor)
Exemplo n.º 4
0
def evaluate_net(net,
                 dataset,
                 devkit_path,
                 mean_pixels,
                 data_shape,
                 model_prefix,
                 epoch,
                 ctx,
                 year=None,
                 sets='test',
                 batch_size=1,
                 nms_thresh=0.5,
                 force_nms=False):
    """
    Evaluate entire dataset, basically simple wrapper for detections

    Parameters:
    ---------
    dataset : str
        name of dataset to evaluate
    devkit_path : str
        root directory of dataset
    mean_pixels : tuple of float
        (R, G, B) mean pixel values
    data_shape : int
        resize input data shape
    model_prefix : str
        load model prefix
    epoch : int
        load model epoch
    ctx : mx.ctx
        running context, mx.cpu() or mx.gpu(0)...
    year : str or None
        evaluate on which year's data
    sets : str
        evaluation set
    batch_size : int
        using batch_size for evaluation
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : bool
        force suppress different categories
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    if dataset == "pascal":
        if not year:
            year = '2007'
        imdb = PascalVoc(sets,
                         year,
                         devkit_path,
                         shuffle=False,
                         is_train=False)
        data_iter = DetIter(imdb,
                            batch_size,
                            data_shape,
                            mean_pixels,
                            rand_samplers=[],
                            rand_mirror=False,
                            is_train=False,
                            shuffle=False)
        sys.path.append(os.path.join(cfg.ROOT_DIR, 'symbol'))
        net = importlib.import_module("symbol_" + net) \
            .get_symbol(imdb.num_classes, nms_thresh, force_nms)
        model_prefix += "_" + str(data_shape)
        detector = Detector(net, model_prefix, epoch, data_shape, mean_pixels,
                            batch_size, ctx)
        logger.info("Start evaluation with {} images, be patient...".format(
            imdb.num_images))
        detections = detector.detect(data_iter)
        imdb.evaluate_detections(detections)
    else:
        raise NotImplementedError("No support for dataset: " + dataset)
Exemplo n.º 5
0
def train_net(net,
              train_path,
              num_classes,
              batch_size,
              data_shape,
              mean_pixels,
              resume,
              finetune,
              pretrained,
              epoch,
              prefix,
              ctx,
              begin_epoch,
              end_epoch,
              frequent,
              learning_rate,
              momentum,
              weight_decay,
              lr_refactor_step,
              lr_refactor_ratio,
              freeze_layer_pattern='',
              num_example=10000,
              label_pad_width=350,
              nms_thresh=0.45,
              force_nms=False,
              ovp_thresh=0.5,
              use_difficult=False,
              class_names=None,
              voc07_metric=False,
              nms_topk=400,
              force_suppress=False,
              train_list="",
              val_path="",
              val_list="",
              iter_monitor=0,
              monitor_pattern=".*",
              log_file=None):
    """
    Wrapper for training phase.

    Parameters:
    ----------
    net : str
        symbol name for the network structure
    train_path : str
        record file path for training
    num_classes : int
        number of object classes, not including background
    batch_size : int
        training batch-size
    data_shape : int or tuple
        width/height as integer or (3, height, width) tuple
    mean_pixels : tuple of floats
        mean pixel values for red, green and blue
    resume : int
        resume from previous checkpoint if > 0
    finetune : int
        fine-tune from previous checkpoint if > 0
    pretrained : str
        prefix of pretrained model, including path
    epoch : int
        load epoch of either resume/finetune/pretrained model
    prefix : str
        prefix for saving checkpoints
    ctx : [mx.cpu()] or [mx.gpu(x)]
        list of mxnet contexts
    begin_epoch : int
        starting epoch for training, should be 0 if not otherwise specified
    end_epoch : int
        end epoch of training
    frequent : int
        frequency to print out training status
    learning_rate : float
        training learning rate
    momentum : float
        trainig momentum
    weight_decay : float
        training weight decay param
    lr_refactor_ratio : float
        multiplier for reducing learning rate
    lr_refactor_step : comma separated integers
        at which epoch to rescale learning rate, e.g. '30, 60, 90'
    freeze_layer_pattern : str
        regex pattern for layers need to be fixed
    num_example : int
        number of training images
    label_pad_width : int
        force padding training and validation labels to sync their label widths
    nms_thresh : float
        non-maximum suppression threshold for validation
    force_nms : boolean
        suppress overlaped objects from different classes
    train_list : str
        list file path for training, this will replace the embeded labels in record
    val_path : str
        record file path for validation
    val_list : str
        list file path for validation, this will replace the embeded labels in record
    iter_monitor : int
        monitor internal stats in networks if > 0, specified by monitor_pattern
    monitor_pattern : str
        regex pattern for monitoring network stats
    log_file : str
        log to file if enabled
    """
    if net == 'resnet101_two_stream' or net == 'resnetsub101_test' or \
            net == 'resnetsub101_one_shared' or net == 'resnetsub101_two_shared' or \
            net == 'resnet50_two_stream_w_four_layers' \
            and resume == -1 and pretrained is not False:
        convert_model = True
    else:
        convert_model = False

    if net == 'resnet50_two_stream' \
            and resume == -1 \
            and pretrained is not False:
        convert_model_concat = True
    else:
        convert_model_concat = False

    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    if log_file:
        fh = logging.FileHandler(log_file)
        logger.addHandler(fh)

    # check args
    num_channel = 3
    if isinstance(data_shape, int):
        data_shape = (num_channel, data_shape, data_shape)
    if isinstance(data_shape, list):
        data_shape = (num_channel, data_shape[0], data_shape[1])
    #assert len(data_shape) == 3 and data_shape[0] == 3
    if prefix.endswith('_'):
        prefix += '_' + str(data_shape[1])

    if isinstance(mean_pixels, (int, float)):
        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
    assert len(mean_pixels) == 3, "must provide all RGB mean values"

    #train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
    #    label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)

    # load imdb
    curr_path = os.path.abspath(os.path.dirname(__file__))
    imdb_train = load_caltech(image_set='train',
                              caltech_path=os.path.join(
                                  curr_path, '..', 'data',
                                  'caltech-pedestrian-dataset-converter'),
                              shuffle=True)
    train_iter = DetIter(imdb_train, batch_size, (data_shape[1], data_shape[2]), \
                         mean_pixels=mean_pixels, rand_samplers=[], \
                         rand_mirror=False, shuffle=False, rand_seed=None, \
                         is_train=True, max_crop_trial=50)

    if val_path:
        #val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
        #    label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
        imdb_val = load_caltech(image_set='val',
                                caltech_path=os.path.join(
                                    curr_path, '..', 'data',
                                    'caltech-pedestrian-dataset-converter'),
                                shuffle=False)
        val_iter = DetIter(imdb_val, batch_size, (data_shape[1], data_shape[2]), \
                           mean_pixels=mean_pixels, rand_samplers=[], \
                           rand_mirror=False, shuffle=False, rand_seed=None, \
                           is_train=True, max_crop_trial=50)
    else:
        val_iter = None

    # load symbol
    #net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
    net = get_symbol_train_concat(net,
                                  data_shape[1],
                                  num_classes=num_classes,
                                  nms_thresh=nms_thresh,
                                  force_suppress=force_suppress,
                                  nms_topk=nms_topk)

    # define layers with fixed weight/bias
    if freeze_layer_pattern.strip():
        re_prog = re.compile(freeze_layer_pattern)
        fixed_param_names = [
            name for name in net.list_arguments() if re_prog.match(name)
        ]
    else:
        fixed_param_names = None

    # load pretrained or resume from previous state
    ctx_str = '(' + ','.join([str(c) for c in ctx]) + ')'
    if resume > 0:
        logger.info("Resume training with {} from epoch {}".format(
            ctx_str, resume))
        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
        begin_epoch = resume
    elif finetune > 0:
        logger.info("Start finetuning with {} from epoch {}".format(
            ctx_str, finetune))
        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
        begin_epoch = finetune
        # check what layers mismatch with the loaded parameters
        exe = net.simple_bind(mx.cpu(),
                              data=(1, 3, 300, 300),
                              label=(1, 1, 5),
                              grad_req='null')
        arg_dict = exe.arg_dict
        fixed_param_names = []
        for k, v in arg_dict.items():
            if k in args:
                if v.shape != args[k].shape:
                    del args[k]
                    logging.info("Removed %s" % k)
                else:
                    if not 'pred' in k:
                        fixed_param_names.append(k)
    elif pretrained:
        logger.info("Start training with {} from pretrained model {}".format(
            ctx_str, pretrained))
        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
        if convert_model:
            args = convert_pretrained(pretrained, args)
        if convert_model_concat:
            args, auxs = convert_pretrained_concat(pretrained, args)
    else:
        logger.info("Experimental: start training from scratch with {}".format(
            ctx_str))
        args = None
        auxs = None
        fixed_param_names = None

    # helper information
    if fixed_param_names:
        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) +
                    ']')

    # init training module
    #mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
    mod = mx.mod.Module(net,
                        label_names=('label', 'label2'),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=fixed_param_names)

    # fit parameters
    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size,
                                                 frequent=frequent)
    epoch_end_callback = mx.callback.do_checkpoint(prefix)
    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate,
                                                   lr_refactor_step,
                                                   lr_refactor_ratio,
                                                   num_example, batch_size,
                                                   begin_epoch)
    optimizer_params = {
        'learning_rate': learning_rate,
        'momentum': momentum,
        'wd': weight_decay,
        'lr_scheduler': lr_scheduler,
        'clip_gradient': None,
        'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0
    }
    monitor = mx.mon.Monitor(
        iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

    # run fit net, every n epochs we run evaluation network to get mAP
    if voc07_metric:
        #valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
        valid_metric = VOC07MApMetric(
            ovp_thresh,
            use_difficult,
            class_names,
            pred_idx=[0, 1],
            output_names=['det_out_output', 'det_out2_output'],
            label_names=['label', 'label2'])
    else:
        #valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
        valid_metric = MApMetric(
            ovp_thresh,
            use_difficult,
            class_names,
            pred_idx=[0, 1],
            output_names=['det_out_output', 'det_out2_output'],
            label_names=['label', 'label2'])

    # messager is activated in base_module
    mod.fit(train_iter,
            val_iter,
            eval_metric=MultiBoxMetric(),
            validation_metric=valid_metric,
            batch_end_callback=batch_end_callback,
            epoch_end_callback=epoch_end_callback,
            optimizer='sgd',
            optimizer_params=optimizer_params,
            begin_epoch=begin_epoch,
            num_epoch=end_epoch,
            initializer=mx.init.Xavier(),
            arg_params=args,
            aux_params=auxs,
            allow_missing=True,
            monitor=monitor)
Exemplo n.º 6
0
def evaluate_net(net,
                 path_imgrec,
                 num_classes,
                 mean_pixels,
                 data_shape,
                 model_prefix,
                 epoch,
                 ctx=mx.cpu(),
                 batch_size=1,
                 path_imglist="",
                 nms_thresh=0.45,
                 force_nms=False,
                 ovp_thresh=0.5,
                 use_difficult=False,
                 class_names=None,
                 voc07_metric=False,
                 use_second_network=False,
                 net1=None,
                 path_imgrec1=None,
                 epoch1=None,
                 model_prefix1=None,
                 data_shape1=None):
    """
    evalute network given validation record file

    Parameters:
    ----------
    net : str or None
        Network name or use None to load from json without modifying
    path_imgrec : str
        path to the record validation file
    path_imglist : str
        path to the list file to replace labels in record file, optional
    num_classes : int
        number of classes, not including background
    mean_pixels : tuple
        (mean_r, mean_g, mean_b)
    data_shape : tuple or int
        (3, height, width) or height/width
    model_prefix : str
        model prefix of saved checkpoint
    epoch : int
        load model epoch
    ctx : mx.ctx
        mx.gpu() or mx.cpu()
    batch_size : int
        validation batch size
    nms_thresh : float
        non-maximum suppression threshold
    force_nms : boolean
        whether suppress different class objects
    ovp_thresh : float
        AP overlap threshold for true/false postives
    use_difficult : boolean
        whether to use difficult objects in evaluation if applicable
    class_names : comma separated str
        class names in string, must correspond to num_classes if set
    voc07_metric : boolean
        whether to use 11-point evluation as in VOC07 competition
    """
    # set up logger
    logging.basicConfig()
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # args
    if isinstance(data_shape, int):
        data_shape = (3, data_shape, data_shape)
    elif isinstance(data_shape, list):
        data_shape = (3, data_shape[0], data_shape[1])
    assert len(data_shape) == 3 and data_shape[0] == 3
    # model_prefix += '_' + str(data_shape[1])

    # iterator
    #eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape,
    #                          path_imglist=path_imglist, **cfg.valid)
    curr_path = os.path.abspath(os.path.dirname(__file__))
    imdb_val = load_caltech(image_set='val',
                            caltech_path=os.path.join(
                                curr_path, '..', 'data',
                                'caltech-pedestrian-dataset-converter'),
                            shuffle=False)
    eval_iter = DetIter(imdb_val, batch_size, (data_shape[1], data_shape[2]), \
                       mean_pixels=[128, 128, 128], rand_samplers=[], \
                       rand_mirror=False, shuffle=False, rand_seed=None, \
                       is_train=True, max_crop_trial=50)
    # model params
    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
    # network
    if net is None:
        net = load_net
    else:
        #net = get_symbol(net, data_shape[1], num_classes=num_classes,
        net = get_symbol_concat(net,
                                data_shape[1],
                                num_classes=num_classes,
                                nms_thresh=nms_thresh,
                                force_suppress=force_nms)
    if not 'label' in net.list_arguments():
        label = mx.sym.Variable(name='label')
        label2 = mx.sym.Variable(name='label2')
        net = mx.sym.Group([net, label, label2])

    # init module
    #mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
    mod = mx.mod.Module(net,
                        label_names=('label', 'label2'),
                        logger=logger,
                        context=ctx,
                        fixed_param_names=net.list_arguments())
    mod.bind(data_shapes=eval_iter.provide_data,
             label_shapes=eval_iter.provide_label)
    mod.set_params(args, auxs, allow_missing=False, force_init=True)

    if voc07_metric:
        #metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=1)
        metric = VOC07MApMetric(
            ovp_thresh,
            use_difficult,
            class_names,
            pred_idx=[0, 1],
            output_names=['detection_output', 'detection2_output'],
            label_names=['label', 'label2'])
    else:
        #metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=1)
        metric = MApMetric(
            ovp_thresh,
            use_difficult,
            class_names,
            pred_idx=[0, 1],
            output_names=['detection_output', 'detection2_output'],
            label_names=['label', 'label2'])

    # run evaluation
    if not use_second_network:
        results = mod.score(eval_iter, metric, num_batch=None)
        for k, v in results:
            print("{}: {}".format(k, v))
    else:
        logging.basicConfig()
        logger1 = logging.getLogger()
        logger1.setLevel(logging.INFO)

        # load sub network
        if isinstance(data_shape1, int):
            data_shape1 = (3, data_shape1, data_shape1)
        elif isinstance(data_shape1, list):
            data_shape1 = (3, data_shape1[0], data_shape1[1])
        assert len(data_shape1) == 3 and data_shape1[0] == 3

        # iterator
        eval_iter1 = DetRecordIter(path_imgrec1,
                                   batch_size,
                                   data_shape1,
                                   path_imglist=path_imglist,
                                   **cfg.valid)
        # model params
        load_net1, args1, auxs1 = mx.model.load_checkpoint(
            model_prefix1, epoch1)
        # network
        if net1 is None:
            net1 = load_net1
        else:
            net1 = net
        if 'label' not in net1.list_arguments():
            label1 = mx.sym.Variable(name='label')
            net1 = mx.sym.Group([net1, label1])

        # init module
        mod1 = mx.mod.Module(net1,
                             label_names=('label', ),
                             logger=logger1,
                             context=ctx,
                             fixed_param_names=net1.list_arguments())
        mod1.bind(data_shapes=eval_iter1.provide_data,
                  label_shapes=eval_iter1.provide_label)
        mod1.set_params(args1, auxs1, allow_missing=False, force_init=True)

        if voc07_metric:
            metric1 = VOC07MApMetric(ovp_thresh, use_difficult, class_names)
        else:
            metric1 = MApMetric(ovp_thresh, use_difficult, class_names)

        # filepath = '/home/binghao/workspace/MXNet-SSD/matlab/kitti/outputs/ssd/'
        filepath1 = '/home/binghao/workspace/MXNet-SSD/matlab/kitti/outputs/ssd_small/'
        # mod.score_m(filepath, eval_iter, metric, num_batch=None)
        mod1.score_m(filepath1, eval_iter1, metric1, num_batch=None)