import os import json from math import ceil, log from pprint import pprint, pformat from mpi4py import MPI import math from nas4candle.nasapi.evaluator import Evaluator from nas4candle.nasapi.search import util, Search from nas4candle.nasapi.search.nas.agent import nas_ppo_async_a3c_emb logger = util.conf_logger('nas4candle.nasapi.search.nas.ppo_a3c_async') def print_logs(runner): logger.debug('num_episodes = {}'.format(runner.global_episode)) logger.debug(' workers = {}'.format(runner.workers)) def key(d): return json.dumps(dict(arch_seq=d['arch_seq'])) LAUNCHER_NODES = int(os.environ.get('BALSAM_LAUNCHER_NODES', 1)) WORKERS_PER_NODE = int(os.environ.get('nas4candle.nasapi_WORKERS_PER_NODE', 1)) class NasPPOAsyncA3C(Search): def __init__(self, problem, run, evaluator, **kwargs): self.rank = MPI.COMM_WORLD.Get_rank()
from random import random import traceback import numpy as np from tensorflow import keras from nas4candle.nasapi.search import util from nas4candle.nasapi.search.nas.model.trainer.classifier_train_valid import \ TrainerClassifierTrainValid from nas4candle.nasapi.search.nas.model.trainer.regressor_train_valid import \ TrainerRegressorTrainValid logger = util.conf_logger('nas4candle.nasapi.search.nas.run') def run(config): # load functions load_data = util.load_attr_from(config['load_data']['func']) config['load_data']['func'] = load_data config['create_structure']['func'] = util.load_attr_from( config['create_structure']['func']) # Loading data kwargs = config['load_data'].get('kwargs') data = load_data() if kwargs is None else load_data(**kwargs) logger.info(f'Data loaded with kwargs: {kwargs}') # Set data shape if type(data) is tuple: if len(data) != 2: raise RuntimeError(
import tensorflow as tf from tensorflow import keras import numpy as np import math from sklearn.metrics import mean_squared_error import nas4candle.nasapi.search.nas.model.arch as a import nas4candle.nasapi.search.nas.model.train_utils as U from nas4candle.nasapi.search import util from nas4candle.nasapi.search.nas.utils._logging import JsonMessage as jm logger = util.conf_logger('nas4candle.nasapi.model.trainer') class TrainerTrainValid: def __init__(self, config, model): self.cname = self.__class__.__name__ self.config = config self.model = model self.callbacks = [] self.data = self.config[a.data] self.config_hp = self.config[a.hyperparameters] self.optimizer_name = self.config_hp[a.optimizer] self.loss_metric_name = self.config_hp[a.loss_metric] self.metrics_name = [U.selectMetric(m) for m in self.config_hp[a.metrics]] self.batch_size = self.config_hp[a.batch_size] self.learning_rate = self.config_hp[a.learning_rate] self.num_epochs = self.config_hp[a.num_epochs]
import json import os.path as osp import numpy as np import tensorflow as tf from mpi4py import MPI import nas4candle.nasapi.search.nas.utils.common.tf_util as U from nas4candle.nasapi.evaluator import Evaluator from nas4candle.nasapi.search.nas.env import NasEnv from nas4candle.nasapi.search.nas.utils import bench, logger from nas4candle.nasapi.search.nas.utils.common import set_global_seeds from nas4candle.nasapi.search import util from nas4candle.nasapi.search.nas.utils._logging import JsonMessage as jm dh_logger = util.conf_logger('nas4candle.nasapi.search.nas.agent.nas_random') def traj_segment_generator(env, horizon): t = 0 ac = env.action_space.sample() # not used, just so we have the datatype new = True # marks if we're on first timestep of an episode ob = env.reset() cur_ep_ret = 0 # return in current episode cur_ep_len = 0 # len of current episode ep_rets = [] # returns of completed episodes in this segment ep_lens = [] # lengths of ... ts_i2n_ep = {}
import numpy as np import tensorflow as tf from mpi4py import MPI import nas4candle.nasapi.search.nas.utils.common.tf_util as U from nas4candle.nasapi.search import util from nas4candle.nasapi.search.nas.agent.utils import ( reward_for_final_timestep, traj_segment_generator) from nas4candle.nasapi.search.nas.utils._logging import JsonMessage as jm from nas4candle.nasapi.search.nas.utils.common import (Dataset, explained_variance, fmt_row, zipsame) from nas4candle.nasapi.search.nas.utils.common.mpi_adam_async import MpiAdamAsync from nas4candle.nasapi.search.nas.utils.common.mpi_moments import mpi_moments dh_logger = util.conf_logger('nas4candle.nasapi.search.nas.agent.pposgd_async') def add_vtarg_and_adv(seg, gamma, lam): """ Compute target value using TD(lambda) estimator, and advantage with GAE(lambda) """ new = np.append( seg["new"], 0 ) # last element is only used for last vtarg, but we already zeroed it if last new = 1 vpred = np.append(seg["vpred"], seg["nextvpred"]) T = len(seg["rew"]) seg["adv"] = gaelam = np.empty(T, 'float32') rew = seg["rew"] lastgaelam = 0 for t in reversed(range(T)):
* ``acq-func`` : Acquisition function * ``LCB`` : * ``EI`` : * ``PI`` : * ``gp_hedge`` : (default) """ import signal from nas4candle.nasapi.search.hps.optimizer import Optimizer from nas4candle.nasapi.search import Search from nas4candle.nasapi.search import util logger = util.conf_logger('nas4candle.nasapi.search.hps.ambs') SERVICE_PERIOD = 2 # Delay (seconds) between main loop iterations CHECKPOINT_INTERVAL = 10 # How many jobs to complete between optimizer checkpoints EXIT_FLAG = False def on_exit(signum, stack): global EXIT_FLAG EXIT_FLAG = True class AMBS(Search): def __init__(self, problem, run, evaluator, **kwargs): super().__init__(problem, run, evaluator, **kwargs) logger.info("Initializing AMBS")
import time import numpy as np import tensorflow as tf from mpi4py import MPI import nas4candle.nasapi.search.nas.utils.common.tf_util as U from nas4candle.nasapi.search import util from nas4candle.nasapi.search.nas.utils._logging import JsonMessage as jm TAG_UPDATE_START = 1 TAG_UPDATE_DONE = 2 dh_logger = util.conf_logger( 'nas4candle.nasapi.baselines.common.mpi_adam_async') class MpiAdamAsync(object): def __init__(self, var_list, *, beta1=0.9, beta2=0.999, epsilon=1e-08, scale_grad_by_procs=True, comm=None): self.var_list = var_list self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon self.scale_grad_by_procs = scale_grad_by_procs