コード例 #1
0
    data = voldata_train
    labels = vollabels_train

    volume_shape = data.shape

    crop_seq_train = VolumeCropSequence(

        data_volume=data,
        labels_volume=labels,

        batch_size=batch_size,

        meta_crop_generator=MetaCrop3DGenerator(

            volume_shape=volume_shape,
            crop_shape=crop_shape,
            is_2halfd=(model_type == T2SModelType.d2half),

            x0y0z0_generator=(grid_pos_gen := t2s_volseq.UniformGridPosition.build_from_volume_crop_shapes(
                volume_shape=volume_shape,
                crop_shape=crop_shape,
                random_state=RandomState(args.random_state_seed),
            )),

            gt_field=t2s_volseq.GTUniformEverywhere(
                gt_type=gt_type,
                grid_position_generator=grid_pos_gen,
                random_state=RandomState(args.random_state_seed),
            ),

            et_field=t2s_volseq.ET3DConstantEverywhere.build_no_displacement(grid_position_generator=grid_pos_gen),
コード例 #2
0
    grid_pos_gen = t2s_volseq.UniformGridPosition.build_from_volume_crop_shapes(
        volume_shape=volume_shape,
        crop_shape=crop_shape,
        random_state=RandomState(args.random_state_seed),
    )

    crop_seq_train = VolumeCropSequence(
        data_volume=data,
        labels_volume=labels,
        batch_size=batch_size,

        # data augmentation
        meta_crop_generator=MetaCrop3DGenerator.build_no_augmentation(
            grid_pos_gen=grid_pos_gen,
            volume_shape=volume_shape,
            crop_shape=crop_shape,
            common_random_state_seed=args.random_state_seed,
            gt_type=gt_type,
            is_2halfd=(model_type == T2SModelType.d2half),
        ),

        # this volume cropper only returns random crops,
        # so the number of crops per epoch/batch is w/e i want
        epoch_size=10,
        **vol_crop_seq_common_kwargs,
        meta_crops_hist_path=t2s_model.train_metacrop_history_path,
    )

    # ## Val

    data = voldata_val
    labels = vollabels_val