def test_single_weights_file(weights_file, net, input_dataset):
    """Test already initalized net with a new set of weights"""
    net.copy_from(weights_file)
    net.layers[0].generate_datum_ids()

    input_num_of_objects = sum([len(image_info['object_infos']) for image_info in input_dataset.image_infos()])
    assert net.layers[0].curr_data_ids_idx == 0
    assert net.layers[0].number_of_datapoints() == input_num_of_objects
    assert net.layers[0].data_ids == range(input_num_of_objects)

    data_samples = net.layers[0].data_samples
    num_of_data_samples = len(data_samples)
    batch_size = net.layers[0].batch_size
    num_of_batches = int(np.ceil(num_of_data_samples / float(batch_size)))

    assert len(net.layers[0].image_loader) == input_dataset.num_of_images()

    # Create Result dataset
    result_dataset = ImageDataset(input_dataset.name())
    result_dataset.set_rootdir(input_dataset.rootdir())
    result_dataset.set_metainfo(input_dataset.metainfo().copy())

    # Add weight file and its md5 checksum to metainfo
    result_dataset.metainfo()['weights_file'] = weights_file
    result_dataset.metainfo()['weights_file_md5'] = md5(open(weights_file, 'rb').read()).hexdigest()

    # Set the image level fields
    for input_im_info in input_dataset.image_infos():
        result_im_info = OrderedDict()
        result_im_info['image_file'] = input_im_info['image_file']
        result_im_info['image_size'] = NoIndent(input_im_info['image_size'])
        result_im_info['image_intrinsic'] = NoIndent(input_im_info['image_intrinsic'])
        result_im_info['object_infos'] = []
        result_dataset.add_image_info(result_im_info)

    assert result_dataset.num_of_images() == input_dataset.num_of_images()

    assert_funcs = {
        "viewpoint": assert_viewpoint,
        "bbx_visible": assert_bbx,
        "bbx_amodal": assert_bbx,
        "center_proj": assert_coord2D,
    }

    performance_metric = {}

    print 'Evaluating for {} batches with {} imaes per batch.'.format(num_of_batches, batch_size)
    for b in tqdm.trange(num_of_batches):
        start_idx = batch_size * b
        end_idx = min(batch_size * (b + 1), num_of_data_samples)
        # print 'Working on batch: %d/%d (Image# %d - %d)' % (b, num_of_batches, start_idx, end_idx)
        output = net.forward()

        # store all accuracy outputs
        for key in [key for key in output if any(x in key for x in ["accuracy", "iou", "error"])]:
            assert np.squeeze(output[key]).shape == (), "Expects {} output to be scalar but got {}".format(key, output[key].shape)
            current_batch_accuracy = float(np.squeeze(output[key]))
            if key in performance_metric:
                performance_metric[key].append(current_batch_accuracy)
            else:
                performance_metric[key] = [current_batch_accuracy]

        for i in xrange(start_idx, end_idx):
            image_id = data_samples[i]['image_id']
            image_info = result_dataset.image_infos()[image_id]

            object_info = OrderedDict()

            # since we are not changing cetegory orid it is directly copied
            object_info['id'] = data_samples[i]['id']
            object_info['category'] = data_samples[i]['category']

            # since we are not predicting bbx_visible, it is directly copied
            object_info['bbx_visible'] = NoIndent(data_samples[i]['bbx_visible'].tolist())

            for info in ["bbx_amodal", "viewpoint", "center_proj"]:
                pred_info = "pred_" + info
                if pred_info in net.blobs:
                    prediction = np.squeeze(net.blobs[pred_info].data[i - start_idx, ...])
                    assert_funcs[info](prediction)
                    object_info[info] = NoIndent(prediction.tolist())

            image_info['object_infos'].append(object_info)

    for key in sorted(performance_metric):
        performance_metric[key] = np.mean(performance_metric[key])
        print 'Test set {}: {:.4f}'.format(key, performance_metric[key])

    regex = re.compile('iter_([0-9]*).caffemodel')
    performance_metric['iter'] = int(regex.findall(weights_file)[0])

    result_num_of_objects = sum([len(image_info['object_infos']) for image_info in result_dataset.image_infos()])
    assert result_num_of_objects == num_of_data_samples
    return result_dataset, performance_metric
