예제 #1
0
def get_dataset(prefetch=False):
    image_path = os.path.join(dataset_path, "BSDS300/images")

    if not os.path.exists(image_path):
        os.makedirs(dataset_path)
        file_name = download(dataset_url)
        with tarfile.open(file_name) as tar:
            for item in tar:
                tar.extract(item, dataset_path)
        os.remove(file_name)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [
        CenterCropAug((crop_size, crop_size)),
        ResizeAug(input_crop_size)
    ]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (ImagePairIter(os.path.join(image_path, "train"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size), batch_size, color_flag,
                           input_transform, target_transform),
             ImagePairIter(os.path.join(image_path, "test"),
                           (input_crop_size, input_crop_size),
                           (crop_size, crop_size), test_batch_size, color_flag,
                           input_transform, target_transform))

    return [PrefetchingIter(i) for i in iters] if prefetch else iters
예제 #2
0
def get_dataset(prefetch=False):
    """Download the BSDS500 dataset and return train and test iters."""

    if path.exists(data_dir):
        print(
            "Directory {} already exists, skipping.\n"
            "To force download and extraction, delete the directory and re-run."
            "".format(data_dir),
            file=sys.stderr,
        )
    else:
        print("Downloading dataset...", file=sys.stderr)
        downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
        print("done", file=sys.stderr)

        print("Extracting files...", end="", file=sys.stderr)
        os.makedirs(data_dir)
        os.makedirs(tmp_dir)
        with zipfile.ZipFile(downloaded_file) as archive:
            archive.extractall(tmp_dir)
        shutil.rmtree(datasets_tmpdir)

        shutil.copytree(
            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
            path.join(data_dir, "images"),
        )
        shutil.copytree(
            path.join(tmp_dir, "BSDS500-master", "BSDS500", "data",
                      "groundTruth"),
            path.join(data_dir, "groundTruth"),
        )
        shutil.rmtree(tmp_dir)
        print("done", file=sys.stderr)

    crop_size = 256
    crop_size -= crop_size % upscale_factor
    input_crop_size = crop_size // upscale_factor

    input_transform = [
        CenterCropAug((crop_size, crop_size)),
        ResizeAug(input_crop_size)
    ]
    target_transform = [CenterCropAug((crop_size, crop_size))]

    iters = (
        ImagePairIter(
            path.join(data_dir, "images", "train"),
            (input_crop_size, input_crop_size),
            (crop_size, crop_size),
            batch_size,
            color_flag,
            input_transform,
            target_transform,
        ),
        ImagePairIter(
            path.join(data_dir, "images", "test"),
            (input_crop_size, input_crop_size),
            (crop_size, crop_size),
            test_batch_size,
            color_flag,
            input_transform,
            target_transform,
        ),
    )

    return [PrefetchingIter(i) for i in iters] if prefetch else iters
