def train_pyramid(pyramid_path, dataset_path, train_stages, checkpoint_name, generate_field_stages, train_params=None, aug_params=None): pyramid_path = os.path.expanduser(pyramid_path) module_dict = get_pyramid_modules(pyramid_path) print (f"Loading dataset {dataset_path}...") dataset = MultimipDataset(dataset_path, aug_params, field_tag=checkpoint_name) prev_mip = None for stage in sorted(module_dict.keys()): module_mip = module_dict[stage]["mip_in"] module_path = module_dict[stage]["path"] model = None if train_stages is None or stage in train_stages: print (f"Training module {stage}...") model = modelhouse.load_model_simple(module_path, finetune=False, pass_field=True, checkpoint_name=checkpoint_name) model if str(module_mip) not in train_params: raise Exception(f"Training parameters not specified for mip {module_mip}") mip_train_params = train_params[str(module_mip)] train_module(model, train_params=mip_train_params, train_dset=dataset.get_train_dset(mip=module_mip, stage=stage), val_dset=dataset.get_val_dset(mip=module_mip, stage=stage), checkpoint_path=os.path.join(module_path, "model")) print (f"Done training module {stage}!") if generate_field_stages is None or stage in generate_field_stages: print (f"Generating fields with module {stage}...") model = modelhouse.load_model_simple(module_path, finetune=True, pass_field=True, finetune_iter=300//(2**stage), checkpoint_name=checkpoint_name) dataset.generate_fields(model, mip=module_mip, stage=stage) print (f"Done generating fields with module {stage}...")
def train_pyramid(world_size, pyramid_path, dataset_path, train_stages, checkpoint_name, generate_field_stages, train_params=None, aug_params=None): pyramid_path = os.path.expanduser(pyramid_path) module_dict = get_pyramid_modules(pyramid_path) print(f"Loading dataset {dataset_path}...") prev_mip = None for stage in sorted(module_dict.keys()): module_mip = module_dict[stage]["mip_in"] module_path = module_dict[stage]["path"] model = None if train_stages is None or stage in train_stages: print(f"Training module {stage}...") mip_train_params = train_params[str(module_mip)] mp.spawn( train_module, args=( world_size, module_path, mip_train_params, # train_params dataset_path, # dataset_path module_mip, # module_mip stage, # stage checkpoint_name, # checkpoint_name None), # aug_params nprocs=world_size, join=True) # train_module(model, train_params=mip_train_params, # train_dset=dataset.get_train_dset(mip=module_mip, stage=stage), # val_dset=dataset.get_val_dset(mip=module_mip, stage=stage), # checkpoint_path=os.path.join(module_path, "model")) print(f"Done training module {stage}!") if generate_field_stages is None or stage in generate_field_stages: print(f"Generating fields with module {stage}...") model = modelhouse.load_model_simple( module_path, finetune=True, pass_field=True, finetune_iter=300 // (2**stage), checkpoint_name=checkpoint_name) dataset.generate_fields(model, mip=module_mip, stage=stage) print(f"Done generating fields with module {stage}...")
def train_module(rank, world_size, module_path, train_params, dataset_path, module_mip, stage, checkpoint_name, aug_params=None): """Train object with its own dataset (specific MIP) Args: module_path (str): path to modelhouse directory train_params (dict) train_dset (AlignmentDataLoader) val_dset (AlignmentDataLoader) checkpoint_name (str) rank (int): process, dictates the GPU used in multi-gpu training world_size (int): total no. of processes """ assert aug_params is None print(f"Running DDP on rank {rank}.") setup(rank, world_size) torch.cuda.set_device(rank) model = modelhouse.load_model_simple(module_path, finetune=False, pass_field=True, checkpoint_name=checkpoint_name) checkpoint_path = os.path.join(module_path, "model") model.aligner.net = model.aligner.net.to(rank) model = DDP(model, device_ids=[rank]) dataset = MultimipDataset(dataset_path, aug_params, field_tag=checkpoint_name) train_dset = dataset.get_train_dset(mip=module_mip, stage=stage) val_dset = dataset.get_val_dset(mip=module_mip, stage=stage) val_data_loader = torch.utils.data.DataLoader(val_dset, batch_size=1, shuffle=True, num_workers=0, pin_memory=False) trainable = [] trainable.extend(list(model.parameters())) for epoch_params in train_params: smoothness = epoch_params["smoothness"] if "print_every" in epoch_params: print_every = epoch_params["print_every"] else: print_every = None if "num_sample" in epoch_params: num_samples = epoch_params["num_samples"] else: num_samples = 10000000 lr = epoch_params["lr"] num_epochs = epoch_params["num_epochs"] mse_keys_to_apply = epoch_params["mse_keys_to_apply"] sm_keys_to_apply = epoch_params["sm_keys_to_apply"] loss_spec = epoch_params["loss_spec"] loss_type = epoch_params["loss_spec"]["type"] simple_loss = loss.unsupervised_loss( smoothness, use_defect_mask=True, sm_keys_to_apply=sm_keys_to_apply, mse_keys_to_apply=mse_keys_to_apply) if loss_type == "plain": training_loss = simple_loss elif loss_type == "metric": training_loss = loss.multilevel_metric_loss(loss_fn=simple_loss, mip_in=0, **loss_spec['params']) else: raise Exception('Bad loss type') augmentor = None if "augmentations" in epoch_params: augmentor = augmentations.Augmentor(epoch_params["augmentations"]) train_dset.set_size_limit(num_samples) # Divide dataset across processes train_sampler = torch.utils.data.distributed.DistributedSampler( train_dset, num_replicas=world_size, rank=rank) train_data_loader = torch.utils.data.DataLoader(train_dset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False, sampler=train_sampler) optimizer = torch.optim.Adam(trainable, lr=lr, weight_decay=0) aligner_train_loop(rank, model, mip_in=0, train_loader=train_data_loader, val_loader=val_data_loader, optimizer=optimizer, num_epochs=num_epochs, loss_fn=training_loss, print_every=print_every, checkpoint_folder=checkpoint_path, augmentor=augmentor) cleanup() pass
def generate_shard(rank, world_size, module_path, checkpoint_name, img_path, prev_field_path, dst_dir, src_mip, dst_mip): """Generate field for subset of image pairs associated with rank Args: rank (int): process order world_size (int): total no. of processes module_path (str): path to modelhouse directory checkpoint_name (str): checkpoint for weights img_path (str): path to image pairs h5 prev_field_path (str): path to previous fields h5 dst_dir (str): path where temporary field h5s will be stored src_mip (int) dst_mip (int) """ print(f"Running DDP on rank {rank}.") setup(rank, world_size) torch.cuda.set_device(rank) model = modelhouse.load_model_simple(module_path, finetune=True, finetune_lr=3e-1, finetune_sm=300e0, finetune_iter=200, pass_field=True, checkpoint_name=checkpoint_name) checkpoint_path = os.path.join(module_path, "model") model.aligner.net = model.aligner.net.to(rank) model = DDP(model, device_ids=[rank]) img_dset = h5py.File(img_path, 'r')['main'] prev_field_dset = h5py.File(prev_field_path, 'r')['main'] assert (img_dset.shape[0] >= world_size) n = img_dset.shape[0] // world_size n_start = rank * n n_stop = min(n_start + n, img_dset.shape[0]) if rank + 1 == world_size: n_stop = img_dset.shape[0] src_mip_filepath = os.path.join(dst_dir, '{}'.format(src_mip)) dst_mip_filepath = os.path.join(dst_dir, '{}'.format(dst_mip)) src_field_dset = CloudVolume(src_mip_filepath, mip=0) dst_field_dset = CloudVolume(dst_mip_filepath, mip=0) # src_field_dset = src_field.create_dataset("main", # shape=field_shape, # dtype=np.float32, # chunks=chunks, # compression='lzf', # scaleoffset=2) # dst_field_dset = dst_field.create_dataset("main", # shape=field_shape, # dtype=np.float32, # chunks=chunks, # compression='lzf', # scaleoffset=2) for b in range(n_start, n_stop): print('{} / {}'.format(img_dset.shape[0], b)) src = helpers.to_tensor(img_dset[b, 0]) tgt = helpers.to_tensor(img_dset[b, 1]) if prev_field_dset is not None: prev_field = helpers.to_tensor(prev_field_dset[b]) else: prev_field = None field = model(src_img=src, tgt_img=tgt, src_agg_field=prev_field, train=False, return_state=False) field_shape = field.shape hsz = (src_field_dset.shape[0] * 2**(src_mip - dst_mip) - dst_field_dset.shape[0]) // 2 src_field_dset[:, :, b - n_start, :] = helpers.get_np( field.permute(2, 3, 0, 1)) # upsample field = field * (2**src_mip) field = field.up(mips=src_mip - dst_mip) field = field / (2**dst_mip) field_cropped = field[:, :, hsz:-hsz, hsz:-hsz] field_cropped = field_cropped.permute(2, 3, 0, 1) dst_field_dset[:, :, b - n_start, :] = helpers.get_np(field_cropped) cleanup() pass