def test_dataset(self):
     train_transform = build_transforms(cfg)
     val_transform = build_transforms(cfg, False)
     train_set = build_dataset(train_transform)
     val_test = build_dataset(val_transform, False)
     from IPython import embed;
     embed()
示例#2
0
 def test_dataset(self):
     train_transform = build_transforms(cfg, True)
     val_transform = build_transforms(cfg, False)
     train_set = build_dataset(cfg, train_transform, True)
     val_test = build_dataset(cfg, val_transform, False)
     from IPython import embed
     embed()
示例#3
0
def _get_feat_data_loader(cfg, source_name, feat):
    dataset = init_dataset(source_name,
                           root=cfg.DATASET.ROOT_DIR,
                           verbose=False)
    generate_train = []
    for i in range(feat.size(0)):
        img_path, _, _ = dataset.train[i]
        generate_train.append((img_path, feat[i], -1))
    dataset.train = generate_train
    dataset.print_dataset_statistics(dataset.train, dataset.query,
                                     dataset.gallery)
    batch_size = cfg.TRAIN.BATCH_SIZE

    train_transforms = build_transforms(cfg, is_train=False)
    train_set = ImageDataset(dataset.train, train_transforms)

    def train_collate_fn_by_feat(batch):
        imgs, feats, _, _, = zip(*batch)
        imgs = torch.stack(imgs, dim=0)
        feats = torch.stack(feats, dim=0)
        return imgs, feats

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=cfg.DATALOADER.NUM_WORKERS,
                              collate_fn=train_collate_fn_by_feat)
    return train_loader, dataset.num_train_pids
示例#4
0
def main():

    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--config-file',
                        type=str,
                        default='',
                        help='path to config file')
    parser.add_argument('--output-name', type=str, default='model')
    parser.add_argument('--verbose',
                        default=False,
                        action='store_true',
                        help='Verbose mode for onnx.export')
    args = parser.parse_args()

    cfg = get_default_config()
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.freeze()

    model = build_model(
        name=cfg.model.name,
        num_classes=1041,  # Does not matter in conversion
        loss=cfg.loss.name,
        pretrained=False,
        use_gpu=True,
        feature_dim=cfg.model.feature_dim,
        fpn=cfg.model.fpn,
        fpn_dim=cfg.model.fpn_dim,
        gap_as_conv=cfg.model.gap_as_conv,
        input_size=(cfg.data.height, cfg.data.width),
        IN_first=cfg.model.IN_first)

    load_pretrained_weights(model, cfg.model.load_weights)
    model.eval()

    _, transform = build_transforms(cfg.data.height,
                                    cfg.data.width,
                                    transforms=cfg.data.transforms,
                                    norm_mean=cfg.data.norm_mean,
                                    norm_std=cfg.data.norm_std,
                                    apply_masks_to_test=False)

    input_size = (cfg.data.height, cfg.data.width, 3)
    img = np.random.rand(*input_size).astype(np.float32)
    img = np.uint8(img * 255)
    im = Image.fromarray(img)
    blob = transform(im).unsqueeze(0)

    torch.onnx.export(
        model,
        blob,
        args.output_name + '.onnx',
        verbose=False,
        export_params=True,
        input_names=['data'],
        output_names=['reid_embedding'],
        opset_version=9)  # 9th version resolves nearest upsample issue
示例#5
0
 def __init__(self, cfg, use_cuda=True, device=None):
     self.cfg = cfg
     self.net = init_extractor(cfg)
     if device == None:
         self.device = "cuda" if use_cuda else "cpu"
     else:
         self.device = device
     self.net.to(self.device)
     self.norm = build_transforms(cfg, is_train=False)
示例#6
0
def _get_train_loader(cfg,
                      batch_size,
                      train_set,
                      sampler,
                      shuffle,
                      is_train=True):
    train_transforms = build_transforms(cfg, is_train=is_train)
    train_set = ImageDataset(train_set, train_transforms)
    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              sampler=sampler,
                              shuffle=shuffle,
                              num_workers=cfg.DATALOADER.NUM_WORKERS,
                              collate_fn=train_collate_fn)
    return train_loader
def get_train_dataloader(cfg):
    print('prepare training set ...')
    tng_tfms = build_transforms(cfg, is_train=False)
    num_workers = cfg.DATALOADER.NUM_WORKERS

    train_img_items = list()
    for d in cfg.DATASETS.NAMES:
        dataset = init_dataset(d)
        train_img_items.extend(dataset.train)

    tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)

    tng_dataloader = DataLoader(tng_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True)

    return tng_dataloader, tng_set.c
示例#8
0
def make_multi_valid_data_loader(cfg, data_set_names, verbose=False):
    valid = OrderedDict()
    for name in data_set_names:
        dataset = init_dataset(name,
                               root=cfg.DATASET.ROOT_DIR,
                               verbose=verbose)
        val_transforms = build_transforms(cfg, is_train=False)
        val_set = ImageDataset(dataset.query + dataset.gallery, val_transforms)
        val_loader = DataLoader(val_set,
                                batch_size=cfg.TEST.BATCH_SIZE,
                                shuffle=False,
                                num_workers=cfg.DATALOADER.NUM_WORKERS,
                                collate_fn=val_collate_fn)
        valid[name] = (val_loader, len(dataset.query))

    return valid
def get_test_dataloader(cfg):
    print('prepare test set ...')
    val_tfms = build_transforms(cfg, is_train=False)
    num_workers = cfg.DATALOADER.NUM_WORKERS
    
    test_dataloader_collection, query_names_len_collection, test_names_collection = list(), list(), list()
    for d in cfg.DATASETS.TEST_NAMES:
        dataset = init_dataset(d)
        query_names, gallery_names = dataset.query, dataset.gallery

        test_set = ImageDataset(query_names+gallery_names, val_tfms, relabel=False)
        
        test_dataloader = DataLoader(test_set, cfg.TEST.IMS_PER_BATCH, num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True)
        test_dataloader_collection.append(test_dataloader)
        query_names_len_collection.append(len(query_names))
        test_names_collection.append(query_names+gallery_names)
    
    return test_dataloader_collection, query_names_len_collection, test_names_collection
示例#10
0
def make_data_loader_for_val_data(cfg):
    val_transforms = build_transforms(cfg, is_train=False)
    num_workers = cfg.DATALOADER.NUM_WORKERS
    if len(cfg.DATASETS.NAMES) == 1:
        dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)
    else:
        # TODO: add multi dataset to train
        dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)

    num_classes = dataset.num_train_pids

    val_set = ImageDataset(dataset= dataset.query + dataset.gallery,
                           rap_data_=dataset.rap_data,
                           transform=val_transforms,
                           is_train=False,
                           swap_roi_rou=False)
    val_loader = DataLoader(val_set,
                            batch_size=cfg.TEST.IMS_PER_BATCH,
                            shuffle=False,
                            num_workers=num_workers,
                            collate_fn=val_collate_fn)
    return val_loader, len(dataset.query), num_classes
