示例#1
0
def create_transform(dim, num_flow_steps, num_bins):
    transform = transforms.CompositeTransform([
        transforms.CompositeTransform([
            create_linear_transform("lu", dim),
            create_base_transform(i, "rqsf_ag", dim, num_bins)
        ]) for i in range(num_flow_steps)
    ] + [create_linear_transform("lu", dim)])
    return transform
示例#2
0
def create_linear_transform(linear_transform_type, dim):
    if linear_transform_type == 'permutation':
        return transforms.RandomPermutation(features=dim)
    elif linear_transform_type == 'lu':
        return transforms.CompositeTransform([
            transforms.RandomPermutation(features=dim),
            transforms.LULinear(dim, identity_init=True)
        ])
    elif linear_transform_type == 'svd':
        return transforms.CompositeTransform([
            transforms.RandomPermutation(features=dim),
            transforms.SVDLinear(dim, num_householder=10, identity_init=True)
        ])
    else:
        raise ValueError
示例#3
0
    def __init__(self, args):
        super(iFlow, self).__init__()

        self.args = args
        self.bs = args['batch_size']
        assert args['latent_dim'] == args['data_dim']
        self.x_dim = self.z_dim = args['latent_dim']
        self.u_dim = args['aux_dim']
        self.k = 2  # number of orders of sufficient statistics

        flow_type = args['flow_type']
        print(args)

        if flow_type == "PlanarFlow":
            self.nf = PlanarFlow(dim=self.x_dim,
                                 flow_length=args['flow_length'])

        elif flow_type == "RQNSF_C":
            transform = transforms.CompositeTransform([
                create_base_transform(i, "rqsf_c", self.z_dim, 64)
                for i in range(args['flow_length'])
            ])
            self.nf = SplineFlow(transform)

        elif flow_type == "RQNSF_AG":
            transform = create_transform(self.z_dim, args['flow_length'],
                                         args['num_bins'])
            self.nf = SplineFlow(transform)

        else:
            raise ValueError

        self.feb = FreeEnergyBound(args=args)

        str2act = {
            "Sigmoid": nn.Sigmoid(),
            "ReLU": nn.ReLU(inplace=True),
            "Softmax": nn.Softmax(),
            "Softplus": nn.Softplus()
        }

        self.max_act_val = None
        act_str = args['nat_param_act']
        if act_str in str2act:
            nat_param_act = str2act[args['nat_param_act']]
        else:
            assert act_str.startswith("Sigmoidx")
            nat_param_act = nn.Sigmoid()
            self.max_act_val = float(act_str.split("x")[-1])

        if self.u_dim == 40:
            self._lambda = nn.Sequential(
                nn.Linear(self.u_dim, 30),
                nn.ReLU(inplace=True),
                nn.Linear(30, 20),
                nn.ReLU(inplace=True),
                nn.Linear(20, 2 * self.z_dim),
                nat_param_act,
            )  ## for self.u_dim == 40
        elif self.u_dim == 3:
            self._lambda = nn.Sequential(
                nn.Linear(self.u_dim, 6),
                nn.ReLU(inplace=True),
                nn.Linear(6, 5),
                nn.ReLU(inplace=True),
                nn.Linear(5, 2 * self.z_dim),
                nat_param_act,
            )  ## for self.u_dim == 60
        elif self.u_dim == 60:
            self._lambda = nn.Sequential(
                nn.Linear(self.u_dim, 45),
                nn.ReLU(inplace=True),
                nn.Linear(45, 25),
                nn.ReLU(inplace=True),
                nn.Linear(25, 2 * self.z_dim),
                nat_param_act,
            )  ## for self.u_dim == 60

        elif self.u_dim == 5:
            self._lambda = nn.Sequential(
                nn.Linear(self.u_dim, 4),
                nn.ReLU(inplace=True),
                nn.Linear(4, 4),
                nn.ReLU(inplace=True),
                nn.Linear(4, 2 * self.z_dim),
                nat_param_act,
            )  ## for visualisation where self.u_dim == 5
        # Network configuration for MNIST dataset
        elif self.u_dim == 10:
            self._lambda = nn.Sequential(
                nn.Linear(self.u_dim, 8),
                nn.ReLU(inplace=True),
                nn.Linear(8, 5),
                nn.ReLU(inplace=True),
                nn.Linear(5, 2 * self.z_dim),
                nat_param_act,
            )
        self.set_mask(self.bs)