def construct_inn(self, input):

        nodes = []
        split_nodes = []

        dctpooling = Ff.Node(input.out0,
                             DCTPooling2d, {'rebalance': 0.5},
                             name='DCT')
        nodes.append(dctpooling)

        split_node = Ff.Node(dctpooling.out0,
                             Fm.Split1D, {
                                 'split_size_or_sections':
                                 (self.n_loss_dims_1d,
                                  self.n_total_dims_1d - self.n_loss_dims_1d),
                                 'dim':
                                 0
                             },
                             name='exit_flow split')
        split_nodes.append(split_node)

        output_node = Ff.OutputNode(split_node.out1, name="out_conv")
        nodes.append(output_node)

        random_permute = Ff.Node(split_node.out0,
                                 Fm.PermuteRandom, {'seed': 0},
                                 name=f'PERM_FC_{0} 1')
        nodes.append(random_permute)

        return nodes, split_nodes
def constuct_inn(img_dims=[3, 32, 32]):

    split_nodes = []

    nodes = [fr.InputNode(*img_dims, name='input')]

    nodes.append(
        fr.Node(nodes[-1].out0,
                la.Reshape, {'target_dim': (img_dims[0], *img_dims[1:])},
                name='reshape'))
    nodes.append(
        fr.Node(nodes[-1].out0,
                la.HaarDownsampling, {'rebalance': 0.5},
                name='haar_down_1'))

    for k in range(n_coupling_blocks_conv_0):
        nodes.append(
            fr.Node(nodes[-1],
                    la.GLOWCouplingBlock, {
                        'subnet_constructor': conv_constr_0,
                        'clamp': clamp
                    },
                    name=f'CB CONV_0_{k}'))
        if k % 2:
            nodes.append(
                fr.Node(nodes[-1], la.Fixed1x1Conv, {'M': random_orthog(12)}))

    nodes.append(
        fr.Node(nodes[-1].out0,
                la.HaarDownsampling, {'rebalance': 0.5},
                name='haar_down_2'))

    for k in range(n_coupling_blocks_conv_1):
        nodes.append(
            fr.Node(nodes[-1],
                    la.GLOWCouplingBlock, {
                        'subnet_constructor': conv_constr_1,
                        'clamp': clamp
                    },
                    name=f'CB CONV_1_{k}'))
        if k % 2:
            nodes.append(
                fr.Node(nodes[-1], la.Fixed1x1Conv, {'M': random_orthog(48)}))

    nodes.append(fr.Node(nodes[-1].out0, la.Flatten, {}, name='flatten'))

    for k in range(n_coupling_blocks_fc):
        nodes.append(
            fr.Node(nodes[-1],
                    la.GLOWCouplingBlock, {
                        'subnet_constructor': fc_constr,
                        'clamp': clamp
                    },
                    name=f'CB FC_{k}'))
        if k % 2:
            nodes.append(fr.Node(nodes[-1], la.PermuteRandom, {'seed': k}))

    nodes.append(fr.OutputNode(nodes[-1], name='output'))

    return fr.ReversibleGraphNet(nodes)
def get_nodes(ndim,
              nb_blocks=4,
              hidden_layer=128,
              small_block=False,
              permute=True):
    nodes = [freia_fw.InputNode(ndim, name='input')]

    for i in range(nb_blocks):

        def F_fully_connected_wrapper(ch_in, ch_out):
            if not small_block:
                return freia_mod.F_fully_connected(ch_in,
                                                   ch_out,
                                                   internal_size=hidden_layer)
            return F_fully_connected_small(ch_in,
                                           ch_out,
                                           internal_size=hidden_layer)

        nodes.append(
            freia_fw.Node([nodes[-1].out0],
                          freia_mod.RNVPCouplingBlock, {
                              'subnet_constructor': F_fully_connected_wrapper,
                          },
                          name='coupling_{}'.format(i)))

        if permute:
            nodes.append(
                freia_fw.Node([nodes[-1].out0],
                              freia_mod.PermuteRandom, {'seed': i},
                              name='permute_{}'.format(i)))

    nodes.append(freia_fw.OutputNode([nodes[-1].out0], name='output'))
    return nodes
Beispiel #4
0
def one_step(node_list, subnet_mode, index, f_size):
    # ActNorm
    node_list.append(
        Ff.Node(node_list[-1], Fm.ActNorm, {}, name=F'actnorm_{index}'))
    # PermuteRandom
    node_list.append(
        Ff.Node(node_list[-1],
                Fm.PermuteRandom, {'seed': index},
                name=F'permute_{index}'))
    # Affine Coupling layer with different subnets alternated
    if subnet_mode == 'base':
        subnet, sn = (subnet_conv_3x3,
                      'conv3') if index % 2 == 0 else (subnet_conv_1x1,
                                                       'conv1')
    elif subnet_mode == 'conv1_to_attn':
        subnet, sn = (subnet_conv_3x3,
                      'conv3') if index % 2 == 0 else (partial(
                          subnet_attn, f_size=f_size), 'attn')
    else:
        raise NotImplementedError
    node_list.append(
        Ff.Node(node_list[-1],
                Fm.GLOWCouplingBlock, {
                    'subnet_constructor': subnet,
                    'clamp': 1.2
                },
                name=F'affinecoupling_{index}_{sn}'))

    return node_list
Beispiel #5
0
def generate_rcINN_old():
    cond = Ff.ConditionNode(config.n_cond_features, name='condition')
    nodes = [Ff.InputNode(config.n_x_features, name='input')]

    for k in range(config.n_blocks):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor': subnet_func,
                        'clamp': config.clamp
                    },
                    conditions=cond,
                    name=F'coupling_{k}'))
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.PermuteRandom, {'seed': k},
                    name=F'permute_{k}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    model = Ff.ReversibleGraphNet(nodes + [cond])
    model.to(device)

    params_trainable = list(
        filter(lambda p: p.requires_grad, model.parameters()))
    for p in params_trainable:
        p.data = config.init_scale * torch.randn_like(p).to(device)

    gamma = config.gamma
    optim = torch.optim.AdamW(params_trainable, lr=config.lr)
    weight_scheduler = torch.optim.lr_scheduler.StepLR(optim,
                                                       step_size=1,
                                                       gamma=gamma)

    return model, optim, weight_scheduler
