def train():
    dataset = voc0712.VOCDetection(root=Config.dataset_root,
                                   transform=augmentations.SSDAugmentation(
                                       Config.image_size, Config.MEANS))
    data_loader = data.DataLoader(dataset,
                                  Config.batch_size,
                                  num_workers=Config.data_load_number_worker,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)

    net = ssd_net_vgg.SSD()
    vgg_weights = torch.load('./weights/vgg16_reducedfc.pth')  #加载预训练模型
    # vgg_weights = torch.load('./weights/final_20200223_VOC_100000.pth')
    net = nn.DataParallel(net)

    # net.apply(weights_init)
    net.vgg.load_state_dict(vgg_weights)
    # net.load_state_dict(vgg_weights)
    # net.apply(weights_init)
    if Config.use_cuda:
        net = torch.nn.DataParallel(net)
        net = net.cuda()
    net.train()
    loss_fun = loss_function.LossFun()
    optimizer = optim.SGD(net.parameters(),
                          lr=Config.lr,
                          momentum=Config.momentum,
                          weight_decay=Config.weight_decacy)
    iter = 0
    step_index = 0
    before_epoch = -1
    for epoch in range(1000):
        for step, (img, target) in enumerate(data_loader):
            if Config.use_cuda:
                img = img.cuda()
                target = [ann.cuda() for ann in target]
            img = torch.Tensor(img)
            loc_pre, conf_pre = net(img)
            priors = utils.default_prior_box()
            optimizer.zero_grad()
            loss_l, loss_c = loss_fun((loc_pre, conf_pre), target, priors)
            loss = loss_l + loss_c
            loss.backward()
            optimizer.step()
            if iter % 1 == 0 or before_epoch != epoch:
                print('epoch : ', epoch, ' iter : ', iter, ' step : ', step,
                      ' loss : ', loss.item())
                before_epoch = epoch
            iter += 1
            if iter in Config.lr_steps:
                step_index += 1
                adjust_learning_rate(optimizer, Config.gamma, step_index)
            if iter % 10000 == 0 and iter != 0:  #每1万次训练保存一个模型
                torch.save(
                    net.state_dict(), 'weights/final_20200226_VOC_' +
                    repr(100000 + iter) + '.pth')
        if iter >= Config.max_iter:
            break
    torch.save(net.state_dict(), 'weights/final_20200223_voc_200000.pth')
Beispiel #2
0
def train():
    # , ("core_500","coreless_5000")

    dataset = ml_data.SIXrayDetection(
        Config.dataset_root, ['core_500', 'coreless_5000'],
        augmentations.SSDAugmentation(Config.image_size, Config.MEANS))
    data_loader = data.DataLoader(dataset,
                                  Config.batch_size,
                                  num_workers=Config.data_load_number_worker,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)

    net = ssd_net_vgg.SSD()
    # vgg_weights = torch.load('./weights/vgg16_reducedfc.pth')
    vgg_weights = torch.load('./weights/vgg16_reducedfc.pth')

    net.apply(weights_init)
    net.vgg.load_state_dict(vgg_weights)
    # net.apply(weights_init)
    if Config.use_cuda:
        net = torch.nn.DataParallel(net)
        net = net.cuda()
    net.train()
    loss_fun = loss_function.LossFun()
    optimizer = optim.SGD(net.parameters(),
                          lr=Config.lr,
                          momentum=Config.momentum,
                          weight_decay=Config.weight_decacy)
    iter = 0
    step_index = 0
    before_epoch = -1
    for epoch in range(Config.epoch_num):
        for step, (img, target) in enumerate(data_loader):
            if Config.use_cuda:
                img = img.cuda()
                target = [ann.cuda() for ann in target]
            try:
                img = torch.Tensor(img)
            except TypeError as e:
                print(e)
            loc_pre, conf_pre = net(img)
            priors = utils.default_prior_box()
            optimizer.zero_grad()
            loss_l, loss_c = loss_fun((loc_pre, conf_pre), target, priors)
            loss = loss_l + loss_c
            loss.backward()
            optimizer.step()
            if iter % 1 == 0 or before_epoch != epoch:
                print('epoch : ', epoch, ' iter : ', iter, ' step : ', step,
                      ' loss : ', loss.item())
                before_epoch = epoch
            iter += 1
            if iter in Config.lr_steps:
                step_index += 1
                adjust_learning_rate(optimizer, Config.gamma, step_index)
            if iter % 10000 == 0 and iter != 0:
                torch.save(net.state_dict(),
                           'weights/ssd300_VOC_' + repr(iter) + '.pth')
            if iter >= Config.max_iter:
                break
    torch.save(net.state_dict(), 'weights/core500.pth')