コード例 #1
0
    def build_inn(self):

        def fc_constr():
            return lambda ch_in, ch_out: nn.Sequential(nn.Linear(ch_in, c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout),
                         nn.Linear(c.internal_width,  c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout),
                         nn.Linear(c.internal_width,  ch_out))

        nodes = [Ff.InputNode(c.x_dim)]

        # outputs of the cond. net at different resolution levels
        conditions = []

        for i in range(c.n_blocks):
            conditions.append(Ff.ConditionNode(c.y_dim_features))

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':fc_constr(), 'clamp':c.exponent_clamping},
                                 conditions=conditions[i]))
            #nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':i}, name=f'PERM_FC_{i}'))
            #nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}, name=f'ActN{i}'))
            

        nodes.append(Ff.OutputNode(nodes[-1]))

        return Ff.ReversibleGraphNet(nodes + conditions, verbose=False)
コード例 #2
0
def generate_rcINN_old():
    cond = Ff.ConditionNode(config.n_cond_features, name='condition')
    nodes = [Ff.InputNode(config.n_x_features, name='input')]

    for k in range(config.n_blocks):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor': subnet_func,
                        'clamp': config.clamp
                    },
                    conditions=cond,
                    name=F'coupling_{k}'))
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.PermuteRandom, {'seed': k},
                    name=F'permute_{k}'))

    nodes.append(Ff.OutputNode(nodes[-1], name='output'))
    model = Ff.ReversibleGraphNet(nodes + [cond])
    model.to(device)

    params_trainable = list(
        filter(lambda p: p.requires_grad, model.parameters()))
    for p in params_trainable:
        p.data = config.init_scale * torch.randn_like(p).to(device)

    gamma = config.gamma
    optim = torch.optim.AdamW(params_trainable, lr=config.lr)
    weight_scheduler = torch.optim.lr_scheduler.StepLR(optim,
                                                       step_size=1,
                                                       gamma=gamma)

    return model, optim, weight_scheduler
コード例 #3
0
ファイル: test_conditioning.py プロジェクト: tbung/FrEIA
    def __init__(self, *args):
        super().__init__(*args)

        self.batch_size = 32
        self.inv_tol = 1e-4
        torch.manual_seed(self.batch_size)

        self.inp_size = (3, 10, 10)
        self.c1_size = (1, 10, 10)
        self.c2_size = (50, )
        self.c3_size = (20, )

        self.x = torch.randn(self.batch_size, *self.inp_size)
        self.c1 = torch.randn(self.batch_size, *self.c1_size)
        self.c2 = torch.randn(self.batch_size, *self.c2_size)
        self.c3 = torch.randn(self.batch_size, *self.c3_size)

        # this is only used for the cuda variant of the tests.
        # if true, all tests are skipped.
        self.skip_all = False

        inp = Ff.InputNode(*self.inp_size, name='input')
        c1 = Ff.ConditionNode(*self.c1_size, name='c1')
        c2 = Ff.ConditionNode(*self.c2_size, name='c2')
        c3 = Ff.ConditionNode(*self.c3_size, name='c3')

        conv = Ff.Node(inp,
                       Fm.RNVPCouplingBlock, {
                           'subnet_constructor': F_conv,
                           'clamp': 1.0
                       },
                       conditions=c1,
                       name='conv::c1')
        flatten = Ff.Node(conv, Fm.Flatten, {}, name='flatten')

        linear = Ff.Node(flatten,
                         Fm.RNVPCouplingBlock, {
                             'subnet_constructor': F_fully_connected,
                             'clamp': 1.0
                         },
                         conditions=[c2, c3],
                         name='linear::c2|c3')

        outp = Ff.OutputNode(linear, name='output')
        self.test_net = Ff.GraphINN(
            [inp, c1, conv, flatten, c2, c3, linear, outp])
