def _get_model(**kwargs):
    return SE3Transformer(num_layers=4,
                          fiber_in=Fiber.create(2, CHANNELS),
                          fiber_hidden=Fiber.create(3, CHANNELS),
                          fiber_out=Fiber.create(2, CHANNELS),
                          fiber_edge=Fiber({}),
                          num_heads=8,
                          channels_div=2,
                          **kwargs)
Exemplo n.º 2
0
    def __init__(self,
                 fiber_in: Fiber,
                 fiber_out: Fiber,
                 fiber_edge: Optional[Fiber] = None,
                 num_heads: int = 4,
                 channels_div: int = 2,
                 use_layer_norm: bool = False,
                 max_degree: bool = 4,
                 fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
                 low_memory: bool = False,
                 **kwargs):
        """
        :param fiber_in:         Fiber describing the input features
        :param fiber_out:        Fiber describing the output features
        :param fiber_edge:       Fiber describing the edge features (node distances excluded)
        :param num_heads:        Number of attention heads
        :param channels_div:     Divide the channels by this integer for computing values
        :param use_layer_norm:   Apply layer normalization between MLP layers
        :param max_degree:       Maximum degree used in the bases computation
        :param fuse_level:       Maximum fuse level to use in TFN convolutions
        """
        super().__init__()
        if fiber_edge is None:
            fiber_edge = Fiber({})
        self.fiber_in = fiber_in
        # value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
        value_fiber = Fiber([(degree, channels // channels_div)
                             for degree, channels in fiber_out])
        # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
        # (queries are merely projected, hence degrees have to match input)
        key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber
                                 if fe.degree in fiber_in.degrees])

        self.to_key_value = ConvSE3(fiber_in,
                                    value_fiber + key_query_fiber,
                                    pool=False,
                                    fiber_edge=fiber_edge,
                                    use_layer_norm=use_layer_norm,
                                    max_degree=max_degree,
                                    fuse_level=fuse_level,
                                    allow_fused_output=True,
                                    low_memory=low_memory)
        self.to_query = LinearSE3(fiber_in, key_query_fiber)
        self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
        self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
Exemplo n.º 3
0
    def __init__(self, fiber_in: Fiber, fiber_out: Fiber, fiber_edge: Fiber,
                 num_degrees: int, num_channels: int, output_dim: int,
                 **kwargs):
        super().__init__()
        kwargs['pooling'] = kwargs['pooling'] or 'max'
        self.transformer = SE3Transformer(fiber_in=fiber_in,
                                          fiber_hidden=Fiber.create(
                                              num_degrees, num_channels),
                                          fiber_out=fiber_out,
                                          fiber_edge=fiber_edge,
                                          return_type=0,
                                          **kwargs)

        n_out_features = fiber_out.num_features
        self.mlp = nn.Sequential(nn.Linear(n_out_features, n_out_features),
                                 nn.ReLU(),
                                 nn.Linear(n_out_features, output_dim))
Exemplo n.º 4
0
    def __init__(self,
                 num_layers: int,
                 fiber_in: Fiber,
                 fiber_hidden: Fiber,
                 fiber_out: Fiber,
                 num_heads: int,
                 channels_div: int,
                 fiber_edge: Fiber = Fiber({}),
                 return_type: Optional[int] = None,
                 pooling: Optional[Literal['avg', 'max']] = None,
                 norm: bool = True,
                 use_layer_norm: bool = True,
                 tensor_cores: bool = False,
                 low_memory: bool = False,
                 **kwargs):
        """
        :param num_layers:          Number of attention layers
        :param fiber_in:            Input fiber description
        :param fiber_hidden:        Hidden fiber description
        :param fiber_out:           Output fiber description
        :param fiber_edge:          Input edge fiber description
        :param num_heads:           Number of attention heads
        :param channels_div:        Channels division before feeding to attention layer
        :param return_type:         Return only features of this type
        :param pooling:             'avg' or 'max' graph pooling before MLP layers
        :param norm:                Apply a normalization layer after each attention block
        :param use_layer_norm:      Apply layer normalization between MLP layers
        :param tensor_cores:        True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
        :param low_memory:          If True, will use slower ops that use less memory
        """
        super().__init__()
        self.num_layers = num_layers
        self.fiber_edge = fiber_edge
        self.num_heads = num_heads
        self.channels_div = channels_div
        self.return_type = return_type
        self.pooling = pooling
        self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees,
                              *fiber_out.degrees)
        self.tensor_cores = tensor_cores
        self.low_memory = low_memory

        if low_memory:
            self.fuse_level = ConvSE3FuseLevel.NONE
        else:
            # Fully fused convolutions when using Tensor Cores (and not low memory mode)
            self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL

        graph_modules = []
        for i in range(num_layers):
            graph_modules.append(
                AttentionBlockSE3(fiber_in=fiber_in,
                                  fiber_out=fiber_hidden,
                                  fiber_edge=fiber_edge,
                                  num_heads=num_heads,
                                  channels_div=channels_div,
                                  use_layer_norm=use_layer_norm,
                                  max_degree=self.max_degree,
                                  fuse_level=self.fuse_level,
                                  low_memory=low_memory))
            if norm:
                graph_modules.append(NormSE3(fiber_hidden))
            fiber_in = fiber_hidden

        graph_modules.append(
            ConvSE3(fiber_in=fiber_in,
                    fiber_out=fiber_out,
                    fiber_edge=fiber_edge,
                    self_interaction=True,
                    use_layer_norm=use_layer_norm,
                    max_degree=self.max_degree))
        self.graph_modules = Sequential(*graph_modules)

        if pooling is not None:
            assert return_type is not None, 'return_type must be specified when pooling'
            self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
Exemplo n.º 5
0
    if args.seed is not None:
        logging.info(f'Using seed {args.seed}')
        seed_everything(args.seed)

    loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
    if args.wandb:
        loggers.append(
            WandbLogger(name=f'QM9({args.task})',
                        save_dir=args.log_dir,
                        project='se3-transformer'))
    logger = LoggerCollection(loggers)

    datamodule = QM9DataModule(**vars(args))
    model = SE3TransformerPooled(
        fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
        fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
        fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
        output_dim=1,
        tensor_cores=using_tensor_cores(
            args.amp),  # use Tensor Cores more effectively
        **vars(args))
    loss_fn = nn.L1Loss()

    if args.benchmark:
        logging.info('Running benchmark mode')
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
    else:
        callbacks = [
            QM9MetricCallback(logger,