예제 #1
0
    def build_model():
        modules = []

        mask = torch.arange(0, num_inputs) % 2
        #mask = torch.ones(num_inputs)
        #mask[round(num_inputs/2):] = 0
        mask = mask.to(device).float()

        # build each modules
        for _ in range(args.num_blocks):
            modules += [
                fnn.ActNorm(num_inputs),
                fnn.LUInvertibleMM(num_inputs),
                fnn.CouplingLayer(num_inputs,
                                  num_hidden,
                                  mask,
                                  num_cond_inputs,
                                  s_act='tanh',
                                  t_act='relu')
            ]
            mask = 1 - mask

        # build model
        model = fnn.FlowSequential(*modules)

        # initialize
        for module in model.modules():
            if isinstance(module, nn.Linear):
                nn.init.orthogonal_(module.weight)
                if hasattr(module, 'bias') and module.bias is not None:
                    module.bias.data.fill_(0)

        model.to(device)

        return model
예제 #2
0
 def __init__(self, input_size, channels_h, K, L, save_memory=False):
     super(Glow, self).__init__()
     self.L = L
     self.save_memory = save_memory
     self.output_sizes = []
     blocks = []
     c, h, w = input_size
     for l in range(L):
         block = [flows.Squeeze()]
         c *= 4
         h //= 2
         w //= 2  # squeeze
         for _ in range(K):
             norm_layer = flows.ActNorm(c)
             if save_memory:
                 perm_layer = flows.RandomRotation(
                     c)  # easily inversible ver
             else:
                 perm_layer = flows.InversibleConv1x1(c)
             coupling_layer = flows.AffineCouplingLayer(c, channels_h)
             block += [norm_layer, perm_layer, coupling_layer]
         blocks.append(flows.FlowSequential(*block))
         self.output_sizes.append((c, h, w))
         c //= 2  # split
     self.blocks = nn.ModuleList(blocks)
예제 #3
0
def build_model(num_blocks,
                num_inputs,
                num_hidden,
                K,
                M,
                lr,
                device=torch.device("cpu"),
                use_bn=True):
    modules = []
    for _ in range(num_blocks):
        if use_bn:
            modules += [
                gmaf.SumSqMAF(num_inputs, num_hidden, K, M),
                fnn.BatchNormFlow(num_inputs),
                fnn.Reverse(num_inputs)
            ]
        else:
            modules += [
                gmaf.SumSqMAF(num_inputs, num_hidden, K, M),
                fnn.Reverse(num_inputs)
            ]
    model = fnn.FlowSequential(*modules)
    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight)
            module.bias.data.fill_(0)

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)

    return model, optimizer
예제 #4
0
    def testSigmoid(self):
        m1 = fnn.FlowSequential(fnn.Sigmoid())

        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sigmoid Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'Sigmoid is wrong.')
예제 #5
0
    def testInv(self):
        m1 = fnn.FlowSequential(fnn.InvertibleMM(NUM_INPUTS))

        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'InvMM Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'InvMM is wrong.')
예제 #6
0
    def testCoupling(self):
        m1 = fnn.FlowSequential(fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN))

        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'CouplingLayer Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'CouplingLayer is wrong')
예제 #7
0
def load_CNF(fbest,
             num_inputs=None,
             num_cond_inputs=None,
             num_blocks=10,
             num_hidden=1024,
             act='relu'):
    ''' train conditional normalizing flow given training (variable,
    conditional variable) pairs. 
    '''
    import copy
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.utils.data

    import flows as fnn
    #############################################################################
    cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if cuda else "cpu")

    seed = 12387
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)
        kwargs = {'num_workers': 4, 'pin_memory': True}
    else:
        kwargs = {}
    #############################################################################
    # set up MAF
    modules = []
    for _ in range(num_blocks):
        modules += [
            fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act),
            fnn.BatchNormFlow(num_inputs),
            fnn.Reverse(num_inputs)
        ]
    model = fnn.FlowSequential(*modules)

    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.data.fill_(0)

    model.to(device)
    #############################################################################
    # load best model
    #############################################################################
    state = torch.load(fbest, map_location=device)
    model.load_state_dict(state)
    model.num_inputs = num_inputs
    return model, device
예제 #8
0
    def testBatchNorm(self):
        m1 = fnn.FlowSequential(fnn.BatchNormFlow(NUM_INPUTS))
        m1.train()

        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'BatchNorm Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'BatchNorm is wrong.')

        # Second run.
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'BatchNorm Det is not zero for the second run.')
        self.assertTrue((x - z).abs().max() < EPS,
                        'BatchNorm is wrong for the second run.')

        m1.eval()
        m1 = fnn.FlowSequential(fnn.BatchNormFlow(NUM_INPUTS))

        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'BatchNorm Det is not zero in eval.')
        self.assertTrue((x - z).abs().max() < EPS,
                        'BatchNorm is wrong in eval.')