コード例 #4
0
    def __init__(self, *args):
        super().__init__(*args)

        self.inp_size = (3, 10, 10)
        self.cond_size = (1, 10, 10)

        inp = Ff.InputNode(*self.inp_size, name='input')
        cond = Ff.ConditionNode(*self.cond_size, name='cond')
        split = Ff.Node(inp, Fm.Split, {'section_sizes': [1,2], 'dim': 0}, name='split1')

        flatten1 = Ff.Node(split.out0, Fm.Flatten, {}, name='flatten1')
        perm = Ff.Node(flatten1, Fm.PermuteRandom, {'seed': 123}, name='perm')
        unflatten1 = Ff.Node(perm, Fm.Reshape, {'output_dims': (1, 10, 10)}, name='unflatten1')

        conv = Ff.Node(split.out1,
                       Fm.RNVPCouplingBlock,
                       {'subnet_constructor': F_conv, 'clamp': 1.0},
                       conditions=cond,
                       name='conv')

        flatten2 = Ff.Node(conv, Fm.Flatten, {}, name='flatten2')

        linear = Ff.Node(flatten2,
                         Fm.RNVPCouplingBlock,
                         {'subnet_constructor': F_fully_connected, 'clamp': 1.0},
                         name='linear')

        unflatten2 = Ff.Node(linear, Fm.Reshape, {'output_dims': (2, 10, 10)}, name='unflatten2')
        concat = Ff.Node([unflatten1.out0, unflatten2.out0], Fm.Concat, {'dim': 0}, name='concat')
        haar = Ff.Node(concat, Fm.HaarDownsampling, {}, name='haar')
        out = Ff.OutputNode(haar, name='output')

        self.test_net = Ff.GraphINN([inp, cond, split, flatten1, perm, unflatten1, conv,
                                     flatten2, linear, unflatten2, concat, haar, out])

        # this is only used for the cuda variant of the tests.
        # if true, all tests are skipped.
        self.skip_all = False

        self.batch_size = 32
        self.inv_tol = 1e-4
        torch.manual_seed(self.batch_size)

        self.x = torch.randn(self.batch_size, *self.inp_size)
        self.cond = torch.randn(self.batch_size, *self.cond_size)
コード例 #5
0
ファイル: model.py プロジェクト: yqGANs/FrEIA
    def build_inn(self):
        def subnet(ch_in, ch_out):
            return nn.Sequential(nn.Linear(ch_in, 512), nn.LeakyReLU(),
                                 nn.Linear(512, ch_out))

        cond = Ff.ConditionNode(10)
        nodes = [Ff.InputNode(1, 28, 28)]
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}))

        for k in range(16):
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
            nodes.append(
                Ff.Node(nodes[-1],
                        Fm.GLOWCouplingBlock, {
                            'subnet_constructor': subnet,
                            'clamp': 0.8
                        },
                        conditions=cond))

        return Ff.ReversibleGraphNet(nodes +
                                     [cond, Ff.OutputNode(nodes[-1])],
                                     verbose=False)
