Пример #1
0
                        0,
                        confidence_threshold=args.confidence_threshold,
                        nms_threshold=args.nms_threshold,
                        top_k=args.top_k,
                        keep_top_k=args.keep_top_k)
    else:
        detect = Detect_RefineDet(
            num_classes,
            int(args.input_size),
            0,
            objectness_threshold,
            confidence_threshold=args.confidence_threshold,
            nms_threshold=args.nms_threshold,
            top_k=args.top_k,
            keep_top_k=args.keep_top_k)
    net = build_refinedet('test', int(args.input_size), num_classes,
                          backbone_dict)

    # test multi models, to filter out the best model.
    # start_epoch = 10; step = 10
    start_epoch = 160
    step = 5
    ToBeTested = []
    ToBeTested = [
        prefix + f'/RefineDet{args.input_size}_VOC_epoches_{epoch}.pth'
        for epoch in range(start_epoch, 240, step)
    ]
    ToBeTested.append(prefix + f'/RefineDet{args.input_size}_VOC_final.pth')
    # ToBeTested.append(prefix + f'/RefineDet{args.input_size}_VOC_epoches_10.pth')

    ap_stats = {"ap50": [], "epoch": []}
    for index, model_path in enumerate(ToBeTested):
Пример #2
0
def train():
    if args.visdom:
        import visdom
        viz = visdom.Visdom()

    print('Loading the dataset...')
    if args.dataset == 'COCO':
        if args.dataset_root == VOC_ROOT:
            if not os.path.exists(COCOroot):
                parser.error('Must specify dataset_root if specifying dataset')
            print("WARNING: Using default COCO dataset_root because " +
                  "--dataset_root was not specified.")
            args.dataset_root = COCOroot
        cfg = coco_refinedet[args.input_size]
        train_sets = [('train2017')]
        # train_sets = [('train2017', 'val2017')]
        dataset = COCODetection(COCOroot, train_sets,
                                SSDAugmentation(cfg['min_dim'], MEANS))
    elif args.dataset == 'VOC':
        '''if args.dataset_root == COCO_ROOT:
            parser.error('Must specify dataset if specifying dataset_root')'''
        cfg = voc_refinedet[args.input_size]
        dataset = VOCDetection(root=VOC_ROOT,
                               transform=SSDAugmentation(
                                   cfg['min_dim'], MEANS))
    print('Training RefineDet on:', dataset.name)
    print('Using the specified args:')
    print(args)

    refinedet_net = build_refinedet('train', int(args.input_size),
                                    cfg['num_classes'], backbone_dict)
    net = refinedet_net
    print(net)

    device = torch.device('cuda:0' if args.cuda else 'cpu')
    if args.ngpu > 1 and args.cuda:
        net = torch.nn.DataParallel(refinedet_net,
                                    device_ids=list(range(args.ngpu)))
    cudnn.benchmark = True
    net = net.to(device)

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        state_dict = torch.load(args.resume)
        # create new OrderedDict that does not contain `module.`
        from collections import OrderedDict
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            head = k[:7]
            if head == 'module.':
                name = k[7:]  # remove `module.`
            else:
                name = k
            new_state_dict[name] = v
        refinedet_net.load_state_dict(new_state_dict)
    else:
        print('Initializing weights...')
        refinedet_net.init_weights(pretrained=pretrained)

    optimizer = optim.SGD(net.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    arm_criterion = RefineDetMultiBoxLoss(2, 0.5, True, 0, True, negpos_ratio,
                                          0.5, False, args.cuda)
    odm_criterion = RefineDetMultiBoxLoss(cfg['num_classes'],
                                          0.5,
                                          True,
                                          0,
                                          True,
                                          negpos_ratio,
                                          0.5,
                                          False,
                                          args.cuda,
                                          use_ARM=True)
    priorbox = PriorBox(cfg)
    with torch.no_grad():
        priors = priorbox.forward()
        priors = priors.to(device)

    net.train()
    # loss counters
    arm_loc_loss = 0
    arm_conf_loss = 0
    odm_loc_loss = 0
    odm_conf_loss = 0
    epoch = 0 + args.resume_epoch

    epoch_size = math.ceil(len(dataset) / args.batch_size)
    max_iter = args.max_epoch * epoch_size

    stepvalues = (args.max_epoch * 2 // 3 * epoch_size,
                  args.max_epoch * 8 // 9 * epoch_size,
                  args.max_epoch * epoch_size)
    if args.dataset == 'VOC':
        stepvalues = (args.max_epoch * 2 // 3 * epoch_size,
                      args.max_epoch * 5 // 6 * epoch_size,
                      args.max_epoch * epoch_size)
    step_index = 0

    if args.resume_epoch > 0:
        start_iter = args.resume_epoch * epoch_size
        for step in stepvalues:
            if step < start_iter:
                step_index += 1
    else:
        start_iter = 0

    if args.visdom:
        vis_title = 'RefineDet.PyTorch on ' + dataset.name
        vis_legend = ['Loc Loss', 'Conf Loss', 'Total Loss']
        iter_plot = create_vis_plot(viz, 'Iteration', 'Loss', vis_title,
                                    vis_legend)
        epoch_plot = create_vis_plot(viz, 'Epoch', 'Loss', vis_title,
                                     vis_legend)

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)
    for iteration in range(start_iter, max_iter):
        if iteration % epoch_size == 0:
            if args.visdom and iteration != 0:
                update_vis_plot(viz, epoch, arm_loc_loss, arm_conf_loss,
                                epoch_plot, None, 'append', epoch_size)
                # reset epoch loss counters
                arm_loc_loss = 0
                arm_conf_loss = 0
                odm_loc_loss = 0
                odm_conf_loss = 0
            # create batch iterator
            batch_iterator = iter(data_loader)
            if (epoch % 10 == 0 and epoch > 0) or (epoch % 5 == 0 and epoch >
                                                   (args.max_epoch * 2 // 3)):
                torch.save(
                    net.state_dict(),
                    args.save_folder + 'RefineDet' + args.input_size + '_' +
                    args.dataset + '_epoches_' + repr(epoch) + '.pth')
            epoch += 1

        t0 = time.time()
        if iteration in stepvalues:
            step_index += 1
        lr = adjust_learning_rate(optimizer, args.gamma, epoch, step_index,
                                  iteration, epoch_size)

        # load train data
        images, targets = next(batch_iterator)
        images = images.to(device)
        targets = [ann.to(device) for ann in targets]
        # for an in targets:
        #     for instance in an:
        #         for cor in instance[:-1]:
        #             if cor < 0 or cor > 1:
        #                 raise StopIteration

        # forward
        out = net(images)

        # backprop
        optimizer.zero_grad()

        arm_loss_l, arm_loss_c = arm_criterion(out, priors, targets)
        odm_loss_l, odm_loss_c = odm_criterion(out, priors, targets)
        arm_loss = arm_loss_l + arm_loss_c
        odm_loss = odm_loss_l + odm_loss_c
        loss = arm_loss + odm_loss

        loss.backward()
        optimizer.step()

        arm_loc_loss += arm_loss_l.item()
        arm_conf_loss += arm_loss_c.item()
        odm_loc_loss += odm_loss_l.item()
        odm_conf_loss += odm_loss_c.item()
        t1 = time.time()
        batch_time = t1 - t0
        eta = int(batch_time * (max_iter - iteration))
        print('Epoch:{}/{} || Epochiter: {}/{} || Iter: {}/{} || ARM_L Loss: {:.4f} ARM_C Loss: {:.4f} ODM_L Loss: {:.4f} ODM_C Loss: {:.4f} loss: {:.4f} || LR: {:.8f} || Batchtime: {:.4f} s || ETA: {}'.\
            format(epoch, args.max_epoch, (iteration % epoch_size) + 1, epoch_size, iteration + 1, max_iter, arm_loss_l.item(), arm_loss_c.item(), odm_loss_l.item(), odm_loss_c.item(), loss.item(), lr, batch_time, str(datetime.timedelta(seconds=eta))))

        if args.visdom:
            update_vis_plot(viz, iteration, arm_loss_l.item(),
                            arm_loss_c.item(), iter_plot, epoch_plot, 'append')

    torch.save(
        refinedet_net.state_dict(), args.save_folder +
        '/RefineDet{}_{}_final.pth'.format(args.input_size, args.dataset))