示例#11
0
def _get_target_data_loader(cfg, target_name):
    dataset = init_dataset(target_name,
                           root=cfg.DATASET.ROOT_DIR,
                           verbose=False)
    batch_size, sampler, shuffle = _get_train_sampler(cfg, dataset.train)

    train_transforms = build_transforms(cfg, is_train=True)
    train_set = ImageDataset(dataset.train, train_transforms)

    def train_collate_fn_add_feat(batch):
        imgs, pids, _, _, = zip(*batch)
        pids = torch.tensor(pids, dtype=torch.int64)
        imgs = torch.stack(imgs, dim=0)
        return imgs, pids

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              sampler=sampler,
                              shuffle=shuffle,
                              num_workers=cfg.DATALOADER.NUM_WORKERS,
                              collate_fn=train_collate_fn_add_feat)
    return train_loader, dataset.num_train_pids
示例#12
0
def make_data_loader(cfg, is_train=True, max_iter=None, start_iter=0):
    train_transform = build_transforms(cfg, is_train=is_train)
    target_transform = build_target_transform(cfg) if is_train else None
    dataset_list = cfg.DATASETS.TRAIN if is_train else cfg.DATASETS.TEST
    datasets = build_dataset(dataset_list,
                             transform=train_transform,
                             target_transform=target_transform,
                             is_train=is_train)

    shuffle = is_train

    data_loaders = []

    for dataset in datasets:
        if shuffle:
            sampler = torch.utils.data.RandomSampler(dataset)
        else:
            sampler = torch.utils.data.sampler.SequentialSampler(dataset)

        batch_size = cfg.SOLVER.BATCH_SIZE if is_train else cfg.TEST.BATCH_SIZE
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            sampler=sampler, batch_size=batch_size, drop_last=False)
        if max_iter is not None:
            batch_sampler = samplers.IterationBasedBatchSampler(
                batch_sampler, num_iterations=max_iter, start_iter=start_iter)

        data_loader = DataLoader(dataset,
                                 num_workers=cfg.DATA_LOADER.NUM_WORKERS,
                                 batch_sampler=batch_sampler,
                                 pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
                                 collate_fn=BatchCollator(is_train))
        data_loaders.append(data_loader)

    if is_train:
        # during training, a single (possibly concatenated) data_loader is returned
        assert len(data_loaders) == 1
        return data_loaders[0]
    return data_loaders
示例#13
0
def run_demo(cfg, ckpt, score_threshold, images_dir, output_dir, dataset_type):
    if dataset_type == "voc":
        class_names = VOCDataset.class_names
    elif dataset_type == 'coco':
        class_names = COCODataset.class_names
    else:
        raise NotImplementedError('Not implemented now.')

    if torch.cuda.is_available():
        device = torch.device(cfg.MODEL.DEVICE)
    else:
        device = torch.device("cpu")

    model = SSDDetector(cfg)
    model = model.to(device)
    checkpointer = CheckPointer(model, save_dir=cfg.OUTPUT_DIR)
    checkpointer.load(ckpt, use_latest=ckpt is None)
    weight_file = ckpt if ckpt else checkpointer.get_checkpoint_file()
    print('Loaded weights from {}'.format(weight_file))

    image_paths = glob.glob(os.path.join(images_dir, '*.jpg'))
    mkdir(output_dir)

    cpu_device = torch.device("cpu")
    transforms = build_transforms(cfg, is_train=False)
    model.eval()
    for i, image_path in enumerate(image_paths):
        start = time.time()
        image_name = os.path.basename(image_path)
        image = np.array(Image.open(image_path).convert("RGB"))
        height, width = image.shape[:2]
        images = transforms(image)[0].unsqueeze(0)
        load_time = time.time() - start

        start = time.time()
        result = model(images.to(device))[0]
        inference_time = time.time() - start

        result = result.resize((width, height)).to(cpu_device).numpy()
        boxes, labels, scores = result['boxes'], result['labels'], result['scores']

        # filter predictions that do not overcome the score_threshold
        indices = scores > score_threshold
        boxes = boxes[indices]  # (xmin, ymin, xmax, ymax)
        labels = labels[indices]
        scores = scores[indices]
        centers = np.apply_along_axis(get_mid_point, 1, boxes)

        start = time.time()
        dbscan_center = dbscan.DBSCAN(eps=37)
        dbscan_center.fit(centers)
        print("dbscan clusters", dbscan_center._labels)
        print(f"DBSCAN clustering time {round((time.time()-start)*1000, 3)}ms")
        image = draw_points(image, centers)  # draw center points on image

        start = time.time()
        def reset_range(old_max, old_min, new_max, new_min, arr):
            old_range = old_max - old_min
            if old_range == 0:
                new_val = arr
                new_val[:] = new_min
            else:
                new_range = new_max - new_min
                new_val = (((arr - old_min) * new_range) / old_range) + new_min
            return new_val

        # POINT DATASET
        x = centers[:, 0]
        y = centers[:, 1]

        # x = reset_range(max(x), min(x), 100, 0, x)
        # y = reset_range(max(y), min(y), 100, 0, y)

        # DEFINE GRID SIZE AND RADIUS(h)
        grid_size = 1
        h = 30

        # GETTING X,Y MIN AND MAX
        x_min = min(x)
        x_max = max(x)
        y_min = min(y)
        y_max = max(y)

        # CONSTRUCT GRID
        x_grid = np.arange(x_min - h, x_max + h, grid_size)
        y_grid = np.arange(y_min - h, y_max + h, grid_size)
        x_mesh, y_mesh = np.meshgrid(x_grid, y_grid)

        # GRID CENTER POINT
        xc = x_mesh + (grid_size / 2)
        yc = y_mesh + (grid_size / 2)

        # FUNCTION TO CALCULATE INTENSITY WITH QUARTIC KERNEL
        def kde_quartic(d, h):
            dn = d / h
            P = (15 / 16) * (1 - dn ** 2) ** 2
            return P

        # PROCESSING
        intensity_list = []
        for j in range(len(xc)):
            intensity_row = []
            for k in range(len(xc[0])):
                kde_value_list = []
                for i in range(len(x)):
                    # CALCULATE DISTANCE
                    d = math.sqrt((xc[j][k] - x[i]) ** 2 + (yc[j][k] - y[i]) ** 2)
                    if d <= h:
                        p = kde_quartic(d, h)
                    else:
                        p = 0
                    kde_value_list.append(p)
                # SUM ALL INTENSITY VALUE
                p_total = sum(kde_value_list)
                intensity_row.append(p_total)
            intensity_list.append(intensity_row)

        # HEATMAP OUTPUT
        intensity = np.array(intensity_list)
        plt.pcolormesh(x_mesh, y_mesh, intensity)
        plt.plot(x, y, 'ro')  # plot center points
        plt.xticks([])
        plt.yticks([])
        # plt.colorbar()
        plt.gca().invert_yaxis()
        plt.savefig(f'demo/result/heatmap_{i}')

        plt.clf()
        print("Heatmap generation time", time.time() - start)

        meters = ' | '.join(
            [
                'objects {:02d}'.format(len(boxes)),
                'load {:03d}ms'.format(round(load_time * 1000)),
                'inference {:03d}ms'.format(round(inference_time * 1000)),
                'FPS {}'.format(round(1.0 / inference_time))
            ]
        )
        print('({:04d}/{:04d}) {}: {}'.format(i + 1, len(image_paths), image_name, meters))

        # Draw the bounding boxes, labels, and scores on the images
        drawn_image = draw_boxes(image, boxes, labels, scores, class_names).astype(
            np.uint8)
        pil_img = Image.fromarray(drawn_image)
        pil_img.save(os.path.join(output_dir, image_name))
