Ejemplo n.º 1
0
def run_test():
    print('Loading model..')
    net = RetinaNet(args.num_classes)

    ckpt = torch.load(args.checkpoint)
    net.load_state_dict(ckpt['net'])
    net.eval()
    net.cuda()

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    print('Loading image..')
    img = Image.open(args.img_path)
    w, h = img.size

    print('Predicting..')
    x = transform(img)
    x = x.unsqueeze(0)
    with torch.no_grad():
        loc_preds, cls_preds = net(x.cuda())

        print('Decoding..')
        encoder = DataEncoder()
        boxes, labels, scores = encoder.decode(loc_preds.cpu().data.squeeze(),
                                               cls_preds.cpu().data.squeeze(),
                                               (w, h))

        label_map = load_pickled_label_map()

        draw = ImageDraw.Draw(img, 'RGBA')
        fnt = ImageFont.truetype('Pillow/Tests/fonts/DejaVuSans.ttf', 11)
        for idx in range(len(boxes)):
            box = boxes[idx]
            label = labels[idx]
            draw.rectangle(list(box), outline=(255, 0, 0, 200))

            item_tag = '{0}: {1:.2f}'.format(label_map[label.item()],
                                             scores[idx])
            iw, ih = fnt.getsize(item_tag)
            ix, iy = list(box[:2])
            draw.rectangle((ix, iy, ix + iw, iy + ih), fill=(255, 0, 0, 100))
            draw.text(list(box[:2]),
                      item_tag,
                      font=fnt,
                      fill=(255, 255, 255, 255))

        img.save(os.path.join('./rst', 'rst.png'), 'PNG')
Ejemplo n.º 2
0
def test():
    print('initializing network...')
    network = RetinaNet(3, 10, 9)
    checkpoint = torch.load(args.pth)
    network.load_state_dict(checkpoint['net'])
    network = network.cuda().eval()
    if args.onnx:
        dummy_input = torch.randn(1, 3, 416, 416, device='cuda')
        torch.onnx.export(network, dummy_input, "retina-bdd.onnx", verbose=True)
        return
    class_names = ["bus","traffic light","traffic sign","person","bike","truck","motor","car","train","rider"]

    image = args.img
    img_tail =  image.split('.')[-1] 
    if img_tail == 'jpg' or img_tail =='jpeg' or img_tail == 'png':
        detect_image(image, network, args.thresh, class_names)   
    elif img_tail == 'mp4' or img_tail =='mkv' or img_tail == 'avi' or img_tail =='0':
        detect_vedio(image, network, args.thresh, class_names)
    else:
        print('unknow image type!!!')
Ejemplo n.º 3
0
def train():
    max_epoch = 120
    lr = 0.001
    step_epoch = 50
    lr_decay = 0.1
    train_batch_size = 64
    val_batch_size = 16
    if args.vis:
        vis = visdom.Visdom(env=u'test1')
    #dataset
    print('importing dataset...')
    substep = args.substep
    trainset = bdd.bddDataset(416, 416)
    loader_train = data.DataLoader(trainset,
                                   batch_size=train_batch_size // substep,
                                   shuffle=1,
                                   num_workers=4,
                                   drop_last=True)
    valset = bdd.bddDataset(416, 416, train=0)
    loader_val = data.DataLoader(valset,
                                 batch_size=val_batch_size // substep,
                                 shuffle=1,
                                 num_workers=4,
                                 drop_last=True)
    #model
    print('initializing network...')
    network = RetinaNet(3, 10, 9)
    if args.resume:
        print('Resuming from checkpoint..')
        checkpoint = torch.load('./checkpoint/retina-bdd-backup.pth')
        network.load_state_dict(checkpoint['net'])
        best_loss = checkpoint['loss']
        start_epoch = checkpoint['epoch']
    else:
        start_epoch = 0
    if args.ngpus > 1:
        net = torch.nn.DataParallel(network).cuda()
    else:
        net = network.cuda()
    #criterion
    criterion = FocalLoss(10, 4, 9)
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)
    #start training
    for i in range(start_epoch, max_epoch):
        print('--------start training epoch %d --------' % i)
        trainset.seen = 0
        valset.seen = 0
        loss_train = 0.0
        net.train()
        t0 = time.time()
        optimizer.zero_grad()
        for ii, (image, cls_truth, box_truth) in enumerate(loader_train):
            image = Variable(image).cuda()
            cls_truth = Variable(cls_truth).cuda()
            box_truth = Variable(box_truth).cuda()
            #forward
            cls_pred, box_pred = net(image)
            #loss
            loss = criterion(cls_pred, box_pred, cls_truth, box_truth)
            #backward
            loss.backward()
            #update
            if (ii + 1) % substep == 0:
                optimizer.step()
                optimizer.zero_grad()
            loss_train += loss.data
            #print('forward time: %f, loss time: %f, backward time: %f, update time: %f'%((t1-t0),(t2-t1),(t3-t2),(t4-t3)))
            print('%3d/%3d => loss: %f, cls_loss: %f, box_loss: %f' %
                  (ii, i, criterion.loss, criterion.cls_loss,
                   criterion.box_loss))
            if args.vis:
                vis.line(Y=loss.data.cpu().view(1, 1).numpy(),
                         X=np.array([ii]),
                         win='loss',
                         update='append' if ii > 0 else None)
        t1 = time.time()
        print('---one training epoch time: %fs---' % ((t1 - t0)))
        if i < 3:
            loss_train = loss.data
        else:
            loss_train = loss_train / ii
        loss_val = 0.0
        net.eval()
        for jj, (image, cls_truth, box_truth) in enumerate(loader_val):
            image = Variable(image).cuda()
            cls_truth = Variable(cls_truth).cuda()
            box_truth = Variable(box_truth).cuda()
            optimizer.zero_grad()
            cls_pred, box_pred = net(image)
            loss = criterion(cls_pred, box_pred, cls_truth, box_truth)
            loss_val += loss.data
            print('val: %3d/%3d => loss: %f, cls_loss: %f, box_loss: %f' %
                  (jj, i, criterion.loss, criterion.cls_loss,
                   criterion.box_loss))
        loss_val = loss_val / jj
        if args.vis:
            vis.line(Y=torch.cat((loss_val.view(1,1), loss_train.view(1,1)),1).cpu().numpy(),X=np.array([i]),\
                        win='eval-train loss',update='append' if i>0 else None)
        print('Saving weights...')
        if args.ngpus > 1:
            state = {
                'net': net.module.state_dict(),
                'loss': loss_val,
                'epoch': i,
            }
        else:
            state = {
                'net': net.state_dict(),
                'loss': loss_val,
                'epoch': i,
            }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        if (i + 1) % 10 == 0:
            torch.save(state, './checkpoint/retina-bdd-%03d.pth' % i)
        torch.save(state, './checkpoint/retina-bdd-backup.pth')
        if (i + 1) % step_epoch == 0:
            lr = lr * lr_decay
            print('learning rate: %f' % lr)
            optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-4)
    torch.save(network, 'retina-bdd_final.pkl')
    print('finished training!!!')