Beispiel #6
0
    def build_inn(self):

        def fc_constr():
            return lambda ch_in, ch_out: nn.Sequential(nn.Linear(ch_in, c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout),
                         nn.Linear(c.internal_width,  c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout),
                         nn.Linear(c.internal_width,  ch_out))

        nodes = [Ff.InputNode(c.x_dim)]

        # outputs of the cond. net at different resolution levels
        conditions = []

        for i in range(c.n_blocks):
            conditions.append(Ff.ConditionNode(c.y_dim_features))

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':fc_constr(), 'clamp':c.exponent_clamping},
                                 conditions=conditions[i]))
            #nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':i}, name=f'PERM_FC_{i}'))
            #nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}, name=f'ActN{i}'))
            

        nodes.append(Ff.OutputNode(nodes[-1]))

        return Ff.ReversibleGraphNet(nodes + conditions, verbose=False)
Beispiel #7
0
    def _inn(self):
        def subnet(ch_in, ch_out):
            return nn.Sequential(nn.Linear(ch_in, 1024), nn.ReLU(),
                                 nn.Linear(1024, ch_out))

        nodes = []
        nodes.append(ff.InputNode(self.feat_length))
        for k in range(self.depth):
            nodes.append(ff.Node(nodes[-1], fm.PermuteRandom, {'seed': k}))
            nodes.append(
                ff.Node(nodes[-1], fm.GLOWCouplingBlock, {
                    'subnet_constructor': subnet,
                    'clamp': 2.0
                }))

        return ff.ReversibleGraphNet(nodes + [ff.OutputNode(nodes[-1])],
                                     verbose=False)
Beispiel #8
0
    def __init__(self, *args):
        super().__init__(*args)

        self.batch_size = 32
        self.inv_tol = 1e-4
        torch.manual_seed(self.batch_size)

        self.inp_size = (3, 10, 10)
        self.c1_size = (1, 10, 10)
        self.c2_size = (50, )
        self.c3_size = (20, )

        self.x = torch.randn(self.batch_size, *self.inp_size)
        self.c1 = torch.randn(self.batch_size, *self.c1_size)
        self.c2 = torch.randn(self.batch_size, *self.c2_size)
        self.c3 = torch.randn(self.batch_size, *self.c3_size)

        # this is only used for the cuda variant of the tests.
        # if true, all tests are skipped.
        self.skip_all = False

        inp = Ff.InputNode(*self.inp_size, name='input')
        c1 = Ff.ConditionNode(*self.c1_size, name='c1')
        c2 = Ff.ConditionNode(*self.c2_size, name='c2')
        c3 = Ff.ConditionNode(*self.c3_size, name='c3')

        conv = Ff.Node(inp,
                       Fm.RNVPCouplingBlock, {
                           'subnet_constructor': F_conv,
                           'clamp': 1.0
                       },
                       conditions=c1,
                       name='conv::c1')
        flatten = Ff.Node(conv, Fm.Flatten, {}, name='flatten')

        linear = Ff.Node(flatten,
                         Fm.RNVPCouplingBlock, {
                             'subnet_constructor': F_fully_connected,
                             'clamp': 1.0
                         },
                         conditions=[c2, c3],
                         name='linear::c2|c3')

        outp = Ff.OutputNode(linear, name='output')
        self.test_net = Ff.GraphINN(
            [inp, c1, conv, flatten, c2, c3, linear, outp])
    def __init__(self, *args):
        super().__init__(*args)

        self.inp_size = (3, 10, 10)
        self.cond_size = (1, 10, 10)

        inp = Ff.InputNode(*self.inp_size, name='input')
        cond = Ff.ConditionNode(*self.cond_size, name='cond')
        split = Ff.Node(inp, Fm.Split, {'section_sizes': [1,2], 'dim': 0}, name='split1')

        flatten1 = Ff.Node(split.out0, Fm.Flatten, {}, name='flatten1')
        perm = Ff.Node(flatten1, Fm.PermuteRandom, {'seed': 123}, name='perm')
        unflatten1 = Ff.Node(perm, Fm.Reshape, {'output_dims': (1, 10, 10)}, name='unflatten1')

        conv = Ff.Node(split.out1,
                       Fm.RNVPCouplingBlock,
                       {'subnet_constructor': F_conv, 'clamp': 1.0},
                       conditions=cond,
                       name='conv')

        flatten2 = Ff.Node(conv, Fm.Flatten, {}, name='flatten2')

        linear = Ff.Node(flatten2,
                         Fm.RNVPCouplingBlock,
                         {'subnet_constructor': F_fully_connected, 'clamp': 1.0},
                         name='linear')

        unflatten2 = Ff.Node(linear, Fm.Reshape, {'output_dims': (2, 10, 10)}, name='unflatten2')
        concat = Ff.Node([unflatten1.out0, unflatten2.out0], Fm.Concat, {'dim': 0}, name='concat')
        haar = Ff.Node(concat, Fm.HaarDownsampling, {}, name='haar')
        out = Ff.OutputNode(haar, name='output')

        self.test_net = Ff.GraphINN([inp, cond, split, flatten1, perm, unflatten1, conv,
                                     flatten2, linear, unflatten2, concat, haar, out])

        # this is only used for the cuda variant of the tests.
        # if true, all tests are skipped.
        self.skip_all = False

        self.batch_size = 32
        self.inv_tol = 1e-4
        torch.manual_seed(self.batch_size)

        self.x = torch.randn(self.batch_size, *self.inp_size)
        self.cond = torch.randn(self.batch_size, *self.cond_size)
Beispiel #10
0
def construct_net_10d(coupling_block, init_identity=True):
    if coupling_block == 'gin':
        block = Fm.GINCouplingBlock
    else:
        assert coupling_block == 'glow'
        block = Fm.GLOWCouplingBlock
    
    nodes = [Ff.InputNode(10, name='input')]
    
    for k in range(8):
        nodes.append(Ff.Node(nodes[-1], block,
                             {'subnet_constructor':lambda c_in,c_out: subnet_fc_10d(c_in, c_out, init_identity), 'clamp':2.0},
                             name=F'coupling_{k}'))
        nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom,
                        {'seed':np.random.randint(2**31)},
                        name=F'permute_{k+1}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    return Ff.ReversibleGraphNet(nodes)
