Ejemplo n.º 1
0
def test_dump_nosplit_merge(index_type, batch_size, use_glob, tmp_path):
    """
    Dump content batch-wise for the whole dataset. Merge dumped files.
    Check that merged and input files contents equal
    The list of files to merge is formed by listing all files or using glob patterns
    """

    index = index_type(path=PATH, name='raw')

    ppl = (SeismicDataset(index).p.load(
        components='raw', fmt='segy', tslice=slice(2000)).dump(
            src='raw',
            path=L(lambda x: os.path.join(tmp_path,
                                          str(x) + '.sgy'))(I()),
            fmt='sgy',
            split=False))

    ppl.run(batch_size=batch_size,
            n_epochs=1,
            drop_last=False,
            shuffle=False,
            bar=False)

    if use_glob:
        files_list = os.path.join(tmp_path, "*.sgy")
    else:
        files_list = [
            os.path.join(tmp_path, f) for f in os.listdir(tmp_path)
            if f.endswith('.sgy')
        ]

    merged_path = os.path.join(tmp_path, "out.sgy")
    merge_segy_files(path=files_list, output_path=merged_path, bar=False)

    compare_files(PATH, merged_path, compare_all=True)
Ejemplo n.º 2
0
    def test_crop_float_coords_ok(self, index_type, batch_size,
                                  single_iteration):
        """ Make crops using float coords """

        index = index_type(path=PATH, name='raw')

        ppl = (SeismicDataset(index).p.init_variable(
            'raw', default=[]).init_variable('crops', default=[]).load(
                components='raw', fmt='segy',
                tslice=slice(2000)).update(V('raw', 'a'), B('raw')).crop(
                    src='raw',
                    dst='crops',
                    coords=[(0, 0), (0, 0.5)],
                    shape=(1, 1),
                    pad_zeros=False).update(V('crops', 'a'), B('crops')))

        if single_iteration:
            ppl.run(batch_size, n_iters=1, shuffle=True)
        else:
            ppl.run(batch_size, n_epochs=1, shuffle=False)

        raw_batches_list = ppl.get_variable('raw')
        crops_batches_list = ppl.get_variable('crops')

        for raw_list, crops_list in zip(raw_batches_list, crops_batches_list):
            for raw, crops in zip(raw_list, crops_list):
                assert np.allclose(raw[0, 0], crops[0])
                assert np.allclose(raw[0, 999], crops[1])
Ejemplo n.º 3
0
    def test_crop_assemble(self, index_type, batch_size, crop_shape, single_iteration, assemble_fill_value):
        """
        Make crops that cover whole array using regular grid,
        then assemble those crops and
        check that the result equals the original array.

        Checks int coords, using different coords for each item in a batch (P named exression), and assembling crops
        """
        index = index_type(path=PATH, name='raw')

        ppl = (
            SeismicDataset(index).p
            .init_variable('raw', default=[])
            .init_variable('assemble', default=[])
            .load(components='raw', fmt='segy', tslice=slice(2000))
            .update(V('raw', 'a'), B('raw'))
            .make_grid_for_crops(src='raw', dst='coords', shape=crop_shape, drop_last=False)
            .crop(src='raw', dst='crops', coords=P(B('coords')), shape=crop_shape, pad_zeros=True)
            .assemble_crops(src='crops', dst='assemble', fill_value=assemble_fill_value)
            .update(V('assemble', 'a'), B('assemble'))
            )

        if single_iteration:
            ppl.run(batch_size, n_iters=1, shuffle=True)
        else:
            ppl.run(batch_size, n_epochs=1, shuffle=False)

        raw_batches_list = ppl.get_variable('raw')
        assemble_batches_list = ppl.get_variable('assemble')

        for raw_list, assemble_list in zip(raw_batches_list, assemble_batches_list):
            for raw, assemble in zip(raw_list, assemble_list):
                assert np.allclose(raw, assemble)
Ejemplo n.º 4
0
    def test_wrong_action_order(self):
        """ assembling and plotting crops should fail if no cropping was done"""

        index = FieldIndex(path=PATH, name='raw')

        ppl = (SeismicDataset(index).p.load(
            components='raw', fmt='segy').assemble_crops(src='raw',
                                                         dst='assemble'))

        with pytest.raises(Exception):
            ppl.run(5, n_iters=1)

        ppl = SeismicDataset(index).p.load(components='raw', fmt='segy')

        batch = ppl.next_batch(5)

        with pytest.raises(Exception):
            batch.crops_plot('raw', index.indices[0])
