def shuffle(key, Xa, Xb):
    ''' Randomly shuffle examples in Xa and Xb along the zeroth axis.
    Args: 
        key: random PRNGkey
        Xa: (P,N) first array to shuffle
        Xb: (P,N) second array to shuffle
    
    Returns:
        Xaperm: (P,N) shuffled copy of Xa
        Xbperm: (P,N) shuffled copy of Xb
    '''
    keya, keyb = random.split(key)
    perma = random.shuffle(keya, np.arange(len(Xa)))
    permb = random.shuffle(keyb, np.arange(len(Xb)))

    return Xa[perma], Xb[permb]
Example #2
0
    def init_fun(self, rng, input_shape, permutation=None):
        """
        :param rng: rng used to initialize parameters
        :param input_shape: input shape
        :param permutation: an optional permutation that is applied to the inputs and controls the order of the
            autoregressive factorization. in particular for the identity permutation the autoregressive structure
            is such that the Jacobian is triangular. By default this is chosen at random.
        :type permutation: array of ints
        """
        if permutation is None:
            # By default set a random permutation of variables, which is important for performance with multiple steps
            rng, rng_perm = random.split(rng)
            self.permutation = onp.array(random.shuffle(rng_perm, np.arange(self.input_dim)))
        else:
            self.permutation = permutation

        # Create masks (no skip connections allowed; TODO add support)
        masks, _ = create_mask(input_dim=self.input_dim, hidden_dims=self.hidden_dims,
                               permutation=self.permutation,
                               output_dim_multiplier=self.output_multiplier)

        # Create masked layers
        self.masked_layers = [MaskedDense(mask) for mask in masks]
        init_params = []
        for mask in self.masked_layers:
            mask_init = mask[0]
            input_shape, param = mask_init(rng, input_shape)
            init_params.append(param)

        return input_shape, init_params
Example #3
0
def test_flows(flow_class, flow_args, input_dim, batch_shape):
    transform = flow_class(*flow_args)
    x = random.normal(random.PRNGKey(0), batch_shape + (input_dim, ))

    # test inverse is correct
    y = transform(x)
    inv = transform.inv(y)
    assert_allclose(x, inv, atol=1e-5)

    # test jacobian shape
    actual = transform.log_abs_det_jacobian(x, y)
    assert onp.shape(actual) == batch_shape

    if batch_shape == ():
        # make sure transform.log_abs_det_jacobian is correct
        jac = jacfwd(transform)(x)
        expected = onp.linalg.slogdet(jac)[1]
        assert_allclose(actual, expected, atol=1e-5)

        # make sure jacobian is triangular, first permute jacobian as necessary
        if isinstance(transform, InverseAutoregressiveTransform):
            permuted_jac = onp.zeros(jac.shape)
            _, rng_key_perm = random.split(random.PRNGKey(0))
            perm = random.shuffle(rng_key_perm, onp.arange(input_dim))

            for j in range(input_dim):
                for k in range(input_dim):
                    permuted_jac[j, k] = jac[perm[j], perm[k]]

            assert onp.sum(onp.abs(onp.triu(permuted_jac, 1))) == 0.00
Example #4
0
 def shuffle(key_and_data):
     key, data = key_and_data
     key, subkey = random.split(key)
     datapoints_per_device = data[0].shape[0]
     indices = np.arange(datapoints_per_device)
     perm = random.shuffle(subkey, indices)
     return key, list(map(lambda x: x[perm], data)), 0
