Esempio n. 1
0
    def __init__(self, rep_in, rep_out, group, ch=384, num_layers=3):
        super().__init__()
        logging.info("Initing EMLP (PyTorch)")
        self.rep_in = rep_in(group)
        self.rep_out = rep_out(group)

        self.G = group
        # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps
        if isinstance(ch, int):
            middle_layers = num_layers * [
                uniform_rep(ch, group)
            ]  #[uniform_rep(ch,group) for _ in range(num_layers)]
        elif isinstance(ch, Rep):
            middle_layers = num_layers * [ch(group)]
        else:
            middle_layers = [
                (c(group) if isinstance(c, Rep) else uniform_rep(c, group))
                for c in ch
            ]
        #assert all((not rep.G is None) for rep in middle_layers[0].reps)
        reps = [self.rep_in] + middle_layers
        #logging.info(f"Reps: {reps}")
        self.network = nn.Sequential(
            *[EMLPBlock(rin, rout) for rin, rout in zip(reps, reps[1:])],
            Linear(reps[-1], self.rep_out))
Esempio n. 2
0
def EMLP(rep_in, rep_out, group, ch=384, num_layers=3):
    """ Equivariant MultiLayer Perceptron. 
        If the input ch argument is an int, uses the hands off uniform_rep heuristic.
        If the ch argument is a representation, uses this representation for the hidden layers.
        Individual layer representations can be set explicitly by using a list of ints or a list of
        representations, rather than use the same for each hidden layer.

        Args:
            rep_in (Rep): input representation
            rep_out (Rep): output representation
            group (Group): symmetry group
            ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers
            num_layers (int): number of hidden layers

        Returns:
            Module: the EMLP objax module."""
    logging.info("Initing EMLP (flax)")
    rep_in = rep_in(group)
    rep_out = rep_out(group)
    if isinstance(ch, int):
        middle_layers = num_layers * [uniform_rep(ch, group)]
    elif isinstance(ch, Rep):
        middle_layers = num_layers * [ch(group)]
    else:
        middle_layers = [
            (c(group) if isinstance(c, Rep) else uniform_rep(c, group))
            for c in ch
        ]
    reps = [rep_in] + middle_layers
    logging.info(f"Reps: {reps}")
    return Sequential(
        *[EMLPBlock(rin, rout) for rin, rout in zip(reps, reps[1:])],
        Linear(reps[-1], rep_out))
Esempio n. 3
0
def test_large_representations(G):
    N = 5
    ch = 256
    rep = repin = repout = uniform_rep(ch, G)
    repW = rep >> rep
    P = repW.equivariant_projector()
    W = np.random.rand(repout.size(), repin.size())
    W = (P @ W.reshape(-1)).reshape(*W.shape)

    x = np.random.rand(N, repin.size())
    gs = G.samples(N)
    ring = vmap(repin.rho_dense)(gs)
    routg = vmap(repout.rho_dense)(gs)
    gx = (ring @ x[..., None])[..., 0]
    Wgx = gx @ W.T
    #print(g.shape,([email protected]).shape)
    gWx = (routg @ (x @ W.T)[..., None])[..., 0]
    equiv_err = vmap(scale_adjusted_rel_error)(Wgx, gWx, gs).mean()
    assert equiv_err < 1e-4, f"Large Rep Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G}"
    logging.info(f"Success with G={G}")


# #print(dir(TestRepresentationSubspace))
# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()
#     parser.add_argument("--log", default="warning",help=("Logging Level Example --log debug', default='warning'"))
#     options,unknown_args = parser.parse_known_args()#["--log"])
#     levels = {'critical': logging.CRITICAL,'error': logging.ERROR,'warn': logging.WARNING,'warning': logging.WARNING,
#         'info': logging.INFO,'debug': logging.DEBUG}
#     level = levels.get(options.log.lower())
#     logging.getLogger().setLevel(level)
#     unit_argv = [sys.argv[0]] + unknown_args
#     unittest.main(argv=unit_argv)