Ejemplo n.º 1
0
def read_keys(_dir, _filter, column_names):

    data = {}
    for cn in column_names:
        data.update({cn: []})

    dirs = ls_dir(_dir)
    print(len(dirs), dirs)
    assert len(dirs) > 0, f'filter {_filter} not found in {_dir}'
    for d in dirs:
        model_name = d.split('/')[-1]
        run_data = {}
        for cn in column_names:
            run_data.update({cn: []})

        if _filter is not None and not '':
            if _filter not in model_name:
                logger.info(f'skipping {model_name} | no match with {_filter}')
                continue

        if not os.path.isfile('{}/progress.csv'.format(d)):
            logger.info('skipping {} | file not found'.format(model_name))
            continue

        dic = csv.DictReader(open(f'{d}/progress.csv'))
        for cn in column_names:
            assert cn in dic.fieldnames, f'{cn} not in {dic.fieldnames}'

        for row in dic:
            for cn in column_names:
                run_data[cn].append(row[cn])

        for cn in column_names:
            data[cn].append(run_data[cn])

    for cn in column_names:
        data[cn] = np.asarray(data[cn], dtype=np.float32)
    return data
Ejemplo n.º 2
0
from forkan.models import VAE
from forkan.datasets.dsprites import load_dsprites

from macode.vae.plot_helper import bars, plot_losses

logger = logging.getLogger(__name__)

network = 'dsprites'
filter = ''
plt_shape = [1, 10]

# whether to plot sigma-bars, kl plots and losses
modes = [True, False]

models_dir = '{}vae-{}/'.format(model_path, network)
dirs = ls_dir(models_dir)

for d in dirs:
    ds_name = d.split('/')[-1].split('-')[0]
    model_name = d.split('/')[-1]

    if filter is not None and not '':
        if filter not in model_name:
            logger.info('skipping {}'.format(model_name))
            continue

    # sigma bars
    if modes[0]:
        (data, _) = load_dsprites('translation', repetitions=10)

        v = VAE(load_from=model_name)