Example #5
0
def _find_binning_thresholds(data,
                             max_bins=256,
                             subsample=200000,
                             random_state=None):
    if 2 > max_bins or max_bins > 256:
        raise ValueError(f'max_bins={max_bins} should be no smaller than 2 '
                         f'and no larger than 256.')
    if random_state is None:
        random_state = int(time.time())
    rng = random.PRNGKey(random_state)
    if subsample is not None and data.shape[0] > subsample:
        subset = random.shuffle(rng, np.arange(data.shape[0]))[:subsample]
        data = data[subset]
    dtype = data.dtype
    if dtype.kind != 'f':
        dtype = np.float32

    percentiles = np.linspace(0, 100, num=max_bins + 1)[1:-1]
    binning_thresholds = []
    for f_idx in range(data.shape[1]):
        col_data = np.array(data[:, f_idx], dtype=dtype, order='C')
        distinct_values = onp.unique(col_data)
        if len(distinct_values) <= max_bins:
            midpoints = (distinct_values[:-1] + distinct_values[1:])
            midpoints *= 0.5
        else:
            midpoints = np.percentile(col_data,
                                      percentiles,
                                      interpolation='midpoint').astype(dtype)
        binning_thresholds.append(midpoints)
    return tuple(binning_thresholds)
Example #6
0
    def run_epoch(rng, _opt_state, epoch_idx):
        _rng, dat_keys = utils.keygen(rng, 1)
        _rng, batch_keys = utils.keygen(_rng, num_batches)

        # Randomize epoch data.
        epoch_data = random.shuffle(next(dat_keys), X_train, axis=0)

        def update(batch_idx, __opt_state):
            """Update func for gradients, includes gradient clipping."""
            kl_warmup = kl_warmup_fun(epoch_idx * num_batches + batch_idx)

            batch_data = lax.dynamic_slice_in_dim(epoch_data,
                                                  batch_idx * BATCH_SIZE,
                                                  BATCH_SIZE,
                                                  axis=0)
            batch_data = batch_data.astype(np.float32)

            params = get_params(__opt_state)
            grads = grad(loss_fn)(params, batch_data, next(batch_keys),
                                  BATCH_SIZE, ic_prior, VAR_MIN, kl_warmup,
                                  L2_REG)
            clipped_grads = optimizers.clip_grads(grads, MAX_GRAD_NORM)

            return opt_update(batch_idx, clipped_grads, __opt_state)

        return lax.fori_loop(0, num_batches, update, _opt_state)
Example #7
0
    def test_permutation_invariance(self):

        num_nodes = 4
        num_features = 2
        rng = random.PRNGKey(0)

        # Generate random graph.
        adjacency = random.randint(rng, (num_nodes, num_nodes), 0, 2)
        node_feats = random.normal(rng, (num_nodes, num_features))
        sources, targets = jnp.where(adjacency)

        # Get permuted graph.
        perm = random.shuffle(rng, jnp.arange(num_nodes))
        node_feats_perm = node_feats[perm]
        adjacency_perm = adjacency[perm]
        for j in range(len(adjacency)):
            adjacency_perm = jax.ops.index_update(adjacency_perm, j,
                                                  adjacency_perm[j][perm])
        sources_perm, targets_perm = jnp.where(adjacency_perm)

        # Create GNN.
        _, initial_params = GNN.init(rng,
                                     node_x=node_feats,
                                     edge_x=None,
                                     sources=sources,
                                     targets=targets)
        model = nn.Model(GNN, initial_params)

        # Feedforward both original and permuted graph.
        logits = model(node_feats, None, sources, targets)
        logits_perm = model(node_feats_perm, None, sources_perm, targets_perm)

        self.assertAllClose(logits[perm], logits_perm, check_dtypes=False)
Example #8
0
 def shuffle(key_and_data):
     key, data = key_and_data
     key, subkey = random.split(key)
     datapoints_per_device = data[0].shape[0]
     indices = np.arange(datapoints_per_device)
     perm = random.shuffle(subkey, indices)
     return key, [x[perm] for x in data], 0
Example #9
0
def get_data(functions, ranges, num_samples=50):
    import random
    random.seed(37)
    onp.random.seed(37)
    random.shuffle(functions)
    X = []
    Y = []
    for i, func in enumerate(functions):
        Xs = list(
            onp.random.uniform(ranges[i], ranges[i + 1], size=num_samples))
        Ys = list(func(Xs) + onp.random.normal(scale=0.3, size=len(Xs)))
        X.append(Xs)
        Y.append(Ys)
    X = np.array(X).reshape(-1, 1)
    Y = np.array(Y).reshape(-1, 1)

    return X, Y, X
