Exemplo n.º 1
0
def train(**kwargs):
    """训练过程"""
    # 加载配置文件中的各种参数设置
    OPT._parse(kwargs)

    # 数据集
    dataset = Dataset(opt=OPT)
    print("加载数据集")
    dataloader = DataLoader(dataset=dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=OPT.num_workers)
    # 测试集
    testset = TestDataset(opt=OPT)
    test_dataloader = DataLoader(dataset=testset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=OPT.num_workers,
                                 pin_memory=True)
    # 模型
    faster_rcnn = FasterRCNNVGG16()
    print("模型加载完成")
    trainer = FasterRCNNTrainer(faster_rcnn).cuda()

    best_map = 0  # 最好的map
    lr_ = OPT.lr  # 学习率
    for epoch in range(OPT.epoch):
        print("Epoch: %s/%s" % (epoch, OPT.epoch - 1))
        print("-" * 10)
        trainer.reset_meters()  # 每次epoch的开始将损失函数清零
        for ii, (img, bbox_, label_,
                 scale) in pb.progressbar(enumerate(dataloader),
                                          max_value=len(dataloader)):
            scale = scalar(scale)  # 原图和处理后的图片之间的一个缩放比例
            img, bbox, label = img.cuda(), bbox_.cuda(), label_.cuda()
            trainer.train_step(imgs=img,
                               bboxes=bbox,
                               labels=label,
                               scale=scale)
        print("train:", trainer.get_meter_data())
        # if (ii + 1) % OPT.plot_every == 0:
        #     print(trainer.get_meter_data())
        trainer.eval()
        for jj, (img, size, _, bbox, label,
                 _) in pb.progressbar(enumerate(test_dataloader),
                                      max_value=len(test_dataloader)):
            img, bbox, label = img.cuda(), bbox.cuda(), label.cuda()
            trainer.val_step(img, size, bbox, label)
        print("val:", trainer.get_meter_data())
        eval_result = evaluate(dataloader=test_dataloader,
                               faster_rcnn=faster_rcnn,
                               test_num=OPT.test_num)
        print("mAP: %.4f" % eval_result["mAP"])
        print()
        trainer.train()
Exemplo n.º 2
0
def train(**kwargs):
    opt._parse(kwargs)

    image_folder_path = 'DataSets/images/'
    cvs_file_path = 'DataSets/labels.csv'

    dataset = DataSets(cvs_file_path, image_folder_path)
    data_size = len(dataset)
    indices = list(range(data_size))
    split = int(np.floor(data_size * 0.2))
    np.random.seed(42)
    np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]
    train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
    valid_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=1,
                                               sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=1,
                                             sampler=valid_sampler)
    print('load data')

    avg_loss = AverageValueMeter()
    ma20_loss = MovingAverageValueMeter(windowsize=20)
    faster_rcnn = FasterRCNNVGG16()
    print('model construct completed')
    start_epoch = 0
    best_map = -100
    trainer = FasterRCNNTrainer(faster_rcnn).cuda()
    optimizer = optim.SGD(trainer.faster_rcnn.parameters(),
                          lr=opt.lr,
                          momentum=0.9)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

    if opt.load_path:
        print('load pretrained model from %s' % opt.load_path)
        checkpoint = torch.load(opt.load_path)
        start_epoch = checkpoint['epoch']
        best_map = checkpoint['best_map']
        trainer.faster_rcnn.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        print("> Loaded checkpoint '{}' (epoch {})".format(
            args.resume, start_epoch))

    #trainer.vis.text(dataset.db.label_names, win='labels')

# set tensor-board for visualization
    writer = SummaryWriter('runs/' + opt.log_root)

    for epoch in range(start_epoch, opt.epoch):
        trainer.train(mode=True)  #must set as that in tranning
        for ii, (img, _, _, bbox_, label_, scale,
                 _) in enumerate(train_loader):
            scale = at.scalar(scale)
            img, bbox, label = img.cuda().float(), bbox_.cuda(), label_.cuda()
            optimizer.zero_grad()
            loss = trainer.forward(img, bbox, label, scale)
            loss.total_loss.backward()
            optimizer.step()
            #print(loss)
            #print(loss.total_loss)
            loss_value = loss.total_loss.cpu().data.numpy()
            avg_loss.add(float(loss_value))
            ma20_loss.add(float(loss_value))
            print(
                '[epoch:{}/{}]  [batch:{}/{}]  [sample_loss:{:.4f}] [avg_loss:{:.4f}]  [ma20_loss:{:.4f}]'
                .format(epoch, opt.epoch, ii + 1, len(train_loader),
                        loss.total_loss.data,
                        avg_loss.value()[0],
                        ma20_loss.value()[0]))

            if (ii + 1) % opt.plot_every == 0:
                niter = epoch * len(train_loader) + ii
                writer.add_scalar('Train/Loss', ma20_loss.value()[0], niter)

        eval_result = eval(val_loader, faster_rcnn, test_num=opt.test_num)
        print(eval_result['map'])

        if eval_result['map'] > best_map:
            best_map = eval_result['map']
            state = {
                "epoch": epoch + 1,
                "best_map": best_map,
                "model_state": trainer.faster_rcnn.state_dict(),
                "optimizer_state": optimizer.state_dict()
            }
            torch.save(state, opt.model_para)
        scheduler.step()
    state = {
        "epoch": epoch + 1,
        "best_map": best_map,
        "model_state": trainer.faster_rcnn.state_dict(),
        "optimizer_state": optimizer.state_dict()
    }
    torch.save(state, 'last_epoch.pkl')
    writer.close()