def __init__(self, in_channels, mid_channels, num_mixtures, num_blocks, dropout, checkerboard=False, flip=False): if checkerboard: num_in = in_channels split_dim = 3 else: num_in = in_channels // 2 split_dim = 1 net = nn.Sequential( TransformerNet(in_channels=num_in, mid_channels=mid_channels, num_blocks=num_blocks, num_mixtures=num_mixtures, dropout=dropout), ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential')) super(MixtureCoupling, self).__init__(coupling_net=net, num_mixtures=num_mixtures, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def __init__(self, x_size, y_size, mid_channels, num_blocks, num_mixtures, dropout, checkerboard=False, flip=False): if checkerboard: in_channels = y_size[0] split_dim = 3 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2 else: in_channels = y_size[0] // 2 split_dim = 1 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] assert y_size[ 0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible" coupling_net = nn.Sequential( TransformerNet(in_channels=in_channels, context_channels=x_size[0], mid_channels=mid_channels, num_blocks=num_blocks, num_mixtures=num_mixtures, dropout=dropout), ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential')) super(SRMixtureCoupling, self).__init__(coupling_net=coupling_net, num_mixtures=num_mixtures, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def __init__(self, in_channels, num_context, num_blocks, mid_channels, depth, growth, dropout, gated_conv, coupling_network): assert in_channels % 2 == 0 if coupling_network == "densenet": net = nn.Sequential( DenseNet(in_channels=in_channels // 2 + num_context, out_channels=in_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": net = nn.Sequential( ConvNet(in_channels=in_channels // 2 + num_context, out_channels=in_channels, mid_channels=mid_channels, num_layers=depth, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(ConditionalCoupling, self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"))
def __init__(self, x_size, y_size, coupling_network, mid_channels, depth, num_blocks=None, dropout=None, gated_conv=None, checkerboard=False, flip=False): if checkerboard: in_channels = y_size[0] + x_size[0] out_channels = y_size[0] * 2 split_dim = 3 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2 else: in_channels = y_size[0] // 2 + x_size[0] out_channels = y_size[0] split_dim = 1 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] assert y_size[ 0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible" if coupling_network == "densenet": coupling_net = nn.Sequential( DenseNet(in_channels=in_channels, out_channels=out_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=mid_channels, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": coupling_net = nn.Sequential( ConvNet(in_channels=in_channels, out_channels=out_channels, mid_channels=mid_channels, num_layers=depth, weight_norm=True, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(SRCoupling, self).__init__(coupling_net=coupling_net, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units, activation, range_flow, augment_size, base_dist, cond_size): D = 2 # Number of data dimensions A = D + augment_size # Number of augmented data dimensions P = 2 if affine else 1 # Number of elementwise parameters # initialize context. Only upsample context in ContextInit if latent shape doesn't change during the flow. context_init = MLP(input_size=cond_size, output_size=D, hidden_units=hidden_units, activation=activation) # initialize flow with either augmentation or Abs surjection if augment_size > 0: assert augment_size % 2 == 0 transforms = [Augment(StandardNormal((augment_size, )), x_size=D)] else: transforms = [] transforms = [SimpleAbsSurjection()] if range_flow == 'logit': transforms += [ ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])), Logit() ] elif range_flow == 'softplus': transforms += [SoftplusInverse()] # apply coupling layer flows for _ in range(num_flows): net = nn.Sequential( MLP(A // 2 + D, P * A // 2, hidden_units=hidden_units, activation=activation), ElementwiseParams(P)) if affine: transforms.append( ConditionalAffineCouplingBijection( net, scale_fn=scale_fn(scale_fn_str))) else: transforms.append(ConditionalAdditiveCouplingBijection(net)) if actnorm: transforms.append(ActNormBijection(D)) transforms.append(Reverse(A)) transforms.pop() if base_dist == "uniform": base = StandardUniform((A, )) else: base = StandardNormal((A, )) super(SRFlow, self).__init__(base_dist=base, transforms=transforms, context_init=context_init)
def __init__(self, in_channels, num_context, num_blocks, mid_channels, depth, dropout, gated_conv, coupling_network, checkerboard=False, flip=False): if checkerboard: num_in = in_channels + num_context num_out = in_channels * 2 split_dim = 3 else: num_in = in_channels // 2 + num_context num_out = in_channels split_dim = 1 assert in_channels % 2 == 0 or split_dim != 1, f"in_channels = {in_channels} not evenly divisible" if coupling_network == "densenet": net = nn.Sequential( DenseNet(in_channels=num_in, out_channels=num_out, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=mid_channels, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": net = nn.Sequential( ConvNet(in_channels=num_in, out_channels=num_out, mid_channels=mid_channels, num_layers=depth, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(ConditionalCoupling, self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def __init__(self, in_channels, mid_channels, num_mixtures, num_blocks, dropout): net = nn.Sequential( TransformerNet(in_channels // 2, mid_channels=mid_channels, num_blocks=num_blocks, num_mixtures=num_mixtures, dropout=dropout), ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential')) super(MixtureCoupling, self).__init__(coupling_net=net, num_mixtures=num_mixtures, scale_fn=scale_fn("tanh_exp"))
def test_bijection_is_well_behaved(self): batch_size = 10 self.eps = 5e-6 for scale_str in ['exp', 'softplus', 'sigmoid', 'tanh_exp']: for shape in [(6, ), (6, 8, 8)]: for num_condition in [None, 1]: with self.subTest(shape=shape, num_condition=num_condition, scale_str=scale_str): x = torch.randn(batch_size, *shape) context = torch.randn(batch_size, *shape) if num_condition is None: if len(shape) == 1: net = nn.Sequential(nn.Linear(3 + 6, 3 * 2), ElementwiseParams(2)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(3 + 6, 3 * 2, kernel_size=3, padding=1), ElementwiseParams2d(2)) else: if len(shape) == 1: net = nn.Sequential(nn.Linear(1 + 6, 5 * 2), ElementwiseParams(2)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(1 + 6, 5 * 2, kernel_size=3, padding=1), ElementwiseParams2d(2)) bijection = ConditionalAffineCouplingBijection( net, num_condition=num_condition, scale_fn=scale_fn(scale_str)) self.assert_bijection_is_well_behaved( bijection, x, context, z_shape=(batch_size, *shape)) z, _ = bijection.forward(x, context) if num_condition is None: self.assertEqual(x[:, :3], z[:, :3]) else: self.assertEqual(x[:, :1], z[:, :1])
def dimwise(transforms): net = nn.Sequential( DenseTransformer(d_input=D // 2, d_output=P * D // 2, d_model=args.d_model, nhead=args.nhead, num_layers=args.num_layers, dim_feedforward=4 * args.d_model, dropout=args.dropout, activation=args.activation, checkpoint_blocks=args.checkpoint_blocks), ElementwiseParams1d(P)) if args.affine: transforms.append( AffineCouplingBijection(net, split_dim=1, scale_fn=scale_fn(args.scale_fn))) else: transforms.append(AdditiveCouplingBijection(net, split_dim=1)) transforms.append(Reverse(D, dim=1)) return transforms
def __init__(self, in_channels, num_context, mid_channels, num_mixtures, num_blocks, dropout, use_attn=True): coupling_net = nn.Sequential( TransformerNet(in_channels // 2, context_channels=num_context, mid_channels=mid_channels, num_blocks=num_blocks, num_mixtures=num_mixtures, use_attn=use_attn, dropout=dropout), ElementwiseParams2d(2 + num_mixtures * 3, mode='sequential')) super(ConditionalMixtureCoupling, self).__init__(coupling_net=coupling_net, num_mixtures=num_mixtures, scale_fn=scale_fn("tanh_exp"))
def lenwise(transforms): net = nn.Sequential( DenseTransformer(d_input=D, d_output=P * D, d_model=args.d_model, nhead=args.nhead, num_layers=args.num_layers, dim_feedforward=4 * args.d_model, dropout=args.dropout, activation=args.activation, checkpoint_blocks=args.checkpoint_blocks), ElementwiseParams1d(P)) if args.affine: transforms.append( AffineCouplingBijection(net, split_dim=2, scale_fn=scale_fn(args.scale_fn))) else: transforms.append(AdditiveCouplingBijection(net, split_dim=2)) if args.stochperm: transforms.append(StochasticPermutation(dim=2)) else: transforms.append(Shuffle(L, dim=2)) return transforms
I = D // 2 O = D // 2 + D % 2 # Decoder if args.num_bits is not None: transforms = [Logit()] for _ in range(args.num_flows): net = nn.Sequential(MLP(C+I, P*O, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**args.num_bits-1) - 1), MLP(D, C, hidden_units=args.hidden_units, activation=args.activation)) if args.affine: transforms.append(ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn(args.scale_fn), num_condition=I)) else: transforms.append(ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net, num_condition=I)) if args.actnorm: transforms.append(ActNormBijection(D)) if args.permutation == 'reverse': transforms.append(Reverse(D)) elif args.permutation == 'shuffle': transforms.append(Shuffle(D)) transforms.pop() decoder = ConditionalFlow(base_dist=StandardNormal((D,)), transforms=transforms).to(args.device) # Flow transforms = [] for _ in range(args.num_flows): net = nn.Sequential(MLP(I, P*O, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn), num_condition=I))
train_loader, test_loader = get_data(args) ################### ## Specify model ## ################### D = 2 # Number of data dimensions P = 2 if args.affine else 1 # Number of elementwise parameters transforms = [] for _ in range(args.num_flows): net = nn.Sequential(MLP(1, P, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) if args.affine: transforms.append(AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn))) else: transforms.append(AdditiveCouplingBijection(net)) if args.actnorm: transforms.append(ActNormBijection(D)) transforms.append(Reverse(D)) transforms.pop() model = Flow(base_dist=StandardNormal((2,)), transforms=transforms).to(args.device) ####################### ## Specify optimizer ## ####################### if args.optimizer == 'adam': optimizer = Adam(model.parameters(), lr=args.lr)
model = Flow(base_dist=ConvNormal2d((16,7,7)), transforms=[ UniformDequantization(num_bits=8), # Augment(StandardUniform((1, 28, 28)), x_size=1), # AffineCouplingBijection(net(2)), ActNormBijection2d(2), Conv1x1(2), # AffineCouplingBijection(net(2)), ActNormBijection2d(2), Conv1x1(2), # Squeeze2d(), Slice(StandardNormal((4, 14, 14)), num_keep=4), # AffineCouplingBijection(net(4)), ActNormBijection2d(4), Conv1x1(4), # AffineCouplingBijection(net(4)), ActNormBijection2d(4), Conv1x1(4), # Squeeze2d(), Slice(StandardNormal((8, 7, 7)), num_keep=8), # AffineCouplingBijection(net(8)), ActNormBijection2d(8), Conv1x1(8), # AffineCouplingBijection(net(8)), ActNormBijection2d(8), Conv1x1(8), #Logit(0.05), ScalarAffineBijection(shift=-0.5), Squeeze2d(), ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(4), Conv1x1(4), AffineCouplingBijection(net(4), scale_fn=scale_fn("tanh_exp")), Squeeze2d(), ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(16), Conv1x1(16), AffineCouplingBijection(net(16), scale_fn=scale_fn("tanh_exp")),
dropout=0.2), ElementwiseParams2d(2 + k * 3)) #model = Flow(base_dist=StandardNormal((16,7,7)), model = Flow( base_dist=ConvNormal2d((16, 7, 7)), transforms=[ UniformDequantization(num_bits=8), #Logit(), ScalarAffineBijection(shift=-0.5), Squeeze2d(), ActNormBijection2d(4), Conv1x1(4), LogisticMixtureAffineCouplingBijection(net(4), num_mixtures=k, scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(4), Conv1x1(4), LogisticMixtureAffineCouplingBijection(net(4), num_mixtures=k, scale_fn=scale_fn("tanh_exp")), Squeeze2d(), ActNormBijection2d(16), Conv1x1(16), LogisticMixtureAffineCouplingBijection(net(16), num_mixtures=k, scale_fn=scale_fn("tanh_exp")), ActNormBijection2d(16), Conv1x1(16), LogisticMixtureAffineCouplingBijection(net(16), num_mixtures=k,
def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units, activation, range_flow, augment_size, base_dist): D = 2 # Number of data dimensions if base_dist == "uniform": classifier = MLP(D, D // 2, hidden_units=hidden_units, activation=activation, out_lambda=lambda x: x.view(-1)) transforms = [ ElementAbsSurjection(classifier=classifier), ShiftBijection(shift=torch.tensor([[0.0, 4.0]])), ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 8]])) ] base = StandardUniform((D, )) else: A = D + augment_size # Number of augmented data dimensions P = 2 if affine else 1 # Number of elementwise parameters # initialize flow with either augmentation or Abs surjection if augment_size > 0: assert augment_size % 2 == 0 transforms = [ Augment(StandardNormal((augment_size, )), x_size=D) ] else: transforms = [SimpleAbsSurjection()] if range_flow == 'logit': transforms += [ ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])), Logit() ] elif range_flow == 'softplus': transforms += [SoftplusInverse()] # apply coupling layer flows for _ in range(num_flows): net = nn.Sequential( MLP(A // 2, P * A // 2, hidden_units=hidden_units, activation=activation), ElementwiseParams(P)) if affine: transforms.append( AffineCouplingBijection( net, scale_fn=scale_fn(scale_fn_str))) else: transforms.append(AdditiveCouplingBijection(net)) if actnorm: transforms.append(ActNormBijection(D)) transforms.append(Reverse(A)) transforms.pop() base = StandardNormal((A, )) super(UnconditionalFlow, self).__init__(base_dist=base, transforms=transforms)
assert args.augdim % 2 == 0 D = 2 # Number of data dimensions A = 2 + args.augdim # Number of augmented data dimensions P = 2 if args.affine else 1 # Number of elementwise parameters transforms = [Augment(StandardNormal((args.augdim, )), x_size=D)] for _ in range(args.num_flows): net = nn.Sequential( MLP(A // 2, P * A // 2, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) if args.affine: transforms.append( AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn))) else: transforms.append(AdditiveCouplingBijection(net)) if args.actnorm: transforms.append(ActNormBijection(D)) transforms.append(Reverse(A)) transforms.pop() model = Flow(base_dist=StandardNormal((A, )), transforms=transforms).to(args.device) ####################### ## Specify optimizer ## ####################### if args.optimizer == 'adam': optimizer = Adam(model.parameters(), lr=args.lr)