Example #10
0
def _make_iaf_args(input_dim, hidden_dims):
    _, rng_key_perm = random.split(random.PRNGKey(0))
    perm = random.shuffle(rng_key_perm, onp.arange(input_dim))
    arn_init, arn = AutoregressiveNN(input_dim,
                                     hidden_dims,
                                     param_dims=[1, 1],
                                     permutation=perm)
    _, init_params = arn_init(random.PRNGKey(0), (input_dim, ))
    return partial(arn, init_params),
Example #11
0
def _make_iaf_args(input_dim, hidden_dims):
    _, rng_perm = random.split(random.PRNGKey(0))
    perm = random.shuffle(rng_perm, onp.arange(input_dim))
    # we use Elu nonlinearity because the default one, Relu, masks out negative hidden values,
    # which in turn create some zero entries in the lower triangular part of Jacobian.
    arn_init, arn = AutoregressiveNN(input_dim,
                                     hidden_dims,
                                     param_dims=[1, 1],
                                     permutation=perm,
                                     nonlinearity=stax.Elu)
    _, init_params = arn_init(random.PRNGKey(0), (input_dim, ))
    return partial(arn, init_params),
Example #12
0
  def testShuffle(self, dtype):
    key = random.PRNGKey(0)
    x = onp.arange(100).astype(dtype)
    rand = lambda key: random.shuffle(key, x)
    crand = api.jit(rand)

    perm1 = rand(key)
    perm2 = crand(key)

    self.assertAllClose(perm1, perm2, check_dtypes=True)
    self.assertFalse(onp.all(perm1 == x))  # seems unlikely!
    self.assertAllClose(onp.sort(perm1), x, check_dtypes=False)
Example #13
0
    def testShuffle(self, dtype):
        key = random.PRNGKey(0)
        x = onp.arange(100).astype(dtype)
        rand = lambda key: random.shuffle(key, x)
        crand = api.jit(rand)

        perm1 = rand(key)
        perm2 = crand(key)

        self.assertTrue(onp.all(perm1 == perm2))
        self.assertTrue(onp.all(perm1.dtype == perm2.dtype))
        self.assertFalse(onp.all(perm1 == x))  # seems unlikely!
        self.assertTrue(onp.all(onp.sort(perm1) == x))
Example #14
0
    def testShuffle(self, dtype):
        key = random.PRNGKey(0)
        x = np.arange(100).astype(dtype)
        rand = lambda key: random.shuffle(key, x)
        crand = api.jit(rand)

        with self.assertWarns(FutureWarning):
            perm1 = rand(key)
        with self.assertWarns(FutureWarning):
            perm2 = crand(key)

        self.assertAllClose(perm1, perm2)
        self.assertFalse(np.all(perm1 == x))  # seems unlikely!
        self.assertAllClose(np.sort(perm1), x, check_dtypes=False)
Example #15
0
def minibatcher(data, batch_size, transform=None, seed=0):
    key = random.PRNGKey(seed)
    size = data.X.shape[0]
    indices = np.arange(size, dtype=np.int32)
    num_batches = size // batch_size

    while True:
        key, subkey = random.split(key)
        perm = random.shuffle(key, indices)
        for i in range(num_batches):
            batch_ids = perm[i * batch_size:(i + 1) * batch_size]
            b = data._replace(X=data.X[batch_ids], Y=data.Y[batch_ids])
            if transform:
                key, subkey = random.split(key)
                b = transform(b, subkey)
            yield b
Example #16
0
def shuffle(key, tensors, axis=0):
    """Shuffles the contents of tensors in unison.

    Args:
        key: Pseudo-random generator state.
        tensors: Iterator of tensors.
        axis: Optional, axis along which to shuffle (default 0).

    Returns:
        List of shuffled tensors.

    Raises:
        ValueError: If shape of tensors do not match along `axis`.
    """
    a = mo.size(tensors, axis=axis)
    p = random.shuffle(key, np.arange(a))

    return [np.take(tsr, p, axis=axis) for tsr in tensors]
