Пример #1
0
def test_balance_by_size_param_scale():
    class Tradeoff(nn.Module):
        def __init__(self, param_size, latent_size):
            super().__init__()
            self.fc = nn.Linear(param_size, param_size)
            self.latent_size = latent_size

        def forward(self, x):
            for i in range(self.latent_size):
                x = x + torch.rand_like(x, requires_grad=True)
            return x

    model = nn.Sequential(
        Tradeoff(param_size=1, latent_size=6),
        Tradeoff(param_size=2, latent_size=5),
        Tradeoff(param_size=3, latent_size=4),
        Tradeoff(param_size=4, latent_size=3),
        Tradeoff(param_size=5, latent_size=2),
        Tradeoff(param_size=6, latent_size=1),
    )

    sample = torch.rand(1, requires_grad=True)

    balance = balance_by_size(2, model, sample, param_scale=0)
    assert balance == [2, 4]

    balance = balance_by_size(2, model, sample, param_scale=100)
    assert balance == [4, 2]
Пример #2
0
def test_balance_by_size_param():
    model = nn.Sequential(*[nn.Linear(i+1, i+2) for i in range(6)])
    sample = torch.rand(7, 1)
    balance = balance_by_size(2, model, sample, param_scale=100)
    assert balance == [4, 2]

    model = nn.Sequential(*[nn.Linear(i+2, i+1) for i in reversed(range(6))])
    sample = torch.rand(1, 7)
    balance = balance_by_size(2, model, sample, param_scale=100)
    assert balance == [2, 4]
Пример #3
0
def test_balance_by_size_tuple():
    class Twin(nn.Module):
        def forward(self, x):
            return x, x.detach()

    class Add(nn.Module):
        def forward(self, a_b):
            a, b = a_b
            return a + b

    model = nn.Sequential(Twin(), Add())
    sample = torch.rand(1, requires_grad=True)
    balance_by_size(1, model, sample)
Пример #4
0
def test_balance_by_size_latent():
    class Expand(nn.Module):
        def __init__(self, times):
            super().__init__()
            self.times = times

        def forward(self, x):
            for i in range(self.times):
                x = x + torch.rand_like(x, requires_grad=True)
            return x

    sample = torch.rand(10, 100, 100)

    model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]])
    balance = balance_by_size(2, model, sample)
    assert balance == [4, 2]

    model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]])
    balance = balance_by_size(2, model, sample)
    assert balance == [2, 4]
