def test_made(self): def get_masks(input_dim, hidden_dim=64, num_hidden=1): masks = [] input_degrees = np.arange(input_dim) degrees = [input_degrees] for n_h in range(num_hidden + 1): degrees += [np.arange(hidden_dim) % (input_dim - 1)] degrees += [input_degrees % input_dim - 1] for (d0, d1) in zip(degrees[:-1], degrees[1:]): masks += [np.transpose(np.expand_dims(d1, -1) >= np.expand_dims(d0, 0)).astype(np.float32)] return masks def masked_transform(rng, input_dim): masks = get_masks(input_dim, hidden_dim=64, num_hidden=1) act = stax.Relu init_fun, apply_fun = stax.serial( flows.MaskedDense(masks[0]), act, flows.MaskedDense(masks[1]), act, flows.MaskedDense(masks[2].tile(2)), ) _, params = init_fun(rng, (input_dim,)) return params, apply_fun for test in (returns_correct_shape, is_bijective): test(self, flows.MADE(masked_transform))
def load_CNF(fbest, num_inputs=None, num_cond_inputs=None, num_blocks=10, num_hidden=1024, act='relu'): ''' train conditional normalizing flow given training (variable, conditional variable) pairs. ''' import copy import numpy as np import torch import torch.nn as nn import torch.optim as optim import torch.utils.data import flows as fnn ############################################################################# cuda = torch.cuda.is_available() device = torch.device("cuda:0" if cuda else "cpu") seed = 12387 torch.manual_seed(seed) if cuda: torch.cuda.manual_seed(seed) kwargs = {'num_workers': 4, 'pin_memory': True} else: kwargs = {} ############################################################################# # set up MAF modules = [] for _ in range(num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) if hasattr(module, 'bias') and module.bias is not None: module.bias.data.fill_(0) model.to(device) ############################################################################# # load best model ############################################################################# state = torch.load(fbest, map_location=device) model.load_state_dict(state) model.num_inputs = num_inputs return model, device
def load_p_Y_model(): num_inputs = np.sum(central_pixel) num_cond_inputs = None modules = [] for _ in range(num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) model.to(device) return model
for _ in range(args.num_blocks): modules += [ fnn.CouplingLayer(num_inputs, num_hidden, mask, num_cond_inputs, s_act='tanh', t_act='relu'), fnn.BatchNormFlow(num_inputs) ] mask = 1 - mask elif args.flow == 'maf': for _ in range(args.num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] elif args.flow == 'maf-split': for _ in range(args.num_blocks): modules += [ fnn.MADESplit(num_inputs, num_hidden, num_cond_inputs, s_act='tanh', t_act='relu'), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] elif args.flow == 'maf-split-glow':
mask = mask.to(device).float() for _ in range(args.num_blocks): modules += [ fnn.CouplingLayer(num_inputs, num_hidden, mask, s_act='tanh', t_act='relu'), fnn.BatchNormFlow(num_inputs) ] mask = 1 - mask elif args.flow == 'maf': for _ in range(args.num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) module.bias.data.fill_(0) model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)
def learn_flow(X, hidden_dim=48, num_hidden=2, num_unit=5, learning_rate=1e-3, num_epochs=200, batch_size=4000, interval=None, seed=123): """Training with 400K of riz data works well the default args. Make sure to remove samples with undetected flux since these otherwise create a delta function. """ # Preprocess input data. scaler = preprocessing.StandardScaler().fit(X) X_preproc = scaler.transform(X) input_dim = X.shape[1] # Initialize random numbers. rng, flow_rng = jax.random.split(jax.random.PRNGKey(seed)) # Initialize our flow bijection. transform = functools.partial(masked_transform, hidden_dim=hidden_dim, num_hidden=num_hidden) bijection_init_fun = flows.Serial( *(flows.MADE(transform), flows.Reverse()) * num_unit) # Create direct and inverse bijection functions. rng, bijection_rng = jax.random.split(rng) bijection_params, bijection_direct, bijection_inverse = bijection_init_fun( bijection_rng, input_dim) # Initialize our flow model. prior_init_fun = flows.Normal() flow_init_fun = flows.Flow(bijection_init_fun, prior_init_fun) initial_params, log_pdf, sample = flow_init_fun(flow_rng, input_dim) if interval is not None: import matplotlib.pyplot as plt bins = np.linspace(-0.05, 1.05, 111) def loss_fn(params, inputs): return -log_pdf(params, inputs).mean() @jax.jit def step(i, opt_state, inputs): params = get_params(opt_state) loss_value, gradients = jax.value_and_grad(loss_fn)(params, inputs) return opt_update(i, gradients, opt_state), loss_value opt_init, opt_update, get_params = optimizers.adam(step_size=learning_rate) opt_state = opt_init(initial_params) root2 = np.sqrt(2.) itercount = itertools.count() for epoch in range(num_epochs): rng, permute_rng = jax.random.split(rng) X_epoch = jax.random.permutation(permute_rng, X_preproc) for batch_index in range(0, len(X), batch_size): opt_state, loss = step( next(itercount), opt_state, X_epoch[batch_index:batch_index + batch_size]) if interval is not None and (epoch + 1) % interval == 0: print(f'epoch {epoch + 1} loss {loss:.3f}') # Map the input data back through the flow to the prior space. epoch_params = get_params(opt_state) X_normal, _ = bijection_direct(epoch_params, X_epoch) X_uniform = 0.5 * ( 1 + scipy.special.erf(np.array(X_normal, np.float64) / root2)) for i in range(input_dim): plt.hist(X_uniform[:, i], bins, histtype='step') plt.show() # Return a function that maps samples to a ~uniform distribution on [0,1] ** input_dim. # Takes a numpy array as input and returns a numpy array of the same shape. final_params = get_params(opt_state) def flow_map(Y): Y_preproc = scaler.transform(Y) Y_normal, _ = bijection_direct(final_params, jnp.array(Y_preproc)) #return np.array(Y_normal) Y_uniform = 0.5 * ( 1 + scipy.special.erf(np.array(Y_normal, np.float64) / root2)) return Y_uniform.astype(np.float32) return flow_map
def get_modules(flow, num_blocks, normalization, hidden_dim=64): modules = [] if flow == 'realnvp': for _ in range(num_blocks): modules += [ flows.AffineCoupling(get_transform(hidden_dim)), flows.Reverse(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'realnvp-conv': for _ in range(num_blocks): modules += [ MNISTAffineCoupling(get_conv_transform(hidden_dim)), flows.Reverse(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'nice': for _ in range(num_blocks): modules += [ flows.AffineCoupling(get_nice_transform(hidden_dim)), flows.Reverse(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'glow': for _ in range(num_blocks): modules += [ flows.AffineCoupling(get_transform(hidden_dim)), flows.InvertibleLinear(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'maf': for _ in range(num_blocks): modules += [ flows.MADE(get_masked_transform(hidden_dim)), flows.Reverse(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'neural-spline': for _ in range(num_blocks): modules += [ flows.NeuralSplineCoupling(), ] elif flow == 'maf-glow': for _ in range(num_blocks): modules += [ flows.MADE(get_masked_transform(hidden_dim)), flows.InvertibleLinear(), ] if normalization: modules += [ flows.ActNorm(), ] elif flow == 'custom': for _ in range(num_blocks): modules += [ flows.MADE(get_masked_transform(hidden_dim)), flows.InvertibleLinear(), ] modules += [ flows.ActNorm(), ] else: raise Exception('Invalid flow: {}'.format(flow)) return modules
def init_model(args, num_inputs=72): args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: os.environ["CUDA_VISIBLE_DEVICES"] = args.device device = torch.device("cuda:" + args.device) else: device = torch.device("cpu") # network structure num_hidden = args.num_hidden num_cond_inputs = None act = 'relu' assert act in ['relu', 'sigmoid', 'tanh'] modules = [] # normalization flow assert args.flow in ['maf', 'realnvp', 'glow'] if args.flow == 'glow': mask = torch.arange(0, num_inputs) % 2 mask = mask.to(device).float() print("Warning: Results for GLOW are not as good as for MAF yet.") for _ in range(args.num_blocks): modules += [ fnn.BatchNormFlow(num_inputs), fnn.LUInvertibleMM(num_inputs), fnn.CouplingLayer(num_inputs, num_hidden, mask, num_cond_inputs, s_act='tanh', t_act='relu') ] mask = 1 - mask elif args.flow == 'realnvp': mask = torch.arange(0, num_inputs) % 2 mask = mask.to(device).float() for _ in range(args.num_blocks): modules += [ fnn.CouplingLayer(num_inputs, num_hidden, mask, num_cond_inputs, s_act='tanh', t_act='relu'), fnn.BatchNormFlow(num_inputs) ] mask = 1 - mask elif args.flow == 'maf': for _ in range(args.num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) if hasattr(module, 'bias') and module.bias is not None: module.bias.data.fill_(0) return model
def train_CNF(thetas, condit, Ntrain, num_blocks=10, epochs=10000, batch_size=100, test_batch_size=1000, num_hidden=1024, act='relu', learning_rate=1e-4, fbest=None, fcheck=None, fmodel=None): ''' train conditional normalizing flow given training (variable, conditional variable) pairs. ''' import copy import numpy as np import torch import torch.nn as nn import torch.optim as optim import torch.utils.data import flows as fnn ############################################################################# cuda = torch.cuda.is_available() device = torch.device("cuda:0" if cuda else "cpu") seed = 12387 torch.manual_seed(seed) if cuda: torch.cuda.manual_seed(seed) kwargs = {'num_workers': 4, 'pin_memory': True} else: kwargs = {} ############################################################################# print('Ntrain = %i; Nvalid = %i' % (Ntrain, thetas.shape[0] - Ntrain)) train_tensor = torch.from_numpy(thetas[:Ntrain, :]) if condit is not None: train_cond = torch.from_numpy(condit[:Ntrain, :]) train_dataset = torch.utils.data.TensorDataset(train_tensor, train_cond) else: train_dataset = torch.utils.data.TensorDataset(train_tensor) valid_tensor = torch.from_numpy(thetas[Ntrain:, :]) if condit is not None: valid_cond = torch.from_numpy(condit[Ntrain:, :]) valid_dataset = torch.utils.data.TensorDataset(valid_tensor, valid_cond) else: valid_dataset = torch.utils.data.TensorDataset(valid_tensor) # number of conditional inputs num_inputs = thetas.shape[1] if condit is not None: num_cond_inputs = condit.shape[1] else: num_cond_inputs = None # set up loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=test_batch_size, shuffle=False, drop_last=False, **kwargs) ############################################################################# # set up MAF modules = [] for _ in range(num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) if hasattr(module, 'bias') and module.bias is not None: module.bias.data.fill_(0) model.to(device) ############################################################################# # train ############################################################################# # adam optimizer optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) def train(epoch): model.train() train_loss = 0 for batch_idx, data in enumerate(train_loader): if isinstance(data, list): if len(data) > 1: cond_data = data[1].float() cond_data = cond_data.to(device) else: cond_data = None data = data[0] else: cond_data = None data = data.to(device) optimizer.zero_grad() loss = -model.log_probs(data, cond_data).mean() train_loss += loss.item() loss.backward() optimizer.step() for module in model.modules(): if isinstance(module, fnn.BatchNormFlow): module.momentum = 0 with torch.no_grad(): if condit is not None: model(train_loader.dataset.tensors[0].to(data.device), train_loader.dataset.tensors[1].to(data.device)) else: model(train_loader.dataset.tensors[0].to(data.device)) for module in model.modules(): if isinstance(module, fnn.BatchNormFlow): module.momentum = 1 def validate(epoch, model, loader, prefix='Validation'): model.eval() val_loss = 0 for batch_idx, data in enumerate(loader): if isinstance(data, list): if len(data) > 1: cond_data = data[1].float() cond_data = cond_data.to(device) else: cond_data = None data = data[0] else: cond_data = None data = data.to(device) with torch.no_grad(): val_loss += -model.log_probs( data, cond_data).sum().item() # sum up batch loss print('validation loss: %f' % (val_loss / len(loader.dataset))) return val_loss / len(loader.dataset) best_validation_loss = float('inf') best_validation_epoch = 0 best_model = model for epoch in range(epochs): print('\nEpoch: {}'.format(epoch)) train(epoch) validation_loss = validate(epoch, model, valid_loader) if epoch - best_validation_epoch >= 30: break if not np.isfinite(validation_loss): print("validation loss is NAN") break if validation_loss < best_validation_loss: best_validation_epoch = epoch best_validation_loss = validation_loss best_model = copy.deepcopy(model) torch.save(best_model.state_dict(), fbest) print( 'Best validation at epoch {}: Average Log Likelihood in nats: {:.4f}' .format(best_validation_epoch, -best_validation_loss)) # save training checkpoint torch.save( { 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict() }, fcheck) # save model only torch.save(model.state_dict(), fmodel) torch.save(best_model.state_dict(), fbest) return model, device
}[args.dataset] modules = [] assert args.flow in ['maf', 'glow'] for _ in range(args.num_blocks): if args.flow == 'glow': print("Warning: Results for GLOW are not as good as for MAF yet.") modules += [ fnn.BatchNormFlow(num_inputs), fnn.InvertibleMM(num_inputs), fnn.CouplingLayer(num_inputs, num_hidden) ] elif args.flow == 'maf': modules += [ fnn.MADE(num_inputs, num_hidden), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) module.bias.data.fill_(0) model.to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)