def constuct_inn(img_dims=[3, 32, 32]): split_nodes = [] nodes = [fr.InputNode(*img_dims, name='input')] nodes.append( fr.Node(nodes[-1].out0, la.Reshape, {'target_dim': (img_dims[0], *img_dims[1:])}, name='reshape')) nodes.append( fr.Node(nodes[-1].out0, la.HaarDownsampling, {'rebalance': 0.5}, name='haar_down_1')) for k in range(n_coupling_blocks_conv_0): nodes.append( fr.Node(nodes[-1], la.GLOWCouplingBlock, { 'subnet_constructor': conv_constr_0, 'clamp': clamp }, name=f'CB CONV_0_{k}')) if k % 2: nodes.append( fr.Node(nodes[-1], la.Fixed1x1Conv, {'M': random_orthog(12)})) nodes.append( fr.Node(nodes[-1].out0, la.HaarDownsampling, {'rebalance': 0.5}, name='haar_down_2')) for k in range(n_coupling_blocks_conv_1): nodes.append( fr.Node(nodes[-1], la.GLOWCouplingBlock, { 'subnet_constructor': conv_constr_1, 'clamp': clamp }, name=f'CB CONV_1_{k}')) if k % 2: nodes.append( fr.Node(nodes[-1], la.Fixed1x1Conv, {'M': random_orthog(48)})) nodes.append(fr.Node(nodes[-1].out0, la.Flatten, {}, name='flatten')) for k in range(n_coupling_blocks_fc): nodes.append( fr.Node(nodes[-1], la.GLOWCouplingBlock, { 'subnet_constructor': fc_constr, 'clamp': clamp }, name=f'CB FC_{k}')) if k % 2: nodes.append(fr.Node(nodes[-1], la.PermuteRandom, {'seed': k})) nodes.append(fr.OutputNode(nodes[-1], name='output')) return fr.ReversibleGraphNet(nodes)
def construct_inn(self, input, backbone: InvertibleArchitecture, head: InvertibleArchitecture): nodes = [] split_nodes = [] nodes.append(input) backbone_nodes, backbone_split_nodes, skip_connections = backbone.construct_inn( nodes[-1]) nodes += backbone_nodes if skip_connections: print("HAS SKIP CONNECTION") head_nodes, head_split_nodes = head.construct_inn( nodes[-1], skip_connections) split_nodes += backbone_split_nodes else: head_nodes, head_split_nodes = head.construct_inn(nodes[-1]) nodes.append(Ff.OutputNode(head_nodes[-1], name='out_fc')) nodes += head_nodes split_nodes += head_split_nodes self.model = Ff.ReversibleGraphNet(nodes + split_nodes, verbose=True) print(self.model) return nodes
def build_inn(self): def fc_constr(): return lambda ch_in, ch_out: nn.Sequential(nn.Linear(ch_in, c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout), nn.Linear(c.internal_width, c.internal_width), nn.LeakyReLU(),#nn.BatchNorm1d(c.internal_width),#nn.Dropout(p=c.fc_dropout), nn.Linear(c.internal_width, ch_out)) nodes = [Ff.InputNode(c.x_dim)] # outputs of the cond. net at different resolution levels conditions = [] for i in range(c.n_blocks): conditions.append(Ff.ConditionNode(c.y_dim_features)) nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':fc_constr(), 'clamp':c.exponent_clamping}, conditions=conditions[i])) #nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':i}, name=f'PERM_FC_{i}')) #nodes.append(Ff.Node(nodes[-1], Fm.ActNorm, {}, name=f'ActN{i}')) nodes.append(Ff.OutputNode(nodes[-1])) return Ff.ReversibleGraphNet(nodes + conditions, verbose=False)
def generate_rcINN_old(): cond = Ff.ConditionNode(config.n_cond_features, name='condition') nodes = [Ff.InputNode(config.n_x_features, name='input')] for k in range(config.n_blocks): nodes.append( Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, { 'subnet_constructor': subnet_func, 'clamp': config.clamp }, conditions=cond, name=F'coupling_{k}')) nodes.append( Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}, name=F'permute_{k}')) nodes.append(Ff.OutputNode(nodes[-1], name='output')) model = Ff.ReversibleGraphNet(nodes + [cond]) model.to(device) params_trainable = list( filter(lambda p: p.requires_grad, model.parameters())) for p in params_trainable: p.data = config.init_scale * torch.randn_like(p).to(device) gamma = config.gamma optim = torch.optim.AdamW(params_trainable, lr=config.lr) weight_scheduler = torch.optim.lr_scheduler.StepLR(optim, step_size=1, gamma=gamma) return model, optim, weight_scheduler
def build_inn(self, nodes, split_nodes, conditions): self.model = Ff.ReversibleGraphNet(nodes + split_nodes + conditions, verbose=False) self.trainable_parameters = [ p for p in self.model.parameters() if p.requires_grad ] for p in self.trainable_parameters: p.data = 0.02 * torch.randn_like(p)
def get_invnet(ndim, nb_blocks=4, hidden_layer=128, small_block=False, permute=True, verbose=False): return freia_fw.ReversibleGraphNet(get_nodes(ndim, nb_blocks, hidden_layer, small_block, permute), verbose=verbose)
def _inn(self): def subnet(ch_in, ch_out): return nn.Sequential(nn.Linear(ch_in, 1024), nn.ReLU(), nn.Linear(1024, ch_out)) nodes = [] nodes.append(ff.InputNode(self.feat_length)) for k in range(self.depth): nodes.append(ff.Node(nodes[-1], fm.PermuteRandom, {'seed': k})) nodes.append( ff.Node(nodes[-1], fm.GLOWCouplingBlock, { 'subnet_constructor': subnet, 'clamp': 2.0 })) return ff.ReversibleGraphNet(nodes + [ff.OutputNode(nodes[-1])], verbose=False)
def construct_net_10d(coupling_block, init_identity=True): if coupling_block == 'gin': block = Fm.GINCouplingBlock else: assert coupling_block == 'glow' block = Fm.GLOWCouplingBlock nodes = [Ff.InputNode(10, name='input')] for k in range(8): nodes.append(Ff.Node(nodes[-1], block, {'subnet_constructor':lambda c_in,c_out: subnet_fc_10d(c_in, c_out, init_identity), 'clamp':2.0}, name=F'coupling_{k}')) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':np.random.randint(2**31)}, name=F'permute_{k+1}')) nodes.append(Ff.OutputNode(nodes[-1], name='output')) return Ff.ReversibleGraphNet(nodes)
def construct_net_emnist(coupling_block): if coupling_block == 'gin': block = Fm.GINCouplingBlock else: assert coupling_block == 'glow' block = Fm.GLOWCouplingBlock nodes = [Ff.InputNode(1, 28, 28, name='input')] nodes.append(Ff.Node(nodes[-1], Fm.IRevNetDownsampling, {}, name='downsample1')) for k in range(4): nodes.append(Ff.Node(nodes[-1], block, {'subnet_constructor':subnet_conv1, 'clamp':2.0}, name=F'coupling_conv1_{k}')) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':np.random.randint(2**31)}, name=F'permute_conv1_{k}')) nodes.append(Ff.Node(nodes[-1], Fm.IRevNetDownsampling, {}, name='downsample2')) for k in range(4): nodes.append(Ff.Node(nodes[-1], block, {'subnet_constructor':subnet_conv2, 'clamp':2.0}, name=F'coupling_conv2_{k}')) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':np.random.randint(2**31)}, name=F'permute_conv2_{k}')) nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten')) for k in range(2): nodes.append(Ff.Node(nodes[-1], block, {'subnet_constructor':subnet_fc, 'clamp':2.0}, name=F'coupling_fc_{k}')) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':np.random.randint(2**31)}, name=F'permute_fc_{k}')) nodes.append(Ff.OutputNode(nodes[-1], name='output')) return Ff.ReversibleGraphNet(nodes)
def build_inn(self): def subnet(ch_in, ch_out): return nn.Sequential(nn.Linear(ch_in, 512), nn.ReLU(), nn.Linear(512, ch_out)) nodes = [] nodes.append(ff.InputNode(28 * 28)) nodes.append(ff.Node(nodes[-1], fm.Flatten, {})) for k in range(20): nodes.append(ff.Node(nodes[-1], fm.PermuteRandom, {'seed': k})) nodes.append( ff.Node(nodes[-1], fm.GLOWCouplingBlock, { 'subnet_constructor': subnet, 'clamp': 1.0 })) return ff.ReversibleGraphNet(nodes + [ff.OutputNode(nodes[-1])], verbose=False)
def mnist_inn_one(mask_size=[28, 28]): """ Return an autoencoder. :param mask_size: size of the input. Default: Size of MNIST images :return: """ img_dims = [1, mask_size[0], mask_size[1]] inp = fr.InputNode(*img_dims, name='input') r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1') r2 = fr.Node([r1.out0], re.reshape_layer, {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )}, name='r2') fc = fr.Node([r2.out0], la.rev_multiplicative_layer, { 'F_class': fu.F_fully_connected, 'F_args': {}, 'clamp': 1 }, name='fc') r3 = fr.Node([fc.out0], re.reshape_layer, {'target_dim': (4, 14, 14)}, name='r3') r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4') outp = fr.OutputNode([r4.out0], name='output') nodes = [inp, outp, r1, r2, r3, r4, fc] coder = fr.ReversibleGraphNet(nodes, 0, 1) return coder
def build_inn(self): def subnet(ch_in, ch_out): return nn.Sequential(nn.Linear(ch_in, 512), nn.LeakyReLU(), nn.Linear(512, ch_out)) cond = Ff.ConditionNode(10) nodes = [Ff.InputNode(1, 28, 28)] nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {})) for k in range(16): nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k})) nodes.append( Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, { 'subnet_constructor': subnet, 'clamp': 0.8 }, conditions=cond)) return Ff.ReversibleGraphNet(nodes + [cond, Ff.OutputNode(nodes[-1])], verbose=False)
def build_inn(fe_c, fe_h, fe_w, subnet_mode, flow_blocks, verbose): nodes = [Ff.InputNode(fe_c, fe_h, fe_w, name='input')] dynamic_var = locals() L, index = len(flow_blocks), 0 # multiscale structure design if L > 1 for l in range(L): for k in range(flow_blocks[l]): nodes = one_step(nodes, subnet_mode, index, fe_h) index += 1 if l == L - 1: nodes.append(Ff.OutputNode(nodes[-1], name=F'output{l}')) else: # split off half of the channels way to reduce dimension redundancy nodes.append( Ff.Node(nodes[-1], Fm.SplitChannel, {}, name=F'split{l}')) dynamic_var['output' + str(l)] = Ff.OutputNode(nodes[-1].out1, name=F'output{l}') return Ff.ReversibleGraphNet( nodes + [dynamic_var['output' + str(i)] for i in range(L - 1)], verbose=verbose)
def inn_model(img_dims=4): """ Return INN model. :param img_dims: size of the model input images. Default: Size of MNIST images :return: INN model """ inp = fr.InputNode(img_dims, name='input') fc1 = fr.Node([inp.out0], la.GLOWCouplingBlock, { 'subnet_constructor': fc_constr, 'clamp': 2 }, name='fc1') fc2 = fr.Node([fc1.out0], la.GLOWCouplingBlock, { 'subnet_constructor': fc_constr, 'clamp': 2 }, name='fc2') fc3 = fr.Node([fc2.out0], la.GLOWCouplingBlock, { 'subnet_constructor': fc_constr, 'clamp': 2 }, name='fc3') outp = fr.OutputNode([fc3.out0], name='output') nodes = [inp, outp, fc1, fc2, fc3] model = fr.ReversibleGraphNet(nodes) return model
def get_mnist_conv(mask_size=[28, 28]): """ Return an autoencoder. :param mask_size: size of the input. Default: Size of MNIST images :return: """ img_dims = [1, mask_size[0], mask_size[1]] inp = fr.InputNode(*img_dims, name='input') r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1') conv1 = fr.Node( [r1.out0], la.rev_multiplicative_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': False }, 'clamp': 1 }, name='conv1') conv2 = fr.Node( [conv1.out0], la.rev_multiplicative_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': False }, 'clamp': 1 }, name='conv2') conv3 = fr.Node( [conv2.out0], la.rev_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': False } }, name='conv3') r2 = fr.Node([conv3.out0], re.reshape_layer, {'target_dim': (784, )}, name='r2') fc = fr.Node([r2.out0], la.rev_multiplicative_layer, { 'F_class': fu.F_fully_connected, 'F_args': {}, 'clamp': 1 }, name='fc') r3 = fr.Node([fc.out0], re.reshape_layer, {'target_dim': (4, 14, 14)}, name='r3') r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4') outp = fr.OutputNode([r4.out0], name='output') nodes = [inp, outp, conv1, conv2, conv3, r1, r2, r3, r4, fc] coder = fr.ReversibleGraphNet(nodes, 0, 1) return coder
def celeba_inn_small(mask_size=[156, 128]): """ Return CelebA INN autoencoder for comparison with classical autoencoder (same number of parameters). :param latent_dim: dimension of the latent space :param mask_size: size of the input. Default: Size of CelebA images :param batch_norm: use batch norm for the F_conv modules :return: CelebA INN autoencoder """ img_dims = [3, mask_size[0], mask_size[1]] inp = fr.InputNode(*img_dims, name='input') r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1') conv11 = fr.Node([r1.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv11') conv12 = fr.Node([conv11.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv12') conv13 = fr.Node([conv12.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv13') r2 = fr.Node([conv13.out0], re.reshape_layer, {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )}, name='r2') fc = fr.Node( [r2.out0], la.rev_multiplicative_layer, { 'F_class': fu.F_small_connected, 'F_args': { 'internal_size': 5000 }, 'clamp': 1 }, name='fc') r3 = fr.Node([fc.out0], re.reshape_layer, {'target_dim': (12, 78, 64)}, name='r3') r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4') outp = fr.OutputNode([r4.out0], name='output') nodes = [inp, outp, conv11, conv12, conv13, fc, r1, r2, r3, r4] coder = fr.ReversibleGraphNet(nodes, 0, 1) return coder
def strided_constr(cin, cout): layers = [ nn.Conv2d(cin, 16, 3, stride=2, padding=1), nn.ReLU(), nn.Conv2d(16, cout, 1, stride=1)] return nn.Sequential(*layers) def low_res_constr(cin, cout): layers = [nn.Conv2d(cin, cout, 1)] return nn.Sequential(*layers) inp = Ff.InputNode(3, 32, 32, name='in') node = Ff.Node(inp, HighPerfCouplingBlock, {'subnet_constructor':low_res_constr}, name='coupling') node2 = Ff.Node(node, DownsampleCouplingBlock, {'subnet_constructor_strided':strided_constr, 'subnet_constructor_low_res':low_res_constr}, name='down_coupling') out = Ff.OutputNode(node2, name='out') net = Ff.ReversibleGraphNet([inp, node, node2, out]) x = torch.randn(4, 3, 32, 32) z = net(x) jac = net.log_jacobian(run_forward=False) x_inv = net(z, rev=True) diff = x - x_inv print('shape in') print(x.shape) print('shape out') print(z.shape) print('shape inv') print(x_inv.shape)
n_glow = 8 for i in range(n_glow): nodes.append( Ff.Node(nodes[-1], Fm.PermuteRandom, {"seed": i}, name=f"permute_{i}")) nodes.append( Ff.Node( nodes[-1], Fm.GLOWCouplingBlock, { "subnet_constructor": CreateFC(128), "clamp": 2 }, name=f"glow_{i}", )) nodes.append(Ff.OutputNode(nodes[-1], name="output")) net = Ff.ReversibleGraphNet(nodes, verbose=False) net = net.to(device=device) print("Training") # # Training # losses = [] val_losses = [] epochs = args.iterations n_batch = args.batch_size writer = SummaryWriter() n = training_data_npy.shape[0] n_val = n // 10
def cifar_inn_glow(mask_size=[32, 32], batch_norm=False): """ Return an autoencoder. :param mask_size: size of the input. Default: Size of CIFAR images :return: autoencoder """ img_dims = [3, mask_size[0], mask_size[1]] inp = fr.InputNode(*img_dims, name='input') r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1') conv11 = fr.Node( [r1.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv11') conv12 = fr.Node( [conv11.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv12') conv13 = fr.Node( [conv12.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv13') r2 = fr.Node([conv13.out0], re.reshape_layer, {'target_dim': (3072, )}, name='r2') fc = fr.Node([r2.out0], la.glow_coupling_layer, { 'F_class': fu.F_fully_connected, 'clamp': 1 }, name='fc') r3 = fr.Node([fc.out0], re.reshape_layer, {'target_dim': (12, 16, 16)}, name='r3') r4 = fr.Node([r3.out0], re.haar_restore_layer, {}, name='r4') outp = fr.OutputNode([r4.out0], name='output') nodes = [inp, outp, conv11, conv12, conv13, r1, r2, r3, r4, fc] coder = fr.ReversibleGraphNet(nodes, 0, 1) return coder
def build_inn(self): nodes = [Ff.InputNode(3, 64, 64, name='inp')] cond_node = Ff.ConditionNode(40, name='attr_cond') split_nodes = [] # ndim_x = 3 * 64 * 64 for i in range(2): nodes.append( Ff.Node([nodes[-1].out0], permute_layer, {'seed': i}, name=F'permute_0_{i}')) nodes.append( Ff.Node( [nodes[-1].out0], glow_coupling_layer, { 'clamp': 1., 'F_class': F_fully_conv, 'F_args': { 'kernel_size': 3, 'channels_hidden': 256, 'leaky_slope': 0.2 } }, name=F'conv_0_{i}')) nodes.append( Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5}, name='haar_1')) for i in range(4): kernel_size = 1 if i % 2 == 0 else 3 nodes.append( Ff.Node([nodes[-1].out0], permute_layer, {'seed': i}, name=F'permute_1_{i}')) nodes.append( Ff.Node( [nodes[-1].out0], glow_coupling_layer, { 'clamp': 1., 'F_class': F_fully_conv_shallow, 'F_args': { 'kernel_size': kernel_size, 'channels_hidden': 64, 'leaky_slope': 0.2 } }, name=F'conv_1_{i}')) nodes.append( Ff.Node(nodes[-1], Fm.Split1D, { 'split_size_or_sections': (8, 4), 'dim': 0 }, name='split_1')) split_nodes.append( Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_1')) nodes.append( Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance': 0.5}, name='haar_2')) for i in range(8): kernel_size = 1 if i % 2 == 0 else 3 nodes.append( Ff.Node(nodes[-1], permute_layer, {'seed': i}, name=F'permute_2_{i}')) nodes.append( Ff.Node(nodes[-1], glow_coupling_layer, { 'clamp': 1., 'F_class': F_fully_conv_shallow, 'F_args': { 'kernel_size': kernel_size, 'channels_hidden': 64, 'leaky_slope': 0.2 } }, name=F'conv_2_{i}')) nodes.append( Ff.Node(nodes[-1], Fm.Split1D, { 'split_size_or_sections': (16, 16), 'dim': 0 }, name='split_2')) split_nodes.append( Ff.Node(nodes[-1].out1, Fm.Flatten, {}, name='flatten_2')) nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten_3')) for i in range(2): nodes.append( Ff.Node(nodes[-1], permute_layer, {'seed': i}, name=F'permute_3_{i}')) nodes.append( Ff.Node(nodes[-1], glow_coupling_layer, { 'clamp': 0.8, 'F_class': F_fully_shallow, 'F_args': { 'internal_size': 256, 'dropout': 0.2 } }, conditions=[cond_node], name=F'fc_{i}')) nodes.append( Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0], Fm.Concat1d, {'dim': 0}, name='concat')) nodes.append(Ff.OutputNode(nodes[-1], name='out')) conv_inn = Ff.ReversibleGraphNet(nodes + split_nodes + [cond_node], verbose=True) return conv_inn
def inn_model(img_dims=[1, 28, 28]): """ Return INN model for MNIST. :param img_dims: size of the model input images. Default: Size of MNIST images :return: INN model """ inp = fr.InputNode(*img_dims, name='input') r1 = fr.Node( [inp.out0], re.reshape_layer, {'target_dim': (img_dims[0] * 4, img_dims[1] // 2, img_dims[2] // 2)}, name='r1') conv1 = fr.Node([r1.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv1') conv2 = fr.Node([conv1.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv2') conv3 = fr.Node([conv2.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv3') r2 = fr.Node([conv3.out0], re.reshape_layer, {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )}, name='r2') fc = fr.Node([r2.out0], la.rev_multiplicative_layer, { 'F_class': fu.F_fully_connected, 'clamp': 1 }, name='fc') r3 = fr.Node([fc.out0], re.reshape_layer, {'target_dim': (img_dims[0], img_dims[1], img_dims[2])}, name='r3') outp = fr.OutputNode([r3.out0], name='output') nodes = [inp, outp, conv1, conv2, conv3, r1, r2, r3, fc] model = fr.ReversibleGraphNet(nodes, 0, 1) return model
def celeba_inn_glow(mask_size=[156, 128], batch_norm=False): """ Return an autoencoder. :param mask_size: size of the input. Default: Size of CIFAR images :return: autoencoder """ img_dims = [3, mask_size[0], mask_size[1]] inp = fr.InputNode(*img_dims, name='input') r1 = fr.Node([inp.out0], re.haar_multiplex_layer, {}, name='r1') conv11 = fr.Node( [r1.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv11') conv12 = fr.Node( [conv11.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv12') conv13 = fr.Node( [conv12.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv13') r2 = fr.Node([conv13.out0], re.haar_multiplex_layer, {}, name='r2') conv21 = fr.Node( [r2.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv21') conv22 = fr.Node( [conv21.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv22') conv23 = fr.Node( [conv22.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128, 'batch_norm': batch_norm }, 'clamp': 1 }, name='conv23') r3 = fr.Node([conv23.out0], re.reshape_layer, {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )}, name='r3') fc = fr.Node( [r3.out0], la.rev_multiplicative_layer, { 'F_class': fu.F_small_connected, 'F_args': { 'internal_size': 500 }, 'clamp': 1 }, name='fc') r4 = fr.Node([fc.out0], re.reshape_layer, {'target_dim': (48, 39, 32)}, name='r4') r5 = fr.Node([r4.out0], re.haar_restore_layer, {}, name='r5') r6 = fr.Node([r5.out0], re.haar_restore_layer, {}, name='r6') outp = fr.OutputNode([r6.out0], name='output') nodes = [ inp, outp, conv11, conv12, conv13, conv21, conv22, conv23, r1, r2, r3, r4, r5, r6, fc ] coder = fr.ReversibleGraphNet(nodes, 0, 1) return coder
def build_inn(self): def sub_conv(ch_hidden, kernel): pad = kernel // 2 return lambda ch_in, ch_out: nn.Sequential( nn.Conv2d(ch_in, ch_hidden, kernel, padding=pad), nn.ReLU(), nn.Conv2d(ch_hidden, ch_out, kernel, padding=pad)) def sub_fc(ch_hidden): return lambda ch_in, ch_out: nn.Sequential( nn.Linear(ch_in, ch_hidden), nn.ReLU(), nn.Linear(ch_hidden, ch_out)) nodes = [Ff.InputNode(2, 64, 64)] # outputs of the cond. net at different resolution levels conditions = [Ff.ConditionNode(64, 64, 64), Ff.ConditionNode(128, 32, 32), Ff.ConditionNode(128, 16, 16), Ff.ConditionNode(512)] split_nodes = [] subnet = sub_conv(32, 3) for k in range(2): nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':subnet, 'clamp':1.0}, conditions=conditions[0])) nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5})) for k in range(4): subnet = sub_conv(64, 3 if k%2 else 1) nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':subnet, 'clamp':1.0}, conditions=conditions[1])) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) #split off 6/8 ch nodes.append(Ff.Node(nodes[-1], Fm.Split1D, {'split_size_or_sections':[2,6], 'dim':0})) split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {})) nodes.append(Ff.Node(nodes[-1], Fm.HaarDownsampling, {'rebalance':0.5})) for k in range(4): subnet = sub_conv(128, 3 if k%2 else 1) nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':subnet, 'clamp':0.6}, conditions=conditions[2])) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) #split off 4/8 ch nodes.append(Ff.Node(nodes[-1], Fm.Split1D, {'split_size_or_sections':[4,4], 'dim':0})) split_nodes.append(Ff.Node(nodes[-1].out1, Fm.Flatten, {})) nodes.append(Ff.Node(nodes[-1], Fm.Flatten, {}, name='flatten')) # fully_connected part subnet = sub_fc(512) for k in range(4): nodes.append(Ff.Node(nodes[-1], Fm.GLOWCouplingBlock, {'subnet_constructor':subnet, 'clamp':0.6}, conditions=conditions[3])) nodes.append(Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed':k})) # concat everything nodes.append(Ff.Node([s.out0 for s in split_nodes] + [nodes[-1].out0], Fm.Concat1d, {'dim':0})) nodes.append(Ff.OutputNode(nodes[-1])) return Ff.ReversibleGraphNet(nodes + split_nodes + conditions, verbose=False)
def get_flow( n_in: int, n_out: int, init_identity: bool = False, coupling_block: Union[Literal["gin", "glow"]] = "gin", num_nodes: int = 8, node_size_factor: int = 1, ): """ Creates an flow-based network. Args: n_in: Dimensionality of the input data n_out: Dimensionality of the output data init_identity: Initialize weights to identity network. coupling_block: Coupling method to use to combine nodes. num_nodes: Depth of the flow network. node_size_factor: Multiplier for the hidden units per node. """ # do lazy imports here such that the package is only # required if one wants to use the flow mixing import FrEIA.framework as Ff import FrEIA.modules as Fm def _invertible_subnet_fc(c_in, c_out, init_identity): subnet = nn.Sequential( nn.Linear(c_in, c_in * node_size), nn.ReLU(), nn.Linear(c_in * node_size, c_in * node_size), nn.ReLU(), nn.Linear(c_in * node_size, c_out), ) if init_identity: subnet[-1].weight.data.fill_(0.0) subnet[-1].bias.data.fill_(0.0) return subnet assert n_in == n_out if coupling_block == "gin": block = Fm.GINCouplingBlock else: assert coupling_block == "glow" block = Fm.GLOWCouplingBlock nodes = [Ff.InputNode(n_in, name="input")] for k in range(num_nodes): nodes.append( Ff.Node( nodes[-1], block, { "subnet_constructor": lambda c_in, c_out: _invertible_subnet_fc( c_in, c_out, init_identity), "clamp": 2.0, }, name=f"coupling_{k}", )) nodes.append(Ff.OutputNode(nodes[-1], name="output")) return Ff.ReversibleGraphNet(nodes, verbose=False)
def artset_inn_model(img_dims=[3, 224, 224]): """ Return INN model for Painter by Numbers artset. :param img_dims: size of the model input images. Default: Size of MNIST images :return: INN model """ inp = fr.InputNode(*img_dims, name='input') r1 = fr.Node( [inp.out0], re.reshape_layer, {'target_dim': (img_dims[0] * 4, img_dims[1] // 2, img_dims[2] // 2)}, name='r1') conv11 = fr.Node([r1.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 256 }, 'clamp': 1 }, name='conv11') conv12 = fr.Node([conv11.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 256 }, 'clamp': 1 }, name='conv12') conv13 = fr.Node([conv12.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 256 }, 'clamp': 1 }, name='conv13') conv21 = fr.Node([conv13.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv21') conv22 = fr.Node([conv21.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv22') conv23 = fr.Node([conv22.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 128 }, 'clamp': 1 }, name='conv23') conv31 = fr.Node([conv23.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 64 }, 'clamp': 1 }, name='conv31') conv32 = fr.Node([conv31.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 64 }, 'clamp': 1 }, name='conv32') conv33 = fr.Node([conv32.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 64 }, 'clamp': 1 }, name='conv33') conv41 = fr.Node([conv33.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 32 }, 'clamp': 1 }, name='conv41') conv42 = fr.Node([conv41.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 32 }, 'clamp': 1 }, name='conv42') conv43 = fr.Node([conv42.out0], la.glow_coupling_layer, { 'F_class': fu.F_conv, 'F_args': { 'channels_hidden': 32 }, 'clamp': 1 }, name='conv43') r2 = fr.Node([conv43.out0], re.reshape_layer, {'target_dim': (img_dims[0] * img_dims[1] * img_dims[2], )}, name='r2') fc = fr.Node( [r2.out0], la.rev_multiplicative_layer, { 'F_class': fu.F_small_connected, 'F_args': { 'internal_size': 128 }, 'clamp': 1 }, name='fc') r3 = fr.Node([fc.out0], re.reshape_layer, {'target_dim': (img_dims[0], img_dims[1], img_dims[2])}, name='r3') outp = fr.OutputNode([r3.out0], name='output') nodes = [ inp, outp, conv11, conv12, conv13, conv21, conv22, conv23, conv31, conv32, conv33, conv41, conv42, conv43, fc, r1, r2, r3 ] coder = fr.ReversibleGraphNet(nodes, 0, 1) return coder
def constuct_inn(classifier, verbose=False): fc_width = int(classifier.args['model']['fc_width']) n_coupling_blocks_fc = int( classifier.args['model']['n_coupling_blocks_fc']) use_dct = eval(classifier.args['model']['dct_pooling']) conv_widths = eval(classifier.args['model']['conv_widths']) n_coupling_blocks_conv = eval( classifier.args['model']['n_coupling_blocks_conv']) dropouts = eval(classifier.args['model']['dropout_conv']) dropouts_fc = float(classifier.args['model']['dropout_fc']) groups = int(classifier.args['model']['n_groups']) clamp = float(classifier.args['model']['clamp']) ndim_input = classifier.dims batchnorm_args = { 'track_running_stats': True, 'momentum': 0.999, 'eps': 1e-4, } coupling_args = { 'subnet_constructor': None, 'clamp': clamp, 'act_norm': float(classifier.args['model']['act_norm']), 'gin_block': False, 'permute_soft': True, } def weights_init(m): if type(m) == nn.Conv2d: torch.nn.init.kaiming_normal_(m.weight) if type(m) == nn.BatchNorm2d: m.weight.data.fill_(1) m.bias.data.zero_() if type(m) == nn.Linear: torch.nn.init.kaiming_normal_(m.weight) m.weight.data *= 0.1 def basic_residual_block(width, groups, dropout, relu_first, cin, cout): width = width * groups layers = [] if relu_first: layers = [nn.ReLU()] else: layers = [] layers.extend([ nn.Conv2d(cin, width, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(width, **batchnorm_args), nn.ReLU(inplace=True), nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False, groups=groups), nn.BatchNorm2d(width, **batchnorm_args), nn.ReLU(inplace=True), nn.Dropout2d(p=dropout), nn.Conv2d(width, cout, 1, padding=0) ]) layers = nn.Sequential(*layers) layers.apply(weights_init) return layers def strided_residual_block(width, groups, cin, cout): width = width * groups layers = nn.Sequential( nn.Conv2d(cin, width, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(width, **batchnorm_args), nn.ReLU(), nn.Conv2d(width, width, kernel_size=3, stride=2, padding=1, bias=False, groups=groups), nn.BatchNorm2d(width, **batchnorm_args), nn.ReLU(inplace=True), nn.Conv2d(width, cout, 1, padding=0)) layers.apply(weights_init) return layers def fc_constr(c_in, c_out): net = [ nn.Linear(c_in, fc_width), nn.ReLU(), nn.Dropout(p=dropouts_fc), nn.Linear(fc_width, c_out) ] net = nn.Sequential(*net) net.apply(weights_init) return net nodes = [Ff.InputNode(*ndim_input, name='input')] channels = classifier.input_channels if classifier.dataset == 'MNIST': nodes.append( Ff.Node(nodes[-1].out0, Fm.Reshape, {'target_dim': (1, *classifier.dims)})) nodes.append( Ff.Node(nodes[-1].out0, Fm.HaarDownsampling, {'rebalance': 1.})) channels *= 4 for i, (width, n_blocks) in enumerate(zip(conv_widths, n_coupling_blocks_conv)): if classifier.dataset == 'MNIST' and i == 0: continue drop = dropouts[i] conv_constr = partial(basic_residual_block, width, groups, drop, True) conv_strided = partial(strided_residual_block, width * 2, groups) conv_lowres = partial(basic_residual_block, width * 2, groups, drop, False) if i == 0: conv_first = partial(basic_residual_block, width, groups, drop, False) else: conv_first = conv_constr nodes.append( Ff.Node(nodes[-1], AIO_Block, dict(coupling_args, subnet_constructor=conv_first), name=f'CONV_{i}_0')) for k in range(1, n_blocks): nodes.append( Ff.Node(nodes[-1], AIO_Block, dict(coupling_args, subnet_constructor=conv_constr), name=f'CONV_{i}_{k}')) if i < len(conv_widths) - 1: nodes.append( Ff.Node(nodes[-1], DownsampleCouplingBlock, { 'subnet_constructor_low_res': conv_lowres, 'subnet_constructor_strided': conv_strided, 'clamp': clamp }, name=f'DOWN_{i}')) channels *= 4 if use_dct: nodes.append( Ff.Node(nodes[-1].out0, DCTPooling2d, {'rebalance': 0.5}, name='DCT')) else: nodes.append(Ff.Node(nodes[-1].out0, Fm.Flatten, {}, name='Flatten')) for k in range(n_coupling_blocks_fc): nodes.append( Ff.Node(nodes[-1], Fm.PermuteRandom, {'seed': k}, name=f'PERM_FC_{k}')) nodes.append( Ff.Node(nodes[-1].out0, Fm.GLOWCouplingBlock, { 'subnet_constructor': fc_constr, 'clamp': 2.0 }, name=f'FC_{k}')) nodes.append(Ff.OutputNode(nodes[-1], name='output')) return Ff.ReversibleGraphNet(nodes, verbose=verbose)