Example #1
0
    def _test_jacobian(self, input_dim, hidden_dim, multiplier):
        jacobian = torch.zeros(input_dim, input_dim)
        arn = AutoRegressiveNN(input_dim, hidden_dim, multiplier)

        def nonzero(x):
            return torch.sign(torch.abs(x))

        for output_index in range(multiplier):
            for j in range(input_dim):
                for k in range(input_dim):
                    x = Variable(torch.randn(1, input_dim))
                    epsilon_vector = torch.zeros(1, input_dim)
                    epsilon_vector[0, j] = self.epsilon
                    delta = (arn(x + Variable(epsilon_vector)) -
                             arn(x)) / self.epsilon
                    jacobian[j, k] = float(
                        delta[0, k +
                              output_index * input_dim].data.cpu().numpy()[0])

            permutation = arn.get_permutation()
            permuted_jacobian = jacobian.clone()
            for j in range(input_dim):
                for k in range(input_dim):
                    permuted_jacobian[j, k] = jacobian[permutation[j],
                                                       permutation[k]]

            lower_sum = torch.sum(
                torch.tril(nonzero(permuted_jacobian), diagonal=0))
            self.assertTrue(lower_sum == float(0.0))
Example #2
0
    def _test_jacobian(self, input_dim, hidden_dim, param_dim):
        jacobian = torch.zeros(input_dim, input_dim)
        arn = AutoRegressiveNN(input_dim, [hidden_dim], param_dims=[param_dim])

        def nonzero(x):
            return torch.sign(torch.abs(x))

        for output_index in range(param_dim):
            for j in range(input_dim):
                for k in range(input_dim):
                    x = torch.randn(1, input_dim)
                    epsilon_vector = torch.zeros(1, input_dim)
                    epsilon_vector[0, j] = self.epsilon
                    delta = (arn(x + 0.5 * epsilon_vector) - arn(x - 0.5 * epsilon_vector)) / self.epsilon
                    jacobian[j, k] = float(delta[0, output_index, k])

            permutation = arn.get_permutation()
            permuted_jacobian = jacobian.clone()
            for j in range(input_dim):
                for k in range(input_dim):
                    permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]]

            lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=0))

            assert lower_sum == float(0.0)
Example #3
0
 def get_transform_w(self):
     prod = np.prod(self.shape)
     transform_w = nn.ModuleList([
         AffineAutoregressive(AutoRegressiveNN(prod, [prod]), stable=True)
         for _ in range(self.t)
     ])
     return transform_w
Example #4
0
def spline_autoregressive(input_dim, hidden_dims=None, count_bins=8, bound=3.0):
    """
    A helper function to create an
    :class:`~pyro.distributions.transforms.SplineAutoregressive` object that takes
    care of constructing an autoregressive network with the correct input/output
    dimensions.

    :param input_dim: Dimension of input variable
    :type input_dim: int
    :param hidden_dims: The desired hidden dimensions of the autoregressive network.
        Defaults to using [3*input_dim + 1]
    :type hidden_dims: list[int]
    :param count_bins: The number of segments comprising the spline.
    :type count_bins: int
    :param bound: The quantity :math:`K` determining the bounding box,
        :math:`[-K,K]\times[-K,K]`, of the spline.
    :type bound: float

    """

    if hidden_dims is None:
        hidden_dims = [input_dim * 10, input_dim * 10]

    param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
    arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=param_dims)
    return SplineAutoregressive(input_dim, arn, count_bins=count_bins, bound=bound, order='linear')
Example #5
0
 def get_transform_b(self):
     transform_b = nn.ModuleList([
         AffineAutoregressive(AutoRegressiveNN(self.shape[0],
                                               [self.shape[0]]),
                              stable=True) for _ in range(self.t)
     ])
     return transform_b
