예제 #1
0
파일: test.py 프로젝트: Oneflow-Inc/models
def main(args):
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)

    net = UNet(n_channels=3, n_classes=1)

    checkpoint = flow.load(args.pretrained_path)
    net.load_state_dict(checkpoint)

    net.to("cuda")

    x_test_dir, y_test_dir = get_datadir_path(args, split="test")

    test_dataset = Dataset(
        x_test_dir, y_test_dir, augmentation=get_test_augmentation(),
    )

    print("Begin Testing...")
    for i, (image, mask) in enumerate(tqdm(test_dataset)):
        show_image = image
        with flow.no_grad():
            image = image / 255.0
            image = image.astype(np.float32)
            image = flow.tensor(image, dtype=flow.float32)
            image = image.permute(2, 0, 1)
            image = image.to("cuda")

            pred = net(image.unsqueeze(0).to("cuda"))
            pred = pred.numpy()
            pred = pred > 0.5
        save_picture_name = os.path.join(args.save_path, "test_image_" + str(i))
        visualize(
            save_picture_name, image=show_image, GT=mask[0, :, :], Pred=pred[0, 0, :, :]
        )
예제 #2
0
def train():
    ex = wandb.init(project="PQRST-segmentation")
    ex.config.setdefaults(wandb_config)

    logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    logging.info(f'Using device {device}')

    net = UNet(in_ch=1, out_ch=4)
    net.to(device)

    try:
        train_model(net=net, device=device, batch_size=wandb.config.batch_size, lr=wandb.config.lr, epochs=wandb.config.epochs)
    except KeyboardInterrupt:
        try:
            save = input("save?(y/n)")
            if save == "y":
                torch.save(net.state_dict(), 'net_params.pkl')
            sys.exit(0)
        except SystemExit:
            os._exit(0)
예제 #3
0
def main(config):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    train_dataset, valid_dataset = generate_datasets(
        config['data_dir'], valid_ids=config['val_ids'])
    # TODO: define and add data augmentation + image normalization
    # train_dataset.transform = train_transform
    # valid_dataset.transform = valid_transform
    transforms = A.Compose([
        A.Normalize(),  # TODO: change values
        ToTensorV2()
    ])
    train_dataset.transform = transforms
    valid_dataset.transform = transforms

    train_loader = DataLoader(train_dataset,
                              batch_size=config['batch_size'],
                              shuffle=True,
                              num_workers=config['num_workers'])
    valid_loader = DataLoader(valid_dataset,
                              config['batch_size'],
                              shuffle=False,
                              num_workers=config['num_workers'])
    model = UNet()
    model = model.to(device)

    criterion = config['criterion']
    optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-4)

    trainer = Trainer(model=model,
                      criterion=criterion,
                      optimizer=optimizer,
                      config=config,
                      train_loader=train_loader,
                      val_loader=valid_loader,
                      device=device)
    trainer.train()

    return model