コード例 #6
0
def baseline_color(m_params):
    if not m_params['permute'] in [
            'soft', 'random', 'none', 'false', None, False
    ]:
        raise (RuntimeError(
            "Erros in model params: No 'permute'' selected or not understood.")
               )
    if not m_params['act_norm'] in [
            'learnednorm', 'movingavg', 'none', 'false', None, False
    ]:
        raise (RuntimeError(
            "Erros in model params: No 'act_norm' selected or not understood.")
               )

    cond = CondNet(m_params)

    nodes = [Ff.InputNode(2, 64, 64)]
    # outputs of the cond. net at different resolution levels
    conditions = [
        Ff.ConditionNode(64, 64, 64),
        Ff.ConditionNode(128, 32, 32),
        Ff.ConditionNode(128, 16, 16),
        Ff.ConditionNode(512)
    ]

    split_nodes = []
    for k in range(m_params['blocks_per_group'][0]):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor':
                        sub_conv(64,
                                 m_params['kernel_size_per_group'][0],
                                 m_params['hidden_layers_per_group'][0],
                                 batchnorm=m_params.get("bn", True)),
                        'clamp':
                        m_params['clamping_per_group'][0]
                    },
                    conditions=conditions[0],
                    name=F'block_{k}'))
        if m_params['permute'] == 'random':
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
        elif m_params['permute'] == 'soft':
            nodes.append(
                Ff.Node([nodes[-1].out0], Fm.conv_1x1,
                        {'M': random_orthog(2)}))
        if m_params['act_norm'] == 'learnednorm':
            nodes.append(
                Ff.Node(nodes[-1], LearnedActNorm, {
                    'M': torch.randn(1),
                    "b": torch.randn(1)
                }))
        elif m_params['act_norm'] == 'movingavg':
            nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}))

    nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5}))

    for k in range(m_params['blocks_per_group'][1]):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor':
                        sub_conv(128,
                                 m_params['kernel_size_per_group'][1],
                                 m_params['hidden_layers_per_group'][1],
                                 batchnorm=m_params.get("bn", True)),
                        'clamp':
                        m_params['clamping_per_group'][1],
                    },
                    conditions=conditions[1],
                    name=F'block_{k + 2}'))
        if m_params['permute'] == 'random':
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
        elif m_params['permute'] == 'soft':
            nodes.append(
                Ff.Node([nodes[-1].out0], Fm.conv_1x1,
                        {'M': random_orthog(8)}))
        if m_params['act_norm'] == 'learnednorm':
            nodes.append(
                Ff.Node(nodes[-1], LearnedActNorm, {
                    'M': torch.randn(1),
                    "b": torch.randn(1)
                }))
        elif m_params['act_norm'] == 'movingavg':
            nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}))

    # split off 8/12 ch
    if m_params.get("split", True):
        nodes.append(
            Ff.Node(nodes[-1], Fm.Split1D, {
                'split_size_or_sections': [6, 2],
                'dim': 0
            }))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))

    nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5}))

    for k in range(m_params['blocks_per_group'][2]):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'subnet_constructor':
                        sub_conv(256,
                                 m_params['kernel_size_per_group'][2][k],
                                 m_params['hidden_layers_per_group'][2],
                                 batchnorm=m_params.get("bn", True)),
                        'clamp':
                        m_params['clamping_per_group'][2],
                    },
                    conditions=conditions[2],
                    name=F'block_{k + 6}'))
        if m_params['permute'] == 'random':
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
        elif m_params['permute'] == 'soft':
            nodes.append(
                Ff.Node([nodes[-1].out0], Fm.conv_1x1, {
                    'M':
                    random_orthog(24 if m_params.get("split", True) else 32)
                }))
        if m_params['act_norm'] == 'learnednorm':
            nodes.append(
                Ff.Node(nodes[-1], LearnedActNorm, {
                    'M': torch.randn(1),
                    "b": torch.randn(1)
                }))
        elif m_params['act_norm'] == 'movingavg':
            nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}))

    # split off 8/16 ch
    if m_params.get("split", True):
        nodes.append(
            Ff.Node(nodes[-1], Fm.Split1D, {
                'split_size_or_sections': [12, 12],
                'dim': 0
            }))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))
    nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten'))

    # fully_connected part
    for k in range(m_params['blocks_per_group'][3]):
        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.GLOWCouplingBlock, {
                        'clamp':
                        m_params['clamping_per_group'][3],
                        'subnet_constructor':
                        sub_fc(m_params['fc_size'],
                               m_params['hidden_layers_per_group'][3],
                               dropout=m_params['dropout_fc'],
                               batchnorm=m_params.get("bn", True))
                    },
                    conditions=conditions[3],
                    name=F'block_{k + 10}'))

        #nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}))
        #nodes.append(Ff.Node(nodes[-1], LearnedActNorm , {'M': torch.randn(1), "b": torch.randn(1)}))
    # concat everything
    if m_params.get("split", True):
        nodes.append(
            Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0],
                    Fm.Concat1d, {'dim': 0}))
    nodes.append(Ff.OutputNode(nodes[-1]))
    inn = SketchINN(cond)
    inn.build_inn(nodes, split_nodes, conditions)
    return inn
コード例 #7
0
    def build_inn(self):

        def sub_conv(ch_hidden, kernel):
            pad = kernel // 2
            return lambda ch_in, ch_out: nn.Sequential(
                                            nn.Conv2d(ch_in, ch_hidden, kernel, padding=pad),
                                            nn.ReLU(),
                                            nn.Conv2d(ch_hidden, ch_out, kernel, padding=pad))

        def sub_fc(ch_hidden):
            return lambda ch_in, ch_out: nn.Sequential(
                                            nn.Linear(ch_in, ch_hidden),
                                            nn.ReLU(),
                                            nn.Linear(ch_hidden, ch_out))

        nodes = [Ff.InputNode(2, 64, 64)]
        # outputs of the cond. net at different resolution levels
        conditions = [Ff.ConditionNode(64, 64, 64),
                      Ff.ConditionNode(128, 32, 32),
                      Ff.ConditionNode(128, 16, 16),
                      Ff.ConditionNode(512)]

        split_nodes = []

        subnet = sub_conv(32, 3)
        for k in range(2):
            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':1.0},
                                 conditions=conditions[0]))

        nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5}))

        for k in range(4):
            subnet = sub_conv(64, 3 if k%2 else 1)

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':1.0},
                                 conditions=conditions[1]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        #split off 6/8 ch
        nodes.append(Ff.Node(nodes[-1], Fm.Split1D,
                             {'split_size_or_sections':[2,6], 'dim':0}))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))

        nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5}))

        for k in range(4):
            subnet = sub_conv(128, 3 if k%2 else 1)

            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':0.6},
                                 conditions=conditions[2]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        #split off 4/8 ch
        nodes.append(Ff.Node(nodes[-1], Fm.Split1D,
                             {'split_size_or_sections':[4,4], 'dim':0}))
        split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {}))
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten'))

        # fully_connected part
        subnet = sub_fc(512)
        for k in range(4):
            nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock,
                                 {'subnet_constructor':subnet, 'clamp':0.6},
                                 conditions=conditions[3]))
            nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k}))

        # concat everything
        nodes.append(Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0],
                             Fm.Concat1d, {'dim':0}))
        nodes.append(Ff.OutputNode(nodes[-1]))

        return Ff.ReversibleGraphNet(nodes + split_nodes + conditions, verbose=False)
