def build_inn(self): def fc_constr(): return lambda ch_in, ch_out: nn.Sequential(nn.Linear(ch_in, c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout), nn.Linear(c.internal_width, c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout), nn.Linear(c.internal_width, ch_out)) nodes = [Ff.InputNode(c.x_dim)] # outputs of the cond. net at different resolution levels conditions = [] for i in range(c.n_blocks): conditions.append(Ff.ConditionNode(c.y_dim_features)) nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':fc_constr(), 'clamp':c.exponent_clamping}, conditions=conditions[i])) #nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':i}, name=f'PERM_FC_{i}')) #nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}, name=f'ActN{i}')) nodes.append(Ff.OutputNode(nodes[-1])) return Ff.ReversibleGraphNet(nodes + conditions, verbose=False)
def generate_rcINN_old(): cond = Ff.ConditionNode(config.n_cond_features, name='condition') nodes = [Ff.InputNode(config.n_x_features, name='input')] for k in range(config.n_blocks): nodes.append( Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, { 'subnet_constructor': subnet_func, 'clamp': config.clamp }, conditions=cond, name=F'coupling_{k}')) nodes.append( Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}, name=F'permute_{k}')) nodes.append(Ff.OutputNode(nodes[-1], name='output')) model = Ff.ReversibleGraphNet(nodes + [cond]) model.to(device) params_trainable = list( filter(lambda p: p.requires_grad, model.parameters())) for p in params_trainable: p.data = config.init_scale * torch.randn_like(p).to(device) gamma = config.gamma optim = torch.optim.AdamW(params_trainable, lr=config.lr) weight_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=gamma) return model, optim, weight_scheduler
def __init__(self, *args): super().__init__(*args) self.batch_size = 32 self.inv_tol = 1e-4 torch.manual_seed(self.batch_size) self.inp_size = (3, 10, 10) self.c1_size = (1, 10, 10) self.c2_size = (50, ) self.c3_size = (20, ) self.x = torch.randn(self.batch_size, *self.inp_size) self.c1 = torch.randn(self.batch_size, *self.c1_size) self.c2 = torch.randn(self.batch_size, *self.c2_size) self.c3 = torch.randn(self.batch_size, *self.c3_size) # this is only used for the cuda variant of the tests. # if true, all tests are skipped. self.skip_all = False inp = Ff.InputNode(*self.inp_size, name='input') c1 = Ff.ConditionNode(*self.c1_size, name='c1') c2 = Ff.ConditionNode(*self.c2_size, name='c2') c3 = Ff.ConditionNode(*self.c3_size, name='c3') conv = Ff.Node(inp, Fm.RNVPCouplingBlock, { 'subnet_constructor': F_conv, 'clamp': 1.0 }, conditions=c1, name='conv::c1') flatten = Ff.Node(conv, Fm.Flatten, {}, name='flatten') linear = Ff.Node(flatten, Fm.RNVPCouplingBlock, { 'subnet_constructor': F_fully_connected, 'clamp': 1.0 }, conditions=[c2, c3], name='linear::c2|c3') outp = Ff.OutputNode(linear, name='output') self.test_net = Ff.GraphINN( [inp, c1, conv, flatten, c2, c3, linear, outp])
def __init__(self, *args): super().__init__(*args) self.inp_size = (3, 10, 10) self.cond_size = (1, 10, 10) inp = Ff.InputNode(*self.inp_size, name='input') cond = Ff.ConditionNode(*self.cond_size, name='cond') split = Ff.Node(inp, Fm.Split, {'section_sizes': [1,2], 'dim': 0}, name='split1') flatten1 = Ff.Node(split.out0, Fm.Flatten, {}, name='flatten1') perm = Ff.Node(flatten1, Fm.PermuteRandom, {'seed': 123}, name='perm') unflatten1 = Ff.Node(perm, Fm.Reshape, {'output_dims': (1, 10, 10)}, name='unflatten1') conv = Ff.Node(split.out1, Fm.RNVPCouplingBlock, {'subnet_constructor': F_conv, 'clamp': 1.0}, conditions=cond, name='conv') flatten2 = Ff.Node(conv, Fm.Flatten, {}, name='flatten2') linear = Ff.Node(flatten2, Fm.RNVPCouplingBlock, {'subnet_constructor': F_fully_connected, 'clamp': 1.0}, name='linear') unflatten2 = Ff.Node(linear, Fm.Reshape, {'output_dims': (2, 10, 10)}, name='unflatten2') concat = Ff.Node([unflatten1.out0, unflatten2.out0], Fm.Concat, {'dim': 0}, name='concat') haar = Ff.Node(concat, Fm.HaarDownsampling, {}, name='haar') out = Ff.OutputNode(haar, name='output') self.test_net = Ff.GraphINN([inp, cond, split, flatten1, perm, unflatten1, conv, flatten2, linear, unflatten2, concat, haar, out]) # this is only used for the cuda variant of the tests. # if true, all tests are skipped. self.skip_all = False self.batch_size = 32 self.inv_tol = 1e-4 torch.manual_seed(self.batch_size) self.x = torch.randn(self.batch_size, *self.inp_size) self.cond = torch.randn(self.batch_size, *self.cond_size)
def build_inn(self): def subnet(ch_in, ch_out): return nn.Sequential(nn.Linear(ch_in, 512), nn.LeakyReLU(), nn.Linear(512, ch_out)) cond = Ff.ConditionNode(10) nodes = [Ff.InputNode(1, 28, 28)] nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {})) for k in range(16): nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k})) nodes.append( Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, { 'subnet_constructor': subnet, 'clamp': 0.8 }, conditions=cond)) return Ff.ReversibleGraphNet(nodes + [cond, Ff.OutputNode(nodes[-1])], verbose=False)
def baseline_color(m_params): if not m_params['permute'] in [ 'soft', 'random', 'none', 'false', None, False ]: raise (RuntimeError( "Erros in model params: No 'permute'' selected or not understood.") ) if not m_params['act_norm'] in [ 'learnednorm', 'movingavg', 'none', 'false', None, False ]: raise (RuntimeError( "Erros in model params: No 'act_norm' selected or not understood.") ) cond = CondNet(m_params) nodes = [Ff.InputNode(2, 64, 64)] # outputs of the cond. net at different resolution levels conditions = [ Ff.ConditionNode(64, 64, 64), Ff.ConditionNode(128, 32, 32), Ff.ConditionNode(128, 16, 16), Ff.ConditionNode(512) ] split_nodes = [] for k in range(m_params['blocks_per_group'][0]): nodes.append( Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, { 'subnet_constructor': sub_conv(64, m_params['kernel_size_per_group'][0], m_params['hidden_layers_per_group'][0], batchnorm=m_params.get("bn", True)), 'clamp': m_params['clamping_per_group'][0] }, conditions=conditions[0], name=F'block_{k}')) if m_params['permute'] == 'random': nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k})) elif m_params['permute'] == 'soft': nodes.append( Ff.Node([nodes[-1].out0], Fm.conv_1x1, {'M': random_orthog(2)})) if m_params['act_norm'] == 'learnednorm': nodes.append( Ff.Node(nodes[-1], LearnedActNorm, { 'M': torch.randn(1), "b": torch.randn(1) })) elif m_params['act_norm'] == 'movingavg': nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {})) nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5})) for k in range(m_params['blocks_per_group'][1]): nodes.append( Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, { 'subnet_constructor': sub_conv(128, m_params['kernel_size_per_group'][1], m_params['hidden_layers_per_group'][1], batchnorm=m_params.get("bn", True)), 'clamp': m_params['clamping_per_group'][1], }, conditions=conditions[1], name=F'block_{k + 2}')) if m_params['permute'] == 'random': nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k})) elif m_params['permute'] == 'soft': nodes.append( Ff.Node([nodes[-1].out0], Fm.conv_1x1, {'M': random_orthog(8)})) if m_params['act_norm'] == 'learnednorm': nodes.append( Ff.Node(nodes[-1], LearnedActNorm, { 'M': torch.randn(1), "b": torch.randn(1) })) elif m_params['act_norm'] == 'movingavg': nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {})) # split off 8/12 ch if m_params.get("split", True): nodes.append( Ff.Node(nodes[-1], Fm.Split1D, { 'split_size_or_sections': [6, 2], 'dim': 0 })) split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {})) nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5})) for k in range(m_params['blocks_per_group'][2]): nodes.append( Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, { 'subnet_constructor': sub_conv(256, m_params['kernel_size_per_group'][2][k], m_params['hidden_layers_per_group'][2], batchnorm=m_params.get("bn", True)), 'clamp': m_params['clamping_per_group'][2], }, conditions=conditions[2], name=F'block_{k + 6}')) if m_params['permute'] == 'random': nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k})) elif m_params['permute'] == 'soft': nodes.append( Ff.Node([nodes[-1].out0], Fm.conv_1x1, { 'M': random_orthog(24 if m_params.get("split", True) else 32) })) if m_params['act_norm'] == 'learnednorm': nodes.append( Ff.Node(nodes[-1], LearnedActNorm, { 'M': torch.randn(1), "b": torch.randn(1) })) elif m_params['act_norm'] == 'movingavg': nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {})) # split off 8/16 ch if m_params.get("split", True): nodes.append( Ff.Node(nodes[-1], Fm.Split1D, { 'split_size_or_sections': [12, 12], 'dim': 0 })) split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {})) nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten')) # fully_connected part for k in range(m_params['blocks_per_group'][3]): nodes.append( Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, { 'clamp': m_params['clamping_per_group'][3], 'subnet_constructor': sub_fc(m_params['fc_size'], m_params['hidden_layers_per_group'][3], dropout=m_params['dropout_fc'], batchnorm=m_params.get("bn", True)) }, conditions=conditions[3], name=F'block_{k + 10}')) #nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k})) #nodes.append(Ff.Node(nodes[-1], LearnedActNorm , {'M': torch.randn(1), "b": torch.randn(1)})) # concat everything if m_params.get("split", True): nodes.append( Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0], Fm.Concat1d, {'dim': 0})) nodes.append(Ff.OutputNode(nodes[-1])) inn = SketchINN(cond) inn.build_inn(nodes, split_nodes, conditions) return inn
def build_inn(self): def sub_conv(ch_hidden, kernel): pad = kernel // 2 return lambda ch_in, ch_out: nn.Sequential( nn.Conv2d(ch_in, ch_hidden, kernel, padding=pad), nn.ReLU(), nn.Conv2d(ch_hidden, ch_out, kernel, padding=pad)) def sub_fc(ch_hidden): return lambda ch_in, ch_out: nn.Sequential( nn.Linear(ch_in, ch_hidden), nn.ReLU(), nn.Linear(ch_hidden, ch_out)) nodes = [Ff.InputNode(2, 64, 64)] # outputs of the cond. net at different resolution levels conditions = [Ff.ConditionNode(64, 64, 64), Ff.ConditionNode(128, 32, 32), Ff.ConditionNode(128, 16, 16), Ff.ConditionNode(512)] split_nodes = [] subnet = sub_conv(32, 3) for k in range(2): nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':subnet, 'clamp':1.0}, conditions=conditions[0])) nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5})) for k in range(4): subnet = sub_conv(64, 3 if k%2 else 1) nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':subnet, 'clamp':1.0}, conditions=conditions[1])) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) #split off 6/8 ch nodes.append(Ff.Node(nodes[-1], Fm.Split1D, {'split_size_or_sections':[2,6], 'dim':0})) split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {})) nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5})) for k in range(4): subnet = sub_conv(128, 3 if k%2 else 1) nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':subnet, 'clamp':0.6}, conditions=conditions[2])) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) #split off 4/8 ch nodes.append(Ff.Node(nodes[-1], Fm.Split1D, {'split_size_or_sections':[4,4], 'dim':0})) split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {})) nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten')) # fully_connected part subnet = sub_fc(512) for k in range(4): nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':subnet, 'clamp':0.6}, conditions=conditions[3])) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) # concat everything nodes.append(Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0], Fm.Concat1d, {'dim':0})) nodes.append(Ff.OutputNode(nodes[-1])) return Ff.ReversibleGraphNet(nodes + split_nodes + conditions, verbose=False)
def build_inn(self): nodes = [Ff.InputNode(3, 64, 64, name='inp')] cond_node = Ff.ConditionNode(40, name='attr_cond') split_nodes = [] # ndim_x = 3 * 64 * 64 for i in range(2): nodes.append( Ff.Node([nodes[-1].out0], permute_layer, {'seed': i}, name=F'permute_0_{i}')) nodes.append( Ff.Node( [nodes[-1].out0], glow_coupling_layer, { 'clamp': 1., 'F_class': F_fully_conv, 'F_args': { 'kernel_size': 3, 'channels_hidden': 256, 'leaky_slope': 0.2 } }, name=F'conv_0_{i}')) nodes.append( Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5}, name='haar_1')) for i in range(4): kernel_size = 1 if i % 2 == 0 else 3 nodes.append( Ff.Node([nodes[-1].out0], permute_layer, {'seed': i}, name=F'permute_1_{i}')) nodes.append( Ff.Node( [nodes[-1].out0], glow_coupling_layer, { 'clamp': 1., 'F_class': F_fully_conv_shallow, 'F_args': { 'kernel_size': kernel_size, 'channels_hidden': 64, 'leaky_slope': 0.2 } }, name=F'conv_1_{i}')) nodes.append( Ff.Node(nodes[-1], Fm.Split1D, { 'split_size_or_sections': (8, 4), 'dim': 0 }, name='split_1')) split_nodes.append( Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_1')) nodes.append( Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5}, name='haar_2')) for i in range(8): kernel_size = 1 if i % 2 == 0 else 3 nodes.append( Ff.Node(nodes[-1], permute_layer, {'seed': i}, name=F'permute_2_{i}')) nodes.append( Ff.Node(nodes[-1], glow_coupling_layer, { 'clamp': 1., 'F_class': F_fully_conv_shallow, 'F_args': { 'kernel_size': kernel_size, 'channels_hidden': 64, 'leaky_slope': 0.2 } }, name=F'conv_2_{i}')) nodes.append( Ff.Node(nodes[-1], Fm.Split1D, { 'split_size_or_sections': (16, 16), 'dim': 0 }, name='split_2')) split_nodes.append( Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_2')) nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten_3')) for i in range(2): nodes.append( Ff.Node(nodes[-1], permute_layer, {'seed': i}, name=F'permute_3_{i}')) nodes.append( Ff.Node(nodes[-1], glow_coupling_layer, { 'clamp': 0.8, 'F_class': F_fully_shallow, 'F_args': { 'internal_size': 256, 'dropout': 0.2 } }, conditions=[cond_node], name=F'fc_{i}')) nodes.append( Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0], Fm.Concat1d, {'dim': 0}, name='concat')) nodes.append(Ff.OutputNode(nodes[-1], name='out')) conv_inn = Ff.ReversibleGraphNet(nodes + split_nodes + [cond_node], verbose=True) return conv_inn