def __init__(self, data_dim=28 * 28, hidden_dim=64, guide_hidden_dim=256): self._data_dim = data_dim # Build up a bunch of torch.Sizes for the powers of two between # hidden_dim and data_dim. dims = list(util.powers_of(2, hidden_dim, data_dim // 4)) + [data_dim] dims.sort() gaussian_likelihood = lambda dim: DiagonalGaussian( dim, latent_name='X^{%d}' % dim, likelihood=True) generators = [] for lower, higher in zip(dims, dims[1:]): # Construct the VLAE decoder and encoder if higher == self._data_dim: decoder = LadderDecoder(lower, higher, noise_dim=2, conv=True, out_dist=gaussian_likelihood) else: decoder = LadderDecoder(lower, higher, noise_dim=2, conv=False, out_dist=None) data = {'effect': decoder.effect} generator = cart_closed.Box(decoder.name, decoder.type.left, decoder.type.right, decoder, data=data) generators.append(generator) # For each dimensionality, construct a prior/posterior ladder pair for dim in set(dims) - {data_dim}: space = types.tensor_type(torch.float, dim) prior = LadderPrior(dim, None) generator = cart_closed.Box(prior.name, Ty(), space, prior, data={'effect': prior.effect}) generators.append(generator) super().__init__(generators, [], data_dim, guide_hidden_dim, list(set(dims) - {data_dim}))
def __init__(self, data_dim=28 * 28, hidden_dim=8, guide_hidden_dim=256): self._data_dim = data_dim # Build up a bunch of torch.Sizes for the powers of two between # hidden_dim and data_dim. dims = list(util.powers_of(2, hidden_dim, data_dim // 4)) + [data_dim] dims.sort() generators = [] for dim_a, dim_b in itertools.combinations(dims, 2): lower, higher = sorted([dim_a, dim_b]) # Construct the decoder and encoder if higher == self._data_dim: decoder = DensityDecoder(lower, higher, ContinuousBernoulliModel, convolve=True) else: decoder = DensityDecoder(lower, higher, DiagonalGaussian) data = {'effect': decoder.effect} generator = cart_closed.Box(decoder.name, decoder.type.left, decoder.type.right, decoder, data=data) generators.append(generator) super().__init__(generators, [], data_dim, guide_hidden_dim)
def _encoder(self, name): encoder = self.encoders[name] return cart_closed.Box(name, encoder.type.left, encoder.type.right, encoder, data={'effect': encoder.effect})
def __init__(self, max_len=120, guide_hidden_dim=256, charset_len=34): hidden_dims = [196, 292, 435] recurrent_dims = [64, 128, 256] generators = [] dagger_generators = [] for hidden in hidden_dims: for recurrent in recurrent_dims: encoder = ConvMolecularEncoder(hidden, charset_len, max_len) decoder = MolecularDecoder(hidden, recurrent_dim=recurrent, charset_len=charset_len, max_len=max_len) data = { 'effect': decoder.effect, 'dagger_effect': encoder.effect } conv_generator = cart_closed.Box(decoder.name, decoder.type.left, decoder.type.right, decoder, data=data) generators.append(conv_generator) data = { 'dagger_effect': decoder.effect, 'effect': encoder.effect } conv_dagger = cart_closed.Box(encoder.name, encoder.type.left, encoder.type.right, encoder, data=data) dagger_generators.append(conv_dagger) encoder = RecurrentMolecularEncoder(hidden, recurrent, charset_len, max_len) decoder = MolecularDecoder(hidden, recurrent_dim=recurrent, charset_len=charset_len, max_len=max_len) data = { 'effect': decoder.effect, 'dagger_effect': encoder.effect } rec_generator = cart_closed.Box(decoder.name, decoder.type.left, decoder.type.right, decoder, data=data) generators.append(rec_generator) data = { 'dagger_effect': decoder.effect, 'effect': encoder.effect } rec_dagger = cart_closed.Box(encoder.name, encoder.type.left, encoder.type.right, encoder, data=data) dagger_generators.append(rec_dagger) super().__init__(generators, [], data_space=(max_len, charset_len), guide_hidden_dim=guide_hidden_dim, no_prior_dims=[max_len, charset_len], dagger_generators=dagger_generators)
def __init__(self, generators, global_elements=[], data_space=(784, ), guide_hidden_dim=256, no_prior_dims=[]): super().__init__() if isinstance(data_space, int): data_space = (data_space, ) self._data_space = data_space self._data_dim = math.prod(data_space) if len(self._data_space) == 1: self._observation_name = '$X^{%d}$' % self._data_dim else: self._observation_name = '$X^{%s}$' % str(self._data_space) obs = set() for generator in generators: ty = generator.dom >> generator.cod obs = obs | unification.base_elements(ty) for element in global_elements: ty = element.dom >> element.cod obs = obs - unification.base_elements(ty) no_prior_dims = no_prior_dims + [self._data_dim] for ob in obs: dim = types.type_size(str(ob)) if dim in no_prior_dims: continue space = types.tensor_type(torch.float, dim) prior = StandardNormal(dim) name = '$p(%s)$' % prior.effects effect = {'effect': prior.effect, 'dagger_effect': []} global_element = cart_closed.Box(name, Ty(), space, prior, data=effect) global_elements.append(global_element) self._category = freecat.FreeCategory(generators, global_elements) self.guide_temperatures = nn.Sequential( nn.Linear(self._data_dim, guide_hidden_dim), nn.LayerNorm(guide_hidden_dim), nn.PReLU(), nn.Linear(guide_hidden_dim, 1 * 2), nn.Softplus(), ) self.guide_arrow_weights = nn.Sequential( nn.Linear(self._data_dim, guide_hidden_dim), nn.LayerNorm(guide_hidden_dim), nn.PReLU(), nn.Linear(guide_hidden_dim, self._category.arrow_weight_loc.shape[0] * 2), ) self._random_variable_names = collections.defaultdict(int) self.encoders = nn.ModuleDict() self.encoder_functor = wiring.Functor( lambda ty: util.double_latent(ty, self.data_space), lambda ar: self._encoder(ar.name), ob_factory=Ty, ar_factory=cart_closed.Box) for arrow in self._category.ars: effect = arrow.data['effect'] cod_dims = util.double_latents( [types.type_size(ob.name) for ob in arrow.cod], self._data_dim) dom_dims = util.double_latents( [types.type_size(ob.name) for ob in arrow.dom], self._data_dim) self.encoders[arrow.name + '†'] = build_encoder( cod_dims, dom_dims, effect)
def __init__(self, data_dim=28 * 28, hidden_dim=64, guide_hidden_dim=256): self._data_dim = data_dim data_side = int(math.sqrt(self._data_dim)) glimpse_side = data_side // 2 glimpse_dim = glimpse_side**2 generators = [] # Build up a bunch of torch.Sizes for the powers of two between # hidden_dim and glimpse_dim. dims = list(util.powers_of(2, hidden_dim, glimpse_dim // 4)) +\ [glimpse_dim] dims.sort() for dim_a, dim_b in itertools.combinations(dims, 2): lower, higher = sorted([dim_a, dim_b]) # Construct the decoder and encoder if higher == glimpse_dim: decoder = DensityDecoder(lower, higher, DiagonalGaussian, convolve=True) else: decoder = DensityDecoder(lower, higher, DiagonalGaussian) data = {'effect': decoder.effect} generator = cart_closed.Box(decoder.name, decoder.type.left, decoder.type.right, decoder, data=data) generators.append(generator) # Build up a bunch of torch.Sizes for the powers of two between # hidden_dim and data_dim. dims = dims + [data_dim] dims.sort() for lower, higher in zip(dims, dims[1:]): # Construct the VLAE decoder and encoder if higher == self._data_dim: decoder = LadderDecoder(lower, higher, noise_dim=2, conv=True, out_dist=DiagonalGaussian) else: decoder = LadderDecoder(lower, higher, noise_dim=2, conv=False, out_dist=DiagonalGaussian) data = {'effect': decoder.effect} generator = cart_closed.Box(decoder.name, decoder.type.left, decoder.type.right, decoder, data=data) generators.append(generator) # For each dimensionality, construct a prior/posterior ladder pair for dim in set(dims) - {glimpse_dim, data_dim}: space = types.tensor_type(torch.float, dim) prior = LadderPrior(dim, DiagonalGaussian) data = {'effect': prior.effect} generator = cart_closed.Box(prior.name, Ty(), space, prior, data=data) generators.append(generator) # Construct writer/reader pair for spatial attention writer = SpatialTransformerWriter(data_side, glimpse_side) writer_l, writer_r = writer.type.left, writer.type.right data = {'effect': writer.effect} generator = cart_closed.Box(writer.name, writer_l, writer_r, writer, data=data) generators.append(generator) # Construct the likelihood likelihood = GaussianLikelihood(data_dim, 'X^{%d}' % data_dim) data = {'effect': likelihood.effect} generator = cart_closed.Box(likelihood.name, likelihood.type.left, likelihood.type.right, likelihood, data=data) generators.append(generator) super().__init__(generators, [], data_dim, guide_hidden_dim, no_prior_dims=[glimpse_dim, data_dim])