Beispiel #11
0
    def build_inn(self):
        def subnet(ch_in, ch_out):
            return nn.Sequential(nn.Linear(ch_in, 512), nn.ReLU(),
                                 nn.Linear(512, ch_out))

        nodes = []
        nodes.append(ff.InputNode(28 * 28))

        nodes.append(ff.Node(nodes[-1], fm.Flatten, {}))

        for k in range(20):
            nodes.append(ff.Node(nodes[-1], fm.PermuteRandom, {'seed': k}))
            nodes.append(
                ff.Node(nodes[-1], fm.GLOWCouplingBlock, {
                    'subnet_constructor': subnet,
                    'clamp': 1.0
                }))

        return ff.ReversibleGraphNet(nodes + [ff.OutputNode(nodes[-1])],
                                     verbose=False)
def mnist_inn_one(mask_size=[28, 28]):
    """
    Return an autoencoder.

    :param mask_size: size of the input. Default: Size of MNIST images
    :return:
    """

    img_dims = [1, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    r2 = fr.Node([r1.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r2')

    fc = fr.Node([r2.out0],
                 la.rev_multiplicative_layer, {
                     'F_class': fu.F_fully_connected,
                     'F_args': {},
                     'clamp': 1
                 },
                 name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (4, 14, 14)},
                 name='r3')

    r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4')

    outp = fr.OutputNode([r4.out0], name='output')

    nodes = [inp, outp, r1, r2, r3, r4, fc]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Beispiel #13
0
    def build_inn(self):
        def subnet(ch_in, ch_out):
            return nn.Sequential(nn.Linear(ch_in, 512), nn.LeakyReLU(),
                                 nn.Linear(512, ch_out))

        cond = Ff.ConditionNode(10)
        nodes = [Ff.InputNode(1, 28, 28)]
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}))

        for k in range(16):
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
            nodes.append(
                Ff.Node(nodes[-1],
                        Fm.GLOWCouplingBlock, {
                            'subnet_constructor': subnet,
                            'clamp': 0.8
                        },
                        conditions=cond))

        return Ff.ReversibleGraphNet(nodes +
                                     [cond, Ff.OutputNode(nodes[-1])],
                                     verbose=False)
