def build_model(num_blocks, num_inputs, num_hidden, K, M, lr, device=torch.device("cpu"), use_bn=True): modules = [] for _ in range(num_blocks): if use_bn: modules += [ gmaf.SumSqMAF(num_inputs, num_hidden, K, M), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] else: modules += [ gmaf.SumSqMAF(num_inputs, num_hidden, K, M), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) module.bias.data.fill_(0) model.to(device) optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6) return model, optimizer
def load_CNF(fbest, num_inputs=None, num_cond_inputs=None, num_blocks=10, num_hidden=1024, act='relu'): ''' train conditional normalizing flow given training (variable, conditional variable) pairs. ''' import copy import numpy as np import torch import torch.nn as nn import torch.optim as optim import torch.utils.data import flows as fnn ############################################################################# cuda = torch.cuda.is_available() device = torch.device("cuda:0" if cuda else "cpu") seed = 12387 torch.manual_seed(seed) if cuda: torch.cuda.manual_seed(seed) kwargs = {'num_workers': 4, 'pin_memory': True} else: kwargs = {} ############################################################################# # set up MAF modules = [] for _ in range(num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) if hasattr(module, 'bias') and module.bias is not None: module.bias.data.fill_(0) model.to(device) ############################################################################# # load best model ############################################################################# state = torch.load(fbest, map_location=device) model.load_state_dict(state) model.num_inputs = num_inputs return model, device
def testBatchNorm(self): m1 = fnn.FlowSequential(fnn.BatchNormFlow(NUM_INPUTS)) m1.train() x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'BatchNorm Det is not zero.') self.assertTrue((x - z).abs().max() < EPS, 'BatchNorm is wrong.') # Second run. x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'BatchNorm Det is not zero for the second run.') self.assertTrue((x - z).abs().max() < EPS, 'BatchNorm is wrong for the second run.') m1.eval() m1 = fnn.FlowSequential(fnn.BatchNormFlow(NUM_INPUTS)) x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'BatchNorm Det is not zero in eval.') self.assertTrue((x - z).abs().max() < EPS, 'BatchNorm is wrong in eval.')
def load_p_Y_model(): num_inputs = np.sum(central_pixel) num_cond_inputs = None modules = [] for _ in range(num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) model.to(device) return model
def testSequentialBN(self): m1 = fnn.FlowSequential(fnn.BatchNormFlow(NUM_INPUTS), fnn.InvertibleMM(NUM_INPUTS), fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN)) m1.train() x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential BN Det is not zero.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential BN is wrong.') # Second run. x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential BN Det is not zero for the second run.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential BN is wrong for the second run.') m1.eval() # Eval run. x = torch.randn(BATCH_SIZE, NUM_INPUTS) y, logdets = m1(x) z, inv_logdets = m1(y, mode='inverse') self.assertTrue((logdets + inv_logdets).abs().max() < EPS, 'Sequential BN Det is not zero for the eval run.') self.assertTrue((x - z).abs().max() < EPS, 'Sequential BN is wrong for the eval run.')
'MNIST': 1024 }[args.dataset] act = 'tanh' if args.dataset is 'GAS' else 'relu' modules = [] assert args.flow in ['maf', 'maf-split', 'maf-split-glow', 'realnvp', 'glow'] if args.flow == 'glow': mask = torch.arange(0, num_inputs) % 2 mask = mask.to(device).float() print("Warning: Results for GLOW are not as good as for MAF yet.") for _ in range(args.num_blocks): modules += [ fnn.BatchNormFlow(num_inputs), fnn.LUInvertibleMM(num_inputs), fnn.CouplingLayer(num_inputs, num_hidden, mask, num_cond_inputs, s_act='tanh', t_act='relu') ] mask = 1 - mask elif args.flow == 'realnvp': mask = torch.arange(0, num_inputs) % 2 mask = mask.to(device).float() for _ in range(args.num_blocks): modules += [
def init_model(args, num_inputs=72): args.cuda = not args.no_cuda and torch.cuda.is_available() if args.cuda: os.environ["CUDA_VISIBLE_DEVICES"] = args.device device = torch.device("cuda:" + args.device) else: device = torch.device("cpu") # network structure num_hidden = args.num_hidden num_cond_inputs = None act = 'relu' assert act in ['relu', 'sigmoid', 'tanh'] modules = [] # normalization flow assert args.flow in ['maf', 'realnvp', 'glow'] if args.flow == 'glow': mask = torch.arange(0, num_inputs) % 2 mask = mask.to(device).float() print("Warning: Results for GLOW are not as good as for MAF yet.") for _ in range(args.num_blocks): modules += [ fnn.BatchNormFlow(num_inputs), fnn.LUInvertibleMM(num_inputs), fnn.CouplingLayer(num_inputs, num_hidden, mask, num_cond_inputs, s_act='tanh', t_act='relu') ] mask = 1 - mask elif args.flow == 'realnvp': mask = torch.arange(0, num_inputs) % 2 mask = mask.to(device).float() for _ in range(args.num_blocks): modules += [ fnn.CouplingLayer(num_inputs, num_hidden, mask, num_cond_inputs, s_act='tanh', t_act='relu'), fnn.BatchNormFlow(num_inputs) ] mask = 1 - mask elif args.flow == 'maf': for _ in range(args.num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) if hasattr(module, 'bias') and module.bias is not None: module.bias.data.fill_(0) return model
def train_CNF(thetas, condit, Ntrain, num_blocks=10, epochs=10000, batch_size=100, test_batch_size=1000, num_hidden=1024, act='relu', learning_rate=1e-4, fbest=None, fcheck=None, fmodel=None): ''' train conditional normalizing flow given training (variable, conditional variable) pairs. ''' import copy import numpy as np import torch import torch.nn as nn import torch.optim as optim import torch.utils.data import flows as fnn ############################################################################# cuda = torch.cuda.is_available() device = torch.device("cuda:0" if cuda else "cpu") seed = 12387 torch.manual_seed(seed) if cuda: torch.cuda.manual_seed(seed) kwargs = {'num_workers': 4, 'pin_memory': True} else: kwargs = {} ############################################################################# print('Ntrain = %i; Nvalid = %i' % (Ntrain, thetas.shape[0] - Ntrain)) train_tensor = torch.from_numpy(thetas[:Ntrain, :]) if condit is not None: train_cond = torch.from_numpy(condit[:Ntrain, :]) train_dataset = torch.utils.data.TensorDataset(train_tensor, train_cond) else: train_dataset = torch.utils.data.TensorDataset(train_tensor) valid_tensor = torch.from_numpy(thetas[Ntrain:, :]) if condit is not None: valid_cond = torch.from_numpy(condit[Ntrain:, :]) valid_dataset = torch.utils.data.TensorDataset(valid_tensor, valid_cond) else: valid_dataset = torch.utils.data.TensorDataset(valid_tensor) # number of conditional inputs num_inputs = thetas.shape[1] if condit is not None: num_cond_inputs = condit.shape[1] else: num_cond_inputs = None # set up loaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=test_batch_size, shuffle=False, drop_last=False, **kwargs) ############################################################################# # set up MAF modules = [] for _ in range(num_blocks): modules += [ fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act), fnn.BatchNormFlow(num_inputs), fnn.Reverse(num_inputs) ] model = fnn.FlowSequential(*modules) for module in model.modules(): if isinstance(module, nn.Linear): nn.init.orthogonal_(module.weight) if hasattr(module, 'bias') and module.bias is not None: module.bias.data.fill_(0) model.to(device) ############################################################################# # train ############################################################################# # adam optimizer optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-6) def train(epoch): model.train() train_loss = 0 for batch_idx, data in enumerate(train_loader): if isinstance(data, list): if len(data) > 1: cond_data = data[1].float() cond_data = cond_data.to(device) else: cond_data = None data = data[0] else: cond_data = None data = data.to(device) optimizer.zero_grad() loss = -model.log_probs(data, cond_data).mean() train_loss += loss.item() loss.backward() optimizer.step() for module in model.modules(): if isinstance(module, fnn.BatchNormFlow): module.momentum = 0 with torch.no_grad(): if condit is not None: model(train_loader.dataset.tensors[0].to(data.device), train_loader.dataset.tensors[1].to(data.device)) else: model(train_loader.dataset.tensors[0].to(data.device)) for module in model.modules(): if isinstance(module, fnn.BatchNormFlow): module.momentum = 1 def validate(epoch, model, loader, prefix='Validation'): model.eval() val_loss = 0 for batch_idx, data in enumerate(loader): if isinstance(data, list): if len(data) > 1: cond_data = data[1].float() cond_data = cond_data.to(device) else: cond_data = None data = data[0] else: cond_data = None data = data.to(device) with torch.no_grad(): val_loss += -model.log_probs( data, cond_data).sum().item() # sum up batch loss print('validation loss: %f' % (val_loss / len(loader.dataset))) return val_loss / len(loader.dataset) best_validation_loss = float('inf') best_validation_epoch = 0 best_model = model for epoch in range(epochs): print('\nEpoch: {}'.format(epoch)) train(epoch) validation_loss = validate(epoch, model, valid_loader) if epoch - best_validation_epoch >= 30: break if not np.isfinite(validation_loss): print("validation loss is NAN") break if validation_loss < best_validation_loss: best_validation_epoch = epoch best_validation_loss = validation_loss best_model = copy.deepcopy(model) torch.save(best_model.state_dict(), fbest) print( 'Best validation at epoch {}: Average Log Likelihood in nats: {:.4f}' .format(best_validation_epoch, -best_validation_loss)) # save training checkpoint torch.save( { 'epoch': epoch, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict() }, fcheck) # save model only torch.save(model.state_dict(), fmodel) torch.save(best_model.state_dict(), fbest) return model, device