Example #6
0
 def __init__(self,
              in_dim,
              h_dim,
              rand_perm=True,
              activation="ReLU",
              num_stable=False):
     super(IATransform, self).__init__()
     if rand_perm:
         self.permutation = torch.randperm(in_dim, dtype=torch.int64)
         nonlinearity = get_activation(activation)
     else:
         self.permutation = torch.arange(in_dim, dtype=torch.int64)
     # if len(h_dim)==1:
     #     self.AR = nn.ModuleList([
     #         nn.Linear(in_features=in_dim, out_features=h_dim),
     #         nonlinearity(),
     #         nn.Linear(in_features=h_dim, out_features=in_dim)]
     #         )
     # else:
     #     self.AR = nn.ModuleList([
     #         nn.Sequential(
     #             nn.Linear(in_features=in_dim, out_features=out_dim),
     #             nonlinearity()) for in_dim, out_dim in zip([in_dim, h_dim[:-1]], [h_dim, in_dim])])# AutoRegressiveNN(input_dim=in_dim, hidden_dims=h_dim, nonlinearity=nonlinearity)  # nn.ELU(0.6))
     # TODO: make forward
     self.AR = AutoRegressiveNN(
         input_dim=in_dim,
         hidden_dims=h_dim,
         nonlinearity=nonlinearity,
         permutation=self.permutation)  # nn.ELU(0.6))
     self.num_stable = num_stable
    def __init__(self,
                 input_dim=88,
                 z_dim=100,
                 emission_dim=100,
                 transition_dim=200,
                 rnn_dim=600,
                 num_layers=1,
                 rnn_dropout_rate=0.0,
                 num_iafs=0,
                 iaf_dim=50,
                 use_cuda=False):
        super(DMM, self).__init__()
        # instantiate PyTorch modules used in the model and guide below
        # if we're using normalizing flows, instantiate those too
        self.iafs = [
            InverseAutoregressiveFlow(AutoRegressiveNN(z_dim, [iaf_dim]))
            for _ in range(num_iafs)
        ]
        self.iafs_modules = nn.ModuleList(self.iafs)

        # define a (trainable) parameters z_0 and z_q_0 that help define the probability
        # distributions p(z_1) and q(z_1)
        # (since for t = 1 there are no previous latents to condition on)
        # define a (trainable) parameter for the initial hidden state of the rnn
        self.h_0 = nn.Parameter(torch.zeros(1, 1, rnn_dim))

        self.use_cuda = use_cuda
        # if on gpu cuda-ize all PyTorch (sub)modules
        if use_cuda:
            self.cuda()
Example #8
0
def affine_autoregressive(input_dim, hidden_dims=None, **kwargs):
    """
    A helper function to create an
    :class:`~pyro.distributions.transforms.AffineAutoregressive` object that takes
    care of constructing an autoregressive network with the correct input/output
    dimensions.

    :param input_dim: Dimension of input variable
    :type input_dim: int
    :param hidden_dims: The desired hidden dimensions of the autoregressive network.
        Defaults to using [3*input_dim + 1]
    :type hidden_dims: list[int]
    :param log_scale_min_clip: The minimum value for clipping the log(scale) from
        the autoregressive NN
    :type log_scale_min_clip: float
    :param log_scale_max_clip: The maximum value for clipping the log(scale) from
        the autoregressive NN
    :type log_scale_max_clip: float
    :param sigmoid_bias: A term to add the logit of the input when using the stable
        tranform.
    :type sigmoid_bias: float
    :param stable: When true, uses the alternative "stable" version of the transform
        (see above).
    :type stable: bool

    """

    if hidden_dims is None:
        hidden_dims = [3 * input_dim + 1]
    arn = AutoRegressiveNN(input_dim, hidden_dims)
    return AffineAutoregressive(arn, **kwargs)
