def train():
    ex = wandb.init(project="PQRST-segmentation")
    ex.config.setdefaults(wandb_config)

    logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    net = UNet(in_ch=1, out_ch=4)
    net.to(device)

    try:
        train_model(net=net, device=device, batch_size=wandb.config.batch_size, lr=wandb.config.lr, epochs=wandb.config.epochs)
    except KeyboardInterrupt:
        try:
            save = input("save?(y/n)")
            if save == "y":
                torch.save(net.state_dict(), 'net_params.pkl')
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Example #2
0
def train(fold_idx=1):

    # 1. Load dataset
    dataset_train = ICH_CT_32(
        ROOT=config['dataset_root'],
        transform=T.Compose([T.ToTensor(),
                             T.Normalize([
                                 0.5,
                             ], [
                                 0.5,
                             ])]),
        is_train=True,
        fold_idx=fold_idx)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=config['batch_size'],
                                  shuffle=True,
                                  num_workers=1)

    dataset_eval = ICH_CT_32(ROOT=config['dataset_root'],
                             transform=T.Compose(
                                 [T.ToTensor(),
                                  T.Normalize([
                                      0.5,
                                  ], [
                                      0.5,
                                  ])]),
                             is_train=False,
                             fold_idx=fold_idx)
    dataloader_eval = DataLoader(dataset_eval,
                                 batch_size=config['batch_size'],
                                 shuffle=False,
                                 num_workers=1)

    # 2. Build model
    net = UNet()
    # net.finetune_from('pretrained_weights/vgg16-397923af.pth')
    net = nn.DataParallel(net, device_ids=[0])
    print(net)

    # 3. Criterion
    criterion = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 85.0]))

    # 4. Optimizer
    optimizer = optim.SGD(net.parameters(),
                          lr=config['lr'],
                          momentum=config['momentum'],
                          weight_decay=config['weight_decay'])
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1000, gamma=0.8)

    # 5. Tensorboard logger
    logger_train = Logger(logdir=os.path.join(config['log_folder'],
                                              'fold_{}'.format(fold_idx),
                                              'train'),
                          flush_secs=2)
    logger_eval = Logger(logdir=os.path.join(config['log_folder'],
                                             'fold_{}'.format(fold_idx),
                                             'eval'),
                         flush_secs=2)

    # 6. Train loop
    DSC_MAX, IOU1_MAX, sensitivity_MAX, specificity_MAX = -1.0, -1.0, -1.0, -1.0
    for epoch in range(config['num_epoch']):

        train_op(net, dataloader_train, criterion, optimizer, scheduler, epoch,
                 logger_train)
        DSC, IOU1, sensitivity, specificity = eval_op(net, dataloader_eval,
                                                      criterion, epoch,
                                                      logger_eval)

        torch.save(net.state_dict(),
                   os.path.join(config['save_folder'], 'UNet.newest.pkl'))

        if DSC_MAX <= DSC:
            DSC_MAX = DSC
            torch.save(net.state_dict(),
                       os.path.join(config['save_folder'], 'UNet.pkl'))
        if IOU1_MAX <= IOU1: IOU1_MAX = IOU1
        if sensitivity_MAX <= sensitivity: sensitivity_MAX = sensitivity
        if specificity_MAX <= specificity: specificity_MAX = specificity

    return DSC_MAX, IOU1_MAX, sensitivity_MAX, specificity_MAX, DSC, IOU1, sensitivity, specificity