예제 #3
0
def pred_eval(predictor,
              test_data,
              imdb,
              cfg,
              vis=False,
              thresh=1e-3,
              logger=None,
              ignore_cache=True):
    """
    wrapper for calculating offline validation for faster data analysis
    in this example, all threshold are set by hand
    :param predictor: Predictor
    :param test_data: data iterator, must be non-shuffle
    :param imdb: image database
    :param vis: controls visualization
    :param thresh: valid detection threshold
    :return:
    """

    det_file = os.path.join(imdb.result_path, imdb.name + '_detections.pkl')
    if os.path.exists(det_file) and not ignore_cache:
        with open(det_file, 'rb') as fid:
            cache_res = cPickle.load(fid)
            all_boxes = cache_res['all_boxes']
            all_keypoints = cache_res.get('all_keypoints')
        info_str = imdb.evaluate_detections(all_boxes,
                                            all_keypoints=all_keypoints)
        if logger:
            logger.info('evaluate detections: \n{}'.format(info_str))
        return

    assert vis or not test_data.shuffle
    data_names = [k[0] for k in test_data.provide_data]

    if not isinstance(test_data, PrefetchingIter):
        test_data = PrefetchingIter(test_data)

    nms = py_nms_wrapper(cfg.TEST.NMS)

    # limit detections to max_per_image over all classes
    max_per_image = cfg.TEST.max_per_image

    num_images = imdb.num_images
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[np.array([]) for _ in range(num_images)]
                 for _ in range(imdb.num_classes)]
    all_keypoints = None
    if cfg.network.PREDICT_KEYPOINTS:
        all_keypoints = [[np.array([]) for _ in range(num_images)]
                         for _ in range(imdb.num_classes)]

    idx = 0
    data_time, net_time, post_time = 0.0, 0.0, 0.0
    t = time.time()
    for data_batch in test_data:
        t1 = time.time() - t
        t = time.time()

        rets = im_detect(predictor, data_batch, data_names, cfg)
        scores_all = rets[0]
        boxes_all = rets[1]
        if cfg.network.PREDICT_KEYPOINTS:
            pred_kps_all = rets[2]

        t2 = time.time() - t
        t = time.time()
        for delta, (scores, boxes) in enumerate(zip(scores_all, boxes_all)):
            if idx + delta >= num_images:
                break
            for j in range(1, imdb.num_classes):
                indexes = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[indexes, j, np.newaxis]
                cls_boxes = boxes[indexes,
                                  4:8] if cfg.CLASS_AGNOSTIC else boxes[
                                      indexes, j * 4:(j + 1) * 4]
                cls_dets = np.hstack((cls_boxes, cls_scores))
                keep = nms(cls_dets)
                all_boxes[j][idx + delta] = cls_dets[keep, :]
                if cfg.network.PREDICT_KEYPOINTS:
                    all_keypoints[j][idx + delta] = pred_kps_all[delta][
                        indexes, :][keep, :]

            if max_per_image > 0:
                image_scores = np.hstack([
                    all_boxes[j][idx + delta][:, -1]
                    for j in range(1, imdb.num_classes)
                ])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in range(1, imdb.num_classes):
                        keep = np.where(
                            all_boxes[j][idx + delta][:,
                                                      -1] >= image_thresh)[0]
                        all_boxes[j][idx +
                                     delta] = all_boxes[j][idx +
                                                           delta][keep, :]
                        if cfg.network.PREDICT_KEYPOINTS:
                            all_keypoints[j][idx + delta] = all_keypoints[j][
                                idx + delta][keep, :]

            if vis:
                boxes_this_image = [[]] + [
                    all_boxes[j][idx + delta]
                    for j in range(1, imdb.num_classes)
                ]
                vis_all_detection(data_dict['data'].asnumpy(),
                                  boxes_this_image, imdb.classes,
                                  scales[delta], cfg)

        idx += test_data.batch_size
        t3 = time.time() - t
        t = time.time()
        msg = 'testing {}/{} data {:.4f}s net {:.4f}s post {:.4f}s'.format(
            idx, imdb.num_images, t1, t2, t3)
        print msg
        if logger:
            logger.info(msg)

    with open(det_file, 'wb') as f:
        cPickle.dump({
            'all_boxes': all_boxes,
            'all_keypoints': all_keypoints
        },
                     f,
                     protocol=cPickle.HIGHEST_PROTOCOL)

    info_str = imdb.evaluate_detections(all_boxes, all_keypoints=all_keypoints)
    if logger:
        logger.info('evaluate detections: \n{}'.format(info_str))