Example #17
0
def mini_batch(x_train, y_train, batch_size, train_epochs):
    # epoch = 0
    start = 0
    key = random.PRNGKey(0)

    while True:
        end = start + batch_size

        if end > x_train.shape[0]:
            key, split = random.split(key)
            permutation = random.shuffle(
                split, np.arange(x_train.shape[0], dtype=np.int64))
            x_train = x_train[permutation]
            y_train = y_train[permutation]
            # epoch += 1
            start = 0
            # print(epoch)
            continue
        yield x_train[start:end], y_train[start:end]
        start = start + batch_size
Example #18
0
def test_auto_reg_nn(input_dim, hidden_dims, param_dims, skip_connections):
    arn_init, arn = AutoregressiveNN(input_dim,
                                     hidden_dims,
                                     param_dims=param_dims,
                                     skip_connections=skip_connections)

    rng = random.PRNGKey(0)
    batch_size = 4
    input_shape = (batch_size, input_dim)
    _, init_params = arn_init(rng, input_shape)

    output = arn(init_params, onp.random.rand(*input_shape))

    if param_dims == [1]:
        assert output.shape == (batch_size, input_dim)
        jac = jacfwd(lambda x: arn(init_params, x))(onp.random.rand(input_dim))
    elif param_dims == [1, 1]:
        assert output[0].shape == (batch_size, input_dim)
        assert output[1].shape == (batch_size, input_dim)
        jac = jacfwd(lambda x: arn(init_params, x)[0])(
            onp.random.rand(input_dim))
    elif param_dims == [2]:
        assert output.shape == (2, batch_size, input_dim)
        jac = jacfwd(lambda x: arn(init_params, x))(onp.random.rand(input_dim))
    elif param_dims == [2, 3]:
        assert output[0].shape == (2, batch_size, input_dim)
        assert output[1].shape == (3, batch_size, input_dim)
        jac = jacfwd(lambda x: arn(init_params, x)[0])(
            onp.random.rand(input_dim))

    # permute jacobian as necessary
    permuted_jac = onp.zeros(jac.shape)
    _, rng_perm = random.split(rng)
    perm = random.shuffle(rng_perm, onp.arange(input_dim))

    for j in range(input_dim):
        for k in range(input_dim):
            permuted_jac[..., j, k] = jac[..., perm[j], perm[k]]

    # make sure jacobians are triangular
    assert onp.sum(onp.abs(onp.triu(permuted_jac))) == 0.0
Example #19
0
def minibatch(x_train, y_train, batch_size, train_epochs):
  """Generate minibatches of data for a set number of epochs."""
  epoch = 0
  start = 0
  key = random.PRNGKey(0)

  while epoch < train_epochs:
    end = start + batch_size

    if end > x_train.shape[0]:
      key, split = random.split(key)
      permutation = random.shuffle(split,
                                   np.arange(x_train.shape[0], dtype=np.int64))
      x_train = x_train[permutation]
      y_train = y_train[permutation]
      epoch += 1
      start = 0
      continue

    yield x_train[start:end], y_train[start:end]
    start = start + batch_size
Example #20
0
    def init_fun(rng, input_shape):
        """
        :param rng: rng used to initialize parameters
        :param input_shape: input shape
        """
        # TODO: consider removing permutation so we can move those layer constructions outside
        # init_fun. It seems that we can add a PermuteTransform layer to achieve the same effect.
        nonlocal permutation, net

        if permutation is None:
            # By default set a random permutation of variables, which is
            # important for performance with multiple steps
            rng, rng_perm = random.split(rng)
            permutation = random.shuffle(rng_perm, np.arange(input_dim))

        # Create masks
        masks, mask_skip = create_mask(input_dim=input_dim,
                                       hidden_dims=hidden_dims,
                                       permutation=permutation,
                                       output_dim_multiplier=output_multiplier)

        main_layers = []
        # Create masked layers
        for i, mask in enumerate(masks):
            main_layers.append(MaskedDense(mask))
            if i < len(masks) - 1:
                main_layers.append(nonlinearity)

        if skip_connections:
            net_init, net = stax.serial(
                stax.FanOut(2),
                stax.parallel(stax.serial(*main_layers),
                              MaskedDense(mask_skip, bias=False)),
                stax.FanInSum)
        else:
            net_init, net = stax.serial(*main_layers)

        return net_init(rng, input_shape)