Example #9
0
    def __init__(self, input_dim=88, z_dim=100, emission_dim=100,
                 transition_dim=200, rnn_dim=600, num_layers=1, rnn_dropout_rate=0.0,
                 num_iafs=0, iaf_dim=50, use_cuda=False):
        super(DMM, self).__init__()
        # instantiate PyTorch modules used in the model and guide below
        self.emitter = Emitter(input_dim, z_dim, emission_dim)
        self.trans = GatedTransition(z_dim, transition_dim)
        self.combiner = Combiner(z_dim, rnn_dim)
        # dropout just takes effect on inner layers of rnn
        rnn_dropout_rate = 0. if num_layers == 1 else rnn_dropout_rate
        self.rnn = nn.RNN(input_size=input_dim, hidden_size=rnn_dim, nonlinearity='relu',
                          batch_first=True, bidirectional=False, num_layers=num_layers,
                          dropout=rnn_dropout_rate)

        # if we're using normalizing flows, instantiate those too
        self.iafs = [InverseAutoregressiveFlow(AutoRegressiveNN(z_dim, [iaf_dim])) for _ in range(num_iafs)]
        self.iafs_modules = nn.ModuleList(self.iafs)

        # define a (trainable) parameters z_0 and z_q_0 that help define the probability
        # distributions p(z_1) and q(z_1)
        # (since for t = 1 there are no previous latents to condition on)
        self.z_0 = nn.Parameter(torch.zeros(z_dim))
        self.z_q_0 = nn.Parameter(torch.zeros(z_dim))
        # define a (trainable) parameter for the initial hidden state of the rnn
        self.h_0 = nn.Parameter(torch.zeros(1, 1, rnn_dim))

        self.use_cuda = use_cuda
        # if on gpu cuda-ize all PyTorch (sub)modules
        if use_cuda:
            self.cuda()
Example #10
0
def neural_autoregressive(input_dim,
                          hidden_dims=None,
                          activation='sigmoid',
                          width=16):
    """
    A helper function to create a
    :class:`~pyro.distributions.transforms.NeuralAutoregressive` object that takes
    care of constructing an autoregressive network with the correct input/output
    dimensions.

    :param input_dim: Dimension of input variable
    :type input_dim: int
    :param hidden_dims: The desired hidden dimensions of the autoregressive network.
        Defaults to using [3*input_dim + 1]
    :type hidden_dims: list[int]
    :param activation: Activation function to use. One of 'ELU', 'LeakyReLU',
        'sigmoid', or 'tanh'.
    :type activation: string
    :param width: The width of the "multilayer perceptron" in the transform (see
        paper). Defaults to 16
    :type width: int

    """

    if hidden_dims is None:
        hidden_dims = [3 * input_dim + 1]
    arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=[width] * 3)
    return NeuralAutoregressive(arn, hidden_units=width, activation=activation)
Example #11
0
 def _make_poly(self, input_dim):
     count_degree = 4
     count_sum = 3
     arn = AutoRegressiveNN(input_dim, [input_dim * 10],
                            param_dims=[(count_degree + 1) * count_sum])
     return transforms.PolynomialFlow(arn,
                                      input_dim=input_dim,
                                      count_degree=count_degree,
                                      count_sum=count_sum)
Example #12
0
def init_affine_autoregressive(dim: int, device: str = "cpu", **kwargs):
    """Provides the default initial arguments for an affine autoregressive transform."""
    hidden_dims = kwargs.pop("hidden_dims", [3 * dim + 5, 3 * dim + 5])
    skip_connections = kwargs.pop("skip_connections", False)
    nonlinearity = kwargs.pop("nonlinearity", nn.ReLU())
    arn = AutoRegressiveNN(dim,
                           hidden_dims,
                           nonlinearity=nonlinearity,
                           skip_connections=skip_connections).to(device)
    return [arn], {"log_scale_min_clip": -3.0}