예제 #4
0
def main():
    # 参数
    args = get_args()
    if not osp.exists(args.result_dir):
        os.makedirs(args.result_dir)
    print("Evaluating configuration:")
    for arg in vars(args):
        print("{}:\t{}".format(arg, getattr(args, arg)))
    with open('eval-config.json', 'w') as f:
        json.dump(args.__dict__, f, indent=4)
    # 数据
    if args.test:
        dataset = SpineDataset(root=args.root,
                               split='test',
                               transform=test_transform)
    else:
        dataset = SpineDataset(root=args.root,
                               split='val',
                               transform=val_transform)

    dataloader = DataLoader(dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=4,
                            pin_memory=True)

    # 模型
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    if args.network == 'DeepLab':
        model = gcv.models.DeepLabV3(nclass=args.num_classes,
                                     backbone=args.backbone)
    elif args.network == 'FCN':
        model = gcv.models.FCN(nclass=args.num_classes, backbone=args.backbone)
    elif args.network == 'PSPNet':
        model = gcv.models.PSP(nclass=args.num_classes, backbone=args.backbone)
    elif args.network == 'UNet':
        model = UNet(n_class=args.num_classes, backbone=args.backbone)
    print('load model from {} ...'.format(args.model))
    model.load_state_dict(
        torch.load(args.model, map_location='cpu')['state_dict'])
    model = model.to(device)
    print('Done!')

    # 测试
    def eval():
        with torch.no_grad():
            model.eval()
            result = []
            tq = tqdm.tqdm(total=len(dataloader))
            if args.test:
                tq.set_description('test')
                for i, (data, img_file) in enumerate(dataloader):
                    tq.update(1)
                    data = data.to(device)
                    predict = np.zeros(
                        (data.size()[1], data.size()[3], data.size()[4]),
                        dtype=np.uint16)
                    for idx in range(data.size()[1]):
                        if args.network in ['DeepLab', 'FCN', 'PSPNet']:
                            final_out = model(data[:, idx])[0]
                        elif args.network == 'UNet':
                            final_out = model(data[:, idx])
                        predict[idx] = final_out.argmax(
                            dim=1).cpu().squeeze().numpy().astype(np.uint16)
                    pred_img = sitk.GetImageFromArray(predict)
                    test_img = sitk.ReadImage(
                        osp.join(args.root, 'test', 'image', img_file[0]))
                    pred_img.CopyInformation(test_img)
                    result_file = 'mask_' + img_file[0].lower()
                    sitk.WriteImage(pred_img,
                                    osp.join(args.result_dir, result_file))
            else:
                tq.set_description('val')
                for i, (data, mask, mask_file) in enumerate(dataloader):
                    tq.update(1)
                    gt_img = sitk.ReadImage(
                        osp.join(args.root, 'val', 'groundtruth',
                                 mask_file[0]))
                    data = data.to(device)
                    predict = np.zeros(
                        (data.size()[1], data.size()[3], data.size()[4]),
                        dtype=np.uint16)
                    for idx in range(data.size()[1]):
                        if args.network in ['DeepLab', 'FCN', 'PSPNet']:
                            final_out = model(data[:, idx])[0]
                        elif args.network == 'UNet':
                            final_out = model(data[:, idx])
                        predict[idx] = final_out.argmax(
                            dim=1).cpu().squeeze().numpy().astype(np.uint16)
                    pred_img = sitk.GetImageFromArray(predict)
                    pred_img.CopyInformation(gt_img)
                    sitk.WriteImage(pred_img,
                                    osp.join(args.result_dir, mask_file[0]))
                    ppv, sensitivity, dice, _ = metrics.precision_recall_fscore_support(
                        mask.numpy().flatten(),
                        predict.flatten(),
                        average='binary')
                    result.append([dice, ppv, sensitivity])
                result = np.array(result)
                result_mean = result.mean(axis=0)
                result_std = result.std(axis=0)
                print(result_mean, result_std)
                np.savetxt(osp.join(args.result_dir, 'result.txt'),
                           result_mean,
                           fmt='%.3f',
                           header='Dice, Sensitivity, PPV')

            tq.close()

    eval()