Beispiel #14
0
def construct_net_emnist(coupling_block):
    if coupling_block == 'gin':
        block = Fm.GINCouplingBlock
    else:
        assert coupling_block == 'glow'
        block = Fm.GLOWCouplingBlock
    
    nodes = [Ff.InputNode(1, 28, 28, name='input')]
    nodes.append(Ff.Node(nodes[-1], Fm.IRevNetDownsampling, {}, name='downsample1'))

    for k in range(4):
        nodes.append(Ff.Node(nodes[-1], block,
                             {'subnet_constructor':subnet_conv1, 'clamp':2.0},
                             name=F'coupling_conv1_{k}'))
        nodes.append(Ff.Node(nodes[-1],
                             Fm.PermuteRandom,
                             {'seed':np.random.randint(2**31)},
                             name=F'permute_conv1_{k}'))

    nodes.append(Ff.Node(nodes[-1], Fm.IRevNetDownsampling, {}, name='downsample2'))

    for k in range(4):
        nodes.append(Ff.Node(nodes[-1], block,
                             {'subnet_constructor':subnet_conv2, 'clamp':2.0},
                             name=F'coupling_conv2_{k}'))
        nodes.append(Ff.Node(nodes[-1],
                             Fm.PermuteRandom,
                             {'seed':np.random.randint(2**31)},
                             name=F'permute_conv2_{k}'))

    nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten'))

    for k in range(2):
        nodes.append(Ff.Node(nodes[-1], block,
                             {'subnet_constructor':subnet_fc, 'clamp':2.0},
                             name=F'coupling_fc_{k}'))
        nodes.append(Ff.Node(nodes[-1],
                             Fm.PermuteRandom,
                             {'seed':np.random.randint(2**31)},
                             name=F'permute_fc_{k}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    return Ff.ReversibleGraphNet(nodes)
def inn_model(img_dims=4):
    """
    Return INN model.

    :param img_dims: size of the model input images. Default: Size of MNIST images
    :return: INN model
    """

    inp = fr.InputNode(img_dims, name='input')

    fc1 = fr.Node([inp.out0],
                  la.GLOWCouplingBlock, {
                      'subnet_constructor': fc_constr,
                      'clamp': 2
                  },
                  name='fc1')
    fc2 = fr.Node([fc1.out0],
                  la.GLOWCouplingBlock, {
                      'subnet_constructor': fc_constr,
                      'clamp': 2
                  },
                  name='fc2')
    fc3 = fr.Node([fc2.out0],
                  la.GLOWCouplingBlock, {
                      'subnet_constructor': fc_constr,
                      'clamp': 2
                  },
                  name='fc3')

    outp = fr.OutputNode([fc3.out0], name='output')

    nodes = [inp, outp, fc1, fc2, fc3]

    model = fr.ReversibleGraphNet(nodes)

    return model
Beispiel #16
0
def build_inn(fe_c, fe_h, fe_w, subnet_mode, flow_blocks, verbose):

    nodes = [Ff.InputNode(fe_c, fe_h, fe_w, name='input')]

    dynamic_var = locals()
    L, index = len(flow_blocks), 0  # multiscale structure design if L > 1
    for l in range(L):
        for k in range(flow_blocks[l]):
            nodes = one_step(nodes, subnet_mode, index, fe_h)
            index += 1
        if l == L - 1:
            nodes.append(Ff.OutputNode(nodes[-1], name=F'output{l}'))
        else:
            # split off half of the channels way to reduce dimension redundancy
            nodes.append(
                Ff.Node(nodes[-1], Fm.SplitChannel, {}, name=F'split{l}'))
            dynamic_var['output' + str(l)] = Ff.OutputNode(nodes[-1].out1,
                                                           name=F'output{l}')

    return Ff.ReversibleGraphNet(
        nodes + [dynamic_var['output' + str(i)] for i in range(L - 1)],
        verbose=verbose)
Beispiel #17
0
def baseline_color(m_params):
    if not m_params['permute'] in [
            'soft', 'random', 'none', 'false', None, False
    ]:
        raise (RuntimeError(
            "Erros in model params: No 'permute'' selected or not understood.")
               )
    if not m_params['act_norm'] in [
            'learnednorm', 'movingavg', 'none', 'false', None, False
    ]:
        raise (RuntimeError(
            "Erros in model params: No 'act_norm' selected or not understood.")
               )

    cond = CondNet(m_params)

    nodes = [Ff.InputNode(2, 64, 64)]
    # outputs of the cond. net at different resolution levels
    conditions = [
        Ff.ConditionNode(64, 64, 64),
        Ff.ConditionNode(128, 32, 32),
        Ff.ConditionNode(128, 16, 16),
        Ff.ConditionNode(512)
    ]

    split_nodes = []
    for k in range(m_params['blocks_per_group'][0]):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor':
                        sub_conv(64,
                                 m_params['kernel_size_per_group'][0],
                                 m_params['hidden_layers_per_group'][0],
                                 batchnorm=m_params.get("bn", True)),
                        'clamp':
                        m_params['clamping_per_group'][0]
                    },
                    conditions=conditions[0],
                    name=F'block_{k}'))
        if m_params['permute'] == 'random':
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
        elif m_params['permute'] == 'soft':
            nodes.append(
                Ff.Node([nodes[-1].out0], Fm.conv_1x1,
                        {'M': random_orthog(2)}))
        if m_params['act_norm'] == 'learnednorm':
            nodes.append(
                Ff.Node(nodes[-1], LearnedActNorm, {
                    'M': torch.randn(1),
                    "b": torch.randn(1)
                }))
        elif m_params['act_norm'] == 'movingavg':
            nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}))

    nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5}))

    for k in range(m_params['blocks_per_group'][1]):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor':
                        sub_conv(128,
                                 m_params['kernel_size_per_group'][1],
                                 m_params['hidden_layers_per_group'][1],
                                 batchnorm=m_params.get("bn", True)),
                        'clamp':
                        m_params['clamping_per_group'][1],
                    },
                    conditions=conditions[1],
                    name=F'block_{k + 2}'))
        if m_params['permute'] == 'random':
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
        elif m_params['permute'] == 'soft':
            nodes.append(
                Ff.Node([nodes[-1].out0], Fm.conv_1x1,
                        {'M': random_orthog(8)}))
        if m_params['act_norm'] == 'learnednorm':
            nodes.append(
                Ff.Node(nodes[-1], LearnedActNorm, {
                    'M': torch.randn(1),
                    "b": torch.randn(1)
                }))
        elif m_params['act_norm'] == 'movingavg':
            nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}))

    # split off 8/12 ch
    if m_params.get("split", True):
        nodes.append(
            Ff.Node(nodes[-1], Fm.Split1D, {
                'split_size_or_sections': [6, 2],
                'dim': 0
            }))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))

    nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5}))

    for k in range(m_params['blocks_per_group'][2]):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor':
                        sub_conv(256,
                                 m_params['kernel_size_per_group'][2][k],
                                 m_params['hidden_layers_per_group'][2],
                                 batchnorm=m_params.get("bn", True)),
                        'clamp':
                        m_params['clamping_per_group'][2],
                    },
                    conditions=conditions[2],
                    name=F'block_{k + 6}'))
        if m_params['permute'] == 'random':
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
        elif m_params['permute'] == 'soft':
            nodes.append(
                Ff.Node([nodes[-1].out0], Fm.conv_1x1, {
                    'M':
                    random_orthog(24 if m_params.get("split", True) else 32)
                }))
        if m_params['act_norm'] == 'learnednorm':
            nodes.append(
                Ff.Node(nodes[-1], LearnedActNorm, {
                    'M': torch.randn(1),
                    "b": torch.randn(1)
                }))
        elif m_params['act_norm'] == 'movingavg':
            nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}))

    # split off 8/16 ch
    if m_params.get("split", True):
        nodes.append(
            Ff.Node(nodes[-1], Fm.Split1D, {
                'split_size_or_sections': [12, 12],
                'dim': 0
            }))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))
    nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten'))

    # fully_connected part
    for k in range(m_params['blocks_per_group'][3]):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'clamp':
                        m_params['clamping_per_group'][3],
                        'subnet_constructor':
                        sub_fc(m_params['fc_size'],
                               m_params['hidden_layers_per_group'][3],
                               dropout=m_params['dropout_fc'],
                               batchnorm=m_params.get("bn", True))
                    },
                    conditions=conditions[3],
                    name=F'block_{k + 10}'))

        #nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
        #nodes.append(Ff.Node(nodes[-1], LearnedActNorm , {'M': torch.randn(1), "b": torch.randn(1)}))
    # concat everything
    if m_params.get("split", True):
        nodes.append(
            Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0],
                    Fm.Concat1d, {'dim': 0}))
    nodes.append(Ff.OutputNode(nodes[-1]))
    inn = SketchINN(cond)
    inn.build_inn(nodes, split_nodes, conditions)
    return inn
Beispiel #18
0
def get_flow(
    n_in: int,
    n_out: int,
    init_identity: bool = False,
    coupling_block: Union[Literal["gin", "glow"]] = "gin",
    num_nodes: int = 8,
    node_size_factor: int = 1,
):
    """
    Creates an flow-based network.

    Args:
        n_in: Dimensionality of the input data
        n_out: Dimensionality of the output data
        init_identity: Initialize weights to identity network.
        coupling_block: Coupling method to use to combine nodes.
        num_nodes: Depth of the flow network.
        node_size_factor: Multiplier for the hidden units per node.
    """

    # do lazy imports here such that the package is only
    # required if one wants to use the flow mixing
    import FrEIA.framework as Ff
    import FrEIA.modules as Fm

    def _invertible_subnet_fc(c_in, c_out, init_identity):
        subnet = nn.Sequential(
            nn.Linear(c_in, c_in * node_size),
            nn.ReLU(),
            nn.Linear(c_in * node_size, c_in * node_size),
            nn.ReLU(),
            nn.Linear(c_in * node_size, c_out),
        )
        if init_identity:
            subnet[-1].weight.data.fill_(0.0)
            subnet[-1].bias.data.fill_(0.0)
        return subnet

    assert n_in == n_out

    if coupling_block == "gin":
        block = Fm.GINCouplingBlock
    else:
        assert coupling_block == "glow"
        block = Fm.GLOWCouplingBlock

    nodes = [Ff.InputNode(n_in, name="input")]

    for k in range(num_nodes):
        nodes.append(
            Ff.Node(
                nodes[-1],
                block,
                {
                    "subnet_constructor":
                    lambda c_in, c_out: _invertible_subnet_fc(
                        c_in, c_out, init_identity),
                    "clamp":
                    2.0,
                },
                name=f"coupling_{k}",
            ))

    nodes.append(Ff.OutputNode(nodes[-1], name="output"))
    return Ff.ReversibleGraphNet(nodes, verbose=False)