def run_demo(cfg, ckpt, score_threshold, images_dir, dataset_type):
    if dataset_type == "voc":
        class_names = VOCDataset.class_names
    elif dataset_type == 'coco':
        class_names = COCODataset.class_names
    else:
        raise NotImplementedError('Not implemented now.')
    device = torch.device(cfg.MODEL.DEVICE)

    model = SSDDetector(cfg)
    model = model.to(device)
    checkpointer = CheckPointer(model, save_dir=cfg.OUTPUT_DIR)
    checkpointer.load(ckpt, use_latest=ckpt is None)
    weight_file = ckpt if ckpt else checkpointer.get_checkpoint_file()
    print('Loaded weights from {}'.format(weight_file))

    cpu_device = torch.device("cpu")
    transforms = build_transforms(cfg, is_train=False)
    model.eval()

    # CHANGE FROM HERE

    global hand_hist
    is_hand_hist_created = False
    capture = cv2.VideoCapture(0)

    ret, frame = capture.read()
    frame = cv2.flip(frame, 1)

    while capture.isOpened():

        pressed_key = cv2.waitKey(1)

        prev_frame = frame[:]

        ret, frame = capture.read()
        frame = cv2.flip(frame, 1)

        if pressed_key & 0xFF == ord('z'):
            hand_hist = hand_histogram(frame)
            is_hand_hist_created = True

        if not is_hand_hist_created:
            frame = draw_rect(frame)
            drawn_image = frame
        else:
            hist_mask_image = hist_masking_improved(frame, hand_hist)
            contour_list = contours(hist_mask_image)
            if len(contour_list) == 0:
                continue
            max_cont = max(contour_list, key=cv2.contourArea)

            # function to draw contours around skin/hand
            #cv2.drawContours(frame, [max_cont], -1, 0xFFFFFF, thickness=4)

            cnt_centroid = centroid(max_cont)
            #cv2.circle(frame, cnt_centroid, 5, [255, 0, 255], -1)

            if max_cont is not None:
                hull = cv2.convexHull(max_cont, returnPoints=False)
                defects = cv2.convexityDefects(max_cont, hull)

                if defects is not None and centroid is not None:  # Pointing detected
                    finger_tip = farthest_point(defects, max_cont,
                                                cnt_centroid)
                    print("--> Finger tip at", finger_tip)

                    height, width = prev_frame.shape[:2]
                    image = transforms(prev_frame)[0].unsqueeze(0)
                    result = model(image.to(device))[0]
                    result = result.resize(
                        (width, height)).to(cpu_device).numpy()
                    boxes, labels, scores = result['boxes'], result[
                        'labels'], result['scores']

                    if (len(boxes) != 0):
                        best_score = 0.2
                        id_final = 0
                        for i, box in enumerate(
                                boxes):  #(xmin, ymin, xmax, ymax)
                            if (box[0] < finger_tip[0] < box[2]
                                    and box[1] < finger_tip[1] <
                                    box[3]):  # bbox contains finger position
                                if scores[i] > best_score:  # best score
                                    best_score = scores[i]
                                    id_final = i

                        if (boxes[id_final][0] < finger_tip[0] <
                                boxes[id_final][2] and boxes[id_final][1] <
                                finger_tip[1] < boxes[id_final][3]
                                and scores[id_final] > 0.2):

                            drawn_image = draw(frame, boxes[id_final],
                                               labels[id_final],
                                               scores[id_final],
                                               class_names).astype(np.uint8)

                            #cv2.imshow("frame", drawn_image)

                        else:
                            continue

                cv2.circle(drawn_image, (finger_tip[0], finger_tip[1]), 1,
                           (255, 0, 0), 16)
        cv2.imshow("Live Feed", drawn_image)

        # for OpenCV major version < 3, manual calculation of frame rate for video feed might be required
        fps = capture.get(cv2.CAP_PROP_FPS)
        print(f"Frames per second using video.get(cv2.CAP_PROP_FPS) : {fps}")

    cv2.destroyAllWindows()
    capture.release()
示例#15
0
def run_demo(cfg, ckpt, score_threshold, images_dir, output_dir, dataset_type):
    if dataset_type == "voc":
        class_names = VOCDataset.class_names
    elif dataset_type == 'coco':
        class_names = COCODataset.class_names
    else:
        raise NotImplementedError('Not implemented now.')

    if torch.cuda.is_available():
        device = torch.device(cfg.MODEL.DEVICE)
    else:
        device = torch.device("cpu")

    model = SSDDetector(cfg)
    model = model.to(device)
    checkpointer = CheckPointer(model, save_dir=cfg.OUTPUT_DIR)
    checkpointer.load(ckpt, use_latest=ckpt is None)
    weight_file = ckpt if ckpt else checkpointer.get_checkpoint_file()
    print('Loaded weights from {}'.format(weight_file))

    cpu_device = torch.device("cpu")
    transforms = build_transforms(cfg, is_train=False)
    model.eval()

    # CHANGE FROM HERE
    capture = cv2.VideoCapture(0)
    while capture.isOpened():
        ret, frame = capture.read()
        image = cv2.flip(frame, 1)
        if ret:
            height, width = image.shape[:2]
            images = transforms(frame)[0].unsqueeze(0)

            result = model(images.to(device))[0]

            result = result.resize((width, height)).to(cpu_device).numpy()
            boxes, labels, scores = result['boxes'], result['labels'], result['scores']

            # filter predictions that do not overcome the score_threshold
            indices = scores > score_threshold
            boxes = boxes[indices]  # (xmin, ymin, xmax, ymax)
            labels = labels[indices]
            scores = scores[indices]

            if len(boxes) != 0:
                centers = np.apply_along_axis(get_mid_point, 1, boxes)
                start = time.time()
                dbscan_center = dbscan.DBSCAN(eps=37)
                dbscan_center.fit(centers)
                print("dbscan clusters", dbscan_center._labels)
                print(f"DBSCAN clustering time {round((time.time() - start) * 1000, 3)}ms")
                image = draw_points(image, centers)  # draw center points on image

            drawn_image = draw_boxes(image, boxes, labels, scores, class_names).astype(
                np.uint8)
            cv2.imshow("frame", drawn_image)

            key = cv2.waitKey(1)
            if key & 0xFF == ord('x'):
                break
        else:
            break
    cv2.destroyAllWindows()
    capture.release()
