def __repr__(self): return "{name} ({Rs_1} x {Rs_2} -> {Rs_out})".format( name=self.__class__.__name__, Rs_1=rs.format_Rs(self.Rs_1), Rs_2=rs.format_Rs(self.Rs_2), Rs_out=rs.format_Rs(self.Rs_out), )
def __init__(self, num_classes): super().__init__() R = partial(CosineBasisModel, max_radius=3.0, number_of_basis=3, h=100, L=3, act=relu) K = partial(Kernel, RadialModel=R) mul = 7 layers = [] Rs = [(1, 0, +1)] for i in range(3): scalars = [(mul, l, p) for mul, l, p in [(mul, 0, +1), (mul, 0, -1)] if haspath(Rs, l, p)] act_scalars = [(mul, relu if p == 1 else tanh) for mul, l, p in scalars] nonscalars = [(mul, l, p) for mul, l, p in [(mul, 1, +1), (mul, 1, -1)] if haspath(Rs, l, p)] gates = [(sum(mul for mul, l, p in nonscalars), 0, +1)] act_gates = [(-1, sigmoid)] print("layer {}: from {} to {}".format(i, rs.format_Rs(Rs), rs.format_Rs(scalars + nonscalars))) act = GatedBlockParity(scalars, act_scalars, gates, act_gates, nonscalars) conv = Convolution(K(Rs, act.Rs_in)) block = torch.nn.ModuleList([conv, act]) layers.append(block) Rs = act.Rs_out act = GatedBlockParity([(mul, 0, +1), (mul, 0, -1)], [(mul, relu), (mul, tanh)], [], [], []) conv = Convolution(K(Rs, act.Rs_in)) block = torch.nn.ModuleList([conv, act]) layers.append(block) self.firstlayers = torch.nn.ModuleList(layers) # the last layer is not equivariant, it is allowed to mix even and odds scalars self.lastlayers = torch.nn.Sequential(AvgSpacial(), torch.nn.Linear(mul + mul, num_classes))
def __repr__(self): return "{name} ({Rs_in1} x {Rs_in2} -> {Rs_out} using {nw} paths)".format( name=self.__class__.__name__, Rs_in1=rs.format_Rs(self.Rs_in1), Rs_in2=rs.format_Rs(self.Rs_in2), Rs_out=rs.format_Rs(self.Rs_out), nw=self.nweight, )
def __repr__(self): return "{name} ({Rs_scalars} + {Rs_gates} + {Rs_nonscalars} -> {Rs_out})".format( name=self.__class__.__name__, Rs_scalars=rs.format_Rs(self.Rs_scalars), Rs_gates=rs.format_Rs(self.Rs_gates), Rs_nonscalars=rs.format_Rs(self.Rs_nonscalars), Rs_out=rs.format_Rs(self.Rs_out), )
def __repr__(self): return "{name} ({Rs_in} ^ 2 -> {Rs_out})".format( name=self.__class__.__name__, Rs_in=rs.format_Rs(self.Rs_in), Rs_out=rs.format_Rs(self.Rs_out), )
def test_format(): assert rs.format_Rs([]) == "" assert rs.format_Rs([2]) == "2"