Пример #5
0
def train(n_feat,
          crop_size,
          bs,
          ep,
          optimizer="rmsprop",
          lr=5e-4,
          pretrain=None):
    model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_"
    print(f"save the best model as '{model_name}' during training.")

    crop_size = [int(cz) for cz in crop_size.split(",")]
    print(f"input image crop_size: {crop_size}")

    # starting training set loader
    train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES)
    if np.any([cz == -1 for cz in crop_size]):  # using full image
        train_transform = Compose([
            AddChannelDict(keys="image"),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform)
        # when bs > 1, the loader assumes that the full image sizes are the same across the dataset
        train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                       num_workers=4,
                                                       batch_size=bs,
                                                       shuffle=True)
    else:
        # draw balanced foreground/background window samples according to the ground truth label
        train_transform = Compose([
            AddChannelDict(keys="image"),
            SpatialPadd(
                keys=("image", "label"),
                spatial_size=crop_size),  # ensure image size >= crop_size
            RandCropByPosNegLabeld(keys=("image", "label"),
                                   label_key="label",
                                   spatial_size=crop_size,
                                   num_samples=bs),
            Rand3DElasticd(
                keys=("image", "label"),
                spatial_size=crop_size,
                sigma_range=(10, 50),  # 30
                magnitude_range=(600, 1200),  # 1000
                prob=0.8,
                rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12),
                shear_range=(np.pi / 18, np.pi / 18, np.pi / 18),
                translate_range=tuple(sz * 0.05 for sz in crop_size),
                scale_range=(0.2, 0.2, 0.2),
                mode=("bilinear", "nearest"),
                padding_mode=("border", "zeros"),
            ),
        ])
        train_dataset = Dataset(train_images, transform=train_transform
                                )  # each dataset item is a list of windows
        train_dataloader = torch.utils.data.DataLoader(  # stack each dataset item into a single tensor
            train_dataset,
            num_workers=4,
            batch_size=1,
            shuffle=True,
            collate_fn=list_data_collate)
    first_sample = first(train_dataloader)
    print(first_sample["image"].shape)

    # starting validation set loader
    val_transform = Compose([AddChannelDict(keys="image")])
    val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES),
                          transform=val_transform)
    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 num_workers=1,
                                                 batch_size=1)
    print(val_dataset[0]["image"].shape)
    print(
        f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}"
    )

    model = UNetPipe(spatial_dims=3,
                     in_channels=1,
                     out_channels=N_CLASSES,
                     n_feat=n_feat)
    model = flatten_sequential(model)
    lossweight = torch.from_numpy(
        np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98],
                 np.float32))

    if optimizer.lower() == "rmsprop":
        optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)  # lr = 5e-4
    elif optimizer.lower() == "momentum":
        optimizer = torch.optim.SGD(model.parameters(), lr=lr,
                                    momentum=0.9)  # lr = 1e-4 for finetuning
    else:
        raise ValueError(
            f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')."
        )

    # config GPipe
    x = first_sample["image"].float()
    x = torch.autograd.Variable(x.cuda())
    partitions = torch.cuda.device_count()
    print(f"partition: {partitions}, input: {x.size()}")
    balance = balance_by_size(partitions, model, x)
    model = GPipe(model, balance, chunks=4, checkpoint="always")

    # config loss functions
    dice_loss_func = DiceLoss(softmax=True, reduction="none")
    # use the same pipeline and loss in
    # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy,
    # Medical Physics, 2018.
    focal_loss_func = FocalLoss(reduction="none")

    if pretrain:
        print(f"loading from {pretrain}.")
        pretrained_dict = torch.load(pretrain)["weight"]
        model_dict = model.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(pretrained_dict)

    b_time = time.time()
    best_val_loss = [0] * (N_CLASSES - 1)  # foreground
    for epoch in range(ep):
        model.train()
        trainloss = 0
        for b_idx, data_dict in enumerate(train_dataloader):
            x_train = data_dict["image"]
            y_train = data_dict["label"]
            flagvec = data_dict["with_complete_groundtruth"]

            x_train = torch.autograd.Variable(x_train.cuda())
            y_train = torch.autograd.Variable(y_train.cuda().float())
            optimizer.zero_grad()
            o = model(x_train).to(0, non_blocking=True).float()

            loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                    lossweight.to(o)).mean()
            loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) *
                           lossweight.to(o)).mean()
            loss.backward()
            optimizer.step()
            trainloss += loss.item()

            if b_idx % 20 == 0:
                print(
                    f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}"
                )
        print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}")

        if epoch % 10 == 0:
            model.eval()
            # check validation dice
            val_loss = [0] * (N_CLASSES - 1)
            n_val = [0] * (N_CLASSES - 1)
            for data_dict in val_dataloader:
                x_val = data_dict["image"]
                y_val = data_dict["label"]
                with torch.no_grad():
                    x_val = torch.autograd.Variable(x_val.cuda())
                o = model(x_val).to(0, non_blocking=True)
                loss = compute_meandice(o,
                                        y_val.to(o),
                                        mutually_exclusive=True,
                                        include_background=False)
                val_loss = [
                    l.item() + tl if l == l else tl
                    for l, tl in zip(loss[0], val_loss)
                ]
                n_val = [
                    n + 1 if l == l else n for l, n in zip(loss[0], n_val)
                ]
            val_loss = [l / n for l, n in zip(val_loss, n_val)]
            print(
                "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(val_loss))
            for c in range(1, 10):
                if best_val_loss[c - 1] < val_loss[c - 1]:
                    best_val_loss[c - 1] = val_loss[c - 1]
                    state = {
                        "epoch": epoch,
                        "weight": model.state_dict(),
                        "score_" + str(c): best_val_loss[c - 1]
                    }
                    torch.save(state, f"{model_name}" + str(c))
            print(
                "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f"
                % tuple(best_val_loss))

    print("total time", time.time() - b_time)
Пример #6
0
        batch_size=batch_size, shuffle=True, num_workers=args.num_workers_dataloader,
        **dataloader_kwargs)

    #---------------------------------------------------------------------------------
    # Move model to GPU.
    print("== Creating model '{}' ==".format(args.arch))
    # model = model_names[args.arch].cuda()
    model = model_names[args.arch]

    print("== Autobalancing partitions ==")
    partitions = torch.cuda.device_count()
    sample = torch.empty(batch_size, 1, 28, 28)
    if args.balance_by == 'time':
        balance = balance_by_time(partitions, model, sample)
    elif args.balance_by == 'memory':
        balance = balance_by_size(partitions, model, sample)
    else:
        raise NotImplementedError("Unsupport value specified for 'balance_by' argument")

    print("== Wrapping model as GPipe model ==")
    model = GPipe(model, balance, chunks=args.num_microbatches)
    
    #---------------------------------------------------------------------------------
    # Specify input and output to the correct device
    devices = list(model.devices)
    in_device  = devices[0]
    out_device = devices[-1]

    throughputs = []
    elapsed_times = []
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)