示例#1
0
def setup(file):
    # get cfg
    cfg = get_cfg()
    cfg.merge_from_file(file)
    cfg.SOLVER.IMS_PER_BATCH = 2

    # get data loader iter
    data_loader = build_detection_train_loader(cfg)
    data_loader_iter = iter(data_loader)
    batched_inputs = next(data_loader_iter)

    # build anchors
    backbone = build_backbone(cfg).to(device)
    images = [x["image"].to(device) for x in batched_inputs]
    images = ImageList.from_tensors(images, backbone.size_divisibility)
    features = backbone(images.tensor.float())

    input_shape = backbone.output_shape()
    in_features = cfg.MODEL.RPN.IN_FEATURES
    anchor_generator = build_anchor_generator(
        cfg, [input_shape[f] for f in in_features])
    anchors = anchor_generator([features[f] for f in in_features])
    anchors = Boxes.cat(anchors).to(device)

    # build matcher
    raw_matcher = Matcher(cfg.MODEL.RPN.IOU_THRESHOLDS,
                          cfg.MODEL.RPN.IOU_LABELS,
                          allow_low_quality_matches=True)
    matcher = TopKMatcher(cfg.MODEL.RPN.IOU_THRESHOLDS,
                          cfg.MODEL.RPN.IOU_LABELS, 9)

    return cfg, data_loader_iter, anchors, matcher, raw_matcher
示例#2
0
def setup(args):
    cfg = get_cfg()
    if args.config_file:
        cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    return cfg
示例#3
0
def test_dataloader():
    from slender_det.data import build_detection_test_loader
    from slender_det.config import get_cfg

    cfg = get_cfg()
    cfg.DATASETS.TEST = ("objects365_val", )

    data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0])
    data_iter = iter(data_loader)

    data = next(data_iter)
    ipdb.set_trace()

    for data in data_iter:
        for dataset_dict in data:
            img = dataset_dict['image']
            img = convert_image_to_rgb(img.permute(1, 2, 0), cfg.INPUT.FORMAT)
            v_gt = Visualizer(img, None)
            v_gt = v_gt.overlay_instances(
                boxes=dataset_dict["instances"].gt_boxes,
                masks=dataset_dict["instances"].gt_masks,
            )
            v_gt.save('/data/tmp/vis_obj365_val_{}.png'.format(
                dataset_dict['image_id']))

            ipdb.set_trace()
示例#4
0
def setup(file):
    # get cfg
    cfg = get_cfg()
    cfg.merge_from_file(file)
    cfg.SOLVER.IMS_PER_BATCH = 2

    return cfg
示例#5
0
def test_model(cfg_file):
    # get cfg
    cfg = get_cfg()
    cfg.merge_from_file(cfg_file)
    cfg.SOLVER.IMS_PER_BATCH = 1

    # get model
    device = torch.device("cuda")
    model = build_model(cfg).to(device)
    ipdb.set_trace()
示例#6
0
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg
示例#7
0
def test_training(cfg_file):
    # get cfg
    cfg = get_cfg()
    cfg.merge_from_file(cfg_file)
    cfg.SOLVER.IMS_PER_BATCH = 2

    # get batch data
    data_loader = build_detection_train_loader(cfg)
    data_loader_iter = iter(data_loader)
    data = next(data_loader_iter)
    print(len(data))

    # get model
    device = torch.device("cuda")
    model = build_model(cfg).to(device)

    model.train()
    outs = model(data[:2])

    ipdb.set_trace()
示例#8
0
import fire

import torch

from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.checkpoint import DetectionCheckpointer

import init_paths
from slender_det.config import get_cfg
from slender_det.data import BorderMaskMapper
from slender_det.modeling.meta_arch.fcos import FCOS, FCOSHead

# get cfg
cfg = get_cfg()
cfg.merge_from_file("configs/fcos/Base-Fcos.yaml")

# get model
device = torch.device("cuda")
model = FCOS(cfg).to(device)

# get batch data
data_loader = build_detection_train_loader(cfg)
data_loader_iter = iter(data_loader)
data = next(data_loader_iter)


def test_training():
    model.train()
    outs = model(data[:2])

    import pdb
示例#9
0
def load_cfg(cfg_file):
    cfg = get_cfg()
    cfg.merge_from_file(cfg_file)
    return cfg