Example #21
0
 def shuffle(self, x, axis=0):
     x = x.value if isinstance(x, JaxArray) else x
     return JaxArray(jr.shuffle(self.split_key(), x, axis=axis))
Example #22
0
def _shuffle_jax(value, seed=None, name=None):  # pylint: disable=unused-argument
    import jax.random as jaxrand  # pylint: disable=g-import-not-at-top
    if seed is None:
        raise ValueError('Must provide PRNGKey to sample in JAX.')
    return jaxrand.shuffle(seed, value, axis=0)
Example #23
0
def gauss_laplace_leapfrog(current_state,
                           target_log_prob_fn,
                           kinetic_energy_fn,
                           step_size,
                           n_disc,
                           rng=None):
    """
    One numerical integration step of the DHMC integrator for a mixed
    Gaussian and Laplace momentum.

    Params
    ------
    f: function(theta, req_grad)
      Returns the log probability and, if req_grad is True, its gradient.
      The gradient for discrete parameters should be zero.
    f_update: function(theta, step_sizeheta, index, aux)
      Computes the difference in the log probability when theta[index] is
      modified by step_sizeheta. The input 'aux' is whatever the quantity saved from
      the previous call to 'f' or 'f_update' that can be recycled.
    M: column vector
      Represents the diagonal mass matrix
    n_disc: int
      Number of discrete parameters. The parameters theta[:-n_disc] are
      assumed continuous.
    """
    del kinetic_energy_fn
    assert isinstance(current_state.state, list)
    assert isinstance(current_state.state_grads, list)

    M = tree_util.tree_map(np.ones_like, current_state.state)
    state, state_grads = current_state.state, current_state.state_grads
    momentum = current_state.momentum

    n_param = len(state)
    state = list(state)
    # Update the continuous parameters
    momentum[:-n_disc] = tree_util.tree_multimap(
        lambda p, g: p + 0.5 * step_size * g, momentum[:-n_disc],
        state_grads[:-n_disc])

    state[:-n_disc] = tree_util.tree_multimap(
        lambda t, p: t + 0.5 * step_size * p, state[:-n_disc],
        momentum[:-n_disc])
    logp = utils.call_fn(target_log_prob_fn, state)
    if np.isinf(logp):
        return current_state
    # Update discrete
    coord_order = n_param - n_disc + np.arange(n_disc)
    coord_order = random.shuffle(rng, coord_order)
    for index in coord_order:
        state, momentum, logp = _update_coordwise(target_log_prob_fn, index,
                                                  state, momentum, M,
                                                  step_size, logp)
    # Another half step of discrete
    state[:-n_disc] = tree_util.tree_multimap(
        lambda t, p: t + 0.5 * step_size * p, state[:-n_disc],
        momentum[:-n_disc])
    new_target_logp, new_state_grads = utils.call_fn_value_and_grad(
        target_log_prob_fn, state)
    momentum[:-n_disc] = tree_util.tree_multimap(
        lambda p, g: p + 0.5 * step_size * g, momentum[:-n_disc],
        new_state_grads[:-n_disc])
    return IntegratorState(state=state,
                           state_grads=new_state_grads,
                           target_log_prob=new_target_logp,
                           momentum=momentum)