def run_inference(weights_file, net, input_dataset):
    """Run inference with already initalized net with a new set of weights"""
    net.copy_from(weights_file)
    net.layers[0].generate_datum_ids()

    num_of_images = input_dataset.num_of_images()
    assert net.layers[0].curr_data_ids_idx == 0
    assert net.layers[0].number_of_datapoints() == num_of_images
    assert net.layers[0].data_ids == range(num_of_images)

    assert len(net.layers[0].image_loader) == num_of_images
    assert len(net.layers[0].data_samples) == num_of_images
    assert net.layers[
        0].rois_per_image < 0, "rois_per_image need to be dynamic for testing"
    assert net.layers[
        0].imgs_per_batch == 1, "We only support one image per batch while testing"
    assert net.layers[0].flip_ratio < 0, "No flipping while testing"
    assert net.layers[0].jitter_iou_min > 1, "No jittering"

    # Create Result dataset
    result_dataset = ImageDataset(input_dataset.name())
    result_dataset.set_rootdir(input_dataset.rootdir())
    result_dataset.set_metainfo(input_dataset.metainfo().copy())

    # Add weight file and its md5 checksum to metainfo
    result_dataset.metainfo()['weights_file'] = weights_file
    result_dataset.metainfo()['weights_file_md5'] = md5(
        open(weights_file, 'rb').read()).hexdigest()

    # Set the image level fields
    for input_im_info in input_dataset.image_infos():
        result_im_info = OrderedDict()
        result_im_info['image_file'] = input_im_info['image_file']
        result_im_info['image_size'] = input_im_info['image_size']
        if 'image_intrinsic' in input_im_info:
            result_im_info['image_intrinsic'] = input_im_info[
                'image_intrinsic']
        obj_infos = []
        for input_obj_info in input_im_info['object_infos']:
            obj_info = OrderedDict()
            for field in ['id', 'category', 'score', 'bbx_visible']:
                if field in input_obj_info:
                    obj_info[field] = input_obj_info[field]
            obj_infos.append(obj_info)
        result_im_info['object_infos'] = obj_infos
        assert len(result_im_info['object_infos']) == len(
            input_im_info['object_infos'])
        result_dataset.add_image_info(result_im_info)

    assert result_dataset.num_of_images() == num_of_images
    assert len(net.layers[0].data_samples) == num_of_images
    for result_img_info, layer_img_info in zip(result_dataset.image_infos(),
                                               net.layers[0].data_samples):
        assert len(result_img_info['object_infos']) == len(
            layer_img_info['object_infos'])

    assert_funcs = {
        "viewpoint": assert_viewpoint,
        "bbx_visible": assert_bbx,
        "bbx_amodal": assert_bbx,
        "center_proj": assert_coord2D,
    }

    print 'Running inference for {} images.'.format(num_of_images)
    for image_id in tqdm.trange(num_of_images):
        # Run forward pass
        _ = net.forward()

        img_info = result_dataset.image_infos()[image_id]
        expected_num_of_rois = len(img_info['object_infos'])
        assert net.blobs['rois'].data.shape == (
            expected_num_of_rois,
            5), "{}_{}".format(net.blobs['rois'].data.shape,
                               expected_num_of_rois)

        for info in ["bbx_amodal", "viewpoint", "center_proj"]:
            pred_info = "pred_" + info
            if pred_info in net.blobs:
                assert net.blobs[pred_info].data.shape[
                    0] == expected_num_of_rois

        for i, obj_info in enumerate(img_info['object_infos']):
            for info in ["bbx_amodal", "viewpoint", "center_proj"]:
                pred_info = "pred_" + info
                if pred_info in net.blobs:
                    prediction = np.squeeze(net.blobs[pred_info].data[i, ...])
                    assert_funcs[info](prediction)
                    obj_info[info] = prediction.tolist()

    return result_dataset