def main():
    torch.backends.cudnn.benchmark = True
    args = getArgs()
    torch.manual_seed(args.seed)
    args.cuda = torch.cuda.is_available()
    if args.cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    # horovod 初始化
    hvd.init()
    torch.manual_seed(args.seed)
    # 打印一下训练使用的配置
    if hvd.rank() == 0:
        print("Training with configure: ")
        for arg in vars(args):
            print("{}:\t{}".format(arg, getattr(args, arg)))
        if not osp.exists(args.save_model_path):
            os.makedirs(args.save_model_path)
        # 保存训练配置
        with open(osp.join(args.save_model_path, 'train-config.json'),
                  'w') as f:
            json.dump(args.__dict__, f, indent=4)
    # 设置随机种子,保证每个 GPU 上的权重初始化都一样
    if args.cuda:
        # Pin GPU to local rank
        torch.cuda.set_device(hvd.local_rank())
        # 这一句似乎没有用的吧。不过按照 horovod 的回复来说,还是加上好了。
        torch.cuda.manual_seed(args.seed)
    # data
    dataset_train = SpineDataset(root=args.data, transform=my_transform)
    # 分布式训练需要使用这个 sampler
    sampler_train = DistributedSampler(dataset_train,
                                       num_replicas=hvd.size(),
                                       rank=hvd.rank())
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=1,
                                  sampler=sampler_train,
                                  num_workers=args.num_workers,
                                  pin_memory=True)
    # model
    if args.network == 'DeepLab':
        if args.voc:
            model = gcv.models.get_deeplab_resnet101_voc(pretrained=True)
        elif args.ade:
            model = gcv.models.get_deeplab_resnet101_ade(pretrained=True)
        else:
            model = gcv.models.DeepLabV3(nclass=args.num_classes,
                                         backbone=args.backbone)
        model.auxlayer.conv5[-1] = nn.Conv2d(256,
                                             args.num_classes,
                                             kernel_size=1)
        model.head.block[-1] = nn.Conv2d(256, args.num_classes, kernel_size=1)
    elif args.network == 'FCN':
        if args.voc:
            model = gcv.models.get_fcn_resnet101_voc(pretrained=True)
        elif args.ade:
            model = gcv.models.get_fcn_resnet101_ade(pretrained=True)
        else:
            model = gcv.models.FCN(nclass=args.num_classes,
                                   backbone=args.backbone)
        model.auxlayer.conv5[-1] = nn.Conv2d(256,
                                             args.num_classes,
                                             kernel_size=1)
        model.head.conv5[-1] = nn.Conv2d(512, args.num_classes, kernel_size=1)
    elif args.network == 'PSPNet':
        if args.voc:
            model = gcv.models.get_psp_resnet101_voc(pretrained=True)
        elif args.ade:
            model = gcv.models.get_psp_resnet101_ade(pretrained=True)
        else:
            model = gcv.models.PSP(nclass=args.num_classes,
                                   backbone=args.backbone)
        model.auxlayer.conv5[-1] = nn.Conv2d(256, 2, kernel_size=1)
        model.head.conv5[-1] = nn.Conv2d(512, args.num_classes, kernel_size=1)
    elif args.network == 'UNet':
        model = UNet(n_class=args.num_classes,
                     backbone=args.backbone,
                     pretrained=True)
    model = convert_syncbn_model(model)
    model = model.to(device)

    # optimizer 要用 hvd 的版本包一下
    # optimizer = torch.optim.Adam(model.parameters(), args.learning_rate * hvd.size())
    # 不同层使用不同的学习率
    if args.network == 'UNet':
        optimizer = torch.optim.SGD([
            {
                'params': model.down_blocks.parameters(),
                'lr': args.learning_rate * 0.5
            },
            {
                'params': model.bridge.parameters()
            },
            {
                'params': model.head.parameters()
            },
        ],
                                    lr=args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=0.0001)
    elif args.network in ['FCN', 'PSPNet', 'DeepLab']:
        optimizer = optim.SGD([{
            'params': model.pretrained.parameters(),
            'lr': args.learning_rate * 0.5
        }, {
            'params': model.auxlayer.parameters()
        }, {
            'params': model.head.parameters()
        }],
                              lr=args.learning_rate,
                              momentum=0.9,
                              weight_decay=0.0001)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=args.learning_rate,
                              momentum=0.9,
                              weight_decay=0.0001)
    optimizer = hvd.DistributedOptimizer(
        optimizer, named_parameters=model.named_parameters())
    # 将模型和优化器的参数广播到各个 GPU 上
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)

    # lr scheduler
    def poly_lr_scheduler(epoch, num_epochs=args.num_epochs, power=args.power):
        return (1 - epoch / num_epochs)**power

    lr_scheduler = LambdaLR(optimizer=optimizer, lr_lambda=poly_lr_scheduler)

    def train(epoch):
        model.train()
        # Horovod: set epoch to sampler for shuffling.
        sampler_train.set_epoch(epoch)
        lr_scheduler.step()
        loss_fn = nn.CrossEntropyLoss()
        for batch_idx, (data, target) in enumerate(dataloader_train):
            data = data.to(device).squeeze()
            target = target.to(device).squeeze()
            for batch_data, batch_target in zip(
                    torch.split(data, args.batch_size),
                    torch.split(target, args.batch_size)):
                optimizer.zero_grad()
                output = model(batch_data)
                if args.network in ['FCN', 'PSPNet', 'DeepLab']:
                    loss = loss_fn(output[0], batch_target) \
                           + 0.2*loss_fn(output[1], batch_target)
                elif args.network == 'UNet':
                    loss = loss_fn(output, batch_target)
                loss.backward()
                optimizer.step()
            if hvd.rank() == 0 and batch_idx % args.log_interval == 0:
                print("Train loss: ", loss.item())

    for epoch in range(args.num_epochs):
        train(epoch)
        if hvd.rank() == 0:
            print("Saving model to {}".format(
                osp.join(args.save_model_path,
                         "checkpoint-{:0>3d}.pth".format(epoch))))
            torch.save({'state_dict': model.state_dict()},
                       osp.join(args.save_model_path,
                                "checkpoint-{:0>3d}.pth".format(epoch)))