def __init__(self, rep_in, rep_out, group, ch=384, num_layers=3): super().__init__() logging.info("Initing EMLP (PyTorch)") self.rep_in = rep_in(group) self.rep_out = rep_out(group) self.G = group # Parse ch as a single int, a sequence of ints, a single Rep, a sequence of Reps if isinstance(ch, int): middle_layers = num_layers * [ uniform_rep(ch, group) ] #[uniform_rep(ch,group) for _ in range(num_layers)] elif isinstance(ch, Rep): middle_layers = num_layers * [ch(group)] else: middle_layers = [ (c(group) if isinstance(c, Rep) else uniform_rep(c, group)) for c in ch ] #assert all((not rep.G is None) for rep in middle_layers[0].reps) reps = [self.rep_in] + middle_layers #logging.info(f"Reps: {reps}") self.network = nn.Sequential( *[EMLPBlock(rin, rout) for rin, rout in zip(reps, reps[1:])], Linear(reps[-1], self.rep_out))
def EMLP(rep_in, rep_out, group, ch=384, num_layers=3): """ Equivariant MultiLayer Perceptron. If the input ch argument is an int, uses the hands off uniform_rep heuristic. If the ch argument is a representation, uses this representation for the hidden layers. Individual layer representations can be set explicitly by using a list of ints or a list of representations, rather than use the same for each hidden layer. Args: rep_in (Rep): input representation rep_out (Rep): output representation group (Group): symmetry group ch (int or list[int] or Rep or list[Rep]): number of channels in the hidden layers num_layers (int): number of hidden layers Returns: Module: the EMLP objax module.""" logging.info("Initing EMLP (flax)") rep_in = rep_in(group) rep_out = rep_out(group) if isinstance(ch, int): middle_layers = num_layers * [uniform_rep(ch, group)] elif isinstance(ch, Rep): middle_layers = num_layers * [ch(group)] else: middle_layers = [ (c(group) if isinstance(c, Rep) else uniform_rep(c, group)) for c in ch ] reps = [rep_in] + middle_layers logging.info(f"Reps: {reps}") return Sequential( *[EMLPBlock(rin, rout) for rin, rout in zip(reps, reps[1:])], Linear(reps[-1], rep_out))
def test_large_representations(G): N = 5 ch = 256 rep = repin = repout = uniform_rep(ch, G) repW = rep >> rep P = repW.equivariant_projector() W = np.random.rand(repout.size(), repin.size()) W = (P @ W.reshape(-1)).reshape(*W.shape) x = np.random.rand(N, repin.size()) gs = G.samples(N) ring = vmap(repin.rho_dense)(gs) routg = vmap(repout.rho_dense)(gs) gx = (ring @ x[..., None])[..., 0] Wgx = gx @ W.T #print(g.shape,([email protected]).shape) gWx = (routg @ (x @ W.T)[..., None])[..., 0] equiv_err = vmap(scale_adjusted_rel_error)(Wgx, gWx, gs).mean() assert equiv_err < 1e-4, f"Large Rep Equivariant gWx=Wgx fails err {equiv_err:.3e} with G={G}" logging.info(f"Success with G={G}") # #print(dir(TestRepresentationSubspace)) # if __name__ == '__main__': # parser = argparse.ArgumentParser() # parser.add_argument("--log", default="warning",help=("Logging Level Example --log debug', default='warning'")) # options,unknown_args = parser.parse_known_args()#["--log"]) # levels = {'critical': logging.CRITICAL,'error': logging.ERROR,'warn': logging.WARNING,'warning': logging.WARNING, # 'info': logging.INFO,'debug': logging.DEBUG} # level = levels.get(options.log.lower()) # logging.getLogger().setLevel(level) # unit_argv = [sys.argv[0]] + unknown_args # unittest.main(argv=unit_argv)