def get_mnist_conv(mask_size=[28, 28]):
    """
    Return an autoencoder.

    :param mask_size: size of the input. Default: Size of MNIST images
    :return:
    """

    img_dims = [1, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    conv1 = fr.Node(
        [r1.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': False
            },
            'clamp': 1
        },
        name='conv1')
    conv2 = fr.Node(
        [conv1.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': False
            },
            'clamp': 1
        },
        name='conv2')
    conv3 = fr.Node(
        [conv2.out0],
        la.rev_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': False
            }
        },
        name='conv3')

    r2 = fr.Node([conv3.out0],
                 re.reshape_layer, {'target_dim': (784, )},
                 name='r2')

    fc = fr.Node([r2.out0],
                 la.rev_multiplicative_layer, {
                     'F_class': fu.F_fully_connected,
                     'F_args': {},
                     'clamp': 1
                 },
                 name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (4, 14, 14)},
                 name='r3')

    r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4')

    outp = fr.OutputNode([r4.out0], name='output')

    nodes = [inp, outp, conv1, conv2, conv3, r1, r2, r3, r4, fc]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Beispiel #20
0
    import pdb

    import FrEIA.framework as Ff
    import FrEIA.modules as Fm

    def strided_constr(cin, cout):
        layers = [ nn.Conv2d(cin, 16, 3, stride=2, padding=1),
                   nn.ReLU(),
                   nn.Conv2d(16, cout, 1, stride=1)]
        return nn.Sequential(*layers)

    def low_res_constr(cin, cout):
        layers = [nn.Conv2d(cin, cout, 1)]
        return nn.Sequential(*layers)
    inp = Ff.InputNode(3, 32, 32, name='in')
    node = Ff.Node(inp, HighPerfCouplingBlock, {'subnet_constructor':low_res_constr}, name='coupling')
    node2 = Ff.Node(node, DownsampleCouplingBlock, {'subnet_constructor_strided':strided_constr,
                                                  'subnet_constructor_low_res':low_res_constr}, name='down_coupling')
    out = Ff.OutputNode(node2, name='out')

    net = Ff.ReversibleGraphNet([inp, node, node2, out])

    x = torch.randn(4, 3, 32, 32)

    z = net(x)
    jac = net.log_jacobian(run_forward=False)

    x_inv = net(z, rev=True)

    diff = x - x_inv
    print('shape in')
