def constuct_inn(img_dims=[3, 32, 32]):

    split_nodes = []

    nodes = [fr.InputNode(*img_dims, name='input')]

    nodes.append(
        fr.Node(nodes[-1].out0,
                la.Reshape, {'target_dim': (img_dims[0], *img_dims[1:])},
                name='reshape'))
    nodes.append(
        fr.Node(nodes[-1].out0,
                la.HaarDownsampling, {'rebalance': 0.5},
                name='haar_down_1'))

    for k in range(n_coupling_blocks_conv_0):
        nodes.append(
            fr.Node(nodes[-1],
                    la.GLOWCouplingBlock, {
                        'subnet_constructor': conv_constr_0,
                        'clamp': clamp
                    },
                    name=f'CB CONV_0_{k}'))
        if k % 2:
            nodes.append(
                fr.Node(nodes[-1], la.Fixed1x1Conv, {'M': random_orthog(12)}))

    nodes.append(
        fr.Node(nodes[-1].out0,
                la.HaarDownsampling, {'rebalance': 0.5},
                name='haar_down_2'))

    for k in range(n_coupling_blocks_conv_1):
        nodes.append(
            fr.Node(nodes[-1],
                    la.GLOWCouplingBlock, {
                        'subnet_constructor': conv_constr_1,
                        'clamp': clamp
                    },
                    name=f'CB CONV_1_{k}'))
        if k % 2:
            nodes.append(
                fr.Node(nodes[-1], la.Fixed1x1Conv, {'M': random_orthog(48)}))

    nodes.append(fr.Node(nodes[-1].out0, la.Flatten, {}, name='flatten'))

    for k in range(n_coupling_blocks_fc):
        nodes.append(
            fr.Node(nodes[-1],
                    la.GLOWCouplingBlock, {
                        'subnet_constructor': fc_constr,
                        'clamp': clamp
                    },
                    name=f'CB FC_{k}'))
        if k % 2:
            nodes.append(fr.Node(nodes[-1], la.PermuteRandom, {'seed': k}))

    nodes.append(fr.OutputNode(nodes[-1], name='output'))

    return fr.ReversibleGraphNet(nodes)
Esempio n. 2
0
    def construct_inn(self, input, backbone: InvertibleArchitecture,
                      head: InvertibleArchitecture):

        nodes = []
        split_nodes = []

        nodes.append(input)

        backbone_nodes, backbone_split_nodes, skip_connections = backbone.construct_inn(
            nodes[-1])

        nodes += backbone_nodes

        if skip_connections:
            print("HAS SKIP CONNECTION")
            head_nodes, head_split_nodes = head.construct_inn(
                nodes[-1], skip_connections)
            split_nodes += backbone_split_nodes
        else:
            head_nodes, head_split_nodes = head.construct_inn(nodes[-1])

        nodes.append(Ff.OutputNode(head_nodes[-1], name='out_fc'))

        nodes += head_nodes
        split_nodes += head_split_nodes

        self.model = Ff.ReversibleGraphNet(nodes + split_nodes, verbose=True)
        print(self.model)

        return nodes
Esempio n. 3
0
    def build_inn(self):

        def fc_constr():
            return lambda ch_in, ch_out: nn.Sequential(nn.Linear(ch_in, c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout),
                         nn.Linear(c.internal_width,  c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout),
                         nn.Linear(c.internal_width,  ch_out))

        nodes = [Ff.InputNode(c.x_dim)]

        # outputs of the cond. net at different resolution levels
        conditions = []

        for i in range(c.n_blocks):
            conditions.append(Ff.ConditionNode(c.y_dim_features))

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':fc_constr(), 'clamp':c.exponent_clamping},
                                 conditions=conditions[i]))
            #nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':i}, name=f'PERM_FC_{i}'))
            #nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}, name=f'ActN{i}'))
            

        nodes.append(Ff.OutputNode(nodes[-1]))

        return Ff.ReversibleGraphNet(nodes + conditions, verbose=False)
Esempio n. 4
0
def generate_rcINN_old():
    cond = Ff.ConditionNode(config.n_cond_features, name='condition')
    nodes = [Ff.InputNode(config.n_x_features, name='input')]

    for k in range(config.n_blocks):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor': subnet_func,
                        'clamp': config.clamp
                    },
                    conditions=cond,
                    name=F'coupling_{k}'))
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.PermuteRandom, {'seed': k},
                    name=F'permute_{k}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    model = Ff.ReversibleGraphNet(nodes + [cond])
    model.to(device)

    params_trainable = list(
        filter(lambda p: p.requires_grad, model.parameters()))
    for p in params_trainable:
        p.data = config.init_scale * torch.randn_like(p).to(device)

    gamma = config.gamma
    optim = torch.optim.AdamW(params_trainable, lr=config.lr)
    weight_scheduler = torch.optim.lr_scheduler.StepLR(optim,
                                                       step_size=1,
                                                       gamma=gamma)

    return model, optim, weight_scheduler