예제 #9
0
def load_p_Y_model():
    num_inputs = np.sum(central_pixel)
    num_cond_inputs = None

    modules = []
    for _ in range(num_blocks):
        modules += [
            fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act),
            fnn.BatchNormFlow(num_inputs),
            fnn.Reverse(num_inputs)
        ]

    model = fnn.FlowSequential(*modules)
    model.to(device)
    return model
예제 #10
0
    def testSequentialBN(self):
        m1 = fnn.FlowSequential(fnn.BatchNormFlow(NUM_INPUTS),
                                fnn.InvertibleMM(NUM_INPUTS),
                                fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN))

        m1.train()
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential BN Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'Sequential BN is wrong.')

        # Second run.
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential BN Det is not zero for the second run.')
        self.assertTrue((x - z).abs().max() < EPS,
                        'Sequential BN is wrong for the second run.')

        m1.eval()
        # Eval run.
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential BN Det is not zero for the eval run.')
        self.assertTrue((x - z).abs().max() < EPS,
                        'Sequential BN is wrong for the eval run.')
예제 #11
0
    def testSequential(self):
        m1 = fnn.FlowSequential(
            fnn.ActNorm(NUM_INPUTS), fnn.InvertibleMM(NUM_INPUTS),
            fnn.CouplingLayer(NUM_INPUTS, NUM_HIDDEN, mask))

        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential Det is not zero.')
        self.assertTrue((x - z).abs().max() < EPS, 'Sequential is wrong.')

        # Second run.
        x = torch.randn(BATCH_SIZE, NUM_INPUTS)

        y, logdets = m1(x)
        z, inv_logdets = m1(y, mode='inverse')

        self.assertTrue((logdets + inv_logdets).abs().max() < EPS,
                        'Sequential Det is not zero for the second run.')
        self.assertTrue((x - z).abs().max() < EPS,
                        'Sequential is wrong for the second run.')
예제 #12
0
            fnn.BatchNormFlow(num_inputs),
            fnn.Reverse(num_inputs)
        ]
elif args.flow == 'maf-split-glow':
    for _ in range(args.num_blocks):
        modules += [
            fnn.MADESplit(num_inputs,
                          num_hidden,
                          num_cond_inputs,
                          s_act='tanh',
                          t_act='relu'),
            fnn.BatchNormFlow(num_inputs),
            fnn.InvertibleMM(num_inputs)
        ]

model = fnn.FlowSequential(*modules)

for module in model.modules():
    if isinstance(module, nn.Linear):
        nn.init.orthogonal_(module.weight)
        if hasattr(module, 'bias') and module.bias is not None:
            module.bias.data.fill_(0)

model.to(device)

optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-6)

writer = SummaryWriter(comment=args.flow + "_" + args.dataset)
global_step = 0

예제 #13
0
def init_model(args, num_inputs=72):
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device
        device = torch.device("cuda:" + args.device)
    else:
        device = torch.device("cpu")
    # network structure
    num_hidden = args.num_hidden
    num_cond_inputs = None

    act = 'relu'
    assert act in ['relu', 'sigmoid', 'tanh']

    modules = []

    # normalization flow
    assert args.flow in ['maf', 'realnvp', 'glow']

    if args.flow == 'glow':
        mask = torch.arange(0, num_inputs) % 2
        mask = mask.to(device).float()

        print("Warning: Results for GLOW are not as good as for MAF yet.")
        for _ in range(args.num_blocks):
            modules += [
                fnn.BatchNormFlow(num_inputs),
                fnn.LUInvertibleMM(num_inputs),
                fnn.CouplingLayer(num_inputs,
                                  num_hidden,
                                  mask,
                                  num_cond_inputs,
                                  s_act='tanh',
                                  t_act='relu')
            ]
            mask = 1 - mask

    elif args.flow == 'realnvp':
        mask = torch.arange(0, num_inputs) % 2
        mask = mask.to(device).float()

        for _ in range(args.num_blocks):
            modules += [
                fnn.CouplingLayer(num_inputs,
                                  num_hidden,
                                  mask,
                                  num_cond_inputs,
                                  s_act='tanh',
                                  t_act='relu'),
                fnn.BatchNormFlow(num_inputs)
            ]
            mask = 1 - mask

    elif args.flow == 'maf':
        for _ in range(args.num_blocks):
            modules += [
                fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act),
                fnn.BatchNormFlow(num_inputs),
                fnn.Reverse(num_inputs)
            ]

    model = fnn.FlowSequential(*modules)

    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.data.fill_(0)

    return model