示例#16
0
def main():
    parser = argparse.ArgumentParser(description="ReID Baseline Training")
    parser.add_argument("--config_file", type=str)
    parser.add_argument("opts",
                        help="Modify config options using the command-line",
                        default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1

    if args.config_file != "":
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    output_dir = cfg.OUTPUT_DIR
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    logger = setup_logger("reid_baseline", output_dir, 0)
    logger.info("Using {} GPUS".format(num_gpus))
    logger.info(args)

    if args.config_file != "":
        logger.info("Loaded configuration file {}".format(args.config_file))
        with open(args.config_file, 'r') as cf:
            config_str = "\n" + cf.read()
            logger.info(config_str)
    logger.info("Running with config:\n{}".format(cfg))

    if cfg.MODEL.DEVICE == "cuda":
        os.environ[
            'CUDA_VISIBLE_DEVICES'] = cfg.MODEL.DEVICE_ID  # new add by gu
    cudnn.benchmark = True

    _1, _2, _3, num_classes = make_data_loader(cfg)
    model = build_model(cfg, num_classes)
    model.load_param(cfg.TEST.WEIGHT)

    # gpu_device
    device = cfg.MODEL.DEVICE

    if device:
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model.to(device)

    # test data-loader
    test_transforms = build_transforms(cfg, is_train=False)

    query_name = os.listdir(Q_ROOT)
    gallery_name = os.listdir(G_ROOT)

    dataset = [os.path.join(Q_ROOT, x) for x in query_name] + \
              [os.path.join(G_ROOT, x) for x in gallery_name]

    test_set = TestImageDataset(dataset=dataset, transform=test_transforms)

    test_loader = DataLoader(test_set,
                             batch_size=cfg.TEST.IMS_PER_BATCH,
                             shuffle=False,
                             num_workers=12,
                             collate_fn=test_collate_fn)

    result = []

    # _inference
    def _inference(batch):
        model.eval()
        with torch.no_grad():
            data = batch
            data = data.to(device) if torch.cuda.device_count() >= 1 else data
            feat = model(data)
            feat = feat.data.cpu().numpy()
            return feat

    count = 0
    for batch in test_loader:
        count += 1
        feat = _inference(batch)
        result.append(feat)

        if count % 100 == 0:
            print(count)

    result = np.concatenate(result, axis=0)

    query_num = len(query_name)
    query_feat = result[:query_num]
    gallery_feat = result[query_num:]

    pickle.dump([query_feat, query_name],
                open(cfg.OUTPUT_DIR + '/query_feature.feat', 'wb'))
    pickle.dump([gallery_feat, gallery_name],
                open(cfg.OUTPUT_DIR + '/gallery_feature.feat', 'wb'))
示例#17
0
def main():
    args = parse_args()
    cfg = get_default_cfg()
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()

    dataset = COCODataset(cfg.data.test[0], cfg.data.test[1])
    num_classes = dataset.num_classes
    label_map = dataset.labels
    model = EfficientDet(num_classes=num_classes, model_name=cfg.model.name)
    device = torch.device(cfg.device)
    model.to(device)
    model.eval()

    inp_size = model.config['inp_size']
    transforms = build_transforms(False, inp_size=inp_size)

    output_dir = cfg.output_dir
    checkpointer = Checkpointer(model, None, None, output_dir, True)
    checkpointer.load(args.ckpt)

    images = []
    if args.img:
        if osp.isdir(args.img):
            for filename in os.listdir(args.img):
                if is_valid_file(filename):
                    images.append(osp.join(args.img, filename))
        else:
            images = [args.img]

    for img_path in images:
        img = cv2.imread(img_path)
        img = inference(model,
                        img,
                        label_map,
                        score_thr=args.score_thr,
                        transforms=transforms)
        save_path = osp.join(args.save, osp.basename(img_path))
        cv2.imwrite(save_path, img)

    if args.vid:
        vCap = cv2.VideoCapture(args.v)
        fps = int(vCap.get(cv2.CAP_PROP_FPS))
        height = int(vCap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        width = int(vCap.get(cv2.CAP_PROP_FRAME_WIDTH))
        size = (width, height)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        save_path = osp.join(args.save, osp.basename(args.v))
        vWrt = cv2.VideoWriter(save_path, fourcc, fps, size)
        while True:
            flag, frame = vCap.read()
            if not flag:
                break
            frame = inference(model,
                              frame,
                              label_map,
                              score_thr=args.score_thr,
                              transforms=transforms)
            vWrt.write(frame)

        vCap.release()
        vWrt.release()
示例#18
0
    def __init__(self, cfg, device_ids):
        self.cfg = cfg
        self.net = init_extractor(cfg)

        self.net = nn.DataParallel(self.net.cuda(), device_ids=device_ids)
        self.norm = build_transforms(cfg, is_train=False)
示例#19
0
文件: P2PaLA.py 项目: fendaq/P2PaLA
def main():
    """
    """
    global_start = time.time()
    #--- init logger
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    #--- keep this logger at DEBUG level, until aguments are processed 
    ch.setLevel(logging.DEBUG)
    formatter = logging.Formatter('%(asctime)s - %(module)s - %(levelname)s - %(message)s')
    ch.setFormatter(formatter)
    logger.addHandler(ch)

    #--- handle Ctrl-C signal
    #signal.signal(signal.SIGINT,signal_handler)

    #--- Get Input Arguments
    in_args = arguments(logger)
    opts = in_args.parse()
    if check_inputs(opts,logger):
        logger.critical('Execution aborted due input errors...')
        exit(1)
    # create file handler which logs even debug messages
    fh = logging.FileHandler(opts.log_file, mode='a')
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    #--- restore ch logger to INFO
    ch.setLevel(logging.INFO)
    logger.debug(in_args)
    #--- Init torch random 
    #--- This two are suposed to be merged in the future, for now keep boot
    torch.manual_seed(opts.seed)
    torch.cuda.manual_seed_all(opts.seed)
    #--- Init model variable
    nnG = None
    bestState = None
    torch.set_default_tensor_type("torch.FloatTensor")
    #--- configure TensorBoard display
    opts.img_size = np.array(opts.img_size, dtype=np.int)
    #--------------------------------------------------------------------------
    #-----  TRAIN STEP
    #--------------------------------------------------------------------------
    if opts.do_train:
        train_start = time.time()
        logger.info('Working on training stage...')
        #--- display is used only on training step
        if not opts.no_display:
            import socket
            from datetime import datetime
            try:
                from tensorboardX import SummaryWriter
                if opts.use_global_log:
                    run_dir = opts.use_global_log
                else:
                    run_dir = os.path.join(opts.work_dir, 'runs')
                log_dir = os.path.join(run_dir, 
                               "".join([datetime.now().strftime('%b%d_%H-%M-%S'),
                               '_',socket.gethostname(), opts.log_comment]))
                                    
                writer = SummaryWriter(log_dir=log_dir) 
                logger.info('TensorBoard log will be stored at {}'.format(log_dir))
                logger.info('run: tensorboard --logdir {}'.format(run_dir))
            except:
                logger.warning('tensorboardX is not installed, display logger set to OFF.')
                opts.no_display = True
    
        #--- Build transforms
        transform = transforms.build_transforms(opts,train=True)
        #--- Get Train Data
        if opts.tr_img_list == '':
            logger.info('Preprocessing data from {}'.format(opts.tr_data))
            tr_data = dp.htrDataProcess(
                                         opts.tr_data,
                                         os.path.join(opts.work_dir,'data','train'),
                                         opts,
                                         logger=logger)
            tr_data.pre_process()
            opts.tr_img_list = tr_data.img_list
            opts.tr_label_list = tr_data.label_list
        else:
            logger.info('Reading data from pre-processed input {}'.format(opts.tr_img_list))
            tr_data = dp.htrDataProcess(
                                        opts.tr_data,
                                        os.path.join(opts.work_dir,'data','train'),
                                        opts,
                                        logger=logger)
            tr_data.set_img_list(opts.tr_img_list)
            tr_data.set_label_list(opts.tr_label_list)

        train_data = dataset.htrDataset(img_lst=opts.tr_img_list,
                                        label_lst=opts.tr_label_list,
                                        transform=transform,
                                        opts=opts)
        train_dataloader = DataLoader(train_data,
                                      batch_size=opts.batch_size,
                                      shuffle=opts.shuffle_data,
                                      num_workers=opts.num_workers,
                                      pin_memory=opts.pin_memory)
        #--- Get Val data, if needed
        if opts.do_val:
            if opts.val_img_list == '':
                logger.info('Preprocessing data from {}'.format(opts.val_data))
                va_data = dp.htrDataProcess(
                                             opts.val_data,
                                             os.path.join(opts.work_dir,'data','val/'),
                                             opts,
                                             logger=logger)
                va_data.pre_process()
                opts.val_img_list = va_data.img_list
                opts.val_label_list = va_data.label_list
            val_transform = transforms.build_transforms(opts,train=False)

            val_data = dataset.htrDataset(img_lst=opts.val_img_list,
                                          label_lst=opts.val_label_list,
                                          transform=val_transform,
                                          opts=opts)
            val_dataloader = DataLoader(val_data,
                                        batch_size=opts.batch_size,
                                        shuffle=False,
                                        num_workers=opts.num_workers,
                                        pin_memory=opts.pin_memory)

        #--- Build Models
        nnG = models.buildUnet(opts.input_channels,
                               opts.output_channels,
                               ngf=opts.cnn_ngf,
                               net_type=opts.net_out_type,
                               out_mode=opts.out_mode)
        #--- TODO: create a funtion @ models to define loss function
        #--- TODO: create a funtion @ models to define loss function
        if opts.do_class:
            lossG = loss_dic['NLL']
            opts.g_loss = 'NLL'
        else:
            lossG = loss_dic[opts.g_loss]
        #--- TODO: implement other initializadion methods
        optimizerG = optim.Adam(nnG.parameters(),
                                lr=opts.adam_lr,
                                betas=(opts.adam_beta1,opts.adam_beta2))
        if opts.cont_train:
            logger.info('Resumming training from model {}'.format(opts.prev_model))
            checkpoint = torch.load(opts.prev_model)
            nnG.load_state_dict(checkpoint['nnG_state'])
            optimizerG.load_state_dict(checkpoint['nnG_optimizer_state'])
            if not opts.g_loss == checkpoint['g_loss']:
                logger.warning(("Previous Model loss function differs from "
                                "current loss funtion {} != {}").format(
                                                                opts.g_loss,
                                                                checkpoint['g_loss']))
                logger.warning('Using {} loss funtion to resume training...'.format(opts.g_loss))
            if opts.use_gpu:
                nnG = nnG.cuda()
                lossG = lossG.cuda()
        else:
            #--- send to GPU before init weigths
            if opts.use_gpu:
                nnG = nnG.cuda()
                lossG = lossG.cuda()
            nnG.apply(models.weights_init_normal)
        logger.debug('GEN Network:\n{}'.format(nnG)) 
        logger.debug('GEN Network, number of parameters: {}'.format(nnG.num_params))

        if opts.use_gan:
            if opts.net_out_type == 'C':
                if opts.out_mode == 'LR':
                    d_out_ch = 2
                else:
                    d_out_ch = 1
            elif opts.net_out_type == 'R':
                d_out_ch = opts.output_channels
            else:
                pass
            nnD = models.buildDNet(opts.input_channels,
                                   d_out_ch,
                                   ngf=opts.cnn_ngf,
                                   n_layers=opts.gan_layers)
            lossD = torch.nn.BCELoss(size_average=True)
            optimizerD = optim.Adam(nnD.parameters(),
                                    lr=opts.adam_lr,
                                    betas=(opts.adam_beta1,opts.adam_beta2))
            if opts.cont_train:
                if 'nnD_state' in checkpoint:
                    nnD.load_state_dict(checkpoint['nnD_state'])
                    optimizerD.load_state_dict(checkpoint['nnD_optimizer_state'])
                else:
                    logger.warning('Previous model did not use GAN, but current does.')
                    logger.warning('Using new GAN model from scratch.')
                if opts.use_gpu:
                    nnD = nnD.cuda()
                    lossD = lossD.cuda()
                    #loss_lambda = loss_lambda.cuda()
            else:
                if opts.use_gpu:
                    nnD = nnD.cuda()
                    lossD = lossD.cuda()
                    #loss_lambda = loss_lambda.cuda()
                nnD.apply(models.weights_init_normal) 
            logger.debug('DIS Network:\n{}'.format(nnD)) 
            logger.debug('DIS Network, number of parameters: {}'.format(nnD.num_params))

        #--- Do the actual train
        #--- TODO: compute statistical boostrap to define if a model is
        #---    statistically better than previous
        best_val = np.inf
        best_tr = np.inf
        best_model = ''
        best_epoch = 0
        if opts.net_out_type == 'C' and opts.fix_class_imbalance:
            if opts.out_mode == 'LR':
                l_w = torch.from_numpy(train_data.w[0])
                r_w = torch.from_numpy(train_data.w[1])
                if opts.use_gpu:
                    l_w = l_w.type(torch.FloatTensor).cuda()
                    r_w = r_w.type(torch.FloatTensor).cuda()
                class_weight = [l_w,r_w]
                logger.debug('class weight: {}'.format(train_data.w))
            else:
                lossG.weight = torch.from_numpy(train_data.w).type(torch.FloatTensor).cuda()
                logger.debug('class weight: {}'.format(train_data.w))

        for epoch in range(opts.epochs):
            epoch_start = time.time()
            epoch_lossG = 0
            epoch_lossGAN = 0
            epoch_lossR = 0
            epoch_lossD = 0
            for batch,sample in enumerate(train_dataloader):
                #--- Reset Grads
                #nnG.apply(models.zero_bias)
                optimizerG.zero_grad()
                x = Variable(sample['image'], requires_grad=False)
                #y_gt_D = Variable(sample['label'].clone().type(torch.FloatTensor), requires_grad=False)
                y_gt = Variable(sample['label'], requires_grad=False)
                if opts.use_gpu:
                    x = x.cuda()
                    y_gt = y_gt.cuda()
                    #y_gt_D = y_gt_D.cuda()
                y_gen = nnG(x)
                if opts.out_mode == 'LR' and opts.net_out_type == 'C':
                    if (y_gen[0] != y_gen[0]).any() or (y_gen[1] != y_gen[1]).any():
                        logger.error('NaN values found in hypotesis')
                        logger.error("Inputs: {}".format(sample['id']))
                        raise RuntimeError 
                    y_l,y_r = torch.split(y_gt,1,dim=1)
                    if opts.fix_class_imbalance:
                        lossG.weight = class_weight[0]
                        g_loss = lossG(y_gen[0],torch.squeeze(y_l))
                        lossG.weight = class_weight[1]
                        g_loss += lossG(y_gen[1],torch.squeeze(y_r))
                    else:
                        g_loss = lossG(y_gen[0],torch.squeeze(y_l)) + lossG(y_gen[1],torch.squeeze(y_r))
                    #g_loss = lossG(y_gen[0],torch.squeeze(y_l)) + lossG(y_gen[1],torch.squeeze(y_r))
                else:
                    if (y_gen != y_gen).any():
                        logger.error('NaN values found in hypotesis')
                        logger.error("Inputs: {}".format(sample['id']))
                        raise RuntimeError 
                    g_loss = lossG(y_gen,y_gt)
                #--- reduce is not implemented, average is implemented in loss
                #--- function itself
                #g_loss = g_loss * (1/y_gen.data[0].numel())
                if opts.use_gan:
                    #nnD.apply(models.zero_bias)
                    optimizerD.zero_grad()
                    if opts.net_out_type == 'C':
                        if opts.out_mode == 'LR':
                            real_D = torch.cat([x,y_gt.type(torch.cuda.FloatTensor)],1)
                            #real_D = torch.cat([x,y_gt_D],1)
                            y_dis_real = nnD(real_D)
                            _, arg_l = torch.max(y_gen[0],dim=1,keepdim=True)
                            _, arg_r = torch.max(y_gen[1],dim=1,keepdim=True)
                            y_fake = torch.cat([arg_l,arg_r],1)
                            fake_D = torch.cat([x,y_fake.type(torch.cuda.FloatTensor)],1).detach()
                        elif opts.out_mode == 'L' or opts.out_mode == 'R':
                            real_D = torch.cat([x,torch.unsqueeze(y_gt.type(torch.cuda.FloatTensor),1)],1)
                            y_dis_real = nnD(real_D)
                            _, arg_y = torch.max(y_gen,dim=1)
                            fake_D = torch.cat([x,torch.unsqueeze(arg_y.type(torch.cuda.FloatTensor),1)],1).detach()
                        else:
                            pass
                    else:
                        real_D = torch.cat([x,y_gt.type(torch.cuda.FloatTensor)],1)
                        y_dis_real = nnD(real_D)
                        fake_D = torch.cat([x,y_gen],1).detach()
                    y_dis_fake = nnD(fake_D) 
                    label_D_size = y_dis_real.size()
                    real_y = Variable(torch.FloatTensor(label_D_size).fill_(1.0),
                                      requires_grad=False)
                    fake_y = Variable(torch.FloatTensor(label_D_size).fill_(0.0),
                                      requires_grad=False)
                    if opts.use_gpu:
                        real_y = real_y.cuda()
                        fake_y = fake_y.cuda()
                    d_loss_real = lossD(y_dis_real,real_y)
                    d_loss_fake = lossD(y_dis_fake,fake_y)
                    d_loss = (d_loss_real + d_loss_fake) * 0.5
                    epoch_lossD += d_loss.data[0]
                    d_loss.backward()
                    optimizerD.step()
                    if opts.net_out_type == 'C':
                        if opts.out_mode == 'LR':
                            _, arg_l = torch.max(y_gen[0],dim=1,keepdim=True)
                            _,arg_r = torch.max(y_gen[1],dim=1,keepdim=True)
                            y_fake = torch.cat([arg_l,arg_r],1)
                            g_fake = torch.cat([x,y_fake.type(torch.cuda.FloatTensor)],1)
                        elif opts.out_mode == 'L' or opts.out_mode == 'R':
                            _, arg_y = torch.max(y_gen,dim=1,keepdim=True)
                            g_fake = torch.cat([x,arg_y.type(torch.cuda.FloatTensor)],1)
                        else:
                            pass
                    else:
                        g_fake = torch.cat([x,y_gen],1)
                    g_y = nnD(g_fake)
                    shared_loss = lossD(g_y,real_y) 
                    epoch_lossR += shared_loss.data[0]
                    gan_loss = (shared_loss + (g_loss * opts.loss_lambda))
                else:
                    gan_loss = g_loss
                epoch_lossG += g_loss.data[0] / y_gt.data.size()[0]
                epoch_lossGAN += gan_loss.data[0] / y_gt.data.size()[0]
                gan_loss.backward()
                optimizerG.step()
            #--- forward pass val
            if opts.do_val:
                val_loss = 0
                for v_batch,v_sample in enumerate(val_dataloader):
                    #--- set vars to volatile, since bo backward used
                    v_img = Variable(v_sample['image'], volatile=True)
                    v_label = Variable(v_sample['label'], volatile=True)
                    if opts.use_gpu:
                        v_img = v_img.cuda()
                        v_label = v_label.cuda()
                    v_y = nnG(v_img)
                    if opts.out_mode == 'LR' and opts.net_out_type == 'C':
                        v_l,v_r = torch.split(v_label,1,dim=1)
                        v_loss = lossG(v_y[0],torch.squeeze(v_l)) + lossG(v_y[1],torch.squeeze(v_r))
                    else:
                        v_loss = lossG(v_y, v_label)
                    #v_loss = v_loss * (1/v_y.data[0].numel())
                    val_loss += v_loss.data[0] / v_label.data.size()[0]
                val_loss = val_loss/v_batch
            #--- Write to Logs
            if not opts.no_display:
                writer.add_scalar('train/lossGAN',epoch_lossGAN/batch,epoch)
                writer.add_scalar('train/lossG',epoch_lossG/batch,epoch)
                writer.add_text('LOG', 'End of epoch {0} of {1} time Taken: {2:.3f} sec'.format(
                             str(epoch),str(opts.epochs),
                             time.time()-epoch_start), epoch)
                if opts.use_gan:
                    writer.add_scalar('train/lossD',epoch_lossD/batch,epoch)
                    writer.add_scalar('train/D_loss_Real',epoch_lossR/batch,epoch)
                if opts.do_val:
                    writer.add_scalar('val/lossG',val_loss,epoch)
            #--- Save model under val or min loss
            if opts.do_val:
                if best_val >= val_loss:
                    best_epoch = epoch
                    state = {
                            'nnG_state':            nnG.state_dict(),
                            'nnG_optimizer_state':  optimizerG.state_dict(),
                            'g_loss':               opts.g_loss
                            }
                    if opts.use_gan:
                        state['nnD_state'] =            nnD.state_dict()
                        state['nnD_optimizer_state'] =  optimizerD.state_dict()
                    best_model = save_checkpoint(state, True, opts, logger, epoch,
                                                 criterion='val' + opts.g_loss)
                    logger.info("New best model, from {} to {}".format(best_val,val_loss))
                    best_val = val_loss
            else:
                if best_tr >= epoch_lossG:
                    best_epoch = epoch
                    state = {
                            'nnG_state':            nnG.state_dict(),
                            'nnG_optimizer_state':  optimizerG.state_dict(),
                            'g_loss':               opts.g_loss
                            }
                    if opts.use_gan:
                        state['nnD_state'] =            nnD.state_dict()
                        state['nnD_optimizer_state'] =  optimizerD.state_dict()
                    best_model = save_checkpoint(state, True, opts, logger, epoch,
                                                 criterion=opts.g_loss)
                    logger.info("New best model, from {} to {}".format(best_tr,epoch_lossG))
                    best_tr = epoch_lossG
            #--- Save checkpoint
            if epoch%opts.save_rate == 0 or epoch == opts.epochs - 1:
                #--- save current model, to test load func
                state = {
                        'nnG_state':            nnG.state_dict(),
                        'nnG_optimizer_state':  optimizerG.state_dict(),
                        'g_loss':               opts.g_loss
                        }
                if opts.use_gan:
                    state['nnD_state'] =            nnD.state_dict()
                    state['nnD_optimizer_state'] =  optimizerD.state_dict()
                best_model = save_checkpoint(state, False, opts, logger, epoch)
        
        logger.info('Trining stage done. total time taken: {}'.format(time.time()-train_start))
        #---- Train is done, next is to save validation inference
        if opts.do_val:
            logger.info('Working on validation inference...')
            res_path = os.path.join(opts.work_dir, 'results', 'val')
            try:
                os.makedirs(os.path.join(res_path,'page'))
                os.makedirs(os.path.join(res_path,'mask'))
            except OSError as exc:
                if exc.errno == errno.EEXIST and os.path.isdir(
                                    os.path.join(res_path,'page')):
                    pass
                else:
                    raise
            if opts.save_prob_mat:
                try:
                    os.makedirs(os.path.join(res_path,'prob_mat'))
                except OSError as exc:
                    if exc.errno == errno.EEXIST and os.path.isdir(res_path + '/prob_mat'):
                        pass
                    else:
                        raise
            #--- Set model to eval, to perform inference step 
            if best_epoch == epoch:
                nnG.eval()
                if opts.do_off:
                    nnG.apply(models.off_dropout)
            else:
                #--- load best model for inference
                checkpoint = torch.load(best_model)
                nnG.load_state_dict(checkpoint['nnG_state'])
                if opts.use_gpu:
                    nnG = nnG.cuda()
                nnG.eval()
                if opts.do_off:
                    nnG.apply(models.off_dropout)

            for v_batch,v_sample in enumerate(val_dataloader):
                #--- set vars to volatile, since no backward used
                v_img = Variable(v_sample['image'], volatile=True)
                v_label = Variable(v_sample['label'], volatile=True)
                v_ids = v_sample['id']
                if opts.use_gpu:
                    v_img = v_img.cuda()
                    v_label = v_label.cuda()
                v_y_gen = nnG(v_img)
                if opts.save_prob_mat:
                    for idx,data in enumerate(v_y_gen.data):
                        fh = open(res_path + '/prob_mat/' + v_ids[idx] + '.pickle', 'w')
                        pickle.dump(data.cpu().float().numpy(),fh,-1)
                        fh.close
                if opts.net_out_type == 'C':
                    if opts.out_mode == 'LR':
                        _, v_l = torch.max(v_y_gen[0],dim=1,keepdim=True)
                        _, v_r = torch.max(v_y_gen[1],dim=1,keepdim=True)
                        v_y_gen = torch.cat([v_l, v_r],1)
                    elif opts.out_mode == 'L' or opts.out_mode == 'R':
                        _, v_y_gen = torch.max(v_y_gen,dim=1,keepdim=True)
                    else:
                        pass
                elif opts.net_out_type == 'R':
                    pass
                else:
                    pass
                #--- save out as image for visual check
                #for idx,data in enumerate(v_label.data):
                #    img = tensor2img(data)
                #    cv2.imwrite(os.path.join(res_path,
                #                             'mask', v_ids[idx] +'_gt.png'),img)
                for idx,data in enumerate(v_y_gen.data):
                    #img = tensor2img(data)
                    #cv2.imwrite(os.path.join(res_path,
                    #                         'mask', v_ids[idx] +'_out.png'),img)
                    va_data.gen_page(v_ids[idx],
                                   data.cpu().float().numpy(),
                                   opts.regions,
                                   approx_alg=opts.approx_alg,
                                   num_segments=opts.num_segments,
                                   out_folder=res_path)
            #--- metrics are taked over the generated PAGE-XML files instead
            #--- of teh current data and label becouse image size may be different
            #--- than the processed image, then during evaluation final image
            #--- must be used
            va_results = page2page_eval.compute_metrics(va_data.hyp_xml_list,
                                                        va_data.gt_xml_list,
                                                        opts)
            logger.info('-'*10 + 'VALIDARION RESULTS SUMMARY' + '-'*10)
            logger.info(','.join(va_results.keys()))
            logger.info(','.join(str(x) for x in va_results.values()))
        if not opts.no_display:
            writer.close()
    
    #--------------------------------------------------------------------------
    #---    TEST INFERENCE
    #--------------------------------------------------------------------------
    if opts.do_test:
        logger.info('Working on test inference...')
        res_path = os.path.join(opts.work_dir, 'results', 'test')
        try:
            os.makedirs(os.path.join(res_path,'page'))
        except OSError as exc:
            if exc.errno == errno.EEXIST and os.path.isdir(res_path + '/page'):
                pass
            else:
                raise
        if opts.save_prob_mat:
            try:
                os.makedirs(os.path.join(res_path,'prob_mat'))
            except OSError as exc:
                if exc.errno == errno.EEXIST and os.path.isdir(res_path + '/prob_mat'):
                    pass
                else:
                    raise
        logger.info('Results will be saved to {}'.format(res_path))

        if nnG == None:
            #--- Load Model 
            nnG = models.buildUnet(opts.input_channels,
                                   opts.output_channels,
                                   ngf=opts.cnn_ngf,
                                   net_type=opts.net_out_type,
                                   out_mode=opts.out_mode)
            logger.info('Resumming from model {}'.format(opts.prev_model))
            checkpoint = torch.load(opts.prev_model)
            nnG.load_state_dict(checkpoint['nnG_state'])
            if opts.use_gpu:
                nnG = nnG.cuda()
            nnG.eval()
            if opts.do_off:
                nnG.apply(models.off_dropout)
            logger.debug('GEN Network:\n{}'.format(nnG)) 
            logger.debug('GEN Network, number of parameters: {}'.format(nnG.num_params))
        else:
            logger.debug('Using prevously loaded Generative module for test...')
            nnG.eval()
            if opts.do_off:
                nnG.apply(models.off_dropout)

        #--- get test data
        test_start_time = time.time()
        if opts.te_img_list == '':
            logger.info('Preprocessing data from {}'.format(opts.te_data))
            te_data = dp.htrDataProcess(
                                         opts.te_data,
                                         os.path.join(opts.work_dir,'data','test'),
                                         opts,
                                         logger=logger)
            te_data.pre_process()
            opts.te_img_list = te_data.img_list
            opts.te_label_list = te_data.label_list
        
        transform = transforms.build_transforms(opts,train=False)

        test_data = dataset.htrDataset(img_lst=opts.te_img_list,
                                        label_lst=opts.te_label_list,
                                        transform=transform,
                                        opts=opts)
        test_dataloader = DataLoader(test_data,
                                      batch_size=opts.batch_size,
                                      shuffle=opts.shuffle_data,
                                      num_workers=opts.num_workers,
                                      pin_memory=opts.pin_memory)
        for te_batch,sample in enumerate(test_dataloader):
            te_x = Variable(sample['image'], volatile=True)
            te_label = Variable(sample['label'], volatile=True)
            te_ids = sample['id']
            if opts.use_gpu:
                te_x = te_x.cuda()
                te_label = te_label.cuda()
            te_y_gen = nnG(te_x)
            if opts.save_prob_mat:
                for idx,data in enumerate(te_y_gen.data):
                    fh = open(res_path + '/prob_mat/' + te_ids[idx] + '.pickle', 'w')
                    pickle.dump(data.cpu().float().numpy(),fh,-1)
                    fh.close
            if opts.net_out_type == 'C':
                if opts.out_mode == 'LR':
                    _, te_l = torch.max(te_y_gen[0],dim=1,keepdim=True)
                    _, te_r = torch.max(te_y_gen[1],dim=1,keepdim=True)
                    te_y_gen = torch.cat([te_l, te_r],1)
                elif opts.out_mode == 'L' or opts.out_mode == 'R':
                    _, te_y_gen = torch.max(te_y_gen,dim=1,keepdim=True)
                else:
                    pass
            elif opts.net_out_type == 'R':
                pass
            else:
                pass

            for idx,data in enumerate(te_y_gen.data):
                #--- TODO: update this function to proccess C-dim tensors
                te_data.gen_page(te_ids[idx],
                                   data.cpu().float().numpy(),
                                   opts.regions,
                                   approx_alg=opts.approx_alg,
                                   num_segments=opts.num_segments,
                                   out_folder=res_path)
        test_end_time = time.time()
        logger.info('Test stage done. total time taken: {}'.format(test_end_time-test_start_time))
        logger.info('Average time per page: {}'.format((test_end_time-test_start_time)/test_data.__len__()))
        #--- metrics are taked over the generated PAGE-XML files instead
        #--- of teh current data and label becouse image size may be different
        #--- than the processed image, then during evaluation final image
        #--- must be used
        te_results = page2page_eval.compute_metrics(te_data.hyp_xml_list,
                                                    te_data.gt_xml_list,
                                                    opts, logger=logger) 
        logger.info('-'*10 + 'TEST RESULTS SUMMARY' + '-'*10)
        logger.info(','.join(te_results.keys()))
        logger.info(','.join(str(x) for x in te_results.values()))
    #--------------------------------------------------------------------------
    #---    PRODUCTION INFERENCE
    #--------------------------------------------------------------------------
    if opts.do_prod:
        logger.info('Working on prod inference...')
        res_path = os.path.join(opts.work_dir, 'results', 'prod')
        try:
            os.makedirs(os.path.join(res_path,'page'))
        except OSError as exc:
            if exc.errno == errno.EEXIST and os.path.isdir(res_path + '/page'):
                pass
            else:
                raise
        if opts.save_prob_mat:
            try:
                os.makedirs(os.path.join(res_path,'prob_mat'))
            except OSError as exc:
                if exc.errno == errno.EEXIST and os.path.isdir(res_path + '/prob_mat'):
                    pass
                else:
                    raise
        logger.info('Results will be saved to {}'.format(res_path))

        if nnG == None:
            #--- Load Model 
            nnG = models.buildUnet(opts.input_channels,
                                   opts.output_channels,
                                   ngf=opts.cnn_ngf,
                                   net_type=opts.net_out_type,
                                   out_mode=opts.out_mode)
            logger.info('Resumming from model {}'.format(opts.prev_model))
            checkpoint = torch.load(opts.prev_model)
            nnG.load_state_dict(checkpoint['nnG_state'])
            if opts.use_gpu:
                nnG = nnG.cuda()
            nnG.eval()
            if opts.do_off:
                nnG.apply(models.off_dropout)
            logger.debug('GEN Network:\n{}'.format(nnG)) 
            logger.debug('GEN Network, number of parameters: {}'.format(nnG.num_params))
        else:
            logger.debug('Using prevously loaded Generative module for prod...')
            nnG.eval()
            if opts.do_off:
                nnG.apply(models.off_dropout)

        #--- get prod data
        prod_start_time = time.time()
        pr_data = dp.htrDataProcess(
                                    opts.prod_data,
                                    os.path.join(opts.work_dir,'data','prod'),
                                    opts,
                                    build_labels=False,
                                    logger=logger)
        if opts.prod_img_list == '':
            logger.info('Preprocessing data from {}'.format(opts.prod_data))
            #pr_data = dp.htrDataProcess(
            #                             opts.prod_data,
            #                             os.path.join(opts.work_dir,'data','prod'),
            #                             opts,
            #                             build_labels=False,
            #                             logger=logger)
            pr_data.pre_process()
            opts.prod_img_list = pr_data.img_list
        else:
            logger.info('Loading pre-processed data from {}'.format(opts.prod_img_list))
            pr_data.set_img_list(opts.prod_img_list)

        
        transform = transforms.build_transforms(opts,train=False)

        prod_data = dataset.htrDataset(img_lst=opts.prod_img_list,
                                       transform=transform,
                                       opts=opts)
        prod_dataloader = DataLoader(prod_data,
                                      batch_size=opts.batch_size,
                                      shuffle=opts.shuffle_data,
                                      num_workers=opts.num_workers,
                                      pin_memory=opts.pin_memory)
        for pr_batch,sample in enumerate(prod_dataloader):
            pr_x = Variable(sample['image'], volatile=True)
            pr_ids = sample['id']
            if opts.use_gpu:
                pr_x = pr_x.cuda()
            pr_y_gen = nnG(pr_x)
            if opts.save_prob_mat:
                for idx,data in enumerate(pr_y_gen.data):
                    fh = open(res_path + '/prob_mat/' + pr_ids[idx] + '.pickle', 'w')
                    pickle.dump(data.cpu().float().numpy(),fh,-1)
                    fh.close
            if opts.net_out_type == 'C':
                if opts.out_mode == 'LR':
                    _, pr_l = torch.max(pr_y_gen[0],dim=1,keepdim=True)
                    _, pr_r = torch.max(pr_y_gen[1],dim=1,keepdim=True)
                    pr_y_gen = torch.cat([pr_l, pr_r],1)
                elif opts.out_mode == 'L' or opts.out_mode == 'R':
                    _, pr_y_gen = torch.max(pr_y_gen,dim=1,keepdim=True)
                else:
                    pass
            elif opts.net_out_type == 'R':
                pass
            else:
                pass
            for idx,data in enumerate(pr_y_gen.data):
                #--- TODO: update this function to proccess C-dim tensors at GPU
                pr_data.gen_page(pr_ids[idx],
                                   data.cpu().float().numpy(),
                                   opts.regions,
                                   approx_alg=opts.approx_alg,
                                   num_segments=opts.num_segments,
                                   out_folder=res_path)
        prod_end_time = time.time()
        logger.info('Production stage done. total time taken: {}'.format(prod_end_time-prod_start_time))
        logger.info('Average time per page: {}'.format((prod_end_time-prod_start_time)/prod_data.__len__()))

    logger.info('All Done...')