예제 #4
0
def train_net(args, ctx, pretrained, epoch, prefix, begin_epoch, end_epoch, lr, lr_step):
    if config.dataset.dataset != 'JSONList':
        logger, final_output_path = create_logger(config.output_path, args.cfg, config.dataset.image_set)
        prefix = os.path.join(final_output_path, prefix)
        shutil.copy2(args.cfg, prefix+'.yaml')
    else:
        import datetime
        import logging
        final_output_path = config.output_path
        prefix = prefix + '_' + datetime.datetime.now().strftime("%Y-%m-%d_%H_%M_%S")
        prefix = os.path.join(final_output_path, prefix)
        shutil.copy2(args.cfg, prefix+'.yaml')
        log_file = prefix + '.log'
        head = '%(asctime)-15s %(message)s'
        logging.basicConfig(filename=log_file, format=head)
        logger = logging.getLogger()
        logger.setLevel(logging.INFO)
        logger.info('prefix: %s' % prefix)
        print('prefix: %s' % prefix)

    # load symbol
    #shutil.copy2(os.path.join(curr_path, '..', 'symbols', config.symbol + '.py'), final_output_path)
    sym_instance = eval(config.symbol + '.' + config.symbol)()
    sym = sym_instance.get_symbol(config, is_train=True)

    # setup multi-gpu
    batch_size = config.TRAIN.IMAGES_PER_GPU * len(ctx)

    # print config
    pprint.pprint(config)
    logger.info('training config:{}\n'.format(pprint.pformat(config)))

    # load dataset and prepare imdb for training
    image_sets = [iset for iset in config.dataset.image_set.split('+')]
    roidbs = [load_gt_roidb(config.dataset.dataset, image_set, config.dataset.root_path, config.dataset.dataset_path,
                            flip=config.TRAIN.FLIP)
              for image_set in image_sets]
    roidb = merge_roidb(roidbs)
    roidb = filter_roidb(roidb, config)
    # load training data
    if config.network.MULTI_RPN:
        assert Fasle, 'still developing' ###
        num_layers = len(config.network.MULTI_RPN_STRIDES)
        rpn_syms = [sym.get_internals()['rpn%d_cls_score_output'% l] for l in range(num_layers)]
        train_data = PyramidAnchorLoader(rpn_syms, roidb, config, batch_size=batch_size, shuffle=config.TRAIN.SHUFFLE,
                                         ctx=ctx, feat_strides=config.network.MULTI_RPN_STRIDES, anchor_scales=config.network.ANCHOR_SCALES,
                                         anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING,
                                         allowed_border=np.inf)
    else:
        feat_sym = sym.get_internals()['rpn_cls_score_output']
        train_data = AnchorLoader(feat_sym, roidb, config, batch_size=batch_size, shuffle=config.TRAIN.SHUFFLE, ctx=ctx,
                                  feat_stride=config.network.RPN_FEAT_STRIDE, anchor_scales=config.network.ANCHOR_SCALES,
                                  anchor_ratios=config.network.ANCHOR_RATIOS, aspect_grouping=config.TRAIN.ASPECT_GROUPING)

    # infer max shape
    data_shape_dict = dict(train_data.provide_data + train_data.provide_label)
    pprint.pprint(data_shape_dict)
    sym_instance.infer_shape(data_shape_dict)

    # load and initialize params
    if config.TRAIN.RESUME:
        print('continue training from ', begin_epoch)
        arg_params, aux_params = load_param(prefix, begin_epoch, convert=True)
    else:
        arg_params, aux_params = load_param(pretrained, epoch, convert=True)
        sym_instance.init_weight(config, arg_params, aux_params)

    # check parameter shapes
    sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict)

    # create solver
    fixed_param_prefix = config.network.FIXED_PARAMS
    mod = MutableModule(sym,
                        train_data.data_names,
                        train_data.label_names,
                        context=ctx,
                        logger=logger,
                        fixed_param_prefix=fixed_param_prefix)

    if config.TRAIN.RESUME:
        mod._preload_opt_states = '%s-%04d.states'%(prefix, begin_epoch)

    # decide training params
    # metric
    rpn_eval_metric = metric.RPNAccMetric()
    rpn_cls_metric = metric.RPNLogLossMetric()
    rpn_bbox_metric = metric.RPNL1LossMetric()
    eval_metric = metric.RCNNAccMetric(config)
    cls_metric = metric.RCNNLogLossMetric(config)
    bbox_metric = metric.RCNNL1LossMetric(config)
    eval_metrics = mx.metric.CompositeEvalMetric()
    # rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric
    for child_metric in [rpn_eval_metric, rpn_cls_metric, rpn_bbox_metric, eval_metric, cls_metric, bbox_metric]:
        eval_metrics.add(child_metric)
    if config.network.PREDICT_KEYPOINTS:
        kps_cls_acc = metric.KeypointAccMetric(config)
        kps_cls_loss = metric.KeypointLogLossMetric(config)
        kps_pos_loss = metric.KeypointL1LossMetric(config)
        eval_metrics.add(kps_cls_acc)
        eval_metrics.add(kps_cls_loss)
        eval_metrics.add(kps_pos_loss)

    # callback
    batch_end_callback = callback.Speedometer(batch_size, frequent=args.frequent)
    means = np.tile(np.array(config.TRAIN.BBOX_MEANS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    stds = np.tile(np.array(config.TRAIN.BBOX_STDS), 2 if config.CLASS_AGNOSTIC else config.dataset.NUM_CLASSES)
    epoch_end_callback = [mx.callback.do_checkpoint(prefix)]
    # decide learning rate
    base_lr = lr
    lr_factor = config.TRAIN.lr_factor
    lr_epoch = [float(epoch) for epoch in lr_step.split(',')]
    lr_epoch_diff = [epoch - begin_epoch for epoch in lr_epoch if epoch > begin_epoch]
    lr = base_lr * (lr_factor ** (len(lr_epoch) - len(lr_epoch_diff)))
    lr_iters = [int(epoch * len(roidb) / batch_size) for epoch in lr_epoch_diff]
    print('lr', lr, 'lr_epoch_diff', lr_epoch_diff, 'lr_iters', lr_iters)
    lr_scheduler = WarmupMultiFactorScheduler(lr_iters, lr_factor, config.TRAIN.warmup, config.TRAIN.warmup_lr, config.TRAIN.warmup_step)
    # optimizer
    optimizer_params = {'momentum': config.TRAIN.momentum,
                        'wd': config.TRAIN.wd,
                        'learning_rate': lr,
                        'lr_scheduler': lr_scheduler,
                        'rescale_grad': 1.0,
                        'clip_gradient': None}

    if not isinstance(train_data, PrefetchingIter):
        train_data = PrefetchingIter(train_data)

    # train
    mod.fit(train_data, eval_metric=eval_metrics, epoch_end_callback=epoch_end_callback,
            batch_end_callback=batch_end_callback, kvstore=config.TRAIN.kvstore,
            optimizer='sgd', optimizer_params=optimizer_params,
            arg_params=arg_params, aux_params=aux_params, begin_epoch=begin_epoch, num_epoch=end_epoch)
    time.sleep(10)
    train_data.iters[0].terminate()