예제 #14
0
def train_CNF(thetas,
              condit,
              Ntrain,
              num_blocks=10,
              epochs=10000,
              batch_size=100,
              test_batch_size=1000,
              num_hidden=1024,
              act='relu',
              learning_rate=1e-4,
              fbest=None,
              fcheck=None,
              fmodel=None):
    ''' train conditional normalizing flow given training (variable,
    conditional variable) pairs. 
    '''
    import copy
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.utils.data

    import flows as fnn
    #############################################################################
    cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if cuda else "cpu")

    seed = 12387
    torch.manual_seed(seed)
    if cuda:
        torch.cuda.manual_seed(seed)
        kwargs = {'num_workers': 4, 'pin_memory': True}
    else:
        kwargs = {}
    #############################################################################
    print('Ntrain = %i; Nvalid = %i' % (Ntrain, thetas.shape[0] - Ntrain))
    train_tensor = torch.from_numpy(thetas[:Ntrain, :])
    if condit is not None:
        train_cond = torch.from_numpy(condit[:Ntrain, :])
        train_dataset = torch.utils.data.TensorDataset(train_tensor,
                                                       train_cond)
    else:
        train_dataset = torch.utils.data.TensorDataset(train_tensor)

    valid_tensor = torch.from_numpy(thetas[Ntrain:, :])
    if condit is not None:
        valid_cond = torch.from_numpy(condit[Ntrain:, :])
        valid_dataset = torch.utils.data.TensorDataset(valid_tensor,
                                                       valid_cond)
    else:
        valid_dataset = torch.utils.data.TensorDataset(valid_tensor)

    # number of conditional inputs
    num_inputs = thetas.shape[1]
    if condit is not None:
        num_cond_inputs = condit.shape[1]
    else:
        num_cond_inputs = None

    # set up loaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=test_batch_size,
                                               shuffle=False,
                                               drop_last=False,
                                               **kwargs)

    #############################################################################
    # set up MAF
    modules = []
    for _ in range(num_blocks):
        modules += [
            fnn.MADE(num_inputs, num_hidden, num_cond_inputs, act=act),
            fnn.BatchNormFlow(num_inputs),
            fnn.Reverse(num_inputs)
        ]
    model = fnn.FlowSequential(*modules)

    for module in model.modules():
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight)
            if hasattr(module, 'bias') and module.bias is not None:
                module.bias.data.fill_(0)

    model.to(device)

    #############################################################################
    # train
    #############################################################################
    # adam optimizer
    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=1e-6)

    def train(epoch):
        model.train()
        train_loss = 0

        for batch_idx, data in enumerate(train_loader):
            if isinstance(data, list):
                if len(data) > 1:
                    cond_data = data[1].float()
                    cond_data = cond_data.to(device)
                else:
                    cond_data = None

                data = data[0]
            else:
                cond_data = None

            data = data.to(device)
            optimizer.zero_grad()
            loss = -model.log_probs(data, cond_data).mean()
            train_loss += loss.item()
            loss.backward()
            optimizer.step()

        for module in model.modules():
            if isinstance(module, fnn.BatchNormFlow):
                module.momentum = 0

        with torch.no_grad():
            if condit is not None:
                model(train_loader.dataset.tensors[0].to(data.device),
                      train_loader.dataset.tensors[1].to(data.device))
            else:
                model(train_loader.dataset.tensors[0].to(data.device))

        for module in model.modules():
            if isinstance(module, fnn.BatchNormFlow):
                module.momentum = 1

    def validate(epoch, model, loader, prefix='Validation'):

        model.eval()
        val_loss = 0

        for batch_idx, data in enumerate(loader):
            if isinstance(data, list):
                if len(data) > 1:
                    cond_data = data[1].float()
                    cond_data = cond_data.to(device)
                else:
                    cond_data = None
                data = data[0]
            else:
                cond_data = None

            data = data.to(device)

            with torch.no_grad():
                val_loss += -model.log_probs(
                    data, cond_data).sum().item()  # sum up batch loss

        print('validation loss: %f' % (val_loss / len(loader.dataset)))
        return val_loss / len(loader.dataset)

    best_validation_loss = float('inf')
    best_validation_epoch = 0
    best_model = model

    for epoch in range(epochs):
        print('\nEpoch: {}'.format(epoch))

        train(epoch)
        validation_loss = validate(epoch, model, valid_loader)

        if epoch - best_validation_epoch >= 30:
            break
        if not np.isfinite(validation_loss):
            print("validation loss is NAN")
            break

        if validation_loss < best_validation_loss:
            best_validation_epoch = epoch
            best_validation_loss = validation_loss
            best_model = copy.deepcopy(model)
            torch.save(best_model.state_dict(), fbest)

        print(
            'Best validation at epoch {}: Average Log Likelihood in nats: {:.4f}'
            .format(best_validation_epoch, -best_validation_loss))

        # save training checkpoint
        torch.save(
            {
                'epoch': epoch,
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict()
            }, fcheck)
        # save model only
        torch.save(model.state_dict(), fmodel)

    torch.save(best_model.state_dict(), fbest)
    return model, device