def main():
    torch.backends.cudnn.benchmark = True
    args = getArgs()
    torch.manual_seed(args.seed)
    args.cuda = torch.cuda.is_available()
    if args.cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    # horovod 初始化
    hvd.init()
    torch.manual_seed(args.seed)
    # 打印一下训练使用的配置
    if hvd.rank() == 0:
        print("Training with configure: ")
        for arg in vars(args):
            print("{}:\t{}".format(arg, getattr(args, arg)))
        if not osp.exists(args.save_model_path):
            os.makedirs(args.save_model_path)
        # 保存训练配置
        with open(osp.join(args.save_model_path, 'train-config.json'),
                  'w') as f:
            json.dump(args.__dict__, f, indent=4)
    # 设置随机种子,保证每个 GPU 上的权重初始化都一样
    if args.cuda:
        # Pin GPU to local rank
        torch.cuda.set_device(hvd.local_rank())
        # 这一句似乎没有用的吧。不过按照 horovod 的回复来说,还是加上好了。
        torch.cuda.manual_seed(args.seed)
    # data
    dataset_train = SpineDataset(root=args.data, transform=my_transform)
    # 分布式训练需要使用这个 sampler
    sampler_train = DistributedSampler(dataset_train,
                                       num_replicas=hvd.size(),
                                       rank=hvd.rank())
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=1,
                                  sampler=sampler_train,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    # model
    if args.network == 'DeepLab':
        if args.voc:
            model = gcv.models.get_deeplab_resnet101_voc(pretrained=True)
        elif args.ade:
            model = gcv.models.get_deeplab_resnet101_ade(pretrained=True)
        else:
            model = gcv.models.DeepLabV3(nclass=args.num_classes,
                                         backbone=args.backbone)
        model.auxlayer.conv5[-1] = nn.Conv2d(256,
                                             args.num_classes,
                                             kernel_size=1)
        model.head.block[-1] = nn.Conv2d(256, args.num_classes, kernel_size=1)
    elif args.network == 'FCN':
        if args.voc:
            model = gcv.models.get_fcn_resnet101_voc(pretrained=True)
        elif args.ade:
            model = gcv.models.get_fcn_resnet101_ade(pretrained=True)
        else:
            model = gcv.models.FCN(nclass=args.num_classes,
                                   backbone=args.backbone)
        model.auxlayer.conv5[-1] = nn.Conv2d(256,
                                             args.num_classes,
                                             kernel_size=1)
        model.head.conv5[-1] = nn.Conv2d(512, args.num_classes, kernel_size=1)
    elif args.network == 'PSPNet':
        if args.voc:
            model = gcv.models.get_psp_resnet101_voc(pretrained=True)
        elif args.ade:
            model = gcv.models.get_psp_resnet101_ade(pretrained=True)
        else:
            model = gcv.models.PSP(nclass=args.num_classes,
                                   backbone=args.backbone)
        model.auxlayer.conv5[-1] = nn.Conv2d(256, 2, kernel_size=1)
        model.head.conv5[-1] = nn.Conv2d(512, args.num_classes, kernel_size=1)
    elif args.network == 'UNet':
        model = UNet(n_class=args.num_classes,
                     backbone=args.backbone,
                     pretrained=True)
    model = convert_syncbn_model(model)
    model = model.to(device)

    # optimizer 要用 hvd 的版本包一下
    # optimizer = torch.optim.Adam(model.parameters(), args.learning_rate * hvd.size())
    # 不同层使用不同的学习率
    if args.network == 'UNet':
        optimizer = torch.optim.SGD([
            {
                'params': model.down_blocks.parameters(),
                'lr': args.learning_rate * 0.5
            },
            {
                'params': model.bridge.parameters()
            },
            {
                'params': model.head.parameters()
            },
        ],
                                    lr=args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=0.0001)
    elif args.network in ['FCN', 'PSPNet', 'DeepLab']:
        optimizer = optim.SGD([{
            'params': model.pretrained.parameters(),
            'lr': args.learning_rate * 0.5
        }, {
            'params': model.auxlayer.parameters()
        }, {
            'params': model.head.parameters()
        }],
                              lr=args.learning_rate,
                              momentum=0.9,
                              weight_decay=0.0001)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=0.9,
                              weight_decay=0.0001)
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())
    # 将模型和优化器的参数广播到各个 GPU 上
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # lr scheduler
    def poly_lr_scheduler(epoch, num_epochs=args.num_epochs, power=args.power):
        return (1 - epoch / num_epochs)**power

    lr_scheduler = LambdaLR(optimizer=optimizer, lr_lambda=poly_lr_scheduler)

    def train(epoch):
        model.train()
        # Horovod: set epoch to sampler for shuffling.
        sampler_train.set_epoch(epoch)
        lr_scheduler.step()
        loss_fn = nn.CrossEntropyLoss()
        for batch_idx, (data, target) in enumerate(dataloader_train):
            data = data.to(device).squeeze()
            target = target.to(device).squeeze()
            for batch_data, batch_target in zip(
                    torch.split(data, args.batch_size),
                    torch.split(target, args.batch_size)):
                optimizer.zero_grad()
                output = model(batch_data)
                if args.network in ['FCN', 'PSPNet', 'DeepLab']:
                    loss = loss_fn(output[0], batch_target) \
                           + 0.2*loss_fn(output[1], batch_target)
                elif args.network == 'UNet':
                    loss = loss_fn(output, batch_target)
                loss.backward()
                optimizer.step()
            if hvd.rank() == 0 and batch_idx % args.log_interval == 0:
                print("Train loss: ", loss.item())

    for epoch in range(args.num_epochs):
        train(epoch)
        if hvd.rank() == 0:
            print("Saving model to {}".format(
                osp.join(args.save_model_path,
                         "checkpoint-{:0>3d}.pth".format(epoch))))
            torch.save({'state_dict': model.state_dict()},
                       osp.join(args.save_model_path,
                                "checkpoint-{:0>3d}.pth".format(epoch)))