예제 #1
0
def test_type(eng, tmpdir):
    data = td.series.fromrandom(engine=eng)
    path = os.path.join(tmpdir.dirname, 'test0')
    save_rdd_as_pickle(data, path)
    with pytest.raises(ValueError) as excinfo:
        _ = load_rdd_from_pickle(eng, path, return_type='error')
    assert 'return_type not' in str(excinfo.value)
예제 #2
0
def test_overwrite_true(eng, tmpdir):
    data = td.images.fromrandom(engine=eng)
    path = os.path.join(tmpdir.dirname, 'test4')
    save_rdd_as_pickle(data, path)
    save_rdd_as_pickle(data, path, overwrite=True)
    reloaded = load_rdd_from_pickle(eng, path)
    data_local = data.toarray()
    reloaded_local = reloaded.toarray()
    assert np.allclose(data_local, reloaded_local)
    assert data_local.dtype == reloaded_local.dtype
    assert reloaded.npartitions() == eng.defaultParallelism
예제 #3
0
def test_rdd(eng, tmpdir):
    data = eng.range(100)
    path = os.path.join(tmpdir.dirname, 'test3')
    save_rdd_as_pickle(data, path)
    reloaded = load_rdd_from_pickle(eng, path, return_type='rdd')
    data_local = np.array(sorted(data.collect()))
    reloaded_local = np.array(sorted(reloaded.collect()))
    assert isinstance(reloaded, RDD)
    assert np.allclose(data_local, reloaded_local)
    assert data_local.dtype == reloaded_local.dtype
    assert reloaded.getNumPartitions() == eng.defaultParallelism
예제 #4
0
def test_images(eng, tmpdir):
    data = td.images.fromrandom(engine=eng)
    path = os.path.join(tmpdir.dirname, 'test2')
    save_rdd_as_pickle(data, path)
    reloaded = load_rdd_from_pickle(eng, path)
    data_local = data.toarray()
    reloaded_local = reloaded.toarray()
    assert isinstance(reloaded, td.images.Images)
    assert np.allclose(data_local, reloaded_local)
    assert data_local.dtype == reloaded_local.dtype
    assert reloaded.npartitions() == eng.defaultParallelism
예제 #5
0
def reloadClean(sc,
                session,
                name=None,
                returnRegDict=False,
                returnTForm=False,
                repartition=True,
                full_path=None):
    """ reloads session and clean data from nrs

    :param sc: Spark context
    :param session: SpineSession object
    :param name: name of session to load
    :param returnRegDict: if to return regDict
    :param returnTForm:  if to return TForm
    :param repartition: if to repartition Clean RDD. True = sc.defaultParallelism * 2 or any other int > 0
    :param full_path: if the full path of the data is known (other then cluster default)
    :return: either clean (if all is set to None / False) or: clean, session, regDict, TForm1
    """
    if name is not None:
        session = session.load(name)
    if full_path is None:
        full_path = '/nrs/svoboda/moharb/New/' + session.animalID + '_' + session.date + session.run + 'CleanBinaryPickle'
    clean = load_rdd_from_pickle(sc, full_path)
    # if bool true
    if repartition > 0:
        # if number specified
        if repartition > 1:
            nPartitions = repartition
        # if number not specified
        else:
            nPartitions = sc.defaultParallelism * 2
        clean = clean.repartition(nPartitions)
        logger.info('Repartitioned Clean to %d partitions' % nPartitions)
    clean.cache()
    clean.count()
    if returnRegDict:
        regDict = session.regDict
        regDict['data'] = clean
    else:
        regDict = None
    if returnTForm:
        if hasattr(session, 'embedDict') and 'TForm1' in session.embedDict:
            TForm1 = session.embedDict['TForm1']
        else:
            logger.error('TForm Not Found')
            TForm1 = None
    else:
        TForm1 = None
    if name is None and regDict is None and TForm1 is None:
        return clean
    else:
        return clean, session, regDict, TForm1
예제 #6
0
def loadRegData(sc,
                session,
                name='step5',
                base='/nrs/svoboda/moharb/New/',
                check=True):
    """ loads regData and session if name is not None

    :param sc: spark context
    :param session: session object
    :param name: session name. If None will not load a session
    :param base: base path of sessions
    :return: session, regData
    """
    if name is not None:
        session = session.load(name)
    path = base + session.animalID + '_' + session.date + session.run + 'RegBinaryPickle'
    regData = load_rdd_from_pickle(sc, path)
    regData.cache()
    count = regData.count()
    if check and count != session.regDict['globalTC'].shape[0]:
        logger.error('count: %d different from globalTC length %d' %
                     (count, session.regDict['globalTC'].shape[0]))
    logger.info('Loaded regData from: %s' % path)
    return session, regData