Example #1
0
 def type(self):
     in_tys = [
         types.tensor_type(torch.float, in_dim) for in_dim in self._in_dims
     ]
     in_space = functools.reduce(lambda t, u: t @ u, in_tys, Ty())
     out_tys = [
         types.tensor_type(torch.float, out_dim)
         for out_dim in self._out_dims
     ]
     out_space = functools.reduce(lambda t, u: t @ u, out_tys, Ty())
     return in_space >> out_space
Example #2
0
def cat2ty(string):
    """
    Takes the string repr of a CCG category,
    returns a :class:`discopy.biclosed.Ty`. """
    def unbracket(string):
        return string[1:-1] if string[0] == '(' else string

    def remove_modifier(string):
        return re.sub(r'\[[^]]*\]', '', string)

    def split(string):
        par_count = 0
        for i, char in enumerate(string):
            if char == "(":
                par_count += 1
            elif char == ")":
                par_count -= 1
            elif char in ["\\", "/"] and par_count == 0:
                return unbracket(string[:i]), char, unbracket(string[i + 1:])
        return remove_modifier(string), None, None

    left, slash, right = split(string)
    if slash == '\\':
        return cat2ty(right) >> cat2ty(left)
    if slash == '/':
        return cat2ty(left) << cat2ty(right)
    return Ty(left)
Example #3
0
def substitute(t, sub):
    if isinstance(t, Under):
        return substitute(t.left, sub) >> substitute(t.right, sub)
    if t.objects:
        return Ty(*[substitute(ty, sub) for ty in t.objects])
    if isinstance(t, TyVar) and t.name in sub:
        return sub[t.name]
    return t
Example #4
0
def base_elements(ty):
    if not isinstance(ty, Ty):
        return Ty(ty)
    if isinstance(ty, Under):
        return base_elements(ty.left) | base_elements(ty.right)
    bases = {ob for ob in ty.objects if not isinstance(ob, Under)}
    recursives = set().union(*[base_elements(ob) for ob in ty.objects])
    return bases | recursives
Example #5
0
 def sample_morphism(self, obj, probs, temperature, min_depth=2, infer={}):
     with name_count():
         if obj in self._graph.nodes:
             return self.path_between(Ty(), obj, probs, temperature,
                                      min_depth, infer)
         entries = unification.unfold_arrow(obj)
         src, dest = unification.fold_product(entries[:-1]), entries[-1]
         return self.path_between(src, dest, probs, temperature, min_depth,
                                  infer)
Example #6
0
def tree2diagram(tree, dom=Ty()):
    """
    Takes a depccg.Tree in JSON format,
    returns a :class:`discopy.biclosed.Diagram`.
    """
    if 'word' in tree:
        return Word(tree['word'], cat2ty(tree['cat']), dom=dom)
    children = list(map(tree2diagram, tree['children']))
    dom = Ty().tensor(*[child.cod for child in children])
    cod = cat2ty(tree['cat'])
    if tree['type'] == 'ba':
        box = BA(dom[1:])
    elif tree['type'] == 'fa':
        box = FA(dom[:1])
    elif tree['type'] == 'fc':
        box = FC(dom[:1], dom[1:])
    else:
        box = Box(tree['type'], dom, cod)
    return Id(Ty()).tensor(*children) >> box
Example #7
0
 def product_arrow(self, obj, probs, temperature, min_depth=0, infer={}):
     product = None
     for ob in obj.objects:
         entry = self.sample_morphism(Ty(ob), probs, temperature + 1,
                                      min_depth, infer)
         if product is None:
             product = entry
         else:
             product = product @ entry
     return product
Example #8
0
 def _add_object(self, obj):
     if obj in self._graph:
         return
     if unification.type_compound(obj):
         if len(obj) > 1:
             for ob in obj:
                 self._add_object(Ty(ob))
         else:
             dom, cod = obj.left, obj.right
             self._add_object(dom)
             self._add_object(cod)
     self._graph.add_node(obj, index=len(self._graph))
