def train_pyramid(pyramid_path, dataset_path, train_stages, checkpoint_name,
        generate_field_stages, train_params=None, aug_params=None):
    pyramid_path = os.path.expanduser(pyramid_path)
    module_dict = get_pyramid_modules(pyramid_path)
    print (f"Loading dataset {dataset_path}...")
    dataset = MultimipDataset(dataset_path, aug_params, field_tag=checkpoint_name)

    prev_mip = None
    for stage in sorted(module_dict.keys()):
        module_mip = module_dict[stage]["mip_in"]
        module_path = module_dict[stage]["path"]
        model = None

        if train_stages is None or stage in train_stages:
            print (f"Training module {stage}...")
            model = modelhouse.load_model_simple(module_path,
                                                 finetune=False,
                                                 pass_field=True,
                                                 checkpoint_name=checkpoint_name)
            model
            if str(module_mip) not in train_params:
                raise Exception(f"Training parameters not specified for mip {module_mip}")

            mip_train_params = train_params[str(module_mip)]
            train_module(model, train_params=mip_train_params,
                    train_dset=dataset.get_train_dset(mip=module_mip, stage=stage),
                    val_dset=dataset.get_val_dset(mip=module_mip, stage=stage),
                    checkpoint_path=os.path.join(module_path, "model"))

            print (f"Done training module {stage}!")

        if generate_field_stages is None or stage in generate_field_stages:
            print (f"Generating fields with module {stage}...")
            model = modelhouse.load_model_simple(module_path,
                    finetune=True,
                    pass_field=True,
                    finetune_iter=300//(2**stage),
                    checkpoint_name=checkpoint_name)
            dataset.generate_fields(model, mip=module_mip, stage=stage)
            print (f"Done generating fields with module {stage}...")
def train_pyramid(world_size,
                  pyramid_path,
                  dataset_path,
                  train_stages,
                  checkpoint_name,
                  generate_field_stages,
                  train_params=None,
                  aug_params=None):
    pyramid_path = os.path.expanduser(pyramid_path)
    module_dict = get_pyramid_modules(pyramid_path)
    print(f"Loading dataset {dataset_path}...")

    prev_mip = None
    for stage in sorted(module_dict.keys()):
        module_mip = module_dict[stage]["mip_in"]
        module_path = module_dict[stage]["path"]
        model = None

        if train_stages is None or stage in train_stages:
            print(f"Training module {stage}...")
            mip_train_params = train_params[str(module_mip)]
            mp.spawn(
                train_module,
                args=(
                    world_size,
                    module_path,
                    mip_train_params,  # train_params
                    dataset_path,  # dataset_path
                    module_mip,  # module_mip
                    stage,  # stage
                    checkpoint_name,  # checkpoint_name
                    None),  # aug_params
                nprocs=world_size,
                join=True)
            #  train_module(model, train_params=mip_train_params,
            #          train_dset=dataset.get_train_dset(mip=module_mip, stage=stage),
            #          val_dset=dataset.get_val_dset(mip=module_mip, stage=stage),
            #          checkpoint_path=os.path.join(module_path, "model"))

            print(f"Done training module {stage}!")

        if generate_field_stages is None or stage in generate_field_stages:
            print(f"Generating fields with module {stage}...")
            model = modelhouse.load_model_simple(
                module_path,
                finetune=True,
                pass_field=True,
                finetune_iter=300 // (2**stage),
                checkpoint_name=checkpoint_name)
            dataset.generate_fields(model, mip=module_mip, stage=stage)
            print(f"Done generating fields with module {stage}...")