Example #24
0
def get_masks_from_jax_params(params, nn_density_level, magnitude_base_bool = True, global_bool = False, reshuffle_seed = 0):
    """ Assemble a collection of 0-1 valued masks which are of the same sizes and shapes as layers' weight tensors
        Note that this function ignores bias parameters.
    
    Args: 
        params: parameters in a jax.experimental.stax format. 
        nn_density_level: the desired density level for weight parameters.
        magnitude_base_bool: a boolean variable that decides whether to prune the network by magnitude or randomly prune the network
        
    Returns:
        masks: a collection of 0-1 valued masks which are of the same sizes and shapes as the layers' weight tensors.
    """ 
    
    if (type(magnitude_base_bool) != bool) or (type(global_bool) != bool):
        raise ValueError("magnitude_base_bool and global_bool should be boolean variables")
    
    masks = []
    
    if global_bool:
        weight_magnitudes_pooled = np.concatenate([ np.abs(layer_params[0].flatten()) for layer_params in params if len(layer_params) > 1])
        idx = int( (1 - nn_density_level) * np.size(weight_magnitudes_pooled) )
        global_thres = np.sort(weight_magnitudes_pooled)[idx]
    
    for layer_index in range( len(params)):


        if len(params[layer_index]) < 2:
            # In this the case, the layer does not contain weight and bias parameters.
            masks.append( [] )
            
        elif len(params[layer_index]) == 2:
            # In this case, the layer contains a tuple of parameters for weights and biases
            
            weights = params[layer_index][0]
            
            weight_magnitudes = np.abs(weights)

            if global_bool and magnitude_base_bool:
                
                this_mask = np.float32(weight_magnitudes > global_thres)
                
            else:
                # index: number of pruned parameters
                idx = int( (1 - nn_density_level) * np.size(weights) )

                # threshold: entries which below the thredhold will be removed
                thres = np.sort(np.reshape(weight_magnitudes, [-1] ))[idx]

                # 0 selected for weight parameters with magnitudes smaller than the threshold, 1 otherwise
                this_mask = np.float32(weight_magnitudes > thres)

                if magnitude_base_bool == False:
                    # in the case of random pruning: randomly shuffle the mask
                    this_mask = random.shuffle(random.PRNGKey(0), this_mask ) 

            masks.append(this_mask ) 

        else:
            raise NotImplementedError

    return masks