Example #13
0
 def get_posterior(self, *args, **kwargs):
     """
     Returns a diagonal Normal posterior distribution transformed by
     :class:`~pyro.distributions.iaf.InverseAutoregressiveFlow`.
     """
     if self.latent_dim == 1:
         raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead')
     if self.hidden_dim is None:
         self.hidden_dim = self.latent_dim
     iaf = dist.InverseAutoregressiveFlow(AutoRegressiveNN(self.latent_dim, [self.hidden_dim]))
     pyro.module("{}_iaf".format(self.prefix), iaf)
     iaf_dist = dist.TransformedDistribution(dist.Normal(0., 1.).expand([self.latent_dim]), [iaf])
     return iaf_dist.to_event(1)
Example #14
0
    def _test_jacobian(self, input_dim, hidden_dim, multiplier):
        jacobian = torch.zeros(input_dim, input_dim)
        arn = AutoRegressiveNN(input_dim, hidden_dim, multiplier)

        def nonzero(x):
            return torch.sign(torch.abs(x))

        for output_index in range(multiplier):
            for j in range(input_dim):
                for k in range(input_dim):
                    x = torch.randn(1, input_dim)
                    epsilon_vector = torch.zeros(1, input_dim)
                    epsilon_vector[0, j] = self.epsilon
                    delta = (arn(x + epsilon_vector) - arn(x)) / self.epsilon
                    jacobian[j, k] = float(delta[0, k + output_index * input_dim])

            permutation = arn.get_permutation()
            permuted_jacobian = jacobian.clone()
            for j in range(input_dim):
                for k in range(input_dim):
                    permuted_jacobian[j, k] = jacobian[permutation[j], permutation[k]]

            lower_sum = torch.sum(torch.tril(nonzero(permuted_jacobian), diagonal=0))
            assert lower_sum == float(0.0)
Example #15
0
    def get_posterior(self, *args, **kwargs):
        """
        Returns a diagonal Normal posterior distribution transformed by
        :class:`~pyro.distributions.transforms.iaf.InverseAutoregressiveFlow`.
        """
        if self.latent_dim == 1:
            raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead')
        if self.hidden_dim is None:
            self.hidden_dim = self.latent_dim
        if self.arn is None:
            self.arn = AutoRegressiveNN(self.latent_dim, [self.hidden_dim])

        iaf = transforms.AffineAutoregressive(self.arn)
        iaf_dist = dist.TransformedDistribution(dist.Normal(0., 1.).expand([self.latent_dim]), [iaf])
        return iaf_dist
 def __init__(self,
              input_dim,
              hidden_dim,
              sigmoid_bias=2.0,
              permutation=None):
     super(InverseAutoregressiveFlow, self).__init__()
     self.input_dim = input_dim
     self.hidden_dim = hidden_dim
     self.arn = AutoRegressiveNN(input_dim,
                                 hidden_dim,
                                 output_dim_multiplier=2,
                                 permutation=permutation)
     self.sigmoid = nn.Sigmoid()
     self.sigmoid_bias = Variable(torch.Tensor([sigmoid_bias]))
     self._intermediates_cache = {}
     self.add_inverse_to_cache = True
Example #17
0
    def model(self, *args, **kargs):

        self.inv_softplus_sigma = pyro.param("inv_softplus_sigma",
                                             torch.ones(self.rank))
        sigma = self.sigma  #torch.nn.functional.softplus(self.inv_softplus_sigma)

        #base_dist = dist.Normal(torch.zeros(self.rank), torch.ones(self.rank))
        # Pavel: introducing `sigma` in the IAF distribution makes training more
        # stable in tems of the scale of the distribution we are trying to learn
        base_dist = dist.Normal(torch.zeros(self.rank), sigma)
        ann = AutoRegressiveNN(self.rank, self.n_hid, skip_connections=True)
        iaf = dist.InverseAutoregressiveFlow(ann)
        iaf_module = pyro.module("my_iaf", iaf)
        iaf_dist = dist.TransformedDistribution(base_dist, [iaf])
        self.t = pyro.sample("t", iaf_dist.to_event(1))
        return self.t
