def read_keys(_dir, _filter, column_names): data = {} for cn in column_names: data.update({cn: []}) dirs = ls_dir(_dir) print(len(dirs), dirs) assert len(dirs) > 0, f'filter {_filter} not found in {_dir}' for d in dirs: model_name = d.split('/')[-1] run_data = {} for cn in column_names: run_data.update({cn: []}) if _filter is not None and not '': if _filter not in model_name: logger.info(f'skipping {model_name} | no match with {_filter}') continue if not os.path.isfile('{}/progress.csv'.format(d)): logger.info('skipping {} | file not found'.format(model_name)) continue dic = csv.DictReader(open(f'{d}/progress.csv')) for cn in column_names: assert cn in dic.fieldnames, f'{cn} not in {dic.fieldnames}' for row in dic: for cn in column_names: run_data[cn].append(row[cn]) for cn in column_names: data[cn].append(run_data[cn]) for cn in column_names: data[cn] = np.asarray(data[cn], dtype=np.float32) return data
from forkan.models import VAE from forkan.datasets.dsprites import load_dsprites from macode.vae.plot_helper import bars, plot_losses logger = logging.getLogger(__name__) network = 'dsprites' filter = '' plt_shape = [1, 10] # whether to plot sigma-bars, kl plots and losses modes = [True, False] models_dir = '{}vae-{}/'.format(model_path, network) dirs = ls_dir(models_dir) for d in dirs: ds_name = d.split('/')[-1].split('-')[0] model_name = d.split('/')[-1] if filter is not None and not '': if filter not in model_name: logger.info('skipping {}'.format(model_name)) continue # sigma bars if modes[0]: (data, _) = load_dsprites('translation', repetitions=10) v = VAE(load_from=model_name)