Esempio n. 5
0
 def build_inn(self, nodes, split_nodes, conditions):
     self.model = Ff.ReversibleGraphNet(nodes + split_nodes + conditions,
                                        verbose=False)
     self.trainable_parameters = [
         p for p in self.model.parameters() if p.requires_grad
     ]
     for p in self.trainable_parameters:
         p.data = 0.02 * torch.randn_like(p)
def get_invnet(ndim,
               nb_blocks=4,
               hidden_layer=128,
               small_block=False,
               permute=True,
               verbose=False):
    return freia_fw.ReversibleGraphNet(get_nodes(ndim, nb_blocks, hidden_layer,
                                                 small_block, permute),
                                       verbose=verbose)
Esempio n. 7
0
    def _inn(self):
        def subnet(ch_in, ch_out):
            return nn.Sequential(nn.Linear(ch_in, 1024), nn.ReLU(),
                                 nn.Linear(1024, ch_out))

        nodes = []
        nodes.append(ff.InputNode(self.feat_length))
        for k in range(self.depth):
            nodes.append(ff.Node(nodes[-1], fm.PermuteRandom, {'seed': k}))
            nodes.append(
                ff.Node(nodes[-1], fm.GLOWCouplingBlock, {
                    'subnet_constructor': subnet,
                    'clamp': 2.0
                }))

        return ff.ReversibleGraphNet(nodes + [ff.OutputNode(nodes[-1])],
                                     verbose=False)
Esempio n. 8
0
def construct_net_10d(coupling_block, init_identity=True):
    if coupling_block == 'gin':
        block = Fm.GINCouplingBlock
    else:
        assert coupling_block == 'glow'
        block = Fm.GLOWCouplingBlock
    
    nodes = [Ff.InputNode(10, name='input')]
    
    for k in range(8):
        nodes.append(Ff.Node(nodes[-1], block,
                             {'subnet_constructor':lambda c_in,c_out: subnet_fc_10d(c_in, c_out, init_identity), 'clamp':2.0},
                             name=F'coupling_{k}'))
        nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom,
                        {'seed':np.random.randint(2**31)},
                        name=F'permute_{k+1}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    return Ff.ReversibleGraphNet(nodes)
