예제 #1
0
        test_opts.num_classes = 20

    if test_opts.data_type == 'coco':
        test_set = COCO_Dataset(root=test_opts.data_root,
                                set_name='val2017',
                                split='test',
                                resize=600)
        test_opts.num_classes = 80

    # 5. data loader
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              collate_fn=test_set.collate_fn,
                                              shuffle=False,
                                              num_workers=0)
    # 6. network
    model = RetinaNet(num_classes=test_opts.num_classes).to(device)
    model = torch.nn.DataParallel(module=model, device_ids=device_ids)
    coder = RETINA_Coder(opts=test_opts)

    # 7. loss
    criterion = Focal_Loss(coder)

    test(epoch=test_opts.epoch,
         vis=vis,
         test_loader=test_loader,
         model=model,
         criterion=criterion,
         coder=coder,
         opts=test_opts)
예제 #2
0
            print('writing {}...'.format(detection_file))
            detections = {'annotations': detections}
            detections['images'] = data_iterator.coco.dataset['images']
            detections['categories'] = [
                data_iterator.coco.dataset['categories']
            ]
            json.dump(detections, open(detection_file, 'w'), indent=4)

            print('evaluating model...')
            coco_pred = data_iterator.coco.loadRes(detections['annotations'])
            coco_eval = COCOeval(data_iterator.coco, coco_pred, 'bbox')
            coco_eval.evaluate()
            coco_eval.accumulate()
            coco_eval.summarize()
        else:
            print('no detections!')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", default=0, type=int)
    parser.add_argument("--epoch", default='final', type=str)
    args = parser.parse_args()
    torch.cuda.set_device(args.local_rank)
    torch.distributed.init_process_group(backend='nccl', init_method='env://')
    model = RetinaNet(state_dict_path=resnet_dir, stride=stride)
    if args.local_rank == 0:
        print('FPN initialized!')
    infer(model, args)
예제 #3
0
        print('preparing dataset...')
    data_iterator = DataIterator(coco_dir,
                                 resize=resize,
                                 max_size=max_size,
                                 batch_size=batch_size,
                                 stride=stride,
                                 training=training,
                                 dist=dist)
    if rank == 0:
        print('finish loading dataset!')

    results = []
    with torch.no_grad():
        for i, (data, ids, ratios) in enumerate(data_iterator, start=1):
            scores, boxes, classes = model(data)
            results.append([scores, boxes, classes, ids, ratios])
            if rank == 0:
                size = len(data_iterator.ids)
                msg = '[{:{len}}/{}]'.format(min(i * batch_size, size),
                                             size,
                                             len=len(str(size)))
                print(msg, flush=True)

    results = [torch.cat(r, dim=0) for r in zip(*results)]
    results = [r.cpu() for r in results]


if __name__ == '__main__':
    model = RetinaNet(state_dict_path=resnet_dir)
    infer(model)