Ejemplo n.º 4
0
def run_train():
    assert torch.cuda.is_available(), 'Error: CUDA not found!'
    start_epoch = 0  # start from epoch 0 or last epoch

    # Data
    print('Load ListDataset')
    transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    trainset = ListDataset(img_dir=config.img_dir,
                           list_filename=config.train_list_filename,
                           label_map_filename=config.label_map_filename,
                           train=True,
                           transform=transform,
                           input_size=config.img_res)
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=8,
        collate_fn=trainset.collate_fn)

    testset = ListDataset(img_dir=config.img_dir,
                          list_filename=config.test_list_filename,
                          label_map_filename=config.label_map_filename,
                          train=False,
                          transform=transform,
                          input_size=config.img_res)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=config.test_batch_size,
                                             shuffle=False,
                                             num_workers=8,
                                             collate_fn=testset.collate_fn)

    # Model
    net = RetinaNet()

    if os.path.exists(config.checkpoint_filename):
        print('Load saved checkpoint: {}'.format(config.checkpoint_filename))
        checkpoint = torch.load(config.checkpoint_filename)
        net.load_state_dict(checkpoint['net'])
        best_loss = checkpoint['loss']
        start_epoch = checkpoint['epoch']
    else:
        print('Load pretrained model: {}'.format(config.pretrained_filename))
        if not os.path.exists(config.pretrained_filename):
            import_pretrained_resnet()
        net.load_state_dict(torch.load(config.pretrained_filename))

    net = torch.nn.DataParallel(net,
                                device_ids=range(torch.cuda.device_count()))
    net.cuda()

    criterion = FocalLoss()
    optimizer = optim.SGD(net.parameters(),
                          lr=1e-3,
                          momentum=0.9,
                          weight_decay=1e-4)

    # Training
    def train(epoch):
        print('\nEpoch: %d' % epoch)
        net.train()
        net.module.freeze_bn()
        train_loss = 0

        total_batches = int(
            math.ceil(trainloader.dataset.num_samples /
                      trainloader.batch_size))

        for batch_idx, targets in enumerate(trainloader):
            inputs = targets[0]
            loc_targets = targets[1]
            cls_targets = targets[2]

            inputs = inputs.cuda()
            loc_targets = loc_targets.cuda()
            cls_targets = cls_targets.cuda()

            optimizer.zero_grad()
            loc_preds, cls_preds = net(inputs)
            loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.data
            print('[%d| %d/%d] loss: %.3f | avg: %.3f' %
                  (epoch, batch_idx, total_batches, loss.data, train_loss /
                   (batch_idx + 1)))

    # Test
    def test(epoch):
        print('\nTest')
        net.eval()
        test_loss = 0

        total_batches = int(
            math.ceil(testloader.dataset.num_samples / testloader.batch_size))

        for batch_idx, targets in enumerate(testloader):
            inputs = targets[0]
            loc_targets = targets[1]
            cls_targets = targets[2]

            inputs = inputs.cuda()
            loc_targets = loc_targets.cuda()
            cls_targets = cls_targets.cuda()

            loc_preds, cls_preds = net(inputs)
            loss = criterion(loc_preds, loc_targets, cls_preds, cls_targets)
            test_loss += loss.data
            print('[%d| %d/%d] loss: %.3f | avg: %.3f' %
                  (epoch, batch_idx, total_batches, loss.data, test_loss /
                   (batch_idx + 1)))

        # Save checkpoint
        global best_loss
        test_loss /= len(testloader)
        if test_loss < best_loss:
            print('Save checkpoint: {}'.format(config.checkpoint_filename))
            state = {
                'net': net.module.state_dict(),
                'loss': test_loss,
                'epoch': epoch,
            }
            if not os.path.exists(os.path.dirname(config.checkpoint_filename)):
                os.makedirs(os.path.dirname(config.checkpoint_filename))
            torch.save(state, config.checkpoint_filename)
            best_loss = test_loss

    for epoch in range(start_epoch, start_epoch + 1000):
        train(epoch)
        test(epoch)