def _get_model(**kwargs): return SE3Transformer(num_layers=4, fiber_in=Fiber.create(2, CHANNELS), fiber_hidden=Fiber.create(3, CHANNELS), fiber_out=Fiber.create(2, CHANNELS), fiber_edge=Fiber({}), num_heads=8, channels_div=2, **kwargs)
def __init__(self, fiber_in: Fiber, fiber_out: Fiber, fiber_edge: Optional[Fiber] = None, num_heads: int = 4, channels_div: int = 2, use_layer_norm: bool = False, max_degree: bool = 4, fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL, low_memory: bool = False, **kwargs): """ :param fiber_in: Fiber describing the input features :param fiber_out: Fiber describing the output features :param fiber_edge: Fiber describing the edge features (node distances excluded) :param num_heads: Number of attention heads :param channels_div: Divide the channels by this integer for computing values :param use_layer_norm: Apply layer normalization between MLP layers :param max_degree: Maximum degree used in the bases computation :param fuse_level: Maximum fuse level to use in TFN convolutions """ super().__init__() if fiber_edge is None: fiber_edge = Fiber({}) self.fiber_in = fiber_in # value_fiber has same structure as fiber_out but #channels divided by 'channels_div' value_fiber = Fiber([(degree, channels // channels_div) for degree, channels in fiber_out]) # key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber # (queries are merely projected, hence degrees have to match input) key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees]) self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge, use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level, allow_fused_output=True, low_memory=low_memory) self.to_query = LinearSE3(fiber_in, key_query_fiber) self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber) self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
def __init__(self, fiber_in: Fiber, fiber_out: Fiber, fiber_edge: Fiber, num_degrees: int, num_channels: int, output_dim: int, **kwargs): super().__init__() kwargs['pooling'] = kwargs['pooling'] or 'max' self.transformer = SE3Transformer(fiber_in=fiber_in, fiber_hidden=Fiber.create( num_degrees, num_channels), fiber_out=fiber_out, fiber_edge=fiber_edge, return_type=0, **kwargs) n_out_features = fiber_out.num_features self.mlp = nn.Sequential(nn.Linear(n_out_features, n_out_features), nn.ReLU(), nn.Linear(n_out_features, output_dim))
def __init__(self, num_layers: int, fiber_in: Fiber, fiber_hidden: Fiber, fiber_out: Fiber, num_heads: int, channels_div: int, fiber_edge: Fiber = Fiber({}), return_type: Optional[int] = None, pooling: Optional[Literal['avg', 'max']] = None, norm: bool = True, use_layer_norm: bool = True, tensor_cores: bool = False, low_memory: bool = False, **kwargs): """ :param num_layers: Number of attention layers :param fiber_in: Input fiber description :param fiber_hidden: Hidden fiber description :param fiber_out: Output fiber description :param fiber_edge: Input edge fiber description :param num_heads: Number of attention heads :param channels_div: Channels division before feeding to attention layer :param return_type: Return only features of this type :param pooling: 'avg' or 'max' graph pooling before MLP layers :param norm: Apply a normalization layer after each attention block :param use_layer_norm: Apply layer normalization between MLP layers :param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases) :param low_memory: If True, will use slower ops that use less memory """ super().__init__() self.num_layers = num_layers self.fiber_edge = fiber_edge self.num_heads = num_heads self.channels_div = channels_div self.return_type = return_type self.pooling = pooling self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees) self.tensor_cores = tensor_cores self.low_memory = low_memory if low_memory: self.fuse_level = ConvSE3FuseLevel.NONE else: # Fully fused convolutions when using Tensor Cores (and not low memory mode) self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL graph_modules = [] for i in range(num_layers): graph_modules.append( AttentionBlockSE3(fiber_in=fiber_in, fiber_out=fiber_hidden, fiber_edge=fiber_edge, num_heads=num_heads, channels_div=channels_div, use_layer_norm=use_layer_norm, max_degree=self.max_degree, fuse_level=self.fuse_level, low_memory=low_memory)) if norm: graph_modules.append(NormSE3(fiber_hidden)) fiber_in = fiber_hidden graph_modules.append( ConvSE3(fiber_in=fiber_in, fiber_out=fiber_out, fiber_edge=fiber_edge, self_interaction=True, use_layer_norm=use_layer_norm, max_degree=self.max_degree)) self.graph_modules = Sequential(*graph_modules) if pooling is not None: assert return_type is not None, 'return_type must be specified when pooling' self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
if args.seed is not None: logging.info(f'Using seed {args.seed}') seed_everything(args.seed) loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)] if args.wandb: loggers.append( WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')) logger = LoggerCollection(loggers) datamodule = QM9DataModule(**vars(args)) model = SE3TransformerPooled( fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}), fiber_out=Fiber({0: args.num_degrees * args.num_channels}), fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}), output_dim=1, tensor_cores=using_tensor_cores( args.amp), # use Tensor Cores more effectively **vars(args)) loss_fn = nn.L1Loss() if args.benchmark: logging.info('Running benchmark mode') world_size = dist.get_world_size() if dist.is_initialized() else 1 callbacks = [PerformanceCallback(logger, args.batch_size * world_size)] else: callbacks = [ QM9MetricCallback(logger,