def celeba_inn_glow(mask_size=[156, 128], batch_norm=False):
    """
    Return an autoencoder.

    :param mask_size: size of the input. Default: Size of CIFAR images
    :return: autoencoder
    """

    img_dims = [3, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    conv11 = fr.Node(
        [r1.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv11')
    conv12 = fr.Node(
        [conv11.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv12')
    conv13 = fr.Node(
        [conv12.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv13')

    r2 = fr.Node([conv13.out0], re.haar_multiplex_layer, {}, name='r2')

    conv21 = fr.Node(
        [r2.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv21')
    conv22 = fr.Node(
        [conv21.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv22')
    conv23 = fr.Node(
        [conv22.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv23')

    r3 = fr.Node([conv23.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r3')

    fc = fr.Node(
        [r3.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_small_connected,
            'F_args': {
                'internal_size': 500
            },
            'clamp': 1
        },
        name='fc')

    r4 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (48, 39, 32)},
                 name='r4')

    r5 = fr.Node([r4.out0], re.haar_restore_layer, {}, name='r5')

    r6 = fr.Node([r5.out0], re.haar_restore_layer, {}, name='r6')

    outp = fr.OutputNode([r6.out0], name='output')

    nodes = [
        inp, outp, conv11, conv12, conv13, conv21, conv22, conv23, r1, r2, r3,
        r4, r5, r6, fc
    ]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Beispiel #22
0
# Build the network
#
print("Building network")

# First create the input node
nodes = [Ff.InputNode(n_dim, name="input")]

# Next create the node for internal coordinates
mixed_nodes = mixed_transform.build_mixed_transformation_layers(
    nodes[-1], training_data, t.topology, "backbone")
nodes.extend(mixed_nodes)

n_glow = 8
for i in range(n_glow):
    nodes.append(
        Ff.Node(nodes[-1], Fm.PermuteRandom, {"seed": i}, name=f"permute_{i}"))
    nodes.append(
        Ff.Node(
            nodes[-1],
            Fm.GLOWCouplingBlock,
            {
                "subnet_constructor": CreateFC(128),
                "clamp": 2
            },
            name=f"glow_{i}",
        ))
nodes.append(Ff.OutputNode(nodes[-1], name="output"))
net = Ff.ReversibleGraphNet(nodes, verbose=False)
net = net.to(device=device)

print("Training")
Beispiel #23
0
    def build_inn(self):

        nodes = [Ff.InputNode(3, 64, 64, name='inp')]
        cond_node = Ff.ConditionNode(40, name='attr_cond')
        split_nodes = []

        # ndim_x = 3 * 64 * 64

        for i in range(2):
            nodes.append(
                Ff.Node([nodes[-1].out0],
                        permute_layer, {'seed': i},
                        name=F'permute_0_{i}'))
            nodes.append(
                Ff.Node(
                    [nodes[-1].out0],
                    glow_coupling_layer, {
                        'clamp': 1.,
                        'F_class': F_fully_conv,
                        'F_args': {
                            'kernel_size': 3,
                            'channels_hidden': 256,
                            'leaky_slope': 0.2
                        }
                    },
                    name=F'conv_0_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.HaarDownsampling, {'rebalance': 0.5},
                    name='haar_1'))

        for i in range(4):
            kernel_size = 1 if i % 2 == 0 else 3
            nodes.append(
                Ff.Node([nodes[-1].out0],
                        permute_layer, {'seed': i},
                        name=F'permute_1_{i}'))
            nodes.append(
                Ff.Node(
                    [nodes[-1].out0],
                    glow_coupling_layer, {
                        'clamp': 1.,
                        'F_class': F_fully_conv_shallow,
                        'F_args': {
                            'kernel_size': kernel_size,
                            'channels_hidden': 64,
                            'leaky_slope': 0.2
                        }
                    },
                    name=F'conv_1_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.Split1D, {
                        'split_size_or_sections': (8, 4),
                        'dim': 0
                    },
                    name='split_1'))
        split_nodes.append(
            Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_1'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.HaarDownsampling, {'rebalance': 0.5},
                    name='haar_2'))

        for i in range(8):
            kernel_size = 1 if i % 2 == 0 else 3

            nodes.append(
                Ff.Node(nodes[-1],
                        permute_layer, {'seed': i},
                        name=F'permute_2_{i}'))
            nodes.append(
                Ff.Node(nodes[-1],
                        glow_coupling_layer, {
                            'clamp': 1.,
                            'F_class': F_fully_conv_shallow,
                            'F_args': {
                                'kernel_size': kernel_size,
                                'channels_hidden': 64,
                                'leaky_slope': 0.2
                            }
                        },
                        name=F'conv_2_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.Split1D, {
                        'split_size_or_sections': (16, 16),
                        'dim': 0
                    },
                    name='split_2'))
        split_nodes.append(
            Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_2'))
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten_3'))

        for i in range(2):
            nodes.append(
                Ff.Node(nodes[-1],
                        permute_layer, {'seed': i},
                        name=F'permute_3_{i}'))
            nodes.append(
                Ff.Node(nodes[-1],
                        glow_coupling_layer, {
                            'clamp': 0.8,
                            'F_class': F_fully_shallow,
                            'F_args': {
                                'internal_size': 256,
                                'dropout': 0.2
                            }
                        },
                        conditions=[cond_node],
                        name=F'fc_{i}'))

        nodes.append(
            Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0],
                    Fm.Concat1d, {'dim': 0},
                    name='concat'))
        nodes.append(Ff.OutputNode(nodes[-1], name='out'))

        conv_inn = Ff.ReversibleGraphNet(nodes + split_nodes + [cond_node],
                                         verbose=True)

        return conv_inn
def cifar_inn_glow(mask_size=[32, 32], batch_norm=False):
    """
    Return an autoencoder.

    :param mask_size: size of the input. Default: Size of CIFAR images
    :return: autoencoder
    """

    img_dims = [3, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    conv11 = fr.Node(
        [r1.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv11')
    conv12 = fr.Node(
        [conv11.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv12')
    conv13 = fr.Node(
        [conv12.out0],
        la.glow_coupling_layer, {
            'F_class': fu.F_conv,
            'F_args': {
                'channels_hidden': 128,
                'batch_norm': batch_norm
            },
            'clamp': 1
        },
        name='conv13')

    r2 = fr.Node([conv13.out0],
                 re.reshape_layer, {'target_dim': (3072, )},
                 name='r2')

    fc = fr.Node([r2.out0],
                 la.glow_coupling_layer, {
                     'F_class': fu.F_fully_connected,
                     'clamp': 1
                 },
                 name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (12, 16, 16)},
                 name='r3')

    r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4')

    outp = fr.OutputNode([r4.out0], name='output')

    nodes = [inp, outp, conv11, conv12, conv13, r1, r2, r3, r4, fc]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Beispiel #25
0
def constuct_inn(classifier, verbose=False):

    fc_width = int(classifier.args['model']['fc_width'])
    n_coupling_blocks_fc = int(
        classifier.args['model']['n_coupling_blocks_fc'])

    use_dct = eval(classifier.args['model']['dct_pooling'])
    conv_widths = eval(classifier.args['model']['conv_widths'])
    n_coupling_blocks_conv = eval(
        classifier.args['model']['n_coupling_blocks_conv'])
    dropouts = eval(classifier.args['model']['dropout_conv'])
    dropouts_fc = float(classifier.args['model']['dropout_fc'])

    groups = int(classifier.args['model']['n_groups'])
    clamp = float(classifier.args['model']['clamp'])

    ndim_input = classifier.dims

    batchnorm_args = {
        'track_running_stats': True,
        'momentum': 0.999,
        'eps': 1e-4,
    }

    coupling_args = {
        'subnet_constructor': None,
        'clamp': clamp,
        'act_norm': float(classifier.args['model']['act_norm']),
        'gin_block': False,
        'permute_soft': True,
    }

    def weights_init(m):
        if type(m) == nn.Conv2d:
            torch.nn.init.kaiming_normal_(m.weight)
        if type(m) == nn.BatchNorm2d:
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        if type(m) == nn.Linear:
            torch.nn.init.kaiming_normal_(m.weight)
            m.weight.data *= 0.1

    def basic_residual_block(width, groups, dropout, relu_first, cin, cout):
        width = width * groups
        layers = []
        if relu_first:
            layers = [nn.ReLU()]
        else:
            layers = []

        layers.extend([
            nn.Conv2d(cin, width, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(width, **batchnorm_args),
            nn.ReLU(inplace=True),
            nn.Conv2d(width,
                      width,
                      kernel_size=3,
                      padding=1,
                      bias=False,
                      groups=groups),
            nn.BatchNorm2d(width, **batchnorm_args),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
            nn.Conv2d(width, cout, 1, padding=0)
        ])

        layers = nn.Sequential(*layers)
        layers.apply(weights_init)

        return layers

    def strided_residual_block(width, groups, cin, cout):
        width = width * groups
        layers = nn.Sequential(
            nn.Conv2d(cin, width, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(width, **batchnorm_args), nn.ReLU(),
            nn.Conv2d(width,
                      width,
                      kernel_size=3,
                      stride=2,
                      padding=1,
                      bias=False,
                      groups=groups), nn.BatchNorm2d(width, **batchnorm_args),
            nn.ReLU(inplace=True), nn.Conv2d(width, cout, 1, padding=0))

        layers.apply(weights_init)

        return layers

    def fc_constr(c_in, c_out):
        net = [
            nn.Linear(c_in, fc_width),
            nn.ReLU(),
            nn.Dropout(p=dropouts_fc),
            nn.Linear(fc_width, c_out)
        ]

        net = nn.Sequential(*net)
        net.apply(weights_init)
        return net

    nodes = [Ff.InputNode(*ndim_input, name='input')]
    channels = classifier.input_channels

    if classifier.dataset == 'MNIST':
        nodes.append(
            Ff.Node(nodes[-1].out0, Fm.Reshape,
                    {'target_dim': (1, *classifier.dims)}))
        nodes.append(
            Ff.Node(nodes[-1].out0, Fm.HaarDownsampling, {'rebalance': 1.}))
        channels *= 4

    for i, (width,
            n_blocks) in enumerate(zip(conv_widths, n_coupling_blocks_conv)):
        if classifier.dataset == 'MNIST' and i == 0:
            continue

        drop = dropouts[i]
        conv_constr = partial(basic_residual_block, width, groups, drop, True)
        conv_strided = partial(strided_residual_block, width * 2, groups)
        conv_lowres = partial(basic_residual_block, width * 2, groups, drop,
                              False)

        if i == 0:
            conv_first = partial(basic_residual_block, width, groups, drop,
                                 False)
        else:
            conv_first = conv_constr

        nodes.append(
            Ff.Node(nodes[-1],
                    AIO_Block,
                    dict(coupling_args, subnet_constructor=conv_first),
                    name=f'CONV_{i}_0'))
        for k in range(1, n_blocks):
            nodes.append(
                Ff.Node(nodes[-1],
                        AIO_Block,
                        dict(coupling_args, subnet_constructor=conv_constr),
                        name=f'CONV_{i}_{k}'))

        if i < len(conv_widths) - 1:
            nodes.append(
                Ff.Node(nodes[-1],
                        DownsampleCouplingBlock, {
                            'subnet_constructor_low_res': conv_lowres,
                            'subnet_constructor_strided': conv_strided,
                            'clamp': clamp
                        },
                        name=f'DOWN_{i}'))
            channels *= 4

    if use_dct:
        nodes.append(
            Ff.Node(nodes[-1].out0,
                    DCTPooling2d, {'rebalance': 0.5},
                    name='DCT'))
    else:
        nodes.append(Ff.Node(nodes[-1].out0, Fm.Flatten, {}, name='Flatten'))

    for k in range(n_coupling_blocks_fc):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.PermuteRandom, {'seed': k},
                    name=f'PERM_FC_{k}'))
        nodes.append(
            Ff.Node(nodes[-1].out0,
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor': fc_constr,
                        'clamp': 2.0
                    },
                    name=f'FC_{k}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    return Ff.ReversibleGraphNet(nodes, verbose=verbose)
Beispiel #26
0
def inn_model(img_dims=[1, 28, 28]):
    """
    Return INN model for MNIST.

    :param img_dims: size of the model input images. Default: Size of MNIST images
    :return: INN model
    """

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node(
        [inp.out0],
        re.reshape_layer,
        {'target_dim': (img_dims[0] * 4, img_dims[1] // 2, img_dims[2] // 2)},
        name='r1')

    conv1 = fr.Node([r1.out0],
                    la.glow_coupling_layer, {
                        'F_class': fu.F_conv,
                        'F_args': {
                            'channels_hidden': 128
                        },
                        'clamp': 1
                    },
                    name='conv1')

    conv2 = fr.Node([conv1.out0],
                    la.glow_coupling_layer, {
                        'F_class': fu.F_conv,
                        'F_args': {
                            'channels_hidden': 128
                        },
                        'clamp': 1
                    },
                    name='conv2')

    conv3 = fr.Node([conv2.out0],
                    la.glow_coupling_layer, {
                        'F_class': fu.F_conv,
                        'F_args': {
                            'channels_hidden': 128
                        },
                        'clamp': 1
                    },
                    name='conv3')

    r2 = fr.Node([conv3.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r2')

    fc = fr.Node([r2.out0],
                 la.rev_multiplicative_layer, {
                     'F_class': fu.F_fully_connected,
                     'clamp': 1
                 },
                 name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0], img_dims[1], img_dims[2])},
                 name='r3')

    outp = fr.OutputNode([r3.out0], name='output')

    nodes = [inp, outp, conv1, conv2, conv3, r1, r2, r3, fc]

    model = fr.ReversibleGraphNet(nodes, 0, 1)

    return model
def celeba_inn_small(mask_size=[156, 128]):
    """
    Return CelebA INN autoencoder for comparison with classical autoencoder (same number of parameters).

    :param latent_dim: dimension of the latent space
    :param mask_size: size of the input. Default: Size of CelebA images
    :param batch_norm: use batch norm for the F_conv modules
    :return: CelebA INN autoencoder
    """

    img_dims = [3, mask_size[0], mask_size[1]]

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1')

    conv11 = fr.Node([r1.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv11')

    conv12 = fr.Node([conv11.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv12')

    conv13 = fr.Node([conv12.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv13')

    r2 = fr.Node([conv13.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r2')

    fc = fr.Node(
        [r2.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_small_connected,
            'F_args': {
                'internal_size': 5000
            },
            'clamp': 1
        },
        name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer, {'target_dim': (12, 78, 64)},
                 name='r3')

    r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4')

    outp = fr.OutputNode([r4.out0], name='output')

    nodes = [inp, outp, conv11, conv12, conv13, fc, r1, r2, r3, r4]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Beispiel #28
0
    def build_inn(self):

        def sub_conv(ch_hidden, kernel):
            pad = kernel // 2
            return lambda ch_in, ch_out: nn.Sequential(
                                            nn.Conv2d(ch_in, ch_hidden, kernel, padding=pad),
                                            nn.ReLU(),
                                            nn.Conv2d(ch_hidden, ch_out, kernel, padding=pad))

        def sub_fc(ch_hidden):
            return lambda ch_in, ch_out: nn.Sequential(
                                            nn.Linear(ch_in, ch_hidden),
                                            nn.ReLU(),
                                            nn.Linear(ch_hidden, ch_out))

        nodes = [Ff.InputNode(2, 64, 64)]
        # outputs of the cond. net at different resolution levels
        conditions = [Ff.ConditionNode(64, 64, 64),
                      Ff.ConditionNode(128, 32, 32),
                      Ff.ConditionNode(128, 16, 16),
                      Ff.ConditionNode(512)]

        split_nodes = []

        subnet = sub_conv(32, 3)
        for k in range(2):
            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':1.0},
                                 conditions=conditions[0]))

        nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5}))

        for k in range(4):
            subnet = sub_conv(64, 3 if k%2 else 1)

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':1.0},
                                 conditions=conditions[1]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        #split off 6/8 ch
        nodes.append(Ff.Node(nodes[-1], Fm.Split1D,
                             {'split_size_or_sections':[2,6], 'dim':0}))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))

        nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5}))

        for k in range(4):
            subnet = sub_conv(128, 3 if k%2 else 1)

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':0.6},
                                 conditions=conditions[2]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        #split off 4/8 ch
        nodes.append(Ff.Node(nodes[-1], Fm.Split1D,
                             {'split_size_or_sections':[4,4], 'dim':0}))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten'))

        # fully_connected part
        subnet = sub_fc(512)
        for k in range(4):
            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':0.6},
                                 conditions=conditions[3]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        # concat everything
        nodes.append(Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0],
                             Fm.Concat1d, {'dim':0}))
        nodes.append(Ff.OutputNode(nodes[-1]))

        return Ff.ReversibleGraphNet(nodes + split_nodes + conditions, verbose=False)
Beispiel #29
0
def artset_inn_model(img_dims=[3, 224, 224]):
    """
    Return INN model for Painter by Numbers artset.

    :param img_dims: size of the model input images. Default: Size of MNIST images
    :return: INN model
    """

    inp = fr.InputNode(*img_dims, name='input')

    r1 = fr.Node(
        [inp.out0],
        re.reshape_layer,
        {'target_dim': (img_dims[0] * 4, img_dims[1] // 2, img_dims[2] // 2)},
        name='r1')

    conv11 = fr.Node([r1.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 256
                         },
                         'clamp': 1
                     },
                     name='conv11')

    conv12 = fr.Node([conv11.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 256
                         },
                         'clamp': 1
                     },
                     name='conv12')

    conv13 = fr.Node([conv12.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 256
                         },
                         'clamp': 1
                     },
                     name='conv13')

    conv21 = fr.Node([conv13.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv21')

    conv22 = fr.Node([conv21.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv22')

    conv23 = fr.Node([conv22.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 128
                         },
                         'clamp': 1
                     },
                     name='conv23')

    conv31 = fr.Node([conv23.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 64
                         },
                         'clamp': 1
                     },
                     name='conv31')

    conv32 = fr.Node([conv31.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 64
                         },
                         'clamp': 1
                     },
                     name='conv32')

    conv33 = fr.Node([conv32.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 64
                         },
                         'clamp': 1
                     },
                     name='conv33')

    conv41 = fr.Node([conv33.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 32
                         },
                         'clamp': 1
                     },
                     name='conv41')

    conv42 = fr.Node([conv41.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 32
                         },
                         'clamp': 1
                     },
                     name='conv42')

    conv43 = fr.Node([conv42.out0],
                     la.glow_coupling_layer, {
                         'F_class': fu.F_conv,
                         'F_args': {
                             'channels_hidden': 32
                         },
                         'clamp': 1
                     },
                     name='conv43')

    r2 = fr.Node([conv43.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )},
                 name='r2')

    fc = fr.Node(
        [r2.out0],
        la.rev_multiplicative_layer, {
            'F_class': fu.F_small_connected,
            'F_args': {
                'internal_size': 128
            },
            'clamp': 1
        },
        name='fc')

    r3 = fr.Node([fc.out0],
                 re.reshape_layer,
                 {'target_dim': (img_dims[0], img_dims[1], img_dims[2])},
                 name='r3')

    outp = fr.OutputNode([r3.out0], name='output')

    nodes = [
        inp, outp, conv11, conv12, conv13, conv21, conv22, conv23, conv31,
        conv32, conv33, conv41, conv42, conv43, fc, r1, r2, r3
    ]

    coder = fr.ReversibleGraphNet(nodes, 0, 1)

    return coder
Beispiel #30
0
    def _residual_flow(self,
                       input,
                       planes,
                       blocks,
                       stride,
                       dilation,
                       create_skip_connection=False):

        self.block += 1

        nodes = []
        split_nodes = []

        if stride == 2:
            bottleneck_strided = self._bottleneck_residual_block(
                planes,
                2,
                groups=self.groups,
                base_width=self.base_width,
                dilation=dilation)
            bottleneck = self._bottleneck_residual_block(
                planes,
                1,
                groups=self.groups,
                base_width=self.base_width,
                dilation=dilation)
            middle_flow = Ff.Node(
                input,
                self.downsampling_layer,
                dict(self.down_coupling_args,
                     subnet_constructor_low_res=bottleneck,
                     subnet_constructor_strided=bottleneck_strided),
                name='Strided entry_flow')
            self.channels *= 4
            nodes.append(middle_flow)
        else:
            bottleneck = self._bottleneck_residual_block(
                planes,
                1,
                groups=self.groups,
                base_width=self.base_width,
                dilation=dilation)
            middle_flow = Ff.Node(input,
                                  self.coupling_layer,
                                  dict(self.coupling_args,
                                       subnet_constructor=bottleneck),
                                  name=f'middle_flow %s_1 ' % self.block)
            nodes.append(middle_flow)

        low_level_node = None
        for i in range(1, blocks):

            if create_skip_connection and i == blocks - 1:

                n_downstream_channels = int(self.channels * (3 / 4))

                n_skip_channels = int(self.channels - n_downstream_channels)
                self.channels = n_downstream_channels

                split = Ff.Node(middle_flow,
                                Fm.Split1D, {
                                    'dim':
                                    0,
                                    'split_size_or_sections':
                                    (n_downstream_channels, n_skip_channels)
                                },
                                name=f'middle_flow split %s_%s ' % (self.block,
                                                                    (1 + i)))
                split_nodes.append(split)

                bottleneck = self._bottleneck_residual_block(
                    planes, 1, self.groups, self.base_width, dilation)
                middle_flow = Ff.Node(split.out0,
                                      self.coupling_layer,
                                      dict(self.coupling_args,
                                           subnet_constructor=bottleneck),
                                      name=f'middle_flow %s_%s ' % (self.block,
                                                                    (1 + i)))
                nodes.append(middle_flow)

                low_level_node = split.out1

            else:
                bottleneck = self._bottleneck_residual_block(
                    planes, 1, self.groups, self.base_width, dilation)
                middle_flow = Ff.Node(middle_flow,
                                      self.coupling_layer,
                                      dict(self.coupling_args,
                                           subnet_constructor=bottleneck),
                                      name=f'middle_flow %s_%s ' % (self.block,
                                                                    (1 + i)))
                nodes.append(middle_flow)

        return nodes, split_nodes, low_level_node