Esempio n. 9
0
def construct_net_emnist(coupling_block):
    if coupling_block == 'gin':
        block = Fm.GINCouplingBlock
    else:
        assert coupling_block == 'glow'
        block = Fm.GLOWCouplingBlock
    
    nodes = [Ff.InputNode(1, 28, 28, name='input')]
    nodes.append(Ff.Node(nodes[-1], Fm.IRevNetDownsampling, {}, name='downsample1'))

    for k in range(4):
        nodes.append(Ff.Node(nodes[-1], block,
                             {'subnet_constructor':subnet_conv1, 'clamp':2.0},
                             name=F'coupling_conv1_{k}'))
        nodes.append(Ff.Node(nodes[-1],
                             Fm.PermuteRandom,
                             {'seed':np.random.randint(2**31)},
                             name=F'permute_conv1_{k}'))

    nodes.append(Ff.Node(nodes[-1], Fm.IRevNetDownsampling, {}, name='downsample2'))

    for k in range(4):
        nodes.append(Ff.Node(nodes[-1], block,
                             {'subnet_constructor':subnet_conv2, 'clamp':2.0},
                             name=F'coupling_conv2_{k}'))
        nodes.append(Ff.Node(nodes[-1],
                             Fm.PermuteRandom,
                             {'seed':np.random.randint(2**31)},
                             name=F'permute_conv2_{k}'))

    nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten'))

    for k in range(2):
        nodes.append(Ff.Node(nodes[-1], block,
                             {'subnet_constructor':subnet_fc, 'clamp':2.0},
                             name=F'coupling_fc_{k}'))
        nodes.append(Ff.Node(nodes[-1],
                             Fm.PermuteRandom,
                             {'seed':np.random.randint(2**31)},
                             name=F'permute_fc_{k}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    return Ff.ReversibleGraphNet(nodes)
Esempio n. 10
0
    def build_inn(self):
        def subnet(ch_in, ch_out):
            return nn.Sequential(nn.Linear(ch_in, 512), nn.ReLU(),
                                 nn.Linear(512, ch_out))

        nodes = []
        nodes.append(ff.InputNode(28 * 28))

        nodes.append(ff.Node(nodes[-1], fm.Flatten, {}))

        for k in range(20):
            nodes.append(ff.Node(nodes[-1], fm.PermuteRandom, {'seed': k}))
            nodes.append(
                ff.Node(nodes[-1], fm.GLOWCouplingBlock, {
                    'subnet_constructor': subnet,
                    'clamp': 1.0
                }))

        return ff.ReversibleGraphNet(nodes + [ff.OutputNode(nodes[-1])],
                                     verbose=False)
def mnist_inn_one(mask_size=[28, 28]):
    """
    Return an autoencoder.

    :param mask_size: size of the input. Default: Size of MNIST images
    :return:
    """

    img_dims = [1, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    r2 = fr.Node([r1.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r2')

    fc = fr.Node([r2.out0],
                 la.rev_multiplicative_layer, {
                     'F_class': fu.F_fully_connected,
                     'F_args': {},
                     'clamp': 1
                 },
                 name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (4, 14, 14)},
                 name='r3')

    r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4')

    outp = fr.OutputNode([r4.out0], name='output')

    nodes = [inp, outp, r1, r2, r3, r4, fc]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Esempio n. 12
0
    def build_inn(self):
        def subnet(ch_in, ch_out):
            return nn.Sequential(nn.Linear(ch_in, 512), nn.LeakyReLU(),
                                 nn.Linear(512, ch_out))

        cond = Ff.ConditionNode(10)
        nodes = [Ff.InputNode(1, 28, 28)]
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}))

        for k in range(16):
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
            nodes.append(
                Ff.Node(nodes[-1],
                        Fm.GLOWCouplingBlock, {
                            'subnet_constructor': subnet,
                            'clamp': 0.8
                        },
                        conditions=cond))

        return Ff.ReversibleGraphNet(nodes +
                                     [cond, Ff.OutputNode(nodes[-1])],
                                     verbose=False)
Esempio n. 13
0
def build_inn(fe_c, fe_h, fe_w, subnet_mode, flow_blocks, verbose):

    nodes = [Ff.InputNode(fe_c, fe_h, fe_w, name='input')]

    dynamic_var = locals()
    L, index = len(flow_blocks), 0  # multiscale structure design if L > 1
    for l in range(L):
        for k in range(flow_blocks[l]):
            nodes = one_step(nodes, subnet_mode, index, fe_h)
            index += 1
        if l == L - 1:
            nodes.append(Ff.OutputNode(nodes[-1], name=F'output{l}'))
        else:
            # split off half of the channels way to reduce dimension redundancy
            nodes.append(
                Ff.Node(nodes[-1], Fm.SplitChannel, {}, name=F'split{l}'))
            dynamic_var['output' + str(l)] = Ff.OutputNode(nodes[-1].out1,
                                                           name=F'output{l}')

    return Ff.ReversibleGraphNet(
        nodes + [dynamic_var['output' + str(i)] for i in range(L - 1)],
        verbose=verbose)
def inn_model(img_dims=4):
    """
    Return INN model.

    :param img_dims: size of the model input images. Default: Size of MNIST images
    :return: INN model
    """

    inp = fr.InputNode(img_dims, name='input')

    fc1 = fr.Node([inp.out0],
                  la.GLOWCouplingBlock, {
                      'subnet_constructor': fc_constr,
                      'clamp': 2
                  },
                  name='fc1')
    fc2 = fr.Node([fc1.out0],
                  la.GLOWCouplingBlock, {
                      'subnet_constructor': fc_constr,
                      'clamp': 2
                  },
                  name='fc2')
    fc3 = fr.Node([fc2.out0],
                  la.GLOWCouplingBlock, {
                      'subnet_constructor': fc_constr,
                      'clamp': 2
                  },
                  name='fc3')

    outp = fr.OutputNode([fc3.out0], name='output')

    nodes = [inp, outp, fc1, fc2, fc3]

    model = fr.ReversibleGraphNet(nodes)

    return model
def get_mnist_conv(mask_size=[28, 28]):
    """
    Return an autoencoder.

    :param mask_size: size of the input. Default: Size of MNIST images
    :return:
    """

    img_dims = [1, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    conv1 = fr.Node(
        [r1.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': False
            },
            'clamp': 1
        },
        name='conv1')
    conv2 = fr.Node(
        [conv1.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': False
            },
            'clamp': 1
        },
        name='conv2')
    conv3 = fr.Node(
        [conv2.out0],
        la.rev_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': False
            }
        },
        name='conv3')

    r2 = fr.Node([conv3.out0],
                 re.reshape_layer, {'target_dim': (784, )},
                 name='r2')

    fc = fr.Node([r2.out0],
                 la.rev_multiplicative_layer, {
                     'F_class': fu.F_fully_connected,
                     'F_args': {},
                     'clamp': 1
                 },
                 name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (4, 14, 14)},
                 name='r3')

    r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4')

    outp = fr.OutputNode([r4.out0], name='output')

    nodes = [inp, outp, conv1, conv2, conv3, r1, r2, r3, r4, fc]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
