def train_nas(arg): wd = arg.wd logger = setup_logger(wd, verbose_level=logging.INFO) gpus = get_available_gpus() configs, config_keys, controller, model_space = read_configs(arg) # Setup env kwargs. tmp = dict(data_descriptive_features=np.stack( [configs[k]["dfeatures"] for k in config_keys]), controller=controller, manager=[configs[k]["manager"] for k in config_keys], logger=logger, max_episode=200, max_step_per_ep=3, working_dir=wd, time_budget="150:00:00", with_input_blocks=False, with_skip_connection=False, save_controller_every=1, resume_prev_run=arg.resume) if arg.parallel is True: env = ParallelMultiManagerEnvironment(processes=len(gpus), **tmp) else: env = MultiManagerEnvironment(**tmp) try: env.train() except KeyboardInterrupt: print("user interrupted training") pass controller.save_weights(os.path.join(wd, "controller_weights.h5"))
def get_manager_common(train_data, val_data, controller, model_space, wd, data_description, verbose=2, n_feats=1, **kwargs): input_node = State('input', shape=(1000, 4), name="input", dtype='float32') output_node = State('dense', units=n_feats, activation='sigmoid') model_compile_dict = { 'loss': 'binary_crossentropy', 'optimizer': 'adam', 'metrics': ['acc'] } session = controller.session reward_fn = LossAucReward(method='auc') gpus = get_available_gpus() num_gpus = len(gpus) mb = KerasModelBuilder(inputs=input_node, outputs=output_node, model_compile_dict=model_compile_dict, model_space=model_space, gpus=num_gpus) # TODO: batch_size here is not effective because it's been set at generator init child_batch_size = 1000 * num_gpus manager = GeneralManager(train_data=train_data, validation_data=val_data, epochs=1000, child_batchsize=child_batch_size, reward_fn=reward_fn, model_fn=mb, store_fn='model_plot', model_compile_dict=model_compile_dict, working_dir=wd, verbose=verbose, save_full_model=True, model_space=model_space, fit_kwargs={ 'steps_per_epoch': 50, 'workers': 8, 'max_queue_size': 50, 'earlystop_patience': 20 }) return manager
def read_configs(arg): dfeature_names = list() with open(arg.dfeature_name_file, "r") as read_file: for line in read_file: line = line.strip() if line: dfeature_names.append(line) wd = arg.wd model_space, layer_embedding_sharing = get_model_space_common() print(layer_embedding_sharing) try: session = tf.Session() except AttributeError: session = tf.compat.v1.Session() controller = get_controller( model_space=model_space, session=session, data_description_len=len(dfeature_names), layer_embedding_sharing=layer_embedding_sharing) # Load in datasets and configurations for them. if arg.config_file.endswith("tsv"): sep = "\t" else: sep = "," configs = pd.read_csv(arg.config_file, sep=sep) tmp = list(configs.columns) # Because pandas doesn't have infer quotes... if any(["\"" in x for x in tmp]): configs = pd.read_csv(arg.config_file, sep=sep, quoting=2) print("Re-read with quoting") configs = configs.to_dict(orient='index') # Get available gpus for parsing to DistributedManager gpus = get_available_gpus() gpus_ = gpus * len(configs) manager_getter = get_manager_distributed if arg.parallel else get_manager_common config_keys = list() seed_generator = np.random.RandomState(seed=1337) for i, k in enumerate(configs.keys()): # Build datasets for train/test/validate splits. for x in ["train", "validate"]: if arg.lockstep_sampling is False and x == "train": cur_seed = seed_generator.randint(0, np.iinfo(np.uint32).max) else: cur_seed = 1337 d = { 'hdf5_fp': arg.train_file if x == 'train' else arg.val_file, 'y_selector': configs[k]['ds_col'], 'batch_size': 512 if arg.parallel else 512 * len(gpus), 'shuffle': x == 'train', } if arg.parallel is True: configs[k][x] = BatchedHDF5Generator configs[k]['%s_data_kwargs' % x] = d else: configs[k][x] = BatchedHDF5Generator(**d) # Build covariates and manager. configs[k]["dfeatures"] = np.array([ configs[k][x] for x in dfeature_names ]) # TODO: Make cols dynamic. tmp = dict() if arg.parallel is False else dict( train_data_kwargs=configs[k]['train_data_kwargs'], validate_data_kwargs=configs[k]["validate_data_kwargs"]) configs[k]["manager"] = manager_getter( devices=[gpus_[i]], train_data=configs[k]["train"], val_data=configs[k]["validate"], controller=controller, model_space=model_space, wd=os.path.join(wd, "manager_%s" % k), data_description=configs[k]["dfeatures"], dag_name="AmberDAG{}".format(k), verbose=0 if arg.parallel is True else 2, n_feats=configs[k]["n_feats"], **tmp) config_keys.append(k) return configs, config_keys, controller, model_space
def train_nas(arg): dfeature_names = list() with open(arg.dfeature_name_file, "r") as read_file: for line in read_file: line = line.strip() if line: dfeature_names.append(line) wd = arg.wd verbose = 1 model_space, layer_embedding_sharing = get_model_space_common() print(layer_embedding_sharing) try: session = tf.Session() except AttributeError: session = tf.compat.v1.Session() controller = get_controller( model_space=model_space, session=session, data_description_len=len(dfeature_names), layer_embedding_sharing=layer_embedding_sharing) # Re-load previously saved weights, if specified if arg.resume: try: controller.load_weights(os.path.join(wd, "controller_weights.h5")) print("loaded existing weights") except Exception as e: print("cannot load controller weights because of %s" % e) # Load in datasets and configurations for them. if arg.config_file.endswith("tsv"): sep = "\t" else: sep = "," configs = pd.read_csv(arg.config_file, sep=sep) tmp = list(configs.columns) # Because pandas doesn't have infer quotes... if any(["\"" in x for x in tmp]): configs = pd.read_csv(arg.config_file, sep=sep, quoting=2) print("Re-read with quoting") configs = configs.to_dict(orient='index') # Get available gpus for parsing to DistributedManager gpus = get_available_gpus() gpus_ = gpus * len(configs) # Build genome. This only works under the assumption that all configs use same genome. genome = EncodedHDF5Genome(input_path=arg.genome_file, in_memory=False) manager_getter = get_manager_distributed if arg.parallel else get_manager_common config_keys = list() seed_generator = np.random.RandomState(seed=1337) for i, k in enumerate(configs.keys()): # Build datasets for train/test/validate splits. for x in ["train", "validate"]: if arg.lockstep_sampling is False and x == "train": cur_seed = seed_generator.randint(0, np.iinfo(np.uint32).max) else: cur_seed = 1337 if x == "train": n = arg.n_train elif x == "test": n = arg.n_test elif x == "validate": n = arg.n_validate else: s = "Unknown mode: {}".format(x) raise ValueError(s) d = { 'example_file': configs[k][x + "_file"], 'reference_sequence': arg.genome_file, 'batch_size': 512 if arg.parallel else 512 * len(gpus), 'seed': cur_seed, 'shuffle': x == 'train', 'n_examples': n, 'pad': 400 } if arg.parallel is True: configs[k][ x] = BatchedBioIntervalSequenceGenerator if x == 'train' else BatchedBioIntervalSequence configs[k]['%s_data_kwargs' % x] = d else: configs[k][x] = BatchedBioIntervalSequenceGenerator( **d) if x == 'train' else BatchedBioIntervalSequence(**d) # Build covariates and manager. configs[k]["dfeatures"] = np.array([ configs[k][x] for x in dfeature_names ]) # TODO: Make cols dynamic. tmp = dict() if arg.parallel is False else dict( train_data_kwargs=configs[k]['train_data_kwargs'], validate_data_kwargs=configs[k]["validate_data_kwargs"]) configs[k]["manager"] = manager_getter( devices=[gpus_[i]], train_data=configs[k]["train"], val_data=configs[k]["validate"], controller=controller, model_space=model_space, wd=os.path.join(wd, "manager_%s" % k), data_description=configs[k]["dfeatures"], dag_name="AmberDAG{}".format(k), verbose=0 if arg.parallel is True else 2, n_feats=configs[k]["n_feats"], **tmp) config_keys.append(k) logger = setup_logger(wd, verbose_level=logging.INFO) # Setup env kwargs. tmp = dict(data_descriptive_features=np.stack( [configs[k]["dfeatures"] for k in config_keys]), controller=controller, manager=[configs[k]["manager"] for k in config_keys], logger=logger, max_episode=200, max_step_per_ep=3, working_dir=wd, time_budget="150:00:00", with_input_blocks=False, with_skip_connection=False, save_controller_every=1) if arg.parallel is True: env = ParallelMultiManagerEnvironment( processes=len(gpus) if arg.parallel else 1, **tmp) else: env = MultiManagerEnvironment(**tmp) try: env.train() except KeyboardInterrupt: print("user interrupted training") pass controller.save_weights(os.path.join(wd, "controller_weights.h5"))