Beispiel #1
0
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)
        print("saving detection results in", det_file)

    print('Evaluating detections')
    evaluate_detections(all_boxes, output_dir, dataset)


def evaluate_detections(box_list, output_dir, dataset):
    write_voc_results_file(box_list, dataset)
    do_python_eval(output_dir)


if __name__ == '__main__':
    # load net
    num_classes = len(labelmap) + 1  # +1 for background
    net = build_bidet_ssd('test', 300, num_classes, nms_conf_thre=args.confidence_threshold,
                          nms_iou_thre=args.iou_threshold, nms_top_k=args.top_k)

    if args.cuda:
        net = net.cuda()
        cudnn.benchmark = True
    net.load_state_dict(torch.load(args.weight_path))
    net.eval()
    print('Finished loading model!')

    # load data
    dataset = VOCDetection(args.voc_root, [('2007', set_type)],
                           BaseTransform(300, dataset_mean),
                           VOCAnnotationTransform())

    # evaluation
    test_net(net, dataset)
Beispiel #2
0
def train():
    if args.dataset == 'COCO':
        cfg = coco
        dataset = COCODetection(root=args.data_root,
                                transform=SSDAugmentation(cfg['min_dim'], MEANS))
    elif args.dataset == 'VOC':
        cfg = voc
        dataset = VOCDetection(root=args.data_root,
                               transform=SSDAugmentation(cfg['min_dim'], MEANS))

    ssd_net = build_bidet_ssd('train', cfg['min_dim'], cfg['num_classes'],
                              nms_conf_thre=NMS_CONF_THRE)
    net = ssd_net

    if args.cuda:
        cudnn.benchmark = True

    if args.resume:
        print('Resuming training, loading {}...'.format(args.weight_path))
        ssd_net.load_weights(args.weight_path)
    else:
        if args.basenet.lower() != 'none':
            vgg_weights = torch.load(args.basenet)
            print('Loading base network...')
            ssd_net.vgg.layers.load_state_dict(vgg_weights, strict=True)

    if args.cuda:
        net = nn.DataParallel(ssd_net).cuda()

    if args.opt.lower() == 'SGD'.lower():
        optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr,
                              momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.opt.lower() == 'Adam'.lower():
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=args.lr,
                               weight_decay=args.weight_decay)
    else:
        exit(-1)
    optimizer.zero_grad()
    criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5, False, args.cuda)

    net.train()

    # loss counters
    loss_count = 0.  # for prior loss
    loc_loss_save = 0.
    conf_loss_save = 0.
    reg_loss_save = 0.
    prior_loss_save = 0.
    loss_l = 0.
    loss_c = 0.
    loss_r = 0.
    loss_p = 0.
    epoch = 0
    print('Loading the dataset...')

    epoch_size = len(dataset) // args.batch_size
    print('Training SSD on:', dataset.name)
    print('Using the specified args:')
    print(args)

    step_index = 0

    data_loader = data.DataLoader(dataset, args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True, collate_fn=detection_collate,
                                  pin_memory=True, drop_last=True)
    # create batch iterator
    batch_iterator = iter(data_loader)
    for iteration in range(args.start_iter, cfg['max_iter']):
        t0 = time.time()

        lr = get_lr(optimizer)

        if iteration % epoch_size == 0 and iteration != 0:
            # reset epoch loss counters
            epoch += 1

        if iteration in cfg['lr_steps']:
            if step_index == 0:
                args.reg_weight = 0.1
                args.prior_weight = 0.2
                REGULARIZATION_LOSS_WEIGHT = args.reg_weight
                PRIOR_LOSS_WEIGHT = args.prior_weight
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index)
            print("decay lr")

        # load train data
        try:
            images, targets = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(data_loader)
            images, targets = next(batch_iterator)

        if args.cuda:
            with torch.no_grad():
                images = Variable(images.float().cuda())
                targets = [Variable(ann.cuda()) for ann in targets]
        else:
            with torch.no_grad():
                images = Variable(images.float())
                targets = [Variable(ann) for ann in targets]

        batch_size = images.size(0)

        if PRIOR_LOSS_WEIGHT != 0.:
            gt_class = [targets[batch_idx][:, -1] for batch_idx in range(batch_size)]

        # forward
        out = net(images)

        loc_data, conf_data, priors, feature_map = out

        # sample loc data from predicted miu and sigma
        normal_dist = torch.randn(loc_data.size(0), loc_data.size(1), 4).float().cuda()
        log_sigma_2 = loc_data[:, :, :4]
        miu = loc_data[:, :, 4:]
        sigma = torch.exp(log_sigma_2 / 2.)
        sample_loc_data = normal_dist * sigma * args.sigma + miu
        loc_data = sample_loc_data

        out = (
            loc_data,
            conf_data,
            priors
        )

        # BP
        loss_l, loss_c = criterion(out, targets)
        loss_temp = loss_l + loss_c

        # COCO dataset bug, maybe due to wrong annotations?
        if loss_temp.item() == float("Inf"):
            print('inf loss error!')
            # the following code is to clear GPU memory for feature_map
            # I don't know other better ways to do so except for BP the loss
            loss_temp.backward()
            net.zero_grad()
            optimizer.zero_grad()
            torch.cuda.empty_cache()
            continue

        if PRIOR_LOSS_WEIGHT != 0.:
            loss_count = 0.

            detect_result = net.module.detect_prior.forward(
                loc_data,  # localization preds
                net.module.softmax(conf_data),  # confidence preds
                priors,  # default boxes
                gt_class
            )  # [batch, classes, top_k, 5 (score, (y1, x1, y2, x2))]

            num_classes = detect_result.size(1)

            # skip j = 0, because it's the background class
            for j in range(1, num_classes):
                all_dets = detect_result[:, j, :, :]  # [batch, top_k, 5]
                all_mask = all_dets[:, :, :1].gt(0.).expand_as(all_dets)  # [batch, top_k, 5]

                for batch_idx in range(batch_size):
                    # skip non-existed class
                    if not (gt_class[batch_idx] == j - 1).any():
                        continue

                    dets = torch.masked_select(all_dets[batch_idx], all_mask[batch_idx]).view(-1, 5)  # [num, 5]

                    if dets.size(0) == 0:
                        continue

                    # if pred num == gt num, skip
                    if dets.size(0) <= ((gt_class[batch_idx] == j - 1).sum().detach().cpu().item()):
                        continue

                    scores = dets[:, 0]  # [num]
                    scores_sum = scores.sum().item()  # no grad
                    scores = scores / scores_sum  # normalization
                    log_scores = log_func(scores)
                    gt_num = (gt_class[batch_idx] == j - 1).sum().detach().cpu().item()
                    loss_p += (-1. * log_scores.sum() / float(gt_num))
                    loss_count += 1.

            loss_p /= (loss_count + 1e-6)
            loss_p *= PRIOR_LOSS_WEIGHT

        # Calculate regularization loss on feature maps
        # directly use L2 loss here
        if REGULARIZATION_LOSS_WEIGHT != 0.:
            f_num = len(feature_map)
            loss_r = 0.

            for f_m in feature_map:
                loss_r += (f_m ** 2).mean()

            loss_r *= REGULARIZATION_LOSS_WEIGHT
            loss_r /= float(f_num)

        loss = loss_l + loss_c + loss_r + loss_p

        # COCO dataset bug, maybe due to wrong annotations?
        if loss.item() == float("Inf"):
            print('inf loss error!')
            # the following code is to clear GPU memory for feature_map
            # I don't know other better ways to do so except for BP the loss
            loss.backward()
            net.zero_grad()
            optimizer.zero_grad()
            torch.cuda.empty_cache()
            continue

        # compute gradient and do optimizer step
        loss.backward()
        # clip gradient because binary net training is very unstable
        if args.clip_grad:
            grad_norm = get_grad_norm(net)
            nn.utils.clip_grad_norm_(net.parameters(), GRADIENT_CLIP_NORM)
        optimizer.step()
        optimizer.zero_grad()

        loss_l = loss_l.detach().cpu().item()
        loss_c = loss_c.detach().cpu().item()
        if REGULARIZATION_LOSS_WEIGHT != 0.:
            loss_r = loss_r.detach().cpu().item()
        if PRIOR_LOSS_WEIGHT != 0.:
            loss_p = loss_p.detach().cpu().item()
        loc_loss_save += loss_l
        conf_loss_save += loss_c
        reg_loss_save += loss_r
        prior_loss_save += loss_p
        t1 = time.time()

        if iteration % 100 == 0:
            print('timer: %.4f sec.' % (t1 - t0))
            print('iter:', iteration, 'loss:', round(loss.detach().cpu().item(), 4))
            print('conf_loss:', round(loss_c, 4), 'loc_loss:', round(loss_l, 4),
                  'reg_loss:', round(loss_r, 4), 'prior_loss:', round(loss_p, 4),
                  'lr:', lr)
            if args.clip_grad:
                print('gradient norm:', grad_norm)
            torch.cuda.empty_cache()

        if iteration != 0 and iteration % 5000 == 0:
            print('Saving state, iter:', iteration)

            loss_save = loc_loss_save + conf_loss_save + reg_loss_save + prior_loss_save
            torch.save(net.module.state_dict(), logs_dir + '/model_' + str(iteration) +
                       '_loc_' + str(round(loc_loss_save / 5000., 4)) +
                       '_conf_' + str(round(conf_loss_save / 5000., 4)) +
                       '_reg_' + str(round(reg_loss_save / 5000., 4)) +
                       '_prior_' + str(round(prior_loss_save / 5000., 4)) +
                       '_loss_' + str(round(loss_save / 5000., 4)) +
                       '_lr_' + str(round(args.lr * (args.gamma ** step_index), 6)) + '.pth')

            loc_loss_save = 0.
            conf_loss_save = 0.
            reg_loss_save = 0.
            prior_loss_save = 0.

        loss_l = 0.
        loss_c = 0.
        loss_r = 0.
        loss_p = 0.
        loss_count = 0.

    torch.save(net.module.state_dict(), logs_dir + '/' + args.dataset + '_final.pth')