def __init__(self, data, targets, connections, model_out, n_ensembles, rng): self.data = data self.targets = targets self.connections = connections self.model_out = model_out self.n_ensembles = n_ensembles self.rng = rng self.enkf = EnKF(maxit=1, online=False, n_batches=1) self.connections = self._shape_connections(connections)
conv_loss_mnist = [] # average test losses test_losses = [] np.random.seed(0) torch.manual_seed(0) batch_size = 64 model = MnistOptimizee(root=root, batch_size=batch_size, seed=0, n_ensembles=n_ensembles).to(device) conv_ens = None gamma = np.eye(10) * 0.01 enkf = EnKF(tol=1e-5, maxit=1, stopping_crit='', online=False, shuffle=False, n_batches=1, converge=False) rng = int(60000 / batch_size * 8) for i in range(1): model.generation = i + 1 if i == 0: try: out = model.load_model('') # replace cov matrix with cov from weights (ensembles) # m = torch.distributions.Normal(out['conv_params'].mean(), # out['conv_params'].std()) # model.cov = m.sample((n_ensembles, model.length)) except FileNotFoundError as fe: print(fe)
dyn_change = { 'model_reps': [], 'n_ensembles': [], 'iteration': [], 'rng': 60000, 'test_loss': [] } # init model model = MnistOptimizee(root=root, batch_size=batch_size, n_ensembles=n_ensembles).to(device) conv_ens = None g_scaler = config['gamma'] gamma = np.eye(10) * g_scaler enkf = EnKF(maxit=1, online=False, n_batches=1) # len dataset divided by batch size and repetitions rng = int(60000 / batch_size * 8) i = 0 while i < rng: model.generation = i + 1 if i == 0: try: out = model.load_model('') except FileNotFoundError as fe: print(fe) print('Model not found! Initalizaing new ensembles.') out = model.create_individual() conv_ens = out['conv_params'] out = model.set_parameters(conv_ens) print('loss {} generation {}'.format(out['conv_loss'],