예제 #4
0
    parser.add_argument('--data_type',
                        type=str,
                        default='coco',
                        help='choose voc or coco')
    parser.add_argument('--num_classes', type=int, default=80)
    demo_opts = parser.parse_args()
    print(demo_opts)

    if demo_opts.data_type == 'voc':
        demo_opts.n_classes = 20

    elif demo_opts.data_type == 'coco':
        demo_opts.n_classes = 80

    model = RetinaNet(num_classes=demo_opts.num_classes)

    model = torch.nn.DataParallel(module=model, device_ids=device_ids)

    # use custom training pth file
    checkpoint = torch.load(
        os.path.join(demo_opts.save_path, demo_opts.save_file_name) +
        '.{}.pth.tar'.format(demo_opts.epoch),
        map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'], strict=True)

    model = model.to(device)
    model.eval()

    coder = RETINA_Coder(opts=demo_opts)
    demo(demo_opts, coder, model, 'jpg')
예제 #5
0
from dataset import CocoDataset
from model import RetinaNet

if __name__ == '__main__':
    coco = CocoDataset()
    item = coco.__getitem__(0)
    net = RetinaNet()
    losses = net(item['img'].data.unsqueeze(0), item['img_meta'].data,
                 item['gt_bboxes'].data, item['gt_labels'].data)
    print(losses)
예제 #6
0
def main():

    # 1. argparser
    opts = parse(sys.argv[1:])
    print(opts)

    # 3. visdom
    vis = visdom.Visdom(port=opts.port)
    # 4. data set
    train_set = None
    test_set = None

    if opts.data_type == 'voc':
        train_set = VOC_Dataset(root=opts.data_root, split='train', resize=opts.resize)
        test_set = VOC_Dataset(root=opts.data_root, split='test', resize=opts.resize)
        opts.num_classes = 20

    elif opts.data_type == 'coco':
        train_set = COCO_Dataset(root=opts.data_root, set_name='train2017', split='train', resize=opts.resize)
        test_set = COCO_Dataset(root=opts.data_root, set_name='val2017', split='test', resize=opts.resize)
        opts.num_classes = 80

    # 5. data loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=opts.batch_size,
                                               collate_fn=train_set.collate_fn,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=1,
                                              collate_fn=test_set.collate_fn,
                                              shuffle=False,
                                              num_workers=2,
                                              pin_memory=True)

    # 6. network
    model = RetinaNet(num_classes=opts.num_classes).to(device)
    model = torch.nn.DataParallel(module=model, device_ids=device_ids)
    coder = RETINA_Coder(opts=opts)  # there is center_anchor in coder.

    # 7. loss
    criterion = Focal_Loss(coder=coder)

    # 8. optimizer
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=opts.lr,
                                momentum=opts.momentum,
                                weight_decay=opts.weight_decay)

    # 9. scheduler
    scheduler = MultiStepLR(optimizer=optimizer, milestones=[30, 45], gamma=0.1)

    # 10. resume
    if opts.start_epoch != 0:

        checkpoint = torch.load(os.path.join(opts.save_path, opts.save_file_name) + '.{}.pth.tar'
                                .format(opts.start_epoch - 1), map_location=device)        # 하나 적은걸 가져와서 train
        model.load_state_dict(checkpoint['model_state_dict'])                              # load model state dict
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])                      # load optim state dict
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])                      # load sched state dict
        print('\nLoaded checkpoint from epoch %d.\n' % (int(opts.start_epoch) - 1))

    else:

        print('\nNo check point to resume.. train from scratch.\n')

    # for statement
    for epoch in range(opts.start_epoch, opts.epoch):

        # 11. train
        train(epoch=epoch,
              vis=vis,
              train_loader=train_loader,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              scheduler=scheduler,
              opts=opts)

        # 12. test
        test(epoch=epoch,
             vis=vis,
             test_loader=test_loader,
             model=model,
             criterion=criterion,
             coder=coder,
             opts=opts)

        scheduler.step()
예제 #7
0
from dataloader import FaceMask
from box import LabelEncoder

num_classes = 2
classes_name = ['face', 'mask']
batch_size = 2
label_encoder = LabelEncoder()

learning_rates = [2.5e-06, 0.000625, 0.00125, 0.0025, 0.00025, 2.5e-05]
learning_rate_boundaries = [125, 250, 500, 240000, 360000]
learning_rate_fn = tf.optimizers.schedules.PiecewiseConstantDecay(
    boundaries=learning_rate_boundaries, values=learning_rates)

resnet50_backbone = get_backbone()
loss_fn = RetinaNetLoss(num_classes)
model = RetinaNet(num_classes, resnet50_backbone)

optimizer = tf.optimizers.SGD(learning_rate=learning_rate_fn, momentum=0.9)
model.compile(loss=loss_fn, optimizer=optimizer)

train_img_paths = sorted(glob('./facedataset/images/train/*.jpg'))
train_labels_paths = sorted(glob('./facedataset/labels/train/*.txt'))
val_img_paths = sorted(glob('./facedataset/images/val/*.jpg'))
val_labels_paths = sorted(glob('./facedataset/labels/val/*.txt'))

train_gen = FaceMask(train_img_paths, train_labels_paths)
train_dataset = tf.data.Dataset.from_generator(
    lambda: train_gen,
    output_types=(tf.float32, tf.float32, tf.int32),
    output_shapes=((640, 640, 3), (None, 4), (None, )))
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

from model import get_backbone, RetinaNet, DecodePredictions
from dataloader import FaceMask, resize_and_pad_image
from glob import glob

weights_dir = "retinanet"
classes_name = ['face', 'mask']
resnet50_backbone = get_backbone()
model = RetinaNet(2, resnet50_backbone)

latest_checkpoint = tf.train.latest_checkpoint(weights_dir)
model.load_weights(latest_checkpoint)

image = tf.keras.Input(shape=[None, None, 3], name="image")
predictions = model(image, training=False)
detections = DecodePredictions(confidence_threshold=0.5)(image, predictions)
inference_model = tf.keras.Model(inputs=image, outputs=detections)


def visualize_detections(image,
                         boxes,
                         classes,
                         scores,
                         figsize=(7, 7),
                         linewidth=1,
                         color=[0, 1, 1]):