tf.config.experimental.set_memory_growth( tf.config.list_physical_devices('GPU')[0], True) tf.debugging.set_log_device_placement(False) tf.autograph.set_verbosity(0) tf.random.set_seed(8) np.random.seed(8) args = ArgController(\ ).add("--override", "Override trained model", False ).parse() SAVE_PATH = "/tmp/vae_audio" if os.path.exists(SAVE_PATH) and args.override: clean_folder(SAVE_PATH, verbose=True) if not os.path.exists(SAVE_PATH): os.makedirs(SAVE_PATH) MODEL_PATH = os.path.join(SAVE_PATH, 'model') # =========================================================================== # Configs # =========================================================================== ZDIM = 32 MAX_LENGTH = 48 BUFFER_SIZE = 100 PARALLEL = tf.data.experimental.AUTOTUNE # GaussianLayer, GammaLayer, NegativeBinomialLayer # POSTERIOR = partialclass(bay.layers.GammaLayer, # concentration_activation='softplus1', # rate_activation='softplus1')
def run(self, overrides=[], ncpu=None, **configs): r""" Arguments: strict: A Boolean, strict configurations prevent the access to unknown key, otherwise, the config will return `None`. Example: exp = SisuaExperimenter(ncpu=1) exp.run( overrides={ 'model': ['sisua', 'dca', 'vae'], 'dataset.name': ['cortex', 'pbmc8kly'], 'train.verbose': 0, 'train.epochs': 2, 'train': ['adam'], }) """ overrides = _overrides(overrides) + _overrides(configs) strict = False command = ' '.join(sys.argv) # parse ncpu if ncpu is None: ncpu = self.ncpu ncpu = int(ncpu) for idx, arg in enumerate(list(sys.argv)): if 'ncpu' in arg: if '=' in arg: ncpu = int(arg.split('=')[-1]) sys.argv.pop(idx) else: ncpu = int(sys.argv[idx + 1]) sys.argv.pop(idx) sys.argv.pop(idx) break # check reset for idx, arg in enumerate(list(sys.argv)): if arg in ('--reset', '--clear', '--clean'): configs_filter = lambda f: 'configs' != f.split('/')[-1] if len( get_all_files(self._save_path, filter_func=configs_filter)) > 0: old_exps = '\n'.join([ " - %s" % i for i in os.listdir(self._save_path) if configs_filter(i) ]) inp = input("<Enter> to clear all exists experiments:" "\n%s\n'n' to cancel, otherwise continue:" % old_exps) if inp.strip().lower() != 'n': clean_folder(self._save_path, filter=configs_filter, verbose=True) sys.argv.pop(idx) # check multirun is_multirun = any(',' in ovr for ovr in overrides) or \ any(',' in arg and '=' in arg for arg in sys.argv) # write history self.write_history(command, "overrides: %s" % str(overrides), "strict: %s" % str(strict), "ncpu: %d" % ncpu, "multirun: %s" % str(is_multirun)) # generate app help hlp = '\n\n'.join([ "%s - %s" % (str(key), ', '.join(sorted(as_tuple(val, t=str)))) for key, val in dict(self.args_help).items() ]) def _run(self, config_file, task_function, overrides): if is_multirun: raise RuntimeError( "Performing single run with multiple overrides in hydra " "(use '-m' for multirun): %s" % str(overrides)) cfg = self.compose_config(config_file=config_file, overrides=overrides, strict=strict, with_log_configuration=True) HydraConfig().set_config(cfg) return run_job( config=cfg, task_function=task_function, job_dir_key="hydra.run.dir", job_subdir_key=None, ) def _multirun(self, config_file, task_function, overrides): # Initial config is loaded without strict (individual job configs may have strict). from hydra._internal.plugins import Plugins cfg = self.compose_config(config_file=config_file, overrides=overrides, strict=strict, with_log_configuration=True) HydraConfig().set_config(cfg) sweeper = Plugins.instantiate_sweeper( config=cfg, config_loader=self.config_loader, task_function=task_function) # override launcher for using multiprocessing sweeper.launcher = ParallelLauncher(ncpu=ncpu) sweeper.launcher.setup(config=cfg, config_loader=self.config_loader, task_function=task_function) return sweeper.sweep(arguments=cfg.hydra.overrides.task) old_multirun = (Hydra.run, Hydra.multirun) Hydra.run = _run Hydra.multirun = _multirun try: # append the new override if len(overrides) > 0: sys.argv += overrides # help for arguments if '--help' in sys.argv: # sys.argv.append("hydra.help.header='**** %s ****'" % # self.__class__.__name__) # sys.argv.append("hydra.help.template=%s" % (_APP_HELP % hlp)) # TODO : fix bug here pass # append the hydra log path job_fmt = "/${now:%d%b%y_%H%M%S}" sys.argv.insert( 1, "hydra.run.dir=%s" % self.get_hydra_path() + job_fmt) sys.argv.insert( 1, "hydra.sweep.dir=%s" % self.get_hydra_path() + job_fmt) sys.argv.insert(1, "hydra.sweep.subdir=${hydra.job.id}") # sys.argv.append(r"hydra.job_logging.formatters.simple.format=" + # r"[\%(asctime)s][\%(name)s][\%(levelname)s] - \%(message)s") args_parser = get_args_parser() run_hydra( args_parser=args_parser, task_function=self._run, config_path=self.config_path, strict=strict, ) except KeyboardInterrupt: sys.exit(-1) except SystemExit: pass Hydra.run = old_multirun[0] Hydra.multirun = old_multirun[1] # update the summary self.summary() return self
def main(cfg): save_to_yaml(cfg) if cfg.ds == 'news5': ds = Newsgroup5() elif cfg.ds == 'news20': ds = Newsgroup20() elif cfg.ds == 'news20clean': ds = Newsgroup20_clean() elif cfg.ds == 'cortex': ds = Cortex() elif cfg.ds == 'lkm': ds = LeukemiaATAC() else: raise NotImplementedError(f"No support for dataset: {cfg.ds}") train = ds.create_dataset(batch_size=batch_size, partition='train', drop_remainder=True) valid = ds.create_dataset(batch_size=batch_size, partition='valid') test = ds.create_dataset(batch_size=batch_size, partition='test') n_words = ds.vocabulary_size vocabulary = ds.vocabulary ######## prepare the path output_dir = get_output_dir() if not os.path.exists(output_dir): os.makedirs(output_dir) model_path = os.path.join(output_dir, 'model') if cfg.override: clean_folder(output_dir, verbose=True) ######### preparing all layers lda = LatentDirichletDecoder( posterior=cfg.posterior, distribution=cfg.distribution, n_words=n_words, n_topics=cfg.n_topics, warmup=cfg.warmup, ) fit_kw = dict(train=train, valid=valid, max_iter=cfg.n_iter, optimizer='adam', learning_rate=learning_rate, batch_size=batch_size, valid_freq=valid_freq, compile_graph=True, logdir=output_dir, skip_fitted=True) output_dist = RVconf( n_words, cfg.distribution, projection=True, preactivation='softmax' if cfg.distribution == 'onehot' else 'linear', kwargs=dict(probs_input=True) if cfg.distribution == 'onehot' else {}, name="Words") latent_dist = RVconf(cfg.n_topics, 'mvndiag', projection=True, name="Latents") ######## AmortizedLDA if cfg.model == 'lda': vae = AmortizedLDA(lda=lda, encoder=NetConf([300, 300, 300], name='Encoder'), decoder='identity', latents='identity', path=model_path) vae.fit(on_valid_end=partial(callback, vae=vae, test=test, vocabulary=vocabulary), **fit_kw) ######## VDA - Variational Dirichlet Autoencoder elif cfg.model == 'vda': vae = BetaVAE( beta=cfg.beta, encoder=NetConf([300, 150], name='Encoder'), decoder=NetConf([150, 300], name='Decoder'), latents=RVconf(cfg.n_topics, 'dirichlet', projection=True, prior=None, name="Topics"), outputs=output_dist, # important, MCMC KL for Dirichlet is very unstable analytic=True, path=model_path, name="VDA") vae.fit(on_valid_end=partial(callback1, vae=vae, test=test, vocabulary=vocabulary), **dict(fit_kw, valid_freq=1000, optimizer=tf.optimizers.Adam(learning_rate=1e-4))) ######## VAE elif cfg.model == 'model': vae = BetaVAE(beta=cfg.beta, encoder=NetConf([300, 300], name='Encoder'), decoder=NetConf([300], name='Decoder'), latents=latent_dist, outputs=output_dist, path=model_path, name="VAE") callback1(vae, test, vocabulary) vae.fit(on_valid_end=partial(callback1, vae=vae, test=test, vocabulary=vocabulary), **dict(fit_kw, valid_freq=1000, optimizer=tf.optimizers.Adam(learning_rate=1e-4))) ######## factorVAE elif cfg.model == 'fvae': vae = FactorVAE(gamma=6.0, beta=cfg.beta, encoder=NetConf([300, 150], name='Encoder'), decoder=NetConf([150, 300], name='Decoder'), latents=latent_dist, outputs=output_dist, path=model_path) vae.fit(on_valid_end=partial(callback1, vae=vae, test=test, vocabulary=vocabulary), **dict(fit_kw, valid_freq=1000, optimizer=[ tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.9, beta_2=0.999), tf.optimizers.Adam(learning_rate=1e-4, beta_1=0.5, beta_2=0.9) ])) ######## TwoStageLDA elif cfg.model == 'lda2': vae0_iter = 10000 vae0 = BetaVAE(beta=1.0, encoder=NetConf(units=[300], name='Encoder'), decoder=NetConf(units=[300, 300], name='Decoder'), outputs=DistributionDense( (n_words, ), posterior='onehot', posterior_kwargs=dict(probs_input=True), activation='softmax', name="Words"), latents=RVconf(cfg.n_topics, 'mvndiag', projection=True, name="Latents"), input_shape=(n_words, ), path=model_path + '_vae0') vae0.fit(on_valid_end=lambda: None if get_current_trainer().is_training else vae0.save_weights(), **dict(fit_kw, logdir=output_dir + "_vae0", max_iter=vae0_iter, learning_rate=learning_rate, track_gradients=False)) vae = TwoStageLDA(lda=lda, encoder=vae0.encoder, decoder=vae0.decoder, latents=vae0.latent_layers, warmup=cfg.warmup - vae0_iter, path=model_path) vae.fit(on_valid_end=partial(callback, vae=vae, test=test, vocabulary=vocabulary), **dict(fit_kw, max_iter=cfg.n_iter - vae0_iter, track_gradients=False)) ######## EM-LDA elif cfg.model == 'em': if os.path.exists(model_path): with open(model_path, 'rb') as f: lda = pickle.load(f) else: writer = tf.summary.create_file_writer(output_dir) lda = LatentDirichletAllocation(n_components=cfg.n_topics, doc_topic_prior=0.7, learning_method='online', verbose=True, n_jobs=4, random_state=1) with writer.as_default(): prog = tqdm(train.repeat(-1), desc="Fitting LDA") for n_iter, x in enumerate(prog): lda.partial_fit(x) if n_iter % 500 == 0: text = get_topics_text(lda.components_, vocabulary) perp = lda.perplexity(test) tf.summary.text("topics", text, n_iter) tf.summary.scalar("perplexity", perp, n_iter) prog.write(f"[#{n_iter}]Perplexity: {perp:.2f}") prog.write("\n".join(text)) if n_iter >= 20000: break with open(model_path, 'wb') as f: pickle.dump(lda, f) # final evaluation text = get_topics_text(lda, vocabulary) final_score = lda.perplexity(data['test']) tf.summary.scalar("perplexity", final_score, step=n_iter + 1) print(f"Perplexity:", final_score) print("\n".join(text))
def run_hydra(output_dir: str = '/tmp/outputs', exclude_keys: List[str] = []) -> Callable[[TaskFunction], Any]: """ A modified main function of Hydra-core for flexibility Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved Useful commands: - `hydra/launcher=joblib` enable joblib launcher - `hydra.launcher.n_jobs=-1` set maximum number of processes - `--list` or `--summary` : list all exist experiments - `-j2` : run multi-processing (with 2 processes) - `--override` : override existed model of the given experiment - `--reset` : remove all files and folder in the output_dir Examples -------- ``` @experimenter.run_hydra() def run(cfg: DictConfig): print(cfg, type(cfg)) run("/tmp/conf/base.yaml") ``` Note ------ to add `db` sub-config file to `database` object, from command line `python tmp.py db=mysql`, with the `#@package database` on top of the `mysql.yaml`, otherwise, from `base.yaml` ``` defaults: - db: mysql ``` Adding db to specific attribute of `database` object, from command line `python tmp.py [email protected]=mysql`, otherwise, from `base.yaml` ``` defaults: - [email protected]: mysql ``` """ output_dir = _abspath(output_dir) ### check if reset all the experiments for i, a in enumerate(list(sys.argv)): if RESET_PATTERN.match(a): print('*Reset all experiments:') clean_folder(output_dir, verbose=True) sys.argv.pop(i) break ### create the log dir log_dir = os.path.join(output_dir, 'logs') if not os.path.exists(log_dir): os.makedirs(log_dir) def main_decorator(task_function: TaskFunction) -> Callable[[], None]: @functools.wraps(task_function) def decorated_main( config: Union[str, dict, list, tuple, DictConfig]) -> Any: ### string if isinstance(config, string_types): # path to a config file if os.path.isfile(config): config_name = os.path.basename(config).replace(".yaml", "") config_path = os.path.dirname(config) # path to a directory elif os.path.isdir(config): config_path = config if not os.path.exists( os.path.join(config_path, 'base.yaml')): config_name = "base" # default name else: config_name = sorted([ i for i in os.listdir(config_path) if '.yaml' in i ])[0].replace(".yaml", "") # YAML string else: config_path, config_name = _save_config_to_tempdir(config) ### dictionary, tuple, list, DictConfig else: config_path, config_name = _save_config_to_tempdir(config) ### list all experiments command for a in sys.argv: if LIST_PATTERN.match(a) or SUMMARY_PATTERN.match(a): print("Output dir:", output_dir) all_logs = defaultdict(list) for i in os.listdir(log_dir): name, time_str = i.replace('.log', '').split(':') all_logs[name].append( (time_str, os.path.join(log_dir, i))) for fname in sorted(os.listdir(output_dir)): path = os.path.join(output_dir, fname) # basics meta print( f" {fname}", f"({len(os.listdir(path))} files)" if os.path.isdir(path) else "") # show the log files info if fname in all_logs: for time_str, log_file in all_logs[fname]: with open(log_file, 'r') as f: log_data = f.read() lines = log_data.split('\n') n = len(lines) print( f' log {datetime.strptime(time_str, TIME_FMT)} ({n} lines)' ) for e in [ l for l in lines if '[ERROR]' in l ]: print(f' {e.split("[ERROR]")[1]}') exit() ### check if overrides provided is_overrided = False for a in sys.argv: match = OVERRIDE_PATTERN.match(a) if match and not any(k in match.string for k in exclude_keys): is_overrided = True ### formatting output dirs if is_overrided: override_id = r"${hydra.job.override_dirname}" else: override_id = r"default" ### check if enable remove exists experiment remove_exists = False for i, a in enumerate(list(sys.argv)): match = REMOVE_EXIST_PATTERN.match(a) if match: remove_exists = True sys.argv.pop(i) break ### parallel jobs provided jobs = 1 for i, a in enumerate(list(sys.argv)): match = JOBS_PATTERN.match(a) if match: jobs = int(match.groups()[-1]) sys.argv.pop(i) break if jobs > 1: _insert_argv(key="hydra/launcher", value="joblib", is_value_string=False) _insert_argv(key="hydra.launcher.n_jobs", value=f"{jobs}", is_value_string=False) ### running dirs _insert_argv(key="hydra.run.dir", value=f"{output_dir}/{override_id}", is_value_string=True) _insert_argv(key="hydra.sweep.dir", value=f"{output_dir}/multirun/{HYDRA_TIME_FMT}", is_value_string=True) _insert_argv(key="hydra.job_logging.handlers.file.filename", value=f"{log_dir}/{override_id}:{HYDRA_TIME_FMT}.log", is_value_string=True) _insert_argv(key="hydra.job.config.override_dirname.exclude_keys", value=f"[{','.join([str(i) for i in exclude_keys])}]", is_value_string=False) # no return value from run_hydra() as it may sometime actually run the task_function # multiple times (--multirun) args = get_args_parser() config_path = _abspath(config_path) ## prepare arguments for task_function spec = inspect.getfullargspec(task_function) ## run hydra @functools.wraps(task_function) def _task_function(_cfg): # print out the running config cfg_text = '\n ----------- \n' cfg_text += OmegaConf.to_yaml(_cfg)[:-1] cfg_text += '\n -----------' logger.info(cfg_text) # remove the exists if remove_exists: output_dir = get_output_dir() dir_base = os.path.dirname(output_dir) dir_name = os.path.basename(output_dir) for folder in get_all_folder(dir_base): if dir_name == os.path.basename(folder): clear_folder(folder, verbose=True) # catch exception, continue running in case try: task_function(_cfg) except Exception as e: _, value, tb = sys.exc_info() for line in traceback.TracebackException( type(value), value, tb, limit=None).format(chain=None): logger.error(line) if jobs == 1: raise e _run_hydra( args_parser=args, task_function=_task_function, config_path=config_path, config_name=config_name, strict=None, ) return decorated_main return main_decorator
def run_hydra(output_dir: str = '/tmp/outputs', exclude_keys: List[str] = []) -> Callable[[TaskFunction], Any]: r""" A modified main function of Hydra-core for flexibility Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved Useful commands: - `hydra/launcher=joblib` enable joblib launcher - `hydra.launcher.n_jobs=-1` set maximum number of processes Example: ``` @experimenter.run_hydra() def run(cfg: DictConfig): print(cfg, type(cfg)) run("/tmp/conf/base.yaml") ``` Note: to add `db` sub-config file to `database` object, from command line `python tmp.py db=mysql`, with the `#@package database` on top of the `mysql.yaml`, otherwise, from `base.yaml` ``` defaults: - db: mysql ``` Adding db to specific attribute of `database` object, from command line `python tmp.py [email protected]=mysql`, otherwise, from `base.yaml` ``` defaults: - [email protected]: mysql ``` """ output_dir = _abspath(output_dir) ### check if reset all the experiments for i, a in enumerate(list(sys.argv)): if RESET_PATTERN.match(a): print('*Reset all experiments:') clean_folder(output_dir, verbose=True) sys.argv.pop(i) break ### create the log dir log_dir = os.path.join(output_dir, 'logs') if not os.path.exists(log_dir): os.makedirs(log_dir) def main_decorator(task_function: TaskFunction) -> Callable[[], None]: @functools.wraps(task_function) def decorated_main( config: Union[str, dict, list, tuple, DictConfig]) -> Any: ### string if isinstance(config, string_types): # path to a config file if os.path.isfile(config): config_name = os.path.basename(config).replace(".yaml", "") config_path = os.path.dirname(config) # path to a directory elif os.path.isdir(config): config_path = config if not os.path.exists(os.path.join(config_path, 'base.yaml')): config_name = "base" # default name else: config_name = sorted([ i for i in os.listdir(config_path) if '.yaml' in i ])[0].replace(".yaml", "") # YAML string else: config_path, config_name = _save_config_to_tempdir(config) ### dictionary, tuple, list, DictConfig else: config_path, config_name = _save_config_to_tempdir(config) ### list all experiments command for a in sys.argv: if LIST_PATTERN.match(a): print("Output dir:", output_dir) for fname in sorted(os.listdir(output_dir)): path = os.path.join(output_dir, fname) print( f" {fname}", f"({len(os.listdir(path))} files)" if os.path.isdir(path) else "") exit() ### check if overrides provided is_overrided = False for a in sys.argv: match = OVERRIDE_PATTERN.match(a) if match and not any(k in match.string for k in exclude_keys): is_overrided = True ### formatting output dirs time_fmt = r"${now:%j_%H%M%S}" if is_overrided: override_id = r"${hydra.job.override_dirname}" else: override_id = r"default" ### parallel jobs provided jobs = 1 for i, a in enumerate(list(sys.argv)): match = JOBS_PATTERN.match(a) if match: jobs = int(match.groups()[-1]) sys.argv.pop(i) break if jobs > 1: _insert_argv(key="hydra/launcher", value="joblib", is_value_string=False) _insert_argv(key="hydra.launcher.n_jobs", value=f"{jobs}", is_value_string=False) ### running dirs _insert_argv(key="hydra.run.dir", value=f"{output_dir}/{override_id}", is_value_string=True) _insert_argv(key="hydra.sweep.dir", value=f"{output_dir}/multirun/{time_fmt}", is_value_string=True) _insert_argv(key="hydra.job_logging.handlers.file.filename", value=f"{log_dir}/{override_id}:{time_fmt}.log", is_value_string=True) _insert_argv(key="hydra.job.config.override_dirname.exclude_keys", value=f"[{','.join([str(i) for i in exclude_keys])}]", is_value_string=False) # no return value from run_hydra() as it may sometime actually run the task_function # multiple times (--multirun) args = get_args_parser() config_path = _abspath(config_path) ## prepare arguments for task_function spec = inspect.getfullargspec(task_function) ## run hydra _run_hydra( args_parser=args, task_function=task_function, config_path=config_path, config_name=config_name, strict=None, ) return decorated_main return main_decorator