예제 #1
0
파일: demo.py 프로젝트: coderzbx/ademxapp
def _get_transformer_image():
    scale, mean_, std_ = _get_scalemeanstd()
    transformers = []
    if scale > 0:
        transformers.append(ts.ColorScale(np.single(scale)))
    transformers.append(ts.ColorNormalize(mean_, std_))
    return transformers
예제 #2
0
def _val_impl(args, model_specs, logger):
    assert args.prefetch_threads == 1
    assert args.weights is not None
    net_args, net_auxs = util.load_params_from_file(args.weights)
    mod = _get_module(model_specs)
    has_gt = args.split in (
        'train',
        'val',
    )
    scale_, mean_, std_ = _get_scalemeanstd()
    if args.test_scales is None:
        crop_sizes = [model_specs['crop_size']]
    else:
        crop_sizes = sorted([int(_)
                             for _ in args.test_scales.split(',')])[::-1]

    batch_images = args.batch_images

    if has_gt:
        gt_labels = np.array(
            parse_split_file(model_specs['split_filename'], args.data_root)[1])
    save_dir = os.path.join(args.output, os.path.splitext(args.log_file)[0])
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    preds = []
    for crop_size in crop_sizes:
        save_path = os.path.join(save_dir, 'preds_sz{}'.format(crop_size))
        if os.path.isfile(save_path):
            logger.info('File %s exists, skipped crop size %d', save_path,
                        crop_size)
            with open(save_path) as f:
                preds.append(cPickle.load(f))
            continue
        ts_list = [
            ts.Scale(crop_size),
            ts.ThreeCrops(crop_size)
            if args.test_3crops else ts.CenterCrop(crop_size),
        ]
        if scale_ > 0:
            ts_list.append(ts.ListInput(ts.ColorScale(np.single(scale_))))
        ts_list += [ts.ListInput(ts.ColorNormalize(mean_, std_))]
        transformer = ts.Compose(ts_list)
        dataiter = FileIter(
            split_filename=model_specs['split_filename'],
            data_root=args.data_root,
            has_gt=has_gt,
            batch_images=batch_images,
            transformer=transformer,
            prefetch_threads=args.prefetch_threads,
            prefetcher_type=args.prefetcher,
        )
        dataiter.reset()
        mod.bind(dataiter.provide_data,
                 dataiter.provide_label,
                 for_training=False,
                 force_rebind=True)
        if not mod.params_initialized:
            mod.init_params(arg_params=net_args, aux_params=net_auxs)
        this_call_preds = []
        start = time.time()
        counter = [0, 0]
        for nbatch, batch in enumerate(dataiter):
            mod.forward(batch, is_train=False)
            outputs = mod.get_outputs()[0].asnumpy()
            outputs = outputs.reshape(
                (batch_images, -1, model_specs['classes'])).mean(1)
            this_call_preds.append(outputs)
            if args.test_flipping:
                batch.data[0] = mx.nd.flip(batch.data[0], axis=3)
                mod.forward(batch, is_train=False)
                outputs = mod.get_outputs()[0].asnumpy()
                outputs = outputs.reshape(
                    (batch_images, -1, model_specs['classes'])).mean(1)
                this_call_preds[-1] = (this_call_preds[-1] + outputs) / 2
            score_str = ''
            if has_gt:
                counter[0] += batch_images
                counter[1] += (this_call_preds[-1].argmax(1) ==
                               gt_labels[nbatch * batch_images:(nbatch + 1) *
                                         batch_images]).sum()
                score_str = ', Top1 {:.4f}%'.format(100.0 * counter[1] /
                                                    counter[0])
            logger.info('Crop size {}, done {}/{} at speed: {:.2f}/s{}'.\
                format(crop_size, nbatch+1, dataiter.batches_per_epoch, 1.*(nbatch+1)*batch_images / (time.time()-start), score_str))
        logger.info('Done crop size {} in {:.4f}s.'.format(
            crop_size,
            time.time() - start))
        this_call_preds = np.vstack(this_call_preds)
        with open(save_path, 'wb') as f:
            cPickle.dump(this_call_preds, f)
        preds.append(this_call_preds)
    for num_sizes in set((
            1,
            len(crop_sizes),
    )):
        for this_pred_inds in itertools.combinations(xrange(len(crop_sizes)),
                                                     num_sizes):
            this_pred = np.mean([preds[_] for _ in this_pred_inds], axis=0)
            this_pred_label = this_pred.argsort(1)[:, -1 - np.arange(5)]
            logger.info('Done testing crop_size %s',
                        [crop_sizes[_] for _ in this_pred_inds])
            if has_gt:
                top1 = 100. * (this_pred_label[:, 0]
                               == gt_labels).sum() / gt_labels.size
                top5 = 100. * sum(
                    map(lambda x, y: y in x.tolist(), this_pred_label,
                        gt_labels)) / gt_labels.size
                logger.info('Top1 %.4f%%, Top5 %.4f%%', top1, top5)
            else:
                # TODO: Save predictions for submission
                raise NotImplementedError('Save predictions for submission')