def create_transform(dim, num_flow_steps, num_bins): transform = transforms.CompositeTransform([ transforms.CompositeTransform([ create_linear_transform("lu", dim), create_base_transform(i, "rqsf_ag", dim, num_bins) ]) for i in range(num_flow_steps) ] + [create_linear_transform("lu", dim)]) return transform
def create_linear_transform(linear_transform_type, dim): if linear_transform_type == 'permutation': return transforms.RandomPermutation(features=dim) elif linear_transform_type == 'lu': return transforms.CompositeTransform([ transforms.RandomPermutation(features=dim), transforms.LULinear(dim, identity_init=True) ]) elif linear_transform_type == 'svd': return transforms.CompositeTransform([ transforms.RandomPermutation(features=dim), transforms.SVDLinear(dim, num_householder=10, identity_init=True) ]) else: raise ValueError
def __init__(self, args): super(iFlow, self).__init__() self.args = args self.bs = args['batch_size'] assert args['latent_dim'] == args['data_dim'] self.x_dim = self.z_dim = args['latent_dim'] self.u_dim = args['aux_dim'] self.k = 2 # number of orders of sufficient statistics flow_type = args['flow_type'] print(args) if flow_type == "PlanarFlow": self.nf = PlanarFlow(dim=self.x_dim, flow_length=args['flow_length']) elif flow_type == "RQNSF_C": transform = transforms.CompositeTransform([ create_base_transform(i, "rqsf_c", self.z_dim, 64) for i in range(args['flow_length']) ]) self.nf = SplineFlow(transform) elif flow_type == "RQNSF_AG": transform = create_transform(self.z_dim, args['flow_length'], args['num_bins']) self.nf = SplineFlow(transform) else: raise ValueError self.feb = FreeEnergyBound(args=args) str2act = { "Sigmoid": nn.Sigmoid(), "ReLU": nn.ReLU(inplace=True), "Softmax": nn.Softmax(), "Softplus": nn.Softplus() } self.max_act_val = None act_str = args['nat_param_act'] if act_str in str2act: nat_param_act = str2act[args['nat_param_act']] else: assert act_str.startswith("Sigmoidx") nat_param_act = nn.Sigmoid() self.max_act_val = float(act_str.split("x")[-1]) if self.u_dim == 40: self._lambda = nn.Sequential( nn.Linear(self.u_dim, 30), nn.ReLU(inplace=True), nn.Linear(30, 20), nn.ReLU(inplace=True), nn.Linear(20, 2 * self.z_dim), nat_param_act, ) ## for self.u_dim == 40 elif self.u_dim == 3: self._lambda = nn.Sequential( nn.Linear(self.u_dim, 6), nn.ReLU(inplace=True), nn.Linear(6, 5), nn.ReLU(inplace=True), nn.Linear(5, 2 * self.z_dim), nat_param_act, ) ## for self.u_dim == 60 elif self.u_dim == 60: self._lambda = nn.Sequential( nn.Linear(self.u_dim, 45), nn.ReLU(inplace=True), nn.Linear(45, 25), nn.ReLU(inplace=True), nn.Linear(25, 2 * self.z_dim), nat_param_act, ) ## for self.u_dim == 60 elif self.u_dim == 5: self._lambda = nn.Sequential( nn.Linear(self.u_dim, 4), nn.ReLU(inplace=True), nn.Linear(4, 4), nn.ReLU(inplace=True), nn.Linear(4, 2 * self.z_dim), nat_param_act, ) ## for visualisation where self.u_dim == 5 # Network configuration for MNIST dataset elif self.u_dim == 10: self._lambda = nn.Sequential( nn.Linear(self.u_dim, 8), nn.ReLU(inplace=True), nn.Linear(8, 5), nn.ReLU(inplace=True), nn.Linear(5, 2 * self.z_dim), nat_param_act, ) self.set_mask(self.bs)