Example #9
0
    def __init__(self, data_dim=28 * 28, hidden_dim=64, guide_hidden_dim=256):
        self._data_dim = data_dim

        # Build up a bunch of torch.Sizes for the powers of two between
        # hidden_dim and data_dim.
        dims = list(util.powers_of(2, hidden_dim, data_dim // 4)) + [data_dim]
        dims.sort()

        gaussian_likelihood = lambda dim: DiagonalGaussian(
            dim, latent_name='X^{%d}' % dim, likelihood=True)

        generators = []
        for lower, higher in zip(dims, dims[1:]):
            # Construct the VLAE decoder and encoder
            if higher == self._data_dim:
                decoder = LadderDecoder(lower,
                                        higher,
                                        noise_dim=2,
                                        conv=True,
                                        out_dist=gaussian_likelihood)
            else:
                decoder = LadderDecoder(lower,
                                        higher,
                                        noise_dim=2,
                                        conv=False,
                                        out_dist=None)
            data = {'effect': decoder.effect}
            generator = cart_closed.Box(decoder.name,
                                        decoder.type.left,
                                        decoder.type.right,
                                        decoder,
                                        data=data)
            generators.append(generator)

        # For each dimensionality, construct a prior/posterior ladder pair
        for dim in set(dims) - {data_dim}:
            space = types.tensor_type(torch.float, dim)
            prior = LadderPrior(dim, None)

            generator = cart_closed.Box(prior.name,
                                        Ty(),
                                        space,
                                        prior,
                                        data={'effect': prior.effect})
            generators.append(generator)

        super().__init__(generators, [], data_dim, guide_hidden_dim,
                         list(set(dims) - {data_dim}))
Example #10
0
def try_unify(a, b, subst={}):
    if isinstance(a, Under) and isinstance(b, Under):
        l, lsub = try_unify(a.left, b.left)
        r, rsub = try_unify(a.right, b.right)
        subst = try_merge_substitution(lsub, rsub)
        return l >> r, subst
    if a == b:
        return a, {}
    if isinstance(a, TyVar):
        return b, {a.name: b}
    if isinstance(b, TyVar):
        return a, {b.name: a}
    if a.objects and b.objects:
        results = [try_unify(ak, bk) for ak, bk in zip(a.objects, b.objects)]
        ty = Ty(*[ty for ty, _ in results])
        subst = functools.reduce(try_merge_substitution,
                                 [subst for _, subst in results])
        return ty, subst
    raise UnificationException(a, b)
Example #11
0
    def path_between(self,
                     src,
                     dest,
                     probs,
                     temperature,
                     min_depth=0,
                     infer={}):
        assert dest != Ty()

        location = src
        path = Id(src)
        dest_index = self._graph.nodes[dest]['index']
        with pyro.markov():
            while location != dest:
                generators = self._object_generators(location, True)
                if len(path) + 1 < min_depth:
                    generators = [(g, cod) for (g, cod) in generators
                                  if cod != dest]
                gens = [self._graph.nodes[g]['index'] for (g, _) in generators]

                dest_probs = probs[gens][:, dest_index]
                viables = dest_probs.nonzero(as_tuple=True)[0]
                selection_probs = F.softmax(dest_probs[viables].log() /
                                            (temperature + 1e-10),
                                            dim=-1)
                generators_categorical = dist.Categorical(selection_probs)
                g_idx = pyro.sample('path_step_{%s -> %s}' % (location, dest),
                                    generators_categorical.to_event(0),
                                    infer=infer)

                gen, cod = generators[viables[g_idx.item()]]
                if isinstance(gen, callable.CallableBox):
                    morphism = gen
                else:
                    morphism = gen(probs, temperature,
                                   min_depth - len(path) - 1, infer)
                path = path >> morphism
                location = cod

        return path
Example #12
0
 def type(self):
     return Ty() >> types.tensor_type(torch.float, self._dim)
Example #13
0
def unfold_product(ty):
    if isinstance(ty, Under):
        return [ty]
    return [Ty(ob) for ob in ty.objects]
Example #14
0
def fold_product(ts):
    if len(ts) == 1:
        return ts[0]
    return Ty(*ts)
Example #15
0
    def __init__(self,
                 generators,
                 global_elements=[],
                 data_space=(784, ),
                 guide_hidden_dim=256,
                 no_prior_dims=[]):
        super().__init__()
        if isinstance(data_space, int):
            data_space = (data_space, )
        self._data_space = data_space
        self._data_dim = math.prod(data_space)
        if len(self._data_space) == 1:
            self._observation_name = '$X^{%d}$' % self._data_dim
        else:
            self._observation_name = '$X^{%s}$' % str(self._data_space)

        obs = set()
        for generator in generators:
            ty = generator.dom >> generator.cod
            obs = obs | unification.base_elements(ty)
        for element in global_elements:
            ty = element.dom >> element.cod
            obs = obs - unification.base_elements(ty)

        no_prior_dims = no_prior_dims + [self._data_dim]
        for ob in obs:
            dim = types.type_size(str(ob))
            if dim in no_prior_dims:
                continue

            space = types.tensor_type(torch.float, dim)
            prior = StandardNormal(dim)
            name = '$p(%s)$' % prior.effects
            effect = {'effect': prior.effect, 'dagger_effect': []}
            global_element = cart_closed.Box(name,
                                             Ty(),
                                             space,
                                             prior,
                                             data=effect)
            global_elements.append(global_element)

        self._category = freecat.FreeCategory(generators, global_elements)

        self.guide_temperatures = nn.Sequential(
            nn.Linear(self._data_dim, guide_hidden_dim),
            nn.LayerNorm(guide_hidden_dim),
            nn.PReLU(),
            nn.Linear(guide_hidden_dim, 1 * 2),
            nn.Softplus(),
        )
        self.guide_arrow_weights = nn.Sequential(
            nn.Linear(self._data_dim, guide_hidden_dim),
            nn.LayerNorm(guide_hidden_dim),
            nn.PReLU(),
            nn.Linear(guide_hidden_dim,
                      self._category.arrow_weight_loc.shape[0] * 2),
        )

        self._random_variable_names = collections.defaultdict(int)

        self.encoders = nn.ModuleDict()
        self.encoder_functor = wiring.Functor(
            lambda ty: util.double_latent(ty, self.data_space),
            lambda ar: self._encoder(ar.name),
            ob_factory=Ty,
            ar_factory=cart_closed.Box)

        for arrow in self._category.ars:
            effect = arrow.data['effect']

            cod_dims = util.double_latents(
                [types.type_size(ob.name) for ob in arrow.cod], self._data_dim)
            dom_dims = util.double_latents(
                [types.type_size(ob.name) for ob in arrow.dom], self._data_dim)
            self.encoders[arrow.name + '†'] = build_encoder(
                cod_dims, dom_dims, effect)
Example #16
0
    def __init__(self, data_dim=28 * 28, hidden_dim=64, guide_hidden_dim=256):
        self._data_dim = data_dim
        data_side = int(math.sqrt(self._data_dim))
        glimpse_side = data_side // 2
        glimpse_dim = glimpse_side**2

        generators = []

        # Build up a bunch of torch.Sizes for the powers of two between
        # hidden_dim and glimpse_dim.
        dims = list(util.powers_of(2, hidden_dim, glimpse_dim // 4)) +\
               [glimpse_dim]
        dims.sort()

        for dim_a, dim_b in itertools.combinations(dims, 2):
            lower, higher = sorted([dim_a, dim_b])
            # Construct the decoder and encoder
            if higher == glimpse_dim:
                decoder = DensityDecoder(lower,
                                         higher,
                                         DiagonalGaussian,
                                         convolve=True)
            else:
                decoder = DensityDecoder(lower, higher, DiagonalGaussian)
            data = {'effect': decoder.effect}
            generator = cart_closed.Box(decoder.name,
                                        decoder.type.left,
                                        decoder.type.right,
                                        decoder,
                                        data=data)
            generators.append(generator)

        # Build up a bunch of torch.Sizes for the powers of two between
        # hidden_dim and data_dim.
        dims = dims + [data_dim]
        dims.sort()

        for lower, higher in zip(dims, dims[1:]):
            # Construct the VLAE decoder and encoder
            if higher == self._data_dim:
                decoder = LadderDecoder(lower,
                                        higher,
                                        noise_dim=2,
                                        conv=True,
                                        out_dist=DiagonalGaussian)
            else:
                decoder = LadderDecoder(lower,
                                        higher,
                                        noise_dim=2,
                                        conv=False,
                                        out_dist=DiagonalGaussian)
            data = {'effect': decoder.effect}
            generator = cart_closed.Box(decoder.name,
                                        decoder.type.left,
                                        decoder.type.right,
                                        decoder,
                                        data=data)
            generators.append(generator)

        # For each dimensionality, construct a prior/posterior ladder pair
        for dim in set(dims) - {glimpse_dim, data_dim}:
            space = types.tensor_type(torch.float, dim)
            prior = LadderPrior(dim, DiagonalGaussian)

            data = {'effect': prior.effect}
            generator = cart_closed.Box(prior.name,
                                        Ty(),
                                        space,
                                        prior,
                                        data=data)
            generators.append(generator)

        # Construct writer/reader pair for spatial attention
        writer = SpatialTransformerWriter(data_side, glimpse_side)
        writer_l, writer_r = writer.type.left, writer.type.right

        data = {'effect': writer.effect}
        generator = cart_closed.Box(writer.name,
                                    writer_l,
                                    writer_r,
                                    writer,
                                    data=data)
        generators.append(generator)

        # Construct the likelihood
        likelihood = GaussianLikelihood(data_dim, 'X^{%d}' % data_dim)
        data = {'effect': likelihood.effect}
        generator = cart_closed.Box(likelihood.name,
                                    likelihood.type.left,
                                    likelihood.type.right,
                                    likelihood,
                                    data=data)
        generators.append(generator)

        super().__init__(generators, [],
                         data_dim,
                         guide_hidden_dim,
                         no_prior_dims=[glimpse_dim, data_dim])
Example #17
0
    def __init__(self, generators, global_elements):
        super().__init__()
        self._graph = nx.DiGraph()
        self._add_object(Ty())
        for i, gen in enumerate(generators):
            assert isinstance(gen, callable.CallableBox)

            if gen.dom not in self._graph:
                self._add_object(gen.dom)
            self._graph.add_node(gen, index=len(self._graph), arrow_index=i)
            if gen.cod not in self._graph:
                self._add_object(gen.cod)
            self._graph.add_edge(gen.dom, gen)
            self._graph.add_edge(gen, gen.cod)

            if isinstance(gen.function, nn.Module):
                self.add_module('generator_%d' % i, gen.function)
            if isinstance(gen, callable.CallableDaggerBox):
                dagger = gen.dagger()
                if isinstance(dagger.function, nn.Module):
                    self.add_module('generator_%d_dagger' % i, dagger.function)

        for i, obj in enumerate(self.obs):
            self._graph.nodes[obj]['object_index'] = i

        for i, elem in enumerate(global_elements):
            assert isinstance(elem, callable.CallableBox)
            assert elem.dom == Ty()
            if not isinstance(elem, callable.CallableDaggerBox):
                dagger_name = '%s$^{\\dagger}$' % elem.name
                elem = callable.CallableDaggerBox(elem.name, elem.dom,
                                                  elem.cod, elem.function,
                                                  lambda *args: (),
                                                  dagger_name)

            self._graph.add_node(elem,
                                 index=len(self._graph),
                                 arrow_index=len(generators) + i)
            self._graph.add_edge(Ty(), elem)
            self._graph.add_edge(elem, elem.cod)

            if isinstance(elem.function, nn.Module):
                self.add_module('global_element_%d' % i, elem.function)
            dagger = elem.dagger()
            if isinstance(dagger.function, nn.Module):
                self.add_module('global_element_%d_dagger' % i,
                                dagger.function)

        for i, obj in enumerate(self.compound_obs):
            if isinstance(obj, Under):
                src, dest = obj.left, obj.right

                def macro(probs, temp, min_depth, infer, l=src, r=dest):
                    return self.path_between(l, r, probs, temp, min_depth,
                                             infer)
            else:

                def macro(probs, temp, min_depth, infer, obj=obj):
                    return self.product_arrow(obj, probs, temp, min_depth,
                                              infer)

            arrow_index = len(generators) + len(global_elements) + i
            self._graph.add_node(macro,
                                 index=len(self._graph),
                                 arrow_index=arrow_index)
            self._graph.add_edge(Ty(), macro)
            self._graph.add_edge(macro, obj)

        self.arrow_weight_alphas = pnn.PyroParam(
            torch.ones(len(self.ars) + len(self.macros)),
            constraint=constraints.positive)
        self.arrow_weight_betas = pnn.PyroParam(
            torch.ones(len(self.ars) + len(self.macros)),
            constraint=constraints.positive)
        self.temperature_alpha = pnn.PyroParam(torch.ones(1),
                                               constraint=constraints.positive)
        self.temperature_beta = pnn.PyroParam(torch.ones(1),
                                              constraint=constraints.positive)

        adjacency_weights = nx.to_numpy_matrix(self._graph)
        for arrow in self.ars:
            i = self._graph.nodes[arrow]['index']
            adjacency_weights[i] /= self._arrow_parameters(arrow) + 1
        self.register_buffer('diffusion_counts',
                             torch.from_numpy(
                                 scipy.linalg.expm(adjacency_weights)),
                             persistent=False)
Example #18
0
def unique_ty():
    return Ty(unique_identifier())
Example #19
0
 def wiring_diagram(self):
     return wiring.Box('',
                       Ty(),
                       self.data_space,
                       data={'effect': lambda e: True})
Example #20
0
def tensor_type(dtype, size):
    if isinstance(size, tuple):
        if len(size) > 1:
            return Ty('$%s^{%s}$' % (_label_dtype(dtype), str(size)))
        size = size[0]
    return Ty('$%s^{%d}$' % (_label_dtype(dtype), size))