Example #18
0
 def __init__(self,
              in_dim,
              h_dim,
              rand_perm=True,
              activation="ReLU",
              num_stable=False):
     super(IATransform, self).__init__()
     if rand_perm:
         self.permutation = torch.randperm(in_dim, dtype=torch.int64)
         nonlinearity = get_activation(activation)
     else:
         self.permutation = torch.arange(in_dim, dtype=torch.int64)
     self.AR = AutoRegressiveNN(input_dim=in_dim,
                                hidden_dims=h_dim,
                                nonlinearity=nonlinearity)  # nn.ELU(0.6))
     self.num_stable = num_stable
Example #19
0
def polynomial(input_dim, hidden_dims=None):
    """
    A helper function to create a :class:`~pyro.distributions.transforms.Polynomial`
    object that takes care of constructing an autoregressive network with the
    correct input/output dimensions.

    :param input_dim: Dimension of input variable
    :type input_dim: int
    :param hidden_dims: The desired hidden dimensions of of the autoregressive
        network. Defaults to using [input_dim * 10]

    """

    count_degree = 4
    count_sum = 3
    if hidden_dims is None:
        hidden_dims = [input_dim * 10]
    arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=[(count_degree + 1) * count_sum])
    return Polynomial(arn, input_dim=input_dim, count_degree=count_degree, count_sum=count_sum)
Example #20
0
 def __init__(self,
              input_dim,
              hidden_dim,
              permutation=None):  #sigmoid_bias=2
     super(InverseAutoregressiveFlow, self).__init__()
     if permutation == None:
         permutation = torch.ones(input_dim).double()
     self.input_dim = input_dim
     self.hidden_dim = hidden_dim
     self.module = nn.Module()
     self.module.arn = AutoRegressiveNN(input_dim,
                                        hidden_dim,
                                        permutation=permutation,
                                        param_dims=[1, 1],
                                        nonlinearity=torch.nn.ELU())
     self.module.activation = nn.Sigmoid()
     self.module.sigmoid_bias = torch.tensor(2).double()
     self._intermediates_cache = {}
     self.add_inverse_to_cache = True
