def _get_model(**kwargs):
    return SE3Transformer(num_layers=4,
                          fiber_in=Fiber.create(2, CHANNELS),
                          fiber_hidden=Fiber.create(3, CHANNELS),
                          fiber_out=Fiber.create(2, CHANNELS),
                          fiber_edge=Fiber({}),
                          num_heads=8,
                          channels_div=2,
                          **kwargs)
예제 #2
0
    def __init__(self, fiber_in: Fiber, fiber_out: Fiber, fiber_edge: Fiber,
                 num_degrees: int, num_channels: int, output_dim: int,
                 **kwargs):
        super().__init__()
        kwargs['pooling'] = kwargs['pooling'] or 'max'
        self.transformer = SE3Transformer(fiber_in=fiber_in,
                                          fiber_hidden=Fiber.create(
                                              num_degrees, num_channels),
                                          fiber_out=fiber_out,
                                          fiber_edge=fiber_edge,
                                          return_type=0,
                                          **kwargs)

        n_out_features = fiber_out.num_features
        self.mlp = nn.Sequential(nn.Linear(n_out_features, n_out_features),
                                 nn.ReLU(),
                                 nn.Linear(n_out_features, output_dim))