def celeba_inn_small(mask_size=[156, 128]):
    """
    Return CelebA INN autoencoder for comparison with classical autoencoder (same number of parameters).

    :param latent_dim: dimension of the latent space
    :param mask_size: size of the input. Default: Size of CelebA images
    :param batch_norm: use batch norm for the F_conv modules
    :return: CelebA INN autoencoder
    """

    img_dims = [3, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    conv11 = fr.Node([r1.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv11')

    conv12 = fr.Node([conv11.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv12')

    conv13 = fr.Node([conv12.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv13')

    r2 = fr.Node([conv13.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r2')

    fc = fr.Node(
        [r2.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_small_connected,
            'F_args': {
                'internal_size': 5000
            },
            'clamp': 1
        },
        name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (12, 78, 64)},
                 name='r3')

    r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4')

    outp = fr.OutputNode([r4.out0], name='output')

    nodes = [inp, outp, conv11, conv12, conv13, fc, r1, r2, r3, r4]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Esempio n. 17
0
    def strided_constr(cin, cout):
        layers = [ nn.Conv2d(cin, 16, 3, stride=2, padding=1),
                   nn.ReLU(),
                   nn.Conv2d(16, cout, 1, stride=1)]
        return nn.Sequential(*layers)

    def low_res_constr(cin, cout):
        layers = [nn.Conv2d(cin, cout, 1)]
        return nn.Sequential(*layers)
    inp = Ff.InputNode(3, 32, 32, name='in')
    node = Ff.Node(inp, HighPerfCouplingBlock, {'subnet_constructor':low_res_constr}, name='coupling')
    node2 = Ff.Node(node, DownsampleCouplingBlock, {'subnet_constructor_strided':strided_constr,
                                                  'subnet_constructor_low_res':low_res_constr}, name='down_coupling')
    out = Ff.OutputNode(node2, name='out')

    net = Ff.ReversibleGraphNet([inp, node, node2, out])

    x = torch.randn(4, 3, 32, 32)

    z = net(x)
    jac = net.log_jacobian(run_forward=False)

    x_inv = net(z, rev=True)

    diff = x - x_inv
    print('shape in')
    print(x.shape)
    print('shape out')
    print(z.shape)
    print('shape inv')
    print(x_inv.shape)
Esempio n. 18
0
n_glow = 8
for i in range(n_glow):
    nodes.append(
        Ff.Node(nodes[-1], Fm.PermuteRandom, {"seed": i}, name=f"permute_{i}"))
    nodes.append(
        Ff.Node(
            nodes[-1],
            Fm.GLOWCouplingBlock,
            {
                "subnet_constructor": CreateFC(128),
                "clamp": 2
            },
            name=f"glow_{i}",
        ))
nodes.append(Ff.OutputNode(nodes[-1], name="output"))
net = Ff.ReversibleGraphNet(nodes, verbose=False)
net = net.to(device=device)

print("Training")
#
# Training
#
losses = []
val_losses = []
epochs = args.iterations
n_batch = args.batch_size

writer = SummaryWriter()

n = training_data_npy.shape[0]
n_val = n // 10
def cifar_inn_glow(mask_size=[32, 32], batch_norm=False):
    """
    Return an autoencoder.

    :param mask_size: size of the input. Default: Size of CIFAR images
    :return: autoencoder
    """

    img_dims = [3, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    conv11 = fr.Node(
        [r1.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv11')
    conv12 = fr.Node(
        [conv11.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv12')
    conv13 = fr.Node(
        [conv12.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv13')

    r2 = fr.Node([conv13.out0],
                 re.reshape_layer, {'target_dim': (3072, )},
                 name='r2')

    fc = fr.Node([r2.out0],
                 la.glow_coupling_layer, {
                     'F_class': fu.F_fully_connected,
                     'clamp': 1
                 },
                 name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (12, 16, 16)},
                 name='r3')

    r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4')

    outp = fr.OutputNode([r4.out0], name='output')

    nodes = [inp, outp, conv11, conv12, conv13, r1, r2, r3, r4, fc]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Esempio n. 20
0
    def build_inn(self):

        nodes = [Ff.InputNode(3, 64, 64, name='inp')]
        cond_node = Ff.ConditionNode(40, name='attr_cond')
        split_nodes = []

        # ndim_x = 3 * 64 * 64

        for i in range(2):
            nodes.append(
                Ff.Node([nodes[-1].out0],
                        permute_layer, {'seed': i},
                        name=F'permute_0_{i}'))
            nodes.append(
                Ff.Node(
                    [nodes[-1].out0],
                    glow_coupling_layer, {
                        'clamp': 1.,
                        'F_class': F_fully_conv,
                        'F_args': {
                            'kernel_size': 3,
                            'channels_hidden': 256,
                            'leaky_slope': 0.2
                        }
                    },
                    name=F'conv_0_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.HaarDownsampling, {'rebalance': 0.5},
                    name='haar_1'))

        for i in range(4):
            kernel_size = 1 if i % 2 == 0 else 3
            nodes.append(
                Ff.Node([nodes[-1].out0],
                        permute_layer, {'seed': i},
                        name=F'permute_1_{i}'))
            nodes.append(
                Ff.Node(
                    [nodes[-1].out0],
                    glow_coupling_layer, {
                        'clamp': 1.,
                        'F_class': F_fully_conv_shallow,
                        'F_args': {
                            'kernel_size': kernel_size,
                            'channels_hidden': 64,
                            'leaky_slope': 0.2
                        }
                    },
                    name=F'conv_1_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.Split1D, {
                        'split_size_or_sections': (8, 4),
                        'dim': 0
                    },
                    name='split_1'))
        split_nodes.append(
            Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_1'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.HaarDownsampling, {'rebalance': 0.5},
                    name='haar_2'))

        for i in range(8):
            kernel_size = 1 if i % 2 == 0 else 3

            nodes.append(
                Ff.Node(nodes[-1],
                        permute_layer, {'seed': i},
                        name=F'permute_2_{i}'))
            nodes.append(
                Ff.Node(nodes[-1],
                        glow_coupling_layer, {
                            'clamp': 1.,
                            'F_class': F_fully_conv_shallow,
                            'F_args': {
                                'kernel_size': kernel_size,
                                'channels_hidden': 64,
                                'leaky_slope': 0.2
                            }
                        },
                        name=F'conv_2_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.Split1D, {
                        'split_size_or_sections': (16, 16),
                        'dim': 0
                    },
                    name='split_2'))
        split_nodes.append(
            Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_2'))
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten_3'))

        for i in range(2):
            nodes.append(
                Ff.Node(nodes[-1],
                        permute_layer, {'seed': i},
                        name=F'permute_3_{i}'))
            nodes.append(
                Ff.Node(nodes[-1],
                        glow_coupling_layer, {
                            'clamp': 0.8,
                            'F_class': F_fully_shallow,
                            'F_args': {
                                'internal_size': 256,
                                'dropout': 0.2
                            }
                        },
                        conditions=[cond_node],
                        name=F'fc_{i}'))

        nodes.append(
            Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0],
                    Fm.Concat1d, {'dim': 0},
                    name='concat'))
        nodes.append(Ff.OutputNode(nodes[-1], name='out'))

        conv_inn = Ff.ReversibleGraphNet(nodes + split_nodes + [cond_node],
                                         verbose=True)

        return conv_inn
Esempio n. 21
0
def inn_model(img_dims=[1, 28, 28]):
    """
    Return INN model for MNIST.

    :param img_dims: size of the model input images. Default: Size of MNIST images
    :return: INN model
    """

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node(
        [inp.out0],
        re.reshape_layer,
        {'target_dim': (img_dims[0] * 4, img_dims[1] // 2, img_dims[2] // 2)},
        name='r1')

    conv1 = fr.Node([r1.out0],
                    la.glow_coupling_layer, {
                        'F_class': fu.F_conv,
                        'F_args': {
                            'channels_hidden': 128
                        },
                        'clamp': 1
                    },
                    name='conv1')

    conv2 = fr.Node([conv1.out0],
                    la.glow_coupling_layer, {
                        'F_class': fu.F_conv,
                        'F_args': {
                            'channels_hidden': 128
                        },
                        'clamp': 1
                    },
                    name='conv2')

    conv3 = fr.Node([conv2.out0],
                    la.glow_coupling_layer, {
                        'F_class': fu.F_conv,
                        'F_args': {
                            'channels_hidden': 128
                        },
                        'clamp': 1
                    },
                    name='conv3')

    r2 = fr.Node([conv3.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r2')

    fc = fr.Node([r2.out0],
                 la.rev_multiplicative_layer, {
                     'F_class': fu.F_fully_connected,
                     'clamp': 1
                 },
                 name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0], img_dims[1], img_dims[2])},
                 name='r3')

    outp = fr.OutputNode([r3.out0], name='output')

    nodes = [inp, outp, conv1, conv2, conv3, r1, r2, r3, fc]

    model = fr.ReversibleGraphNet(nodes, 0, 1)

    return model
def celeba_inn_glow(mask_size=[156, 128], batch_norm=False):
    """
    Return an autoencoder.

    :param mask_size: size of the input. Default: Size of CIFAR images
    :return: autoencoder
    """

    img_dims = [3, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    conv11 = fr.Node(
        [r1.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv11')
    conv12 = fr.Node(
        [conv11.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv12')
    conv13 = fr.Node(
        [conv12.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv13')

    r2 = fr.Node([conv13.out0], re.haar_multiplex_layer, {}, name='r2')

    conv21 = fr.Node(
        [r2.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv21')
    conv22 = fr.Node(
        [conv21.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv22')
    conv23 = fr.Node(
        [conv22.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv23')

    r3 = fr.Node([conv23.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r3')

    fc = fr.Node(
        [r3.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_small_connected,
            'F_args': {
                'internal_size': 500
            },
            'clamp': 1
        },
        name='fc')

    r4 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (48, 39, 32)},
                 name='r4')

    r5 = fr.Node([r4.out0], re.haar_restore_layer, {}, name='r5')

    r6 = fr.Node([r5.out0], re.haar_restore_layer, {}, name='r6')

    outp = fr.OutputNode([r6.out0], name='output')

    nodes = [
        inp, outp, conv11, conv12, conv13, conv21, conv22, conv23, r1, r2, r3,
        r4, r5, r6, fc
    ]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Esempio n. 23
0
    def build_inn(self):

        def sub_conv(ch_hidden, kernel):
            pad = kernel // 2
            return lambda ch_in, ch_out: nn.Sequential(
                                            nn.Conv2d(ch_in, ch_hidden, kernel, padding=pad),
                                            nn.ReLU(),
                                            nn.Conv2d(ch_hidden, ch_out, kernel, padding=pad))

        def sub_fc(ch_hidden):
            return lambda ch_in, ch_out: nn.Sequential(
                                            nn.Linear(ch_in, ch_hidden),
                                            nn.ReLU(),
                                            nn.Linear(ch_hidden, ch_out))

        nodes = [Ff.InputNode(2, 64, 64)]
        # outputs of the cond. net at different resolution levels
        conditions = [Ff.ConditionNode(64, 64, 64),
                      Ff.ConditionNode(128, 32, 32),
                      Ff.ConditionNode(128, 16, 16),
                      Ff.ConditionNode(512)]

        split_nodes = []

        subnet = sub_conv(32, 3)
        for k in range(2):
            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':1.0},
                                 conditions=conditions[0]))

        nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5}))

        for k in range(4):
            subnet = sub_conv(64, 3 if k%2 else 1)

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':1.0},
                                 conditions=conditions[1]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        #split off 6/8 ch
        nodes.append(Ff.Node(nodes[-1], Fm.Split1D,
                             {'split_size_or_sections':[2,6], 'dim':0}))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))

        nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5}))

        for k in range(4):
            subnet = sub_conv(128, 3 if k%2 else 1)

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':0.6},
                                 conditions=conditions[2]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        #split off 4/8 ch
        nodes.append(Ff.Node(nodes[-1], Fm.Split1D,
                             {'split_size_or_sections':[4,4], 'dim':0}))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten'))

        # fully_connected part
        subnet = sub_fc(512)
        for k in range(4):
            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':0.6},
                                 conditions=conditions[3]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        # concat everything
        nodes.append(Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0],
                             Fm.Concat1d, {'dim':0}))
        nodes.append(Ff.OutputNode(nodes[-1]))

        return Ff.ReversibleGraphNet(nodes + split_nodes + conditions, verbose=False)
Esempio n. 24
0
def get_flow(
    n_in: int,
    n_out: int,
    init_identity: bool = False,
    coupling_block: Union[Literal["gin", "glow"]] = "gin",
    num_nodes: int = 8,
    node_size_factor: int = 1,
):
    """
    Creates an flow-based network.

    Args:
        n_in: Dimensionality of the input data
        n_out: Dimensionality of the output data
        init_identity: Initialize weights to identity network.
        coupling_block: Coupling method to use to combine nodes.
        num_nodes: Depth of the flow network.
        node_size_factor: Multiplier for the hidden units per node.
    """

    # do lazy imports here such that the package is only
    # required if one wants to use the flow mixing
    import FrEIA.framework as Ff
    import FrEIA.modules as Fm

    def _invertible_subnet_fc(c_in, c_out, init_identity):
        subnet = nn.Sequential(
            nn.Linear(c_in, c_in * node_size),
            nn.ReLU(),
            nn.Linear(c_in * node_size, c_in * node_size),
            nn.ReLU(),
            nn.Linear(c_in * node_size, c_out),
        )
        if init_identity:
            subnet[-1].weight.data.fill_(0.0)
            subnet[-1].bias.data.fill_(0.0)
        return subnet

    assert n_in == n_out

    if coupling_block == "gin":
        block = Fm.GINCouplingBlock
    else:
        assert coupling_block == "glow"
        block = Fm.GLOWCouplingBlock

    nodes = [Ff.InputNode(n_in, name="input")]

    for k in range(num_nodes):
        nodes.append(
            Ff.Node(
                nodes[-1],
                block,
                {
                    "subnet_constructor":
                    lambda c_in, c_out: _invertible_subnet_fc(
                        c_in, c_out, init_identity),
                    "clamp":
                    2.0,
                },
                name=f"coupling_{k}",
            ))

    nodes.append(Ff.OutputNode(nodes[-1], name="output"))
    return Ff.ReversibleGraphNet(nodes, verbose=False)
Esempio n. 25
0
def artset_inn_model(img_dims=[3, 224, 224]):
    """
    Return INN model for Painter by Numbers artset.

    :param img_dims: size of the model input images. Default: Size of MNIST images
    :return: INN model
    """

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node(
        [inp.out0],
        re.reshape_layer,
        {'target_dim': (img_dims[0] * 4, img_dims[1] // 2, img_dims[2] // 2)},
        name='r1')

    conv11 = fr.Node([r1.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 256
                         },
                         'clamp': 1
                     },
                     name='conv11')

    conv12 = fr.Node([conv11.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 256
                         },
                         'clamp': 1
                     },
                     name='conv12')

    conv13 = fr.Node([conv12.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 256
                         },
                         'clamp': 1
                     },
                     name='conv13')

    conv21 = fr.Node([conv13.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv21')

    conv22 = fr.Node([conv21.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv22')

    conv23 = fr.Node([conv22.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv23')

    conv31 = fr.Node([conv23.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 64
                         },
                         'clamp': 1
                     },
                     name='conv31')

    conv32 = fr.Node([conv31.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 64
                         },
                         'clamp': 1
                     },
                     name='conv32')

    conv33 = fr.Node([conv32.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 64
                         },
                         'clamp': 1
                     },
                     name='conv33')

    conv41 = fr.Node([conv33.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 32
                         },
                         'clamp': 1
                     },
                     name='conv41')

    conv42 = fr.Node([conv41.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 32
                         },
                         'clamp': 1
                     },
                     name='conv42')

    conv43 = fr.Node([conv42.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 32
                         },
                         'clamp': 1
                     },
                     name='conv43')

    r2 = fr.Node([conv43.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r2')

    fc = fr.Node(
        [r2.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_small_connected,
            'F_args': {
                'internal_size': 128
            },
            'clamp': 1
        },
        name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0], img_dims[1], img_dims[2])},
                 name='r3')

    outp = fr.OutputNode([r3.out0], name='output')

    nodes = [
        inp, outp, conv11, conv12, conv13, conv21, conv22, conv23, conv31,
        conv32, conv33, conv41, conv42, conv43, fc, r1, r2, r3
    ]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Esempio n. 26
0
def constuct_inn(classifier, verbose=False):

    fc_width = int(classifier.args['model']['fc_width'])
    n_coupling_blocks_fc = int(
        classifier.args['model']['n_coupling_blocks_fc'])

    use_dct = eval(classifier.args['model']['dct_pooling'])
    conv_widths = eval(classifier.args['model']['conv_widths'])
    n_coupling_blocks_conv = eval(
        classifier.args['model']['n_coupling_blocks_conv'])
    dropouts = eval(classifier.args['model']['dropout_conv'])
    dropouts_fc = float(classifier.args['model']['dropout_fc'])

    groups = int(classifier.args['model']['n_groups'])
    clamp = float(classifier.args['model']['clamp'])

    ndim_input = classifier.dims

    batchnorm_args = {
        'track_running_stats': True,
        'momentum': 0.999,
        'eps': 1e-4,
    }

    coupling_args = {
        'subnet_constructor': None,
        'clamp': clamp,
        'act_norm': float(classifier.args['model']['act_norm']),
        'gin_block': False,
        'permute_soft': True,
    }

    def weights_init(m):
        if type(m) == nn.Conv2d:
            torch.nn.init.kaiming_normal_(m.weight)
        if type(m) == nn.BatchNorm2d:
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        if type(m) == nn.Linear:
            torch.nn.init.kaiming_normal_(m.weight)
            m.weight.data *= 0.1

    def basic_residual_block(width, groups, dropout, relu_first, cin, cout):
        width = width * groups
        layers = []
        if relu_first:
            layers = [nn.ReLU()]
        else:
            layers = []

        layers.extend([
            nn.Conv2d(cin, width, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(width, **batchnorm_args),
            nn.ReLU(inplace=True),
            nn.Conv2d(width,
                      width,
                      kernel_size=3,
                      padding=1,
                      bias=False,
                      groups=groups),
            nn.BatchNorm2d(width, **batchnorm_args),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
            nn.Conv2d(width, cout, 1, padding=0)
        ])

        layers = nn.Sequential(*layers)
        layers.apply(weights_init)

        return layers

    def strided_residual_block(width, groups, cin, cout):
        width = width * groups
        layers = nn.Sequential(
            nn.Conv2d(cin, width, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(width, **batchnorm_args), nn.ReLU(),
            nn.Conv2d(width,
                      width,
                      kernel_size=3,
                      stride=2,
                      padding=1,
                      bias=False,
                      groups=groups), nn.BatchNorm2d(width, **batchnorm_args),
            nn.ReLU(inplace=True), nn.Conv2d(width, cout, 1, padding=0))

        layers.apply(weights_init)

        return layers

    def fc_constr(c_in, c_out):
        net = [
            nn.Linear(c_in, fc_width),
            nn.ReLU(),
            nn.Dropout(p=dropouts_fc),
            nn.Linear(fc_width, c_out)
        ]

        net = nn.Sequential(*net)
        net.apply(weights_init)
        return net

    nodes = [Ff.InputNode(*ndim_input, name='input')]
    channels = classifier.input_channels

    if classifier.dataset == 'MNIST':
        nodes.append(
            Ff.Node(nodes[-1].out0, Fm.Reshape,
                    {'target_dim': (1, *classifier.dims)}))
        nodes.append(
            Ff.Node(nodes[-1].out0, Fm.HaarDownsampling, {'rebalance': 1.}))
        channels *= 4

    for i, (width,
            n_blocks) in enumerate(zip(conv_widths, n_coupling_blocks_conv)):
        if classifier.dataset == 'MNIST' and i == 0:
            continue

        drop = dropouts[i]
        conv_constr = partial(basic_residual_block, width, groups, drop, True)
        conv_strided = partial(strided_residual_block, width * 2, groups)
        conv_lowres = partial(basic_residual_block, width * 2, groups, drop,
                              False)

        if i == 0:
            conv_first = partial(basic_residual_block, width, groups, drop,
                                 False)
        else:
            conv_first = conv_constr

        nodes.append(
            Ff.Node(nodes[-1],
                    AIO_Block,
                    dict(coupling_args, subnet_constructor=conv_first),
                    name=f'CONV_{i}_0'))
        for k in range(1, n_blocks):
            nodes.append(
                Ff.Node(nodes[-1],
                        AIO_Block,
                        dict(coupling_args, subnet_constructor=conv_constr),
                        name=f'CONV_{i}_{k}'))

        if i < len(conv_widths) - 1:
            nodes.append(
                Ff.Node(nodes[-1],
                        DownsampleCouplingBlock, {
                            'subnet_constructor_low_res': conv_lowres,
                            'subnet_constructor_strided': conv_strided,
                            'clamp': clamp
                        },
                        name=f'DOWN_{i}'))
            channels *= 4

    if use_dct:
        nodes.append(
            Ff.Node(nodes[-1].out0,
                    DCTPooling2d, {'rebalance': 0.5},
                    name='DCT'))
    else:
        nodes.append(Ff.Node(nodes[-1].out0, Fm.Flatten, {}, name='Flatten'))

    for k in range(n_coupling_blocks_fc):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.PermuteRandom, {'seed': k},
                    name=f'PERM_FC_{k}'))
        nodes.append(
            Ff.Node(nodes[-1].out0,
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor': fc_constr,
                        'clamp': 2.0
                    },
                    name=f'FC_{k}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    return Ff.ReversibleGraphNet(nodes, verbose=verbose)