Example #25
0
    def train(self,
              bs,
              solutions=[None],
              retrain=False,
              tensorboard_writer=None,
              work_unit=None):

        if not retrain and not self.flaxd:
            opt_state = self.opt_init(self.net_params)
        if retrain:
            opt_state = self.opt_init(self.opt_params)
        loss = onp.zeros(self.training_iter // 10 + 1)
        gradients = onp.zeros(self.training_iter // 10 + 1)
        if not self.flaxd:
            param = self.get_params(opt_state)
        else:
            param = self.optimizer.target
            opt_state = self.optimizer
        og_loss = self.test_loss(
            self.preconditioner, self.n_test, self.mesh, param,
            np.zeros((bs.shape[1], self.n * self.n)), bs[0].reshape(
                bs.shape[1], self.n * self.n), 0, self.k) / 10000000
        print(og_loss)
        if work_unit is not None:
            work_unit.get_measurement_series(
                label='train/loss').create_measurement(objective_value=og_loss,
                                                       step=0)
        for i in range(self.training_iter):
            m = bs.shape[0]
            order = random.shuffle(random.PRNGKey(i), np.arange(m))
            for _ in range(50):
                for b in bs[order]:
                    current_loss, grad, opt_state = self.step(
                        i, opt_state, np.zeros((b.shape[0], self.n * self.n)),
                        b, solutions[min(m,
                                         len(solutions) - 1)])

            if i % 10 == 0:
                if not self.flaxd:
                    param = self.get_params(opt_state)
                else:
                    param = opt_state.target
                current_loss_test = self.test_loss(
                    self.preconditioner, self.n_test, self.mesh, param,
                    np.zeros((b.shape[0], self.n * self.n)), b, 0,
                    self.k) / 10000000
                current_loss = current_loss / 10000000
                avg_grad = onp.mean(onp.abs(onp_utils.flatten(grad)[-1]))
                print(
                    f'step{i: 5d}: loss { current_loss :1.5f} : avg_gradient \
              { avg_grad :1.5f} : current_loss_test { current_loss_test :1.5f}'
                )
                logging.info(
                    f'step{i: 5d}: loss { current_loss :1.5f} : avg_gradient \
              { avg_grad :1.5f} : current_loss_test { current_loss_test :1.5f}'
                )
                loss[i // 10] = current_loss
                gradients[i // 10] = avg_grad
                if work_unit is not None:
                    work_unit.get_measurement_series(
                        label='train/loss').create_measurement(
                            objective_value=current_loss_test, step=i)
                    tensorboard_writer.scalar('train/loss',
                                              current_loss_test,
                                              step=i + 1)
                    work_unit.get_measurement_series(
                        label='train/loss ' +
                        str(self.iter_gmres(i))).create_measurement(
                            objective_value=current_loss, step=i + 1)
                    tensorboard_writer.scalar('train/loss ' +
                                              str(self.iter_gmres(i)),
                                              current_loss,
                                              step=i + 1)
            if i % 50 == 0:
                if self.flaxd:
                    self.opt_params = opt_state.target.params
                else:
                    self.opt_params = self.get_params(opt_state)
                self.save(str(i))
        if self.flaxd:
            self.optimizer = opt_state
        else:
            self.opt_params = self.get_params(opt_state)
            self.opt_state = opt_state
        if self.model_dir is None:
            self.model_dir = ''

        with open(os.path.join(self.model_dir, 'train_loss.np'), 'wb') as f:
            onp.save(f, loss)
        with open(os.path.join(self.model_dir, 'train_gradients.np'),
                  'wb') as f:
            onp.save(f, gradients)
        self.save()
        if work_unit is not None:
            tensorboard_writer.close()
Example #26
0
    def search(self,
               method_id,
               method_params,
               problem_id,
               problem_params,
               loss,
               search_space,
               trials=None,
               smoothing=10,
               min_steps=100,
               verbose=0):
        """
        Description: Search for optimal method parameters
        Args:
            method_id (string): id of method
            method_params (dict): initial method parameters dict (updated by search space)
            problem_id (string): id of problem to try on
            problem_params (dict): problem parameters dict
            loss (function): a function mapping y_pred, y_true -> scalar loss
            search_space (dict): dict mapping parameter names to a finite set of options
            trials (int, None): number of random trials to sample from search space / try all parameters
            smoothing (int): loss computed over smoothing number of steps to decrease variance
            min_steps (int): minimum number of steps that the method gets to run for
            verbose (int): if 1, print progress and current parameters
        """
        self.method_id = method_id
        self.method_params = method_params
        self.problem_id = problem_id
        self.problem_params = problem_params
        self.loss = loss

        # store the order to test parameters
        param_list = list(
            itertools.product(*[v for k, v in search_space.items()]))
        index = np.arange(
            len(param_list)
        )  # np.random.shuffle doesn't work directly on non-JAX objects
        shuffled_index = random.shuffle(generate_key(), index)
        param_order = [param_list[int(i)]
                       for i in shuffled_index]  # shuffle order of elements

        # helper method
        def _update_smoothing(l, val):
            """ update smoothing loss list with new val """
            return jax.ops.index_update(np.roll(l, 1), 0, val)

        self._update_smoothing = jit(_update_smoothing)

        # store optimal params and optimal loss
        optimal_params, optimal_loss = {}, None
        t = 0
        for params in param_order:  # loop over all params in the given order
            t += 1
            curr_params = method_params.copy()
            curr_params.update(
                {k: v
                 for k, v in zip(search_space.keys(), params)})
            loss = self._run_test(curr_params,
                                  smoothing=smoothing,
                                  min_steps=min_steps,
                                  verbose=verbose)
            if not optimal_loss or loss < optimal_loss:
                optimal_params = curr_params
                optimal_loss = loss
            if t == trials:  # break after trials number of attempts, unless trials is None
                break
        return optimal_params, optimal_loss