Ejemplo n.º 5
0
def predict(path_raw, path_model, num_zero=100, save_to='dump.csv',
            batch_size=1000, trace_len=1000, device='cpu', shift=0):
    """Make predictions and dump results using loaded model and path to data.

    Parameters
    ----------
    path_raw: str
        Path to SEGY file.
    path_model: str
        Path to the file with trained model.
    num_zero: int, default: 100
        Reauired number of zero values in a row in the trace to drop such trace.
    save_to: str, default: 'dump.csv'
        Path to CSV file where the results will be stored.
    bs: int, default: 1000
        The batch size for inference.
    trace_len: int, default: 1000
        The number of first samples in the trace to load to the pipeline.
    device: str or torch.device, default: 'cpu'
        The device used for inference. Can be 'gpu' in case of avaliavle GPU.
    shift: float, default: 0
        Shift the picking times for each trace on the given phase shift, measured in radians.

    """
    data = SeismicDataset(TraceIndex(name='raw', path=path_raw))

    config_predict = {
        'build': False,
        'load/path': path_model,
        'device': device
    }

    try:
        os.remove(save_to)
    except OSError:
        pass

    test_tmpl = (data.p
                 .init_model('dynamic', UNet, 'my_model', config=config_predict)
                 .load(components='raw', fmt='segy', tslice=slice(0, trace_len))
                 .drop_zero_traces(num_zero=num_zero, src='raw')
                 .standardize(src='raw', dst='raw')
                 .add_components(components='predictions')
                 .call(lambda batch: np.stack(batch.raw), save_to=B('raw'))
                 .predict_model('my_model', B('raw'), fetches='predictions',
                                save_to=B('predictions', mode='a'))
                 .mask_to_pick(src='predictions', dst='predictions', labels=False)
                )
    if shift:
        test_tmpl += Pipeline().shift_pick_phase(src='predictions', dst='predictions', src_traces='raw', shift=shift)

    test_pipeline = test_tmpl + Pipeline().dump(src='predictions', fmt='picks', path=save_to, src_traces='raw')
    test_pipeline.run(batch_size, n_epochs=1, drop_last=False, shuffle=False, bar=True)
Ejemplo n.º 6
0
def compare_files(path_1, path_2, compare_all):
    """ Checks that all traces from SEG-Y file `path_2` are found in `path_1`,
    or if `compare_all=True` that both files contain exacly same traces

    Parameters
    ----------
    path_1 : str
        path to the larger file
    path_2 : str
        path to the smaller file
    compare_all : bool
        whether to chek exact equality
    """

    index1 = TraceIndex(name='f1', path=path_1)
    index2 = TraceIndex(name='f2', path=path_2)
    index = index1.merge(index2)

    # all traces from the smaller file are in the larger one
    assert len(index2) == len(index)

    if compare_all:
        # both files have same traces
        assert len(index1) == len(index2)

    index = FieldIndex(index)

    ppl = (SeismicDataset(index).p.load(
        components='f1', fmt='segy', tslice=slice(2000)).load(
            components='f2', fmt='segy', tslice=slice(2000)).sort_traces(
                src=('f1', 'f2'),
                sort_by='TraceNumber').add_components('res').init_variable(
                    'res', default=[]).apply_parallel(
                        lambda arrs: np.allclose(*arrs),
                        src=('f1', 'f2'),
                        dst='res').update(V('res', 'a'), B('res')))

    ppl.run(batch_size=1,
            n_epochs=1,
            drop_last=False,
            shuffle=False,
            bar=False)

    res = np.stack(ppl.get_variable('res'))

    assert np.all(res)
Ejemplo n.º 7
0
def test_dump_split_merge(index_type, tmp_path):
    """
    Dump content item-wise for one iteration. Merge dumped files.
    Check that all traces from the merged file are in the input file
    """

    index = index_type(path=PATH, name='raw')

    ppl = (
        SeismicDataset(index).p
        .load(components='raw', fmt='segy', tslice=slice(2000))
        .dump(src='raw', path=tmp_path, fmt='sgy', split=True)
        )

    ppl.next_batch(4)

    merged_path = os.path.join(tmp_path, "out.sgy")
    merge_segy_files(path=os.path.join(tmp_path, "*.sgy"), output_path=merged_path, bar=False)

    compare_files(PATH, merged_path, compare_all=False)