def __init__( self, n_conditionals, n_inputs, n_hiddens, n_mades, n_components=10, activation="relu", batch_norm=True, input_order="sequential", mode="sequential", alpha=0.1, ): super(ConditionalMixtureMaskedAutoregressiveFlow, self).__init__(n_conditionals, n_inputs) # save input arguments self.n_conditionals = n_conditionals self.n_inputs = n_inputs self.n_hiddens = n_hiddens self.n_mades = n_mades self.activation = activation self.batch_norm = batch_norm self.mode = mode self.alpha = alpha self.n_components = n_components # Dtype and GPU / CPU management self.to_args = None self.to_kwargs = None # Build MADEs self.mades = nn.ModuleList() for i in range(n_mades - 1): made = ConditionalGaussianMADE( n_conditionals, n_inputs, n_hiddens, activation=activation, input_order=input_order, mode=mode ) self.mades.append(made) if not (isinstance(input_order, str) and input_order != "random"): input_order = made.input_order[::-1] # Last MADE MoG self.made_mog = ConditionalMixtureMADE( n_conditionals, n_inputs, n_hiddens, n_components=n_components, activation=activation, input_order=input_order, mode=mode, ) # Batch normalizatino self.bns = None if self.batch_norm: self.bns = nn.ModuleList() for i in range(n_mades): bn = BatchNorm(n_inputs, alpha=self.alpha) self.bns.append(bn)
def __init__(self, n_inputs, n_hiddens, n_mades, activation='relu', batch_norm=True, input_order='sequential', mode='sequential', alpha=0.1): super(MaskedAutoregressiveFlow, self).__init__(n_inputs) # save input arguments self.n_inputs = n_inputs self.n_hiddens = n_hiddens self.n_mades = n_mades self.activation = activation self.batch_norm = batch_norm self.mode = mode self.alpha = alpha # Dtype and GPU / CPU management self.to_args = None self.to_kwargs = None # Build MADEs self.mades = nn.ModuleList() for i in range(n_mades): made = GaussianMADE(n_inputs, n_hiddens, activation=activation, input_order=input_order, mode=mode) self.mades.append(made) if not (isinstance(input_order, str) and input_order == 'random'): input_order = made.input_order[::-1] # Batch normalization self.bns = None if self.batch_norm: self.bns = nn.ModuleList() for i in range(n_mades): bn = BatchNorm(n_inputs, alpha=self.alpha) self.bns.append(bn)