コード例 #1
0
def tick(spec, unit):
    '''
    Method to tick lab unit (experiment, trial, session) in meta spec to advance their indices
    Reset lower lab indices to -1 so that they tick to 0
    spec_util.tick(spec, 'session')
    session = Session(spec)
    '''
    meta_spec = spec['meta']
    if unit == 'experiment':
        meta_spec['experiment_ts'] = util.get_ts()
        meta_spec['experiment'] += 1
        meta_spec['trial'] = -1
        meta_spec['session'] = -1
    elif unit == 'trial':
        if meta_spec['experiment'] == -1:
            meta_spec['experiment'] += 1
        meta_spec['trial'] += 1
        meta_spec['session'] = -1
    elif unit == 'session':
        if meta_spec['experiment'] == -1:
            meta_spec['experiment'] += 1
        if meta_spec['trial'] == -1:
            meta_spec['trial'] += 1
        meta_spec['session'] += 1
    else:
        raise ValueError(f'Unrecognized lab unit to tick: {unit}')
    # set prepath since it is determined at this point
    meta_spec['prepath'] = prepath = util.get_prepath(spec, unit)
    for folder in ('graph', 'info', 'log', 'model'):
        folder_prepath = util.insert_folder(prepath, folder)
        os.makedirs(os.path.dirname(util.smart_path(folder_prepath)),
                    exist_ok=True)
        meta_spec[f'{folder}_prepath'] = folder_prepath
    return spec
コード例 #2
0
ファイル: viz.py プロジェクト: kiseliu/NeuralPipeline_DSTC8
def save_image(figure, filepath):
    if os.environ['PY_ENV'] == 'test':
        return
    filepath = util.smart_path(filepath)
    try:
        pio.write_image(figure, filepath)
    except Exception as e:
        orca_warn_once(e)
コード例 #3
0
def load(net, model_path):
    '''Save model weights from a path into a net module'''
    device = None if torch.cuda.is_available() else 'cpu'
    net.load_state_dict(
        torch.load(util.smart_path(model_path), map_location=device))
コード例 #4
0
def save(net, model_path):
    '''Save model weights to path'''
    torch.save(net.state_dict(), util.smart_path(model_path))