Beispiel #3
0
def train_module(rank,
                 world_size,
                 module_path,
                 train_params,
                 dataset_path,
                 module_mip,
                 stage,
                 checkpoint_name,
                 aug_params=None):
    """Train object with its own dataset (specific MIP)

    Args:
        module_path (str): path to modelhouse directory
        train_params (dict)
        train_dset (AlignmentDataLoader)
        val_dset (AlignmentDataLoader)
        checkpoint_name (str)
        rank (int): process, dictates the GPU used in multi-gpu training
        world_size (int): total no. of processes
    """
    assert aug_params is None
    print(f"Running DDP on rank {rank}.")
    setup(rank, world_size)
    torch.cuda.set_device(rank)
    model = modelhouse.load_model_simple(module_path,
                                         finetune=False,
                                         pass_field=True,
                                         checkpoint_name=checkpoint_name)
    checkpoint_path = os.path.join(module_path, "model")
    model.aligner.net = model.aligner.net.to(rank)
    model = DDP(model, device_ids=[rank])

    dataset = MultimipDataset(dataset_path,
                              aug_params,
                              field_tag=checkpoint_name)
    train_dset = dataset.get_train_dset(mip=module_mip, stage=stage)
    val_dset = dataset.get_val_dset(mip=module_mip, stage=stage)
    val_data_loader = torch.utils.data.DataLoader(val_dset,
                                                  batch_size=1,
                                                  shuffle=True,
                                                  num_workers=0,
                                                  pin_memory=False)

    trainable = []
    trainable.extend(list(model.parameters()))

    for epoch_params in train_params:
        smoothness = epoch_params["smoothness"]
        if "print_every" in epoch_params:
            print_every = epoch_params["print_every"]
        else:
            print_every = None

        if "num_sample" in epoch_params:
            num_samples = epoch_params["num_samples"]
        else:
            num_samples = 10000000

        lr = epoch_params["lr"]
        num_epochs = epoch_params["num_epochs"]
        mse_keys_to_apply = epoch_params["mse_keys_to_apply"]
        sm_keys_to_apply = epoch_params["sm_keys_to_apply"]
        loss_spec = epoch_params["loss_spec"]
        loss_type = epoch_params["loss_spec"]["type"]

        simple_loss = loss.unsupervised_loss(
            smoothness,
            use_defect_mask=True,
            sm_keys_to_apply=sm_keys_to_apply,
            mse_keys_to_apply=mse_keys_to_apply)
        if loss_type == "plain":
            training_loss = simple_loss
        elif loss_type == "metric":
            training_loss = loss.multilevel_metric_loss(loss_fn=simple_loss,
                                                        mip_in=0,
                                                        **loss_spec['params'])
        else:
            raise Exception('Bad loss type')

        augmentor = None
        if "augmentations" in epoch_params:
            augmentor = augmentations.Augmentor(epoch_params["augmentations"])

        train_dset.set_size_limit(num_samples)
        # Divide dataset across processes
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dset, num_replicas=world_size, rank=rank)
        train_data_loader = torch.utils.data.DataLoader(train_dset,
                                                        batch_size=1,
                                                        shuffle=False,
                                                        num_workers=0,
                                                        pin_memory=False,
                                                        sampler=train_sampler)

        optimizer = torch.optim.Adam(trainable, lr=lr, weight_decay=0)
        aligner_train_loop(rank,
                           model,
                           mip_in=0,
                           train_loader=train_data_loader,
                           val_loader=val_data_loader,
                           optimizer=optimizer,
                           num_epochs=num_epochs,
                           loss_fn=training_loss,
                           print_every=print_every,
                           checkpoint_folder=checkpoint_path,
                           augmentor=augmentor)
    cleanup()
    pass
Beispiel #4
0
def generate_shard(rank, world_size, module_path, checkpoint_name, img_path,
                   prev_field_path, dst_dir, src_mip, dst_mip):
    """Generate field for subset of image pairs associated with rank

    Args:
        rank (int): process order
        world_size (int): total no. of processes
        module_path (str): path to modelhouse directory
        checkpoint_name (str): checkpoint for weights
        img_path (str): path to image pairs h5
        prev_field_path (str): path to previous fields h5
        dst_dir (str): path where temporary field h5s will be stored
        src_mip (int)
        dst_mip (int)
    """
    print(f"Running DDP on rank {rank}.")
    setup(rank, world_size)
    torch.cuda.set_device(rank)
    model = modelhouse.load_model_simple(module_path,
                                         finetune=True,
                                         finetune_lr=3e-1,
                                         finetune_sm=300e0,
                                         finetune_iter=200,
                                         pass_field=True,
                                         checkpoint_name=checkpoint_name)
    checkpoint_path = os.path.join(module_path, "model")
    model.aligner.net = model.aligner.net.to(rank)
    model = DDP(model, device_ids=[rank])

    img_dset = h5py.File(img_path, 'r')['main']
    prev_field_dset = h5py.File(prev_field_path, 'r')['main']
    assert (img_dset.shape[0] >= world_size)
    n = img_dset.shape[0] // world_size
    n_start = rank * n
    n_stop = min(n_start + n, img_dset.shape[0])
    if rank + 1 == world_size:
        n_stop = img_dset.shape[0]
    src_mip_filepath = os.path.join(dst_dir, '{}'.format(src_mip))
    dst_mip_filepath = os.path.join(dst_dir, '{}'.format(dst_mip))
    src_field_dset = CloudVolume(src_mip_filepath, mip=0)
    dst_field_dset = CloudVolume(dst_mip_filepath, mip=0)
    # src_field_dset = src_field.create_dataset("main",
    #                                        shape=field_shape,
    #                                        dtype=np.float32,
    #                                        chunks=chunks,
    #                                        compression='lzf',
    #                                        scaleoffset=2)
    # dst_field_dset = dst_field.create_dataset("main",
    #                                        shape=field_shape,
    #                                        dtype=np.float32,
    #                                        chunks=chunks,
    #                                        compression='lzf',
    #                                        scaleoffset=2)

    for b in range(n_start, n_stop):
        print('{} / {}'.format(img_dset.shape[0], b))
        src = helpers.to_tensor(img_dset[b, 0])
        tgt = helpers.to_tensor(img_dset[b, 1])
        if prev_field_dset is not None:
            prev_field = helpers.to_tensor(prev_field_dset[b])
        else:
            prev_field = None

        field = model(src_img=src,
                      tgt_img=tgt,
                      src_agg_field=prev_field,
                      train=False,
                      return_state=False)
        field_shape = field.shape
        hsz = (src_field_dset.shape[0] * 2**(src_mip - dst_mip) -
               dst_field_dset.shape[0]) // 2
        src_field_dset[:, :, b - n_start, :] = helpers.get_np(
            field.permute(2, 3, 0, 1))
        # upsample
        field = field * (2**src_mip)
        field = field.up(mips=src_mip - dst_mip)
        field = field / (2**dst_mip)
        field_cropped = field[:, :, hsz:-hsz, hsz:-hsz]
        field_cropped = field_cropped.permute(2, 3, 0, 1)
        dst_field_dset[:, :, b - n_start, :] = helpers.get_np(field_cropped)

    cleanup()
    pass