コード例 #8
0
    def build_inn(self):

        nodes = [Ff.InputNode(3, 64, 64, name='inp')]
        cond_node = Ff.ConditionNode(40, name='attr_cond')
        split_nodes = []

        # ndim_x = 3 * 64 * 64

        for i in range(2):
            nodes.append(
                Ff.Node([nodes[-1].out0],
                        permute_layer, {'seed': i},
                        name=F'permute_0_{i}'))
            nodes.append(
                Ff.Node(
                    [nodes[-1].out0],
                    glow_coupling_layer, {
                        'clamp': 1.,
                        'F_class': F_fully_conv,
                        'F_args': {
                            'kernel_size': 3,
                            'channels_hidden': 256,
                            'leaky_slope': 0.2
                        }
                    },
                    name=F'conv_0_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.HaarDownsampling, {'rebalance': 0.5},
                    name='haar_1'))

        for i in range(4):
            kernel_size = 1 if i % 2 == 0 else 3
            nodes.append(
                Ff.Node([nodes[-1].out0],
                        permute_layer, {'seed': i},
                        name=F'permute_1_{i}'))
            nodes.append(
                Ff.Node(
                    [nodes[-1].out0],
                    glow_coupling_layer, {
                        'clamp': 1.,
                        'F_class': F_fully_conv_shallow,
                        'F_args': {
                            'kernel_size': kernel_size,
                            'channels_hidden': 64,
                            'leaky_slope': 0.2
                        }
                    },
                    name=F'conv_1_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.Split1D, {
                        'split_size_or_sections': (8, 4),
                        'dim': 0
                    },
                    name='split_1'))
        split_nodes.append(
            Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_1'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.HaarDownsampling, {'rebalance': 0.5},
                    name='haar_2'))

        for i in range(8):
            kernel_size = 1 if i % 2 == 0 else 3

            nodes.append(
                Ff.Node(nodes[-1],
                        permute_layer, {'seed': i},
                        name=F'permute_2_{i}'))
            nodes.append(
                Ff.Node(nodes[-1],
                        glow_coupling_layer, {
                            'clamp': 1.,
                            'F_class': F_fully_conv_shallow,
                            'F_args': {
                                'kernel_size': kernel_size,
                                'channels_hidden': 64,
                                'leaky_slope': 0.2
                            }
                        },
                        name=F'conv_2_{i}'))

        nodes.append(
            Ff.Node(nodes[-1],
                    Fm.Split1D, {
                        'split_size_or_sections': (16, 16),
                        'dim': 0
                    },
                    name='split_2'))
        split_nodes.append(
            Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_2'))
        nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten_3'))

        for i in range(2):
            nodes.append(
                Ff.Node(nodes[-1],
                        permute_layer, {'seed': i},
                        name=F'permute_3_{i}'))
            nodes.append(
                Ff.Node(nodes[-1],
                        glow_coupling_layer, {
                            'clamp': 0.8,
                            'F_class': F_fully_shallow,
                            'F_args': {
                                'internal_size': 256,
                                'dropout': 0.2
                            }
                        },
                        conditions=[cond_node],
                        name=F'fc_{i}'))

        nodes.append(
            Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0],
                    Fm.Concat1d, {'dim': 0},
                    name='concat'))
        nodes.append(Ff.OutputNode(nodes[-1], name='out'))

        conv_inn = Ff.ReversibleGraphNet(nodes + split_nodes + [cond_node],
                                         verbose=True)

        return conv_inn