示例#1
0
 def __init__(self, path, img_size, batch_size, debug=False):
     dataset = ListDataset(path,
                           img_size=img_size,
                           augment=False,
                           multiscale=False)
     if debug:
         dataset.img_files = dataset.img_files[:batch_size]
     self.dataloader = torch.utils.data.DataLoader(
         dataset,
         batch_size=batch_size,
         shuffle=False,
         num_workers=0,
         collate_fn=dataset.collate_fn)
def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--epochs",
                        type=int,
                        default=10000,
                        help="number of epochs")
    parser.add_argument("--batch_size",
                        type=int,
                        default=20,
                        help="size of each image batch")
    parser.add_argument("--data_config",
                        type=str,
                        default="config/adc.data",
                        help="path to data config file")
    # parser.add_argument("--pretrained_weights", type=str, default="config/yolov3_ckpt_5.pth")  # models/model1/yolov3_ckpt_73.pth
    parser.add_argument("--pretrained_weights",
                        type=str)  # models/model1/yolov3_ckpt_73.pth
    parser.add_argument(
        "--n_cpu",
        type=int,
        default=0,
        help="number of cpu threads to use during batch generation")
    parser.add_argument("--img_size",
                        type=int,
                        default=[768, 1024],
                        help="size of each image dimension")
    parser.add_argument("--evaluation_interval",
                        type=int,
                        default=1,
                        help="interval evaluations on validation set")
    parser.add_argument("--multiscale",
                        default='False',
                        choices=['True', 'False'])
    parser.add_argument("--augment",
                        default='False',
                        choices=['True', 'False'])
    parser.add_argument("--save_path",
                        type=str,
                        default='models/weights_1350_0102',
                        help="save model path")
    parser.add_argument("--debug",
                        type=str,
                        default='False',
                        choices=['True', 'False'],
                        help="debug")
    parser.add_argument("--lr", type=float, default=0.01, help="learning rate")
    args = parser.parse_args(argv)

    args.debug = True if args.debug == 'True' else False
    args.multiscale = True if args.multiscale == 'True' else False
    args.augment = True if args.augment == 'True' else False
    print_args(args)

    print(
        datetime.datetime.strftime(datetime.datetime.now(),
                                   '%Y-%m-%d %H:%M:%S'))

    if args.debug:
        print('debug...')
        import shutil
        # if os.path.exists(args.save_path):
        #     shutil.rmtree(args.save_path)
        args.evaluation_interval = 1
        # debug模式下先删除save_path,并每间隔一轮验证一次

    # assert not os.path.exists(args.save_path)
    # os.makedirs(args.save_path)

    # adc.dat下有train和valid两个dat还有anchor.txt的路径
    data_config = parse_data_config(args.data_config)
    train_path = data_config["train"]
    valid_path = data_config["valid"]
    if args.debug:
        valid_path = train_path
    anchors = get_anchors(data_config['anchors']).to('cuda')

    model = ResNet(anchors).to('cuda')
    if args.pretrained_weights:
        print('pretrained weights: ', args.pretrained_weights)
        model.load_pretrained_weights(args.pretrained_weights)

    dataset = ListDataset(train_path,
                          img_size=args.img_size,
                          augment=args.augment,
                          multiscale=args.multiscale)
    eval = evaluate(path=valid_path,
                    img_size=args.img_size,
                    batch_size=args.batch_size,
                    debug=args.debug)

    if args.debug:
        dataset.img_files = dataset.img_files[:10 * args.batch_size]
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.n_cpu,
        collate_fn=dataset.collate_fn,
    )
    print('Number train sample: ', len(dataset))
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=5e-5)
    # 这里优化器和学习率是不是要调节?
    print('\n### train ...')
    for epoch in range(args.epochs):
        model.train()

        lr = max(1e-10, args.lr * (0.95**epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        for batch_i, (imgs, targets, _) in enumerate(dataloader):
            imgs = Variable(imgs.to('cuda'))
            # 训练集有经过augment_sequential,而验证集没有
            # targets=([[0.0000, 0.7328, 0.2808, 0.0934, 0.0808],
            #         [1.0000, 0.5255, 0.5466, 0.0596, 0.1587],
            #         [1.0000, 0.5585, 0.8077, 0.0553, 0.2250],
            #         [3.0000, 0.4519, 0.4351, 0.1365, 0.2048]], device='cuda:0')
            targets = Variable(targets.to('cuda'), requires_grad=False)

            yolo_map, _ = model(imgs)
            #  yolo_map.shape : [4,]  其中每个yolo_map的格式如下: batch,featuremap_h,featuremap_w,anchor_num,(x,y,w,h,conf)
            loss, metrics = model.loss(yolo_map, targets)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if (batch_i + 1) % 100 == 0 or (batch_i + 1) == len(dataloader):
                time_str = datetime.datetime.strftime(datetime.datetime.now(),
                                                      '%Y-%m-%d %H:%M:%S')
                lr = optimizer.param_groups[0]['lr']
                loss = metrics["loss"]
                xy = metrics["xy"]
                wh = metrics["wh"]
                conf = metrics["conf"]
                loss_str = 'loss: {:<8.2f}'.format(loss)
                loss_str += 'xy: {:<8.2f}'.format(xy)
                loss_str += 'wh: {:<8.2f}'.format(wh)
                loss_str += 'conf: {:<8.2f}'.format(conf)
                epoch_str = 'Epoch: {:4}({:4}/{:4})'.format(
                    epoch, batch_i + 1, len(dataloader))
                print('[{}]{} {} lr:{}'.format(time_str, epoch_str, loss_str,
                                               lr))
        print()
        if epoch % args.evaluation_interval == 0:
            print("\n---- Evaluating Model ----")
            save_model_epoch = 'yolov3_ckpt_{}.pth'.format(epoch)
            model.save_weights(os.path.join(args.save_path, save_model_epoch))
            print(save_model_epoch)
            example_save_path = args.save_path
            for conf in [0.1, 0.3, 0.5, 0.7]:
                metrics = eval(model,
                               iou_thres=0.5,
                               conf_thres=conf,
                               nms_thres=0.5,
                               save_path=example_save_path)
                example_save_path = None
                print(
                    'image_acc: {}\t{}\tbbox_acc: {}\tbbox_recall: {}'.format(
                        *metrics[1:]))

                names = ['image', 'ture', 'det', 'box_acc', 'image_acc']
                print('{:<10}{:<10}{:<10}{:<10}{:<10}'.format(*names))
                print('{:<10}{:<10}{:<10}{:<10}{:<10}'.format(*metrics[0][0]))
                print('{:<10}{:<10}{:<10}{:<10}{:<10}'.format(*metrics[0][1]))
                print()