def test_shape(self): module = ElementwiseParams(3) y = module(self.x) expected_shape = (10, 2, 3) self.assertEqual(y.shape, expected_shape) module = ElementwiseParams(2) y = module(self.x) expected_shape = (10, 3, 2) self.assertEqual(y.shape, expected_shape)
def test_bijection_is_well_behaved(self): num_bins = 16 num_mix = 8 batch_size = 10 elementwise_params = 3 * num_mix self.eps = 1e-6 for shape in [(6, ), (6, 4, 4)]: for num_condition in [None, 1]: with self.subTest(shape=shape, num_condition=num_condition): x = torch.rand(batch_size, *shape) if num_condition is None: if len(shape) == 1: net = nn.Sequential( nn.Linear(3, 3 * elementwise_params), ElementwiseParams(elementwise_params)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(3, 3 * elementwise_params, kernel_size=3, padding=1), ElementwiseParams2d(elementwise_params)) else: if len(shape) == 1: net = nn.Sequential( nn.Linear(1, 5 * elementwise_params), ElementwiseParams(elementwise_params)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(1, 5 * elementwise_params, kernel_size=3, padding=1), ElementwiseParams2d(elementwise_params)) bijection = CensoredLogisticMixtureCouplingBijection( net, num_mixtures=num_mix, num_bins=num_bins, num_condition=num_condition) self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape)) z, _ = bijection.forward(x) if num_condition is None: self.assertEqual(x[:, :3], z[:, :3]) else: self.assertEqual(x[:, :1], z[:, :1])
def test_order(self): module = ElementwiseParams(2, mode='interleaved') y = module(self.x) self.assertEqual(y[:, 0], torch.stack([self.x[:, 0], self.x[:, 3]], dim=-1)) self.assertEqual(y[:, 1], torch.stack([self.x[:, 1], self.x[:, 4]], dim=-1)) self.assertEqual(y[:, 2], torch.stack([self.x[:, 2], self.x[:, 5]], dim=-1)) module = ElementwiseParams(2, mode='sequential') y = module(self.x) self.assertEqual(y[:, 0], self.x[:, 0:2]) self.assertEqual(y[:, 1], self.x[:, 2:4]) self.assertEqual(y[:, 2], self.x[:, 4:6])
def test_bijection_is_well_behaved(self): batch_size = 10 self.eps = 5e-6 for scale_str in ['exp', 'softplus', 'sigmoid', 'tanh_exp']: for shape in [(6, ), (6, 8, 8)]: for num_condition in [None, 1]: with self.subTest(shape=shape, num_condition=num_condition, scale_str=scale_str): x = torch.randn(batch_size, *shape) context = torch.randn(batch_size, *shape) if num_condition is None: if len(shape) == 1: net = nn.Sequential(nn.Linear(3 + 6, 3 * 2), ElementwiseParams(2)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(3 + 6, 3 * 2, kernel_size=3, padding=1), ElementwiseParams2d(2)) else: if len(shape) == 1: net = nn.Sequential(nn.Linear(1 + 6, 5 * 2), ElementwiseParams(2)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(1 + 6, 5 * 2, kernel_size=3, padding=1), ElementwiseParams2d(2)) bijection = ConditionalAffineCouplingBijection( net, num_condition=num_condition, scale_fn=scale_fn(scale_str)) self.assert_bijection_is_well_behaved( bijection, x, context, z_shape=(batch_size, *shape)) z, _ = bijection.forward(x, context) if num_condition is None: self.assertEqual(x[:, :3], z[:, :3]) else: self.assertEqual(x[:, :1], z[:, :1])
def test_bijection_is_well_behaved(self): num_bins = 16 batch_size = 10 num_params = 2 * num_bins + 1 self.eps = 5e-3 for shape in [(6, ), (6, 8, 8)]: for num_condition in [None, 1]: with self.subTest(shape=shape, num_condition=num_condition): x = torch.rand(batch_size, *shape) if num_condition is None: if len(shape) == 1: net = nn.Sequential(nn.Linear(3, 3 * num_params), ElementwiseParams(num_params)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(3, 3 * num_params, kernel_size=3, padding=1), ElementwiseParams2d(num_params)) else: if len(shape) == 1: net = nn.Sequential(nn.Linear(1, 5 * num_params), ElementwiseParams(num_params)) if len(shape) == 3: net = nn.Sequential( nn.Conv2d(1, 5 * num_params, kernel_size=3, padding=1), ElementwiseParams2d(num_params)) bijection = QuadraticSplineCouplingBijection( net, num_bins=num_bins, num_condition=num_condition) self.assert_bijection_is_well_behaved(bijection, x, z_shape=(batch_size, *shape)) z, _ = bijection.forward(x) if num_condition is None: self.assertEqual(x[:, :3], z[:, :3]) else: self.assertEqual(x[:, :1], z[:, :1])
def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units, activation, range_flow, augment_size, base_dist, cond_size): D = 2 # Number of data dimensions A = D + augment_size # Number of augmented data dimensions P = 2 if affine else 1 # Number of elementwise parameters # initialize context. Only upsample context in ContextInit if latent shape doesn't change during the flow. context_init = MLP(input_size=cond_size, output_size=D, hidden_units=hidden_units, activation=activation) # initialize flow with either augmentation or Abs surjection if augment_size > 0: assert augment_size % 2 == 0 transforms = [Augment(StandardNormal((augment_size, )), x_size=D)] else: transforms = [] transforms = [SimpleAbsSurjection()] if range_flow == 'logit': transforms += [ ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])), Logit() ] elif range_flow == 'softplus': transforms += [SoftplusInverse()] # apply coupling layer flows for _ in range(num_flows): net = nn.Sequential( MLP(A // 2 + D, P * A // 2, hidden_units=hidden_units, activation=activation), ElementwiseParams(P)) if affine: transforms.append( ConditionalAffineCouplingBijection( net, scale_fn=scale_fn(scale_fn_str))) else: transforms.append(ConditionalAdditiveCouplingBijection(net)) if actnorm: transforms.append(ActNormBijection(D)) transforms.append(Reverse(A)) transforms.pop() if base_dist == "uniform": base = StandardUniform((A, )) else: base = StandardNormal((A, )) super(SRFlow, self).__init__(base_dist=base, transforms=transforms, context_init=context_init)
def __init__(self, features, num_params, hidden_features, random_order=False, random_mask=False, random_seed=None, activation='relu', dropout_prob=0., batch_norm=False): layers = [] # Build layers data_degrees = MaskedLinear.get_data_degrees(features, random_order=random_order, random_seed=random_seed) in_degrees = copy.deepcopy(data_degrees) for i, out_features in enumerate(hidden_features): layers.append( MaskedLinear( in_degrees=in_degrees, out_features=out_features, data_features=features, random_mask=random_mask, random_seed=random_seed + i if random_seed else None, # Change random seed to get different masks is_output=False)) in_degrees = layers[-1].degrees if batch_norm: layers.append(nn.BatchNorm1d(out_features)) layers.append(act_module(activation)) if dropout_prob > 0.0: layers.append(nn.Dropout(dropout_prob)) # Build output layer layers.append( MaskedLinear(in_degrees=in_degrees, out_features=features * num_params, data_features=features, random_mask=random_mask, random_seed=random_seed, is_output=True, data_degrees=data_degrees)) layers.append(ElementwiseParams(num_params, mode='sequential')) super(MADE, self).__init__(*layers)
## Specify model ## ################### assert args.augdim % 2 == 0 D = 2 # Number of data dimensions A = 2 + args.augdim # Number of augmented data dimensions P = 2 if args.affine else 1 # Number of elementwise parameters transforms = [Augment(StandardNormal((args.augdim, )), x_size=D)] for _ in range(args.num_flows): net = nn.Sequential( MLP(A // 2, P * A // 2, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) if args.affine: transforms.append( AffineCouplingBijection(net, scale_fn=scale_fn(args.scale_fn))) else: transforms.append(AdditiveCouplingBijection(net)) if args.actnorm: transforms.append(ActNormBijection(D)) transforms.append(Reverse(A)) transforms.pop() model = Flow(base_dist=StandardNormal((A, )), transforms=transforms).to(args.device) ####################### ## Specify optimizer ## #######################
D = args.num_dims # Number of data dimensions P = 2 if args.affine else 1 # Number of elementwise parameters C = args.context_size # Size of context I = D // 2 O = D // 2 + D % 2 # Decoder if args.num_bits is not None: transforms = [Logit()] for _ in range(args.num_flows): net = nn.Sequential(MLP(C+I, P*O, hidden_units=args.hidden_units, activation=args.activation), ElementwiseParams(P)) context_net = nn.Sequential(LambdaLayer(lambda x: 2*x.float()/(2**args.num_bits-1) - 1), MLP(D, C, hidden_units=args.hidden_units, activation=args.activation)) if args.affine: transforms.append(ConditionalAffineCouplingBijection(coupling_net=net, context_net=context_net, scale_fn=scale_fn(args.scale_fn), num_condition=I)) else: transforms.append(ConditionalAdditiveCouplingBijection(coupling_net=net, context_net=context_net, num_condition=I)) if args.actnorm: transforms.append(ActNormBijection(D)) if args.permutation == 'reverse': transforms.append(Reverse(D)) elif args.permutation == 'shuffle': transforms.append(Shuffle(D)) transforms.pop() decoder = ConditionalFlow(base_dist=StandardNormal((D,)), transforms=transforms).to(args.device) # Flow transforms = [] for _ in range(args.num_flows):
def test_layer_is_well_behaved(self): module = ElementwiseParams(3) self.assert_layer_is_well_behaved(module, self.x)
def net(): return nn.Sequential(nn.Linear(1, 200), nn.ReLU(), nn.Linear(200, 100), nn.ReLU(), nn.Linear(100, 2), ElementwiseParams(2))
def __init__(self, num_flows, actnorm, affine, scale_fn_str, hidden_units, activation, range_flow, augment_size, base_dist): D = 2 # Number of data dimensions if base_dist == "uniform": classifier = MLP(D, D // 2, hidden_units=hidden_units, activation=activation, out_lambda=lambda x: x.view(-1)) transforms = [ ElementAbsSurjection(classifier=classifier), ShiftBijection(shift=torch.tensor([[0.0, 4.0]])), ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 8]])) ] base = StandardUniform((D, )) else: A = D + augment_size # Number of augmented data dimensions P = 2 if affine else 1 # Number of elementwise parameters # initialize flow with either augmentation or Abs surjection if augment_size > 0: assert augment_size % 2 == 0 transforms = [ Augment(StandardNormal((augment_size, )), x_size=D) ] else: transforms = [SimpleAbsSurjection()] if range_flow == 'logit': transforms += [ ScaleBijection(scale=torch.tensor([[1 / 4, 1 / 4]])), Logit() ] elif range_flow == 'softplus': transforms += [SoftplusInverse()] # apply coupling layer flows for _ in range(num_flows): net = nn.Sequential( MLP(A // 2, P * A // 2, hidden_units=hidden_units, activation=activation), ElementwiseParams(P)) if affine: transforms.append( AffineCouplingBijection( net, scale_fn=scale_fn(scale_fn_str))) else: transforms.append(AdditiveCouplingBijection(net)) if actnorm: transforms.append(ActNormBijection(D)) transforms.append(Reverse(A)) transforms.pop() base = StandardNormal((A, )) super(UnconditionalFlow, self).__init__(base_dist=base, transforms=transforms)