Example #21
0
    def __init__(self,
                 flow_type,
                 num_flows,
                 hidden_dim=20,
                 need_permute=False):
        super(NormFlow, self).__init__()
        self.need_permute = need_permute
        if flow_type == 'IAF':
            self.flow = nn.ModuleList([
                AffineAutoregressive(AutoRegressiveNN(hidden_dim,
                                                      [2 * hidden_dim]),
                                     stable=True) for _ in range(num_flows)
            ])
        elif flow_type == 'BNAF':
            self.flow = nn.ModuleList([
                BlockAutoregressive(input_dim=hidden_dim)
                for _ in range(num_flows)
            ])
        elif flow_type == 'RNVP':
            split_dim = hidden_dim // 2
            param_dims = [hidden_dim - split_dim, hidden_dim - split_dim]
            hypernet = DenseNN(split_dim, [2 * hidden_dim], param_dims)
            self.flow = nn.ModuleList([
                AffineCoupling(split_dim, hypernet) for _ in range(num_flows)
            ])
        else:
            raise NotImplementedError

        even = [i for i in range(0, hidden_dim, 2)]
        odd = [i for i in range(1, hidden_dim, 2)]
        undo_eo = [
            i // 2 if i % 2 == 0 else (i // 2 + len(even))
            for i in range(hidden_dim)
        ]
        undo_oe = [(i // 2 + len(odd)) if i % 2 == 0 else i // 2
                   for i in range(hidden_dim)]
        self.register_buffer('eo', torch.tensor(even + odd, dtype=torch.int64))
        self.register_buffer('oe', torch.tensor(odd + even, dtype=torch.int64))
        self.register_buffer('undo_eo', torch.tensor(undo_eo,
                                                     dtype=torch.int64))
        self.register_buffer('undo_oe', torch.tensor(undo_oe,
                                                     dtype=torch.int64))
Example #22
0
File: sos.py Project: talesa/lgf
    def __init__(
        self,
        num_input_channels,
        hidden_channels,
        activation,
        num_polynomials,
        polynomial_degree,
    ):
        super().__init__(x_shape=(num_input_channels, ),
                         z_shape=(num_input_channels, ))

        arn = AutoRegressiveNN(input_dim=int(num_input_channels),
                               hidden_dims=hidden_channels,
                               param_dims=[
                                   (polynomial_degree + 1) * num_polynomials
                               ],
                               nonlinearity=activation())

        self.flow = PolynomialFlow(autoregressive_nn=arn,
                                   input_dim=int(num_input_channels),
                                   count_degree=polynomial_degree,
                                   count_sum=num_polynomials)
Example #23
0
def init_spline_autoregressive(dim: int, device: str = "cpu", **kwargs):
    """Provides the default initial arguments for an spline autoregressive transform."""
    hidden_dims = kwargs.pop("hidden_dims", [3 * dim + 5, 3 * dim + 5])
    skip_connections = kwargs.pop("skip_connections", False)
    nonlinearity = kwargs.pop("nonlinearity", nn.ReLU())
    count_bins = kwargs.get("count_bins", 10)
    order = kwargs.get("order", "linear")
    bound = kwargs.get("bound", 10)
    if order == "linear":
        param_dims = [count_bins, count_bins, (count_bins - 1), count_bins]
    else:
        param_dims = [count_bins, count_bins, (count_bins - 1)]
    neural_net = AutoRegressiveNN(
        dim,
        hidden_dims,
        param_dims=param_dims,
        skip_connections=skip_connections,
        nonlinearity=nonlinearity,
    ).to(device)
    return [dim, neural_net], {
        "count_bins": count_bins,
        "bound": bound,
        "order": order
    }
def affine_autoregressive(input_dim, hidden_dims=None, nonlinearity=nn.LeakyReLU(0.1), **kwargs):
    if hidden_dims is None:
        hidden_dims = [3 * input_dim + 1]
    arn = AutoRegressiveNN(input_dim, hidden_dims, nonlinearity=nonlinearity)
    return AffineAutoregressive(arn, **kwargs)
 def loadIAFs(self, num_iafs, z_dim=100, iaf_dim=320):
     flow_fn = lambda: InverseAutoregressiveFlow(
         AutoRegressiveNN(z_dim, [iaf_dim, iaf_dim]))
     self.loadFlows(num_iafs, flow_fn)
Example #26
0
 def __init__(self, in_dim, h_dim, rand_perm=True):
     super(IATransform, self).__init__()
     # YOUR CODE HERE
     self.AR = AutoRegressiveNN(input_dim = in_dim, 
                                hidden_dims = [h_dim])
     self.rand_perm = rand_perm
Example #27
0
 def _make_iaf(self, input_dim):
     arn = AutoRegressiveNN(input_dim, [3 * input_dim + 1])
     return transforms.InverseAutoregressiveFlow(arn)
Example #28
0
 def _make_iaf_stable(self, input_dim):
     arn = AutoRegressiveNN(input_dim, [3 * input_dim + 1])
     return transforms.InverseAutoregressiveFlowStable(arn,
                                                       sigmoid_bias=0.5)
Example #29
0
 def _make_neural_autoregressive(self, input_dim, activation):
     arn = AutoRegressiveNN(input_dim, [3 * input_dim + 1],
                            param_dims=[16] * 3)
     return transforms.NeuralAutoregressive(arn,
                                            hidden_units=16,
                                            activation=activation)
Example #30
0
def train_vae(args):
    # pdb.set_trace()
    best_metric = -float("inf")

    prior_params = list([])
    varflow_params = list([])
    prior_flow = None
    variational_flow = None

    data = Dataset(args)
    if args.data in ['goodreads', 'big_dataset']:
        args.feature_shape = data.feature_shape

    if args.nf_prior:
        flows = []
        for i in range(args.num_flows_prior):
            if args.nf_prior == 'IAF':
                one_arn = AutoRegressiveNN(args.z_dim,
                                           [2 * args.z_dim]).to(args.device)
                one_flow = AffineAutoregressive(one_arn)
            elif args.nf_prior == 'RNVP':
                hypernet = DenseNN(
                    input_dim=args.z_dim // 2,
                    hidden_dims=[2 * args.z_dim, 2 * args.z_dim],
                    param_dims=[
                        args.z_dim - args.z_dim // 2,
                        args.z_dim - args.z_dim // 2
                    ]).to(args.device)
                one_flow = AffineCoupling(args.z_dim // 2,
                                          hypernet).to(args.device)
            flows.append(one_flow)
        prior_flow = nn.ModuleList(flows)
        prior_params = list(prior_flow.parameters())

    if args.data == 'mnist':
        encoder = Encoder(args).to(args.device)
    elif args.data in ['goodreads', 'big_dataset']:
        encoder = Encoder_rec(args).to(args.device)

    if args.nf_vardistr:
        flows = []
        for i in range(args.num_flows_vardistr):
            one_arn = AutoRegressiveNN(args.z_dim, [2 * args.z_dim],
                                       param_dims=[2 * args.z_dim] * 3).to(
                                           args.device)
            one_flows = NeuralAutoregressive(one_arn, hidden_units=256)
            flows.append(one_flows)
        variational_flow = nn.ModuleList(flows)
        varflow_params = list(variational_flow.parameters())

    if args.data == 'mnist':
        decoder = Decoder(args).to(args.device)
    elif args.data in ['goodreads', 'big_dataset']:
        decoder = Decoder_rec(args).to(args.device)

    params = list(encoder.parameters()) + list(
        decoder.parameters()) + prior_params + varflow_params
    optimizer = torch.optim.Adam(params=params)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=100,
                                                gamma=0.1)

    current_tolerance = 0
    # with torch.autograd.detect_anomaly():
    for ep in tqdm(range(args.num_epoches)):
        # training cycle
        for batch_num, batch_train in enumerate(data.next_train_batch()):
            batch_train_repeated = batch_train.repeat(
                *[[args.n_samples] + [1] * (len(batch_train.shape) - 1)])
            mu, sigma = encoder(batch_train_repeated)
            sum_log_sigma = torch.sum(torch.log(sigma), 1)
            sum_log_jacobian = 0.
            eps = args.std_normal.sample(mu.shape)
            z = mu + sigma * eps
            if not args.use_reparam:
                z = z.detach()
            if variational_flow:
                prev_v = z
                for flow_num in range(args.num_flows_vardistr):
                    u = variational_flow[flow_num](prev_v)
                    sum_log_jacobian += variational_flow[
                        flow_num].log_abs_det_jacobian(prev_v, u)
                    prev_v = u
                z = u
            logits = decoder(z)
            elbo = compute_objective(args=args,
                                     x_logits=logits,
                                     x_true=batch_train_repeated,
                                     sampled_noise=eps,
                                     inf_samples=z,
                                     sum_log_sigma=sum_log_sigma,
                                     prior_flow=prior_flow,
                                     sum_log_jacobian=sum_log_jacobian,
                                     mu=mu,
                                     sigma=sigma)
            (-elbo).backward()
            optimizer.step()
            optimizer.zero_grad()
        # scheduler step
        scheduler.step()

        # validation
        with torch.no_grad():
            metric = validate_vae(args=args,
                                  encoder=encoder,
                                  decoder=decoder,
                                  dataset=data,
                                  prior_flow=prior_flow,
                                  variational_flow=variational_flow)
            if (metric != metric).sum():
                print('NAN appeared!')
                raise ValueError
            if metric > best_metric:
                current_tolerance = 0
                best_metric = metric
                if not os.path.exists('./models/{}/'.format(args.data)):
                    os.makedirs('./models/{}/'.format(args.data))
                torch.save(
                    encoder,
                    './models/{}/best_encoder_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
                    .format(args.data, args.data, args.use_skips,
                            args.nf_prior, args.num_flows_prior,
                            args.nf_vardistr, args.num_flows_vardistr,
                            args.n_samples, args.z_dim, args.use_reparam))
                torch.save(
                    decoder,
                    './models/{}/best_decoder_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
                    .format(args.data, args.data, args.use_skips,
                            args.nf_prior, args.num_flows_prior,
                            args.nf_vardistr, args.num_flows_vardistr,
                            args.n_samples, args.z_dim, args.use_reparam))
                if args.nf_prior:
                    torch.save(
                        prior_flow,
                        './models/{}/best_prior_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
                        .format(args.data, args.data, args.use_skips,
                                args.nf_prior, args.num_flows_prior,
                                args.nf_vardistr, args.num_flows_vardistr,
                                args.n_samples, args.z_dim, args.use_reparam))
                if args.nf_vardistr:
                    torch.save(
                        variational_flow,
                        './models/{}/best_varflow_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
                        .format(args.data, args.data, args.use_skips,
                                args.nf_prior, args.num_flows_prior,
                                args.nf_vardistr, args.num_flows_vardistr,
                                args.n_samples, args.z_dim, args.use_reparam))
            else:
                current_tolerance += 1
                if current_tolerance >= args.early_stopping_tolerance:
                    print(
                        "Early stopping on epoch {} (effectively trained for {} epoches)"
                        .format(ep, ep - args.early_stopping_tolerance))
                    break
            print(
                'Current epoch: {}'.format(ep), '\t',
                'Current validation {}: {}'.format(args.metric_name, metric),
                '\t', 'Best validation {}: {}'.format(args.metric_name,
                                                      best_metric))

    # return best models:
    encoder = torch.load(
        './models/{}/best_encoder_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
        .format(args.data, args.data, args.use_skips, args.nf_prior,
                args.num_flows_prior, args.nf_vardistr,
                args.num_flows_vardistr, args.n_samples, args.z_dim,
                args.use_reparam))
    decoder = torch.load(
        './models/{}/best_decoder_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
        .format(args.data, args.data, args.use_skips, args.nf_prior,
                args.num_flows_prior, args.nf_vardistr,
                args.num_flows_vardistr, args.n_samples, args.z_dim,
                args.use_reparam))
    if args.nf_prior:
        prior_flow = torch.load(
            './models/{}/best_prior_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
            .format(args.data, args.data, args.use_skips, args.nf_prior,
                    args.num_flows_prior, args.nf_vardistr,
                    args.num_flows_vardistr, args.n_samples, args.z_dim,
                    args.use_reparam))
    if args.nf_vardistr:
        variational_flow = torch.load(
            './models/{}/best_varflow_data_{}_skips_{}_prior_{}_numflows_{}_varflow_{}_numvarflows_{}_samples_{}_zdim_{}_usereparam_{}.pt'
            .format(args.data, args.data, args.use_skips, args.nf_prior,
                    args.num_flows_prior, args.nf_vardistr,
                    args.num_flows_vardistr, args.n_samples, args.z_dim,
                    args.use_reparam))
    return encoder, decoder, prior_flow, variational_flow, data
def spline_autoregressive(input_dim, hidden_dims=None, count_bins=8, bound=3.0, order='linear', nonlinearity=nn.LeakyReLU(0.1)):
    if hidden_dims is None:
        hidden_dims = [3 * input_dim + 1]
    param_dims = [count_bins, count_bins, count_bins - 1, count_bins]
    arn = AutoRegressiveNN(input_dim, hidden_dims, param_dims=param_dims, nonlinearity=nonlinearity)
    return SplineAutoregressive(input_dim, arn, count_bins=count_bins, bound=bound, order=order)