def test_balance_by_size_param_scale(): class Tradeoff(nn.Module): def __init__(self, param_size, latent_size): super().__init__() self.fc = nn.Linear(param_size, param_size) self.latent_size = latent_size def forward(self, x): for i in range(self.latent_size): x = x + torch.rand_like(x, requires_grad=True) return x model = nn.Sequential( Tradeoff(param_size=1, latent_size=6), Tradeoff(param_size=2, latent_size=5), Tradeoff(param_size=3, latent_size=4), Tradeoff(param_size=4, latent_size=3), Tradeoff(param_size=5, latent_size=2), Tradeoff(param_size=6, latent_size=1), ) sample = torch.rand(1, requires_grad=True) balance = balance_by_size(2, model, sample, param_scale=0) assert balance == [2, 4] balance = balance_by_size(2, model, sample, param_scale=100) assert balance == [4, 2]
def test_balance_by_size_param(): model = nn.Sequential(*[nn.Linear(i+1, i+2) for i in range(6)]) sample = torch.rand(7, 1) balance = balance_by_size(2, model, sample, param_scale=100) assert balance == [4, 2] model = nn.Sequential(*[nn.Linear(i+2, i+1) for i in reversed(range(6))]) sample = torch.rand(1, 7) balance = balance_by_size(2, model, sample, param_scale=100) assert balance == [2, 4]
def test_balance_by_size_tuple(): class Twin(nn.Module): def forward(self, x): return x, x.detach() class Add(nn.Module): def forward(self, a_b): a, b = a_b return a + b model = nn.Sequential(Twin(), Add()) sample = torch.rand(1, requires_grad=True) balance_by_size(1, model, sample)
def test_balance_by_size_latent(): class Expand(nn.Module): def __init__(self, times): super().__init__() self.times = times def forward(self, x): for i in range(self.times): x = x + torch.rand_like(x, requires_grad=True) return x sample = torch.rand(10, 100, 100) model = nn.Sequential(*[Expand(i) for i in [1, 2, 3, 4, 5, 6]]) balance = balance_by_size(2, model, sample) assert balance == [4, 2] model = nn.Sequential(*[Expand(i) for i in [6, 5, 4, 3, 2, 1]]) balance = balance_by_size(2, model, sample) assert balance == [2, 4]
def train(n_feat, crop_size, bs, ep, optimizer="rmsprop", lr=5e-4, pretrain=None): model_name = f"./HaN_{n_feat}_{bs}_{ep}_{crop_size}_{lr}_" print(f"save the best model as '{model_name}' during training.") crop_size = [int(cz) for cz in crop_size.split(",")] print(f"input image crop_size: {crop_size}") # starting training set loader train_images = ImageLabelDataset(path=TRAIN_PATH, n_class=N_CLASSES) if np.any([cz == -1 for cz in crop_size]): # using full image train_transform = Compose([ AddChannelDict(keys="image"), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform) # when bs > 1, the loader assumes that the full image sizes are the same across the dataset train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=4, batch_size=bs, shuffle=True) else: # draw balanced foreground/background window samples according to the ground truth label train_transform = Compose([ AddChannelDict(keys="image"), SpatialPadd( keys=("image", "label"), spatial_size=crop_size), # ensure image size >= crop_size RandCropByPosNegLabeld(keys=("image", "label"), label_key="label", spatial_size=crop_size, num_samples=bs), Rand3DElasticd( keys=("image", "label"), spatial_size=crop_size, sigma_range=(10, 50), # 30 magnitude_range=(600, 1200), # 1000 prob=0.8, rotate_range=(np.pi / 12, np.pi / 12, np.pi / 12), shear_range=(np.pi / 18, np.pi / 18, np.pi / 18), translate_range=tuple(sz * 0.05 for sz in crop_size), scale_range=(0.2, 0.2, 0.2), mode=("bilinear", "nearest"), padding_mode=("border", "zeros"), ), ]) train_dataset = Dataset(train_images, transform=train_transform ) # each dataset item is a list of windows train_dataloader = torch.utils.data.DataLoader( # stack each dataset item into a single tensor train_dataset, num_workers=4, batch_size=1, shuffle=True, collate_fn=list_data_collate) first_sample = first(train_dataloader) print(first_sample["image"].shape) # starting validation set loader val_transform = Compose([AddChannelDict(keys="image")]) val_dataset = Dataset(ImageLabelDataset(VAL_PATH, n_class=N_CLASSES), transform=val_transform) val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=1, batch_size=1) print(val_dataset[0]["image"].shape) print( f"training images: {len(train_dataloader)}, validation images: {len(val_dataloader)}" ) model = UNetPipe(spatial_dims=3, in_channels=1, out_channels=N_CLASSES, n_feat=n_feat) model = flatten_sequential(model) lossweight = torch.from_numpy( np.array([2.22, 1.31, 1.99, 1.13, 1.93, 1.93, 1.0, 1.0, 1.90, 1.98], np.float32)) if optimizer.lower() == "rmsprop": optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) # lr = 5e-4 elif optimizer.lower() == "momentum": optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9) # lr = 1e-4 for finetuning else: raise ValueError( f"Unknown optimizer type {optimizer}. (options are 'rmsprop' and 'momentum')." ) # config GPipe x = first_sample["image"].float() x = torch.autograd.Variable(x.cuda()) partitions = torch.cuda.device_count() print(f"partition: {partitions}, input: {x.size()}") balance = balance_by_size(partitions, model, x) model = GPipe(model, balance, chunks=4, checkpoint="always") # config loss functions dice_loss_func = DiceLoss(softmax=True, reduction="none") # use the same pipeline and loss in # AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy, # Medical Physics, 2018. focal_loss_func = FocalLoss(reduction="none") if pretrain: print(f"loading from {pretrain}.") pretrained_dict = torch.load(pretrain)["weight"] model_dict = model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(pretrained_dict) b_time = time.time() best_val_loss = [0] * (N_CLASSES - 1) # foreground for epoch in range(ep): model.train() trainloss = 0 for b_idx, data_dict in enumerate(train_dataloader): x_train = data_dict["image"] y_train = data_dict["label"] flagvec = data_dict["with_complete_groundtruth"] x_train = torch.autograd.Variable(x_train.cuda()) y_train = torch.autograd.Variable(y_train.cuda().float()) optimizer.zero_grad() o = model(x_train).to(0, non_blocking=True).float() loss = (dice_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss += 0.5 * (focal_loss_func(o, y_train.to(o)) * flagvec.to(o) * lossweight.to(o)).mean() loss.backward() optimizer.step() trainloss += loss.item() if b_idx % 20 == 0: print( f"Train Epoch: {epoch} [{b_idx}/{len(train_dataloader)}] \tLoss: {loss.item()}" ) print(f"epoch {epoch} TRAIN loss {trainloss / len(train_dataloader)}") if epoch % 10 == 0: model.eval() # check validation dice val_loss = [0] * (N_CLASSES - 1) n_val = [0] * (N_CLASSES - 1) for data_dict in val_dataloader: x_val = data_dict["image"] y_val = data_dict["label"] with torch.no_grad(): x_val = torch.autograd.Variable(x_val.cuda()) o = model(x_val).to(0, non_blocking=True) loss = compute_meandice(o, y_val.to(o), mutually_exclusive=True, include_background=False) val_loss = [ l.item() + tl if l == l else tl for l, tl in zip(loss[0], val_loss) ] n_val = [ n + 1 if l == l else n for l, n in zip(loss[0], n_val) ] val_loss = [l / n for l, n in zip(val_loss, n_val)] print( "validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(val_loss)) for c in range(1, 10): if best_val_loss[c - 1] < val_loss[c - 1]: best_val_loss[c - 1] = val_loss[c - 1] state = { "epoch": epoch, "weight": model.state_dict(), "score_" + str(c): best_val_loss[c - 1] } torch.save(state, f"{model_name}" + str(c)) print( "best validation scores %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f, %.4f" % tuple(best_val_loss)) print("total time", time.time() - b_time)
batch_size=batch_size, shuffle=True, num_workers=args.num_workers_dataloader, **dataloader_kwargs) #--------------------------------------------------------------------------------- # Move model to GPU. print("== Creating model '{}' ==".format(args.arch)) # model = model_names[args.arch].cuda() model = model_names[args.arch] print("== Autobalancing partitions ==") partitions = torch.cuda.device_count() sample = torch.empty(batch_size, 1, 28, 28) if args.balance_by == 'time': balance = balance_by_time(partitions, model, sample) elif args.balance_by == 'memory': balance = balance_by_size(partitions, model, sample) else: raise NotImplementedError("Unsupport value specified for 'balance_by' argument") print("== Wrapping model as GPipe model ==") model = GPipe(model, balance, chunks=args.num_microbatches) #--------------------------------------------------------------------------------- # Specify input and output to the correct device devices = list(model.devices) in_device = devices[0] out_device = devices[-1] throughputs = [] elapsed_times = [] optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)