def optimize_params(self, p0, num_iters, step_size, tolerance, verbose): """ Perform gradient descent using JAX optimizers. """ opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) opt_state = opt_init(p0) @jit def step(_i, _opt_state): p = get_params(_opt_state) g = grad(self.negative_log_evidence)(p) return opt_update(_i, g, _opt_state) cost_list = [] params_list = [] if verbose: self.print_progress_header(p0) for i in range(num_iters): opt_state = step(i, opt_state) params_list.append(get_params(opt_state)) cost_list.append(self.negative_log_evidence(params_list[-1])) if verbose: if i % verbose == 0: self.print_progress(i, params_list[-1], cost_list[-1]) if len(params_list) > tolerance: if jnp.all((jnp.array(cost_list[1:])) - jnp.array(cost_list[:-1]) > 0): params = params_list[0] if verbose: print('Stop: cost has been monotonically increasing for {} steps.'.format(tolerance)) break elif jnp.all(jnp.array(cost_list[:-1]) - jnp.array(cost_list[1:]) < 1e-5): params = params_list[-1] if verbose: print('Stop: cost has been stop changing for {} steps.'.format(tolerance)) break else: params_list.pop(0) cost_list.pop(0) else: params = params_list[-1] if verbose: print('Stop: reached {0} steps, final cost={1:.5f}.'.format(num_iters, cost_list[-1])) return params
def __init__(self, vocab_size, output_size, loss_fn, rng, temperature=1, learning_rate=0.001, conv_depth=300, n_conv_layers=2, n_dense_units=300, n_dense_layers=0, kernel_size=5, across_batch=False, add_pos_encoding=False, mean_over_pos=False, model_fn=build_model_stax): self.output_size = output_size self.temperature = temperature # Setup randomness. self.rng = rng model_settings = { "output_size": output_size, "n_dense_units": n_dense_units, "n_dense_layers": n_dense_layers, "conv_depth": conv_depth, "n_conv_layers": n_conv_layers, "across_batch": across_batch, "kernel_size": kernel_size, "add_pos_encoding": add_pos_encoding, "mean_over_pos": mean_over_pos, "mode": "train" } self._model_init, model_train = model_fn(**model_settings) self._model_train = jax.jit(model_train) model_settings["mode"] = "eval" _, model_predict = model_fn(**model_settings) self._model_predict = jax.jit(model_predict) self.rng, subrng = jrand.split(self.rng) _, init_params = self._model_init(subrng, (-1, -1, vocab_size)) self.params = init_params # Setup parameters for model and optimizer self.make_state, self._opt_update_state, self._get_params = adam( learning_rate) self.loss_fn = functools.partial(loss_fn, run_model_fn=self.run_model) self.loss_grad_fn = jax.grad(self.loss_fn) # Track steps of optimization so far. self._step_idx = 0
def __init__(self, optimizer_name, model, lr): if optimizer_name not in jax_optimizers: raise ValueError( "Optimizer {} is not implemented yet".format(name)) if optimizer_name == "adam": opt_init, opt_update, get_params = optimizers.adam(step_size=lr) self.opt_state = opt_init(model.net_params) self.get_params = get_params @jit def step(i, _opt_state, data): _params = get_params(_opt_state) loss_func = model.loss_func g = jit(jax.grad(loss_func))(_params, data) return opt_update(i, g, _opt_state) self._step = step self.iter_cnt = 0
return Trace_ELBO().loss(random.PRNGKey(0), {}, model, guide, x) def renyi_loss_fn(x): return RenyiELBO(alpha=alpha, num_particles=10).loss(random.PRNGKey(0), {}, model, guide, x) elbo_loss, elbo_grad = value_and_grad(elbo_loss_fn)(2.0) renyi_loss, renyi_grad = value_and_grad(renyi_loss_fn)(2.0) assert_allclose(elbo_loss, renyi_loss, rtol=1e-6) assert_allclose(elbo_grad, renyi_grad, rtol=1e-6) @pytest.mark.parametrize("elbo", [Trace_ELBO(), RenyiELBO(num_particles=10)]) @pytest.mark.parametrize( "optimizer", [optim.Adam(0.05), optimizers.adam(0.05)]) def test_beta_bernoulli(elbo, optimizer): data = jnp.array([1.0] * 8 + [0.0] * 2) def model(data): f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) with numpyro.plate("N", len(data)): numpyro.sample("obs", dist.Bernoulli(f), obs=data) def guide(data): alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) numpyro.sample("beta", dist.Beta(alpha_q, beta_q))
def optimize_params(self, p0, extra, num_epochs, num_iters, metric, step_size, tolerance, verbose, return_model=None) -> dict: """ Gradient descent using JAX optimizer, and verbose logging. """ if return_model is None: if extra is not None: return_model = 'best_dev_cost' else: return_model = 'best_train_cost' assert (extra is not None) or ( 'dev' not in return_model), 'Cannot use dev set if dev set is not given.' if num_epochs != 1: raise NotImplementedError() @jit def step(_i, _opt_state): p = get_params(_opt_state) g = grad(self.cost)(p) return opt_update(_i, g, _opt_state) # preallocation cost_train = np.zeros(num_iters) cost_dev = np.zeros(num_iters) metric_train = np.zeros(num_iters) metric_dev = np.zeros(num_iters) params_list = [] if verbose: self.print_progress_header(c_train=True, c_dev=extra, m_train=metric is not None, m_dev=metric is not None and extra) time_start = time.time() opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) opt_state = opt_init(p0) i = 0 c_dev = None m_train = None m_dev = None y_pred_dev = None for i in range(num_iters): opt_state = step(i, opt_state) params = get_params(opt_state) params_list.append(params) y_pred_train = self.forwardpass(p=params, extra=None) c_train = self.cost(p=params, precomputed=y_pred_train) cost_train[i] = c_train if extra is not None: y_pred_dev = self.forwardpass(p=params, extra=extra) c_dev = self.cost(p=params, extra=extra, precomputed=y_pred_dev) cost_dev[i] = c_dev if metric is not None: m_train = self.compute_score(self.y, y_pred_train, metric) metric_train[i] = m_train if extra is not None: m_dev = self.compute_score(extra['y'], y_pred_dev, metric) metric_dev[i] = m_dev time_elapsed = time.time() - time_start if verbose and (i % int(verbose) == 0): self.print_progress(i, time_elapsed, c_train=c_train, c_dev=c_dev, m_train=m_train, m_dev=m_dev) if tolerance and i > 300: # tolerance = 0: no early stop. total_time_elapsed = time.time() - time_start cost_train_slice = cost_train[i - tolerance:i] cost_dev_slice = cost_dev[i - tolerance:i] if jnp.all(cost_dev_slice[1:] - cost_dev_slice[:-1] > 0): stop = 'dev_stop' if verbose: print( f'Stop at {i} steps: ' + f'cost (dev) has been monotonically increasing for {tolerance} steps.\n' ) break if jnp.all( cost_train_slice[:-1] - cost_train_slice[1:] < 1e-5): stop = 'train_stop' if verbose: print( f'Stop at {i} steps: ' + f'cost (train) has been changing less than 1e-5 for {tolerance} steps.\n' ) break else: total_time_elapsed = time.time() - time_start stop = 'maxiter_stop' if verbose: print('Stop: reached {0} steps.\n'.format(num_iters)) if return_model == 'best_dev_cost': best = np.argmin(cost_dev[:i + 1]) elif return_model == 'best_train_cost': best = np.argmin(cost_train[:i + 1]) elif return_model == 'best_dev_metric': if metric in ['mse', 'gcv']: best = np.argmin(metric_dev[:i + 1]) else: best = np.argmax(metric_dev[:i + 1]) elif return_model == 'best_train_metric': if metric in ['mse', 'gcv']: best = np.argmin(metric_train[:i + 1]) else: best = np.argmax(metric_train[:i + 1]) else: if return_model != 'last': print( 'Provided `return_model` is not supported. Fallback to `last`' ) if stop == 'dev_stop': best = i - tolerance else: best = i params = params_list[best] metric_dev_opt = metric_dev[best] self.cost_train = cost_train[:i + 1] self.cost_dev = cost_dev[:i + 1] self.metric_train = metric_train[:i + 1] self.metric_dev = metric_dev[:i + 1] self.metric_dev_opt = metric_dev_opt self.total_time_elapsed = total_time_elapsed return params
def initialize_parametric_nonlinearity(self, init_to='exponential', method=None, params_dict=None): if method is None: # if no methods specified, use defaults. method = self.output_nonlinearity or self.filter_nonlinearity else: # otherwise, overwrite the default nonlinearity. self.output_nonlinearity = method if self.filter_nonlinearity is not None: self.filter_nonlinearity = method assert method is not None # prepare data if params_dict is None: params_dict = {} xrange = params_dict['xrange'] if 'xrange' in params_dict else 5 nx = params_dict['nx'] if 'nx' in params_dict else 1000 x0 = jnp.linspace(-xrange, xrange, nx) if init_to == 'exponential': y0 = jnp.exp(x0) elif init_to == 'softplus': y0 = softplus(x0) elif init_to == 'relu': y0 = relu(x0) elif init_to == 'nonparametric': y0 = self.fnl_nonparametric(x0) elif init_to == 'gaussian': import scipy.signal # noinspection PyUnresolvedReferences y0 = scipy.signal.gaussian(nx, nx / 10) else: raise NotImplementedError(init_to) # fit nonlin if method == 'spline': smooth = params_dict['smooth'] if 'smooth' in params_dict else 'cr' df = params_dict['df'] if 'df' in params_dict else 7 if smooth == 'cr': X = cr(x0, df) elif smooth == 'cc': X = cc(x0, df) elif smooth == 'bs': deg = params_dict['degree'] if 'degree' in params_dict else 3 X = bs(x0, df, deg) else: raise NotImplementedError(smooth) opt_params = jnp.linalg.pinv(X.T @ X) @ X.T @ y0 self.nl_basis = X def _nl(_opt_params, x_new): return jnp.maximum(interp1d(x0, X @ _opt_params)(x_new), 0) elif method == 'nn': def loss(_params, _data): x = _data['x'] y = _data['y'] yhat = _predict(_params, x) return jnp.mean((y - yhat)**2) @jit def step(_i, _opt_state, _data): p = get_params(_opt_state) g = grad(loss)(p, _data) return opt_update(_i, g, _opt_state) random_seed = params_dict[ 'random_seed'] if 'random_seed' in params_dict else 2046 key = random.PRNGKey(random_seed) step_size = params_dict[ 'step_size'] if 'step_size' in params_dict else 0.01 layer_sizes = params_dict[ 'layer_sizes'] if 'layer_sizes' in params_dict else [ 10, 10, 1 ] layers = [] for layer_size in layer_sizes: layers.append(Dense(layer_size)) layers.append(BatchNorm(axis=(0, 1))) layers.append(Relu) else: layers.pop(-1) init_random_params, _predict = stax.serial(*layers) num_subunits = params_dict[ 'num_subunits'] if 'num_subunits' in params_dict else 1 _, init_params = init_random_params(key, (-1, num_subunits)) opt_init, opt_update, get_params = optimizers.adam(step_size) opt_state = opt_init(init_params) num_iters = params_dict[ 'num_iters'] if 'num_iters' in params_dict else 1000 if num_subunits == 1: data = {'x': x0.reshape(-1, 1), 'y': y0.reshape(-1, 1)} else: data = { 'x': jnp.vstack([x0 for _ in range(num_subunits)]).T, 'y': y0.reshape(-1, 1) } for i in range(num_iters): opt_state = step(i, opt_state, data) opt_params = get_params(opt_state) def _nl(_opt_params, x_new): if len(x_new.shape) == 1: x_new = x_new.reshape(-1, 1) return jnp.maximum(_predict(_opt_params, x_new), 0) else: raise NotImplementedError(method) self.nl_xrange = x0 self.nl_params = opt_params self.fnl_fitted = _nl
class OptimizersEquivalenceTest(chex.TestCase): def setUp(self): super().setUp() self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4., 5.])) self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3., 1.])) @chex.all_variants() @parameterized.named_parameters( ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop( LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5), ('adam', alias.adam(LR_SCHED, 0.9, 0.999, 1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4), ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5), ('rmsprop_momentum', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9), optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5), ('adagrad', alias.adagrad( LR_SCHED, 0., 0., ), optimizers.adagrad(LR, 0.), 1e-5), ('sm3', alias.sm3(LR), optimizers.sm3(LR), 1e-2), ) def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer, rtol): # example_libraries/optimizers.py jax_params = self.init_params opt_init, opt_update, get_params = jax_optimizer state = opt_init(jax_params) for i in range(STEPS): state = opt_update(i, self.per_step_updates, state) jax_params = get_params(state) # optax optax_params = self.init_params state = optax_optimizer.init(optax_params) @self.variant def step(updates, state): return optax_optimizer.update(updates, state) for _ in range(STEPS): updates, state = step(self.per_step_updates, state) optax_params = update.apply_updates(optax_params, updates) # Check equivalence. chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
def main(): start_time = time() init_out_dir() last_step = get_last_ckpt_step() if last_step >= 0: my_log(f'\nCheckpoint found: {last_step}\n') else: clear_log() print_args() net_init, net_apply, net_init_cache, net_apply_fast = get_net() rng, rng_net = jrand.split(jrand.PRNGKey(args.seed)) in_shape = (args.batch_size, args.L, args.L, 1) out_shape, params_init = net_init(rng_net, in_shape) _, cache_init = net_init_cache(params_init, jnp.zeros(in_shape), (-1, -1)) # sample_fun = get_sample_fun(net_apply, None) sample_fun = get_sample_fun(net_apply_fast, cache_init) log_q_fun = get_log_q_fun(net_apply) need_beta_anneal = args.beta_anneal_step > 0 opt_init, opt_update, get_params = optimizers.adam(args.lr) @jit def update(step, opt_state, rng): params = get_params(opt_state) rng, rng_sample = jrand.split(rng) spins = sample_fun(args.batch_size, params, rng_sample) log_q = log_q_fun(params, spins) / args.L**2 energy = energy_fun(spins) / args.L**2 def neg_log_Z_fun(params, spins): log_q = log_q_fun(params, spins) / args.L**2 energy = energy_fun(spins) / args.L**2 beta = args.beta if need_beta_anneal: beta *= jnp.minimum(step / args.beta_anneal_step, 1) neg_log_Z = log_q + beta * energy return neg_log_Z loss_fun = partial(expect, log_q_fun, neg_log_Z_fun, mean_grad_expected_is_zero=True) grads = grad(loss_fun)(params, spins, spins) opt_state = opt_update(step, grads, opt_state) return spins, log_q, energy, opt_state, rng if last_step >= 0: params_init = load_ckpt(last_step) opt_state = opt_init(params_init) my_log('Training...') for step in range(last_step + 1, args.max_step + 1): spins, log_q, energy, opt_state, rng = update(step, opt_state, rng) if args.print_step and step % args.print_step == 0: # Use the final beta, not the annealed beta free_energy = log_q / args.beta + energy my_log(', '.join([ f'step = {step}', f'F = {free_energy.mean():.8g}', f'F_std = {free_energy.std():.8g}', f'S = {-log_q.mean():.8g}', f'E = {energy.mean():.8g}', f'time = {time() - start_time:.3f}', ])) if args.save_step and step % args.save_step == 0: params = get_params(opt_state) save_ckpt(params, step)
return q_state_out def cost(params): state = circuit(params) op_qs = jnp.array([[0], [0.707], [0.707], [0]]) fid = jnp.abs(jnp.dot(jnp.transpose(jnp.conjugate(op_qs)), state))**2 return -jnp.real(fid)[0][0] # fixed random parameter initialization init_params = [ 2 * jnp.pi * np.random.rand(), 2 * jnp.pi * np.random.rand(), 2 * jnp.pi * np.random.rand() ] opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2) opt_state = opt_init(init_params) def step(i, opt_state, opt_update): params = get_params(opt_state) g = grad(cost)(params) return opt_update(i, g, opt_state) epoch = 0 epoch_max = 250 loss_hist = [] loss = cost(init_params)
def __init__(self, n=2**7 - 1, rng=None, channels=8, loss=loss_gmres, iter_gmres=lambda i: 10, training_iter=500, name='net', model_dir=None, lr=3e-4, k=0.0, n_test=10, beta1=0.9, beta2=0.999, lr_og=3e-3, flaxd=False): self.n = n self.n_test = n_test self.mesh = meshes.Mesh(n) self.in_shape = (-1, n, n, 1) self.inner_channels = channels def itera(i): return onp.random.choice([5, 10, 10, 10, 10, 15, 15, 15, 20, 25]) self.iter_gmres = itera self.training_iter = training_iter self.name = name self.k = k self.model_dir = model_dir if flaxd: self.test_loss = loss_gmresR_flax else: self.test_loss = loss_gmresR self.beta1 = beta1 self.beta2 = beta2 if rng is None: rng = random.PRNGKey(1) if not flaxd: self.net_init, self.net_apply = stax.serial( UNetBlock(1, (3, 3), stax.serial( UnbiasedConv(self.inner_channels, (3, 3), padding='SAME'), UnbiasedConv(self.inner_channels, (3, 3), padding='SAME'), UNetBlock(self.inner_channels, (3, 3), stax.serial( UnbiasedConv(self.inner_channels, (3, 3), padding='SAME'), UnbiasedConv(self.inner_channels, (3, 3), padding='SAME'), UnbiasedConv(self.inner_channels, (3, 3), padding='SAME'), ), strides=(2, 2), padding='VALID'), UnbiasedConv(self.inner_channels, (3, 3), padding='SAME'), UnbiasedConv(self.inner_channels, (3, 3), padding='SAME'), ), strides=(2, 2), padding='VALID'), ) out_shape, net_params = self.net_init(rng, self.in_shape) else: #import pdb;pdb.set_trace() model_def = flax_cnn.new_CNN.partial( inner_channels=self.inner_channels) out_shape, net_params = model_def.init_by_shape( rng, [(self.in_shape, np.float32)]) self.model_def = model_def self.model = nn.Model(model_def, net_params) self.net_apply = lambda param, x: nn.Model(model_def, param)( x) #.reshape(self.in_shape)) self.out_shape = out_shape self.net_params = net_params self.loss = loss self.lr_og = lr_og self.lr = lr if not flaxd: self.opt_init, self.opt_update, self.get_params = optimizers.adam( step_size=lambda i: np.where(i < 100, lr_og, lr), b1=beta1, b2=beta2) self.opt_state = self.opt_init(self.net_params) self.step = self.step_notflax if flaxd: self.step = self.step_flax self.optimizer = flax.optim.Adam(learning_rate=lr, beta1=beta1, beta2=beta2).create(self.model) #self.optimizer = flax.optim.Momentum( # learning_rate= lr, beta=beta1, # weight_decay=0, nesterov=False).create(self.model) self.alpha = lambda i: 0.0 self.flaxd = flaxd if flaxd: self.preconditioner = self.preconditioner_flaxed else: self.preconditioner = self.preconditioner_unflaxed
def optimize(self, p0, num_iters, metric, step_size, tolerance, verbose, return_model): """Workhorse of optimization. p0: dict A dictionary of the initial model parameters to be optimized. num_iters: int Maximum number of iteration. metric: str Method of model evaluation. Can be `mse`, `corrcoeff`, `r2` step_size: float or jax scheduler Learning rate. tolerance: int Tolerance for early stop. If the training cost doesn't change more than 1e-5 in the last (tolerance) steps, or the dev cost monotonically increase, stop. verbose: int Print progress. If verbose=0, no progress will be print. return_model: str Return the 'best' model on dev set metrics or the 'last' model. """ if return_model is None: if 'dev' in self.y: return_model = 'best_dev_cost' else: return_model = 'best_train_cost' assert ('dev' in self.y) or ( 'dev' not in return_model), 'Cannot use dev set if dev set is not given.' @jit def step(_i, _opt_state): p = get_params(_opt_state) l, g = value_and_grad(self.cost)(p) return l, opt_update(_i, g, _opt_state) opt_init, opt_update, get_params = optimizers.adam(step_size=step_size) opt_state = opt_init(p0) cost_train = np.full(num_iters, np.nan) cost_dev = np.full(num_iters, np.nan) metric_train = np.full(num_iters, np.nan) metric_dev = np.full(num_iters, np.nan) params_list = [] extra = 'dev' in self.y if verbose: self.print_progress_header(c_train=True, c_dev=extra, m_train=metric is not None, m_dev=metric is not None and extra) time_start = time.time() i = 0 for i in range(num_iters): cost_train[i], opt_state = step(i, opt_state) params = get_params(opt_state) params_list.append(params) y_pred_train = self.forwardpass(p=params, kind='train') metric_train[i] = self.compute_score(self.y['train'], y_pred_train, metric) if 'dev' in self.y: y_pred_dev = self.forwardpass(p=params, kind='dev') cost_dev[i] = self.cost(p=params, kind='dev', precomputed=y_pred_dev, penalize=False) metric_dev[i] = self.compute_score(self.y['dev'], y_pred_dev, metric) time_elapsed = time.time() - time_start if verbose: if i % int(verbose) == 0: self.print_progress(i, time_elapsed, c_train=cost_train[i], c_dev=cost_dev[i], m_train=metric_train[i], m_dev=metric_dev[i]) if tolerance and i > 300: # tolerance = 0: no early stop. total_time_elapsed = time.time() - time_start if 'dev' in self.y and np.all( np.diff(cost_dev[i - tolerance:i]) > 0): stop = 'dev_stop' if verbose: print( 'Stop at {0} steps: cost (dev) has been monotonically increasing for {1} steps.' .format(i, tolerance)) print('Total time elapsed: {0:.3f}s.\n'.format( total_time_elapsed)) break if np.all(np.diff(cost_train[i - tolerance:i]) < 1e-5): stop = 'train_stop' if verbose: print( 'Stop at {0} steps: cost (train) has been changing less than 1e-5 for {1} steps.' .format(i, tolerance)) print('Total time elapsed: {0:.3f}s.\n'.format( total_time_elapsed)) break else: total_time_elapsed = time.time() - time_start stop = 'maxiter_stop' if verbose: print('Stop: reached {0} steps.'.format(num_iters)) print('Total time elapsed: {0:.3f}s.\n'.format( total_time_elapsed)) if return_model == 'best_dev_cost': best = np.argmin(cost_dev[:i + 1]) elif return_model == 'best_train_cost': best = np.argmin(cost_train[:i + 1]) elif return_model == 'best_dev_metric': if metric in ['mse', 'gcv']: best = np.argmin(metric_dev[:i + 1]) else: best = np.argmax(metric_dev[:i + 1]) elif return_model == 'best_train_metric': if metric in ['mse', 'gcv']: best = np.argmin(metric_train[:i + 1]) else: best = np.argmax(metric_train[:i + 1]) else: if return_model != 'last': print( 'Provided `return_model` is not supported. Fallback to `last`' ) if stop == 'dev_stop': best = i - tolerance else: best = i params = params_list[best] metric_dev_opt = metric_dev[best] self.cost_train = cost_train[:i + 1] self.cost_dev = cost_dev[:i + 1] self.metric_train = metric_train[:i + 1] self.metric_dev = metric_dev[:i + 1] self.metric_dev_opt = metric_dev_opt self.total_time_elapsed = total_time_elapsed self.all_params = params_list[:i + 1] # not sure if this will occupy a lot of RAM. self.y_pred['opt'].update({'train': y_pred_train}) if 'dev' in self.y: self.y_pred['opt'].update({'dev': y_pred_dev}) return params