def train():
    dataset = voc0712.VOCDetection(root=Config.dataset_root,
                                   transform=augmentations.SSDAugmentation(
                                       Config.image_size, Config.MEANS))
    data_loader = data.DataLoader(dataset,
                                  Config.batch_size,
                                  num_workers=Config.data_load_number_worker,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)

    net = ssd_net_vgg.SSD()
    vgg_weights = torch.load('./weights/vgg16_reducedfc.pth')  #加载预训练模型
    # vgg_weights = torch.load('./weights/final_20200223_VOC_100000.pth')
    net = nn.DataParallel(net)

    # net.apply(weights_init)
    net.vgg.load_state_dict(vgg_weights)
    # net.load_state_dict(vgg_weights)
    # net.apply(weights_init)
    if Config.use_cuda:
        net = torch.nn.DataParallel(net)
        net = net.cuda()
    net.train()
    loss_fun = loss_function.LossFun()
    optimizer = optim.SGD(net.parameters(),
                          lr=Config.lr,
                          momentum=Config.momentum,
                          weight_decay=Config.weight_decacy)
    iter = 0
    step_index = 0
    before_epoch = -1
    for epoch in range(1000):
        for step, (img, target) in enumerate(data_loader):
            if Config.use_cuda:
                img = img.cuda()
                target = [ann.cuda() for ann in target]
            img = torch.Tensor(img)
            loc_pre, conf_pre = net(img)
            priors = utils.default_prior_box()
            optimizer.zero_grad()
            loss_l, loss_c = loss_fun((loc_pre, conf_pre), target, priors)
            loss = loss_l + loss_c
            loss.backward()
            optimizer.step()
            if iter % 1 == 0 or before_epoch != epoch:
                print('epoch : ', epoch, ' iter : ', iter, ' step : ', step,
                      ' loss : ', loss.item())
                before_epoch = epoch
            iter += 1
            if iter in Config.lr_steps:
                step_index += 1
                adjust_learning_rate(optimizer, Config.gamma, step_index)
            if iter % 10000 == 0 and iter != 0:  #每1万次训练保存一个模型
                torch.save(
                    net.state_dict(), 'weights/final_20200226_VOC_' +
                    repr(100000 + iter) + '.pth')
        if iter >= Config.max_iter:
            break
    torch.save(net.state_dict(), 'weights/final_20200223_voc_200000.pth')
Example #2
0
    with open(det_file, 'wb') as f:
        pickle.dump(all_boxes, f, pickle.HIGHEST_PROTOCOL)

    print('Evaluating detections')
    evaluate_detections(all_boxes, output_dir, dataset)


def evaluate_detections(box_list, output_dir, dataset):
    write_voc_results_file(box_list, dataset)
    do_python_eval(output_dir)


if __name__ == '__main__':
    # load net
    num_classes = len(labelmap) + 1  # +1 for background
    net = ssd_net_vgg.SSD()  # initialize SSD
    net = torch.nn.DataParallel(net)
    net = net.cuda()
    net.train(mode=False)
    net.load_state_dict(
        torch.load('./weights/ssd300_VOC_90000.pth',
                   map_location=lambda storage, loc: storage))
    # net.load_state_dict(torch.load(args.trained_model))
    # net.eval()
    print('Finished loading model!')
    # load data
    # test_sets = "./data/sixray/test_1650.txt"
    test_sets = imgsetpath
    dataset = SIXrayDetection(
        Config.dataset_test_root, ['core_500', 'coreless_5000'],
        augmentations.SSDAugmentation(Config.image_size, Config.MEANS))
Example #3
0
def train():
    # , ("core_500","coreless_5000")

    dataset = ml_data.SIXrayDetection(
        Config.dataset_root, ['core_500', 'coreless_5000'],
        augmentations.SSDAugmentation(Config.image_size, Config.MEANS))
    data_loader = data.DataLoader(dataset,
                                  Config.batch_size,
                                  num_workers=Config.data_load_number_worker,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)

    net = ssd_net_vgg.SSD()
    # vgg_weights = torch.load('./weights/vgg16_reducedfc.pth')
    vgg_weights = torch.load('./weights/vgg16_reducedfc.pth')

    net.apply(weights_init)
    net.vgg.load_state_dict(vgg_weights)
    # net.apply(weights_init)
    if Config.use_cuda:
        net = torch.nn.DataParallel(net)
        net = net.cuda()
    net.train()
    loss_fun = loss_function.LossFun()
    optimizer = optim.SGD(net.parameters(),
                          lr=Config.lr,
                          momentum=Config.momentum,
                          weight_decay=Config.weight_decacy)
    iter = 0
    step_index = 0
    before_epoch = -1
    for epoch in range(Config.epoch_num):
        for step, (img, target) in enumerate(data_loader):
            if Config.use_cuda:
                img = img.cuda()
                target = [ann.cuda() for ann in target]
            try:
                img = torch.Tensor(img)
            except TypeError as e:
                print(e)
            loc_pre, conf_pre = net(img)
            priors = utils.default_prior_box()
            optimizer.zero_grad()
            loss_l, loss_c = loss_fun((loc_pre, conf_pre), target, priors)
            loss = loss_l + loss_c
            loss.backward()
            optimizer.step()
            if iter % 1 == 0 or before_epoch != epoch:
                print('epoch : ', epoch, ' iter : ', iter, ' step : ', step,
                      ' loss : ', loss.item())
                before_epoch = epoch
            iter += 1
            if iter in Config.lr_steps:
                step_index += 1
                adjust_learning_rate(optimizer, Config.gamma, step_index)
            if iter % 10000 == 0 and iter != 0:
                torch.save(net.state_dict(),
                           'weights/ssd300_VOC_' + repr(iter) + '.pth')
            if iter >= Config.max_iter:
                break
    torch.save(net.state_dict(), 'weights/core500.pth')