def __init__(self,
                 schema,
                 input_channels=1,
                 activation=F.relu,
                 layers=[64, 64, 64],
                 embedding_dim=50,
                 dropout=0,
                 norm=True,
                 pool_op='mean',
                 norm_affine=False,
                 final_activation=nn.Identity(),
                 embedding_entities=None,
                 output_relations=None,
                 in_fc_layer=True):
        super(EquivHGAE, self).__init__()
        self.schema = schema
        if output_relations == None:
            self.schema_out = schema
        else:
            self.schema_out = DataSchema(
                schema.entities, {rel.id: rel
                                  for rel in output_relations})
        self.input_channels = input_channels

        self.activation = activation
        self.rel_activation = Activation(schema,
                                         self.activation,
                                         is_sparse=True)

        self.dropout = Dropout(p=dropout)
        self.rel_dropout = Activation(schema, self.dropout, is_sparse=True)

        self.use_in_fc_layer = in_fc_layer
        # Equivariant Layeres
        self.equiv_layers = nn.ModuleList([])
        if self.use_in_fc_layer:
            # Simple fully connected layers for input attributes
            self.fc_in_layer = SparseMatrixRelationLinear(
                schema, self.input_channels, layers[0])
            self.n_equiv_layers = len(layers) - 1
        else:
            # Alternatively, use an equivariant layer
            self.equiv_layers.append(
                SparseMatrixEquivariantLayer(schema,
                                             input_channels,
                                             layers[0],
                                             pool_op=pool_op))
            self.n_equiv_layers = len(layers)
        self.equiv_layers.extend([
            SparseMatrixEquivariantLayer(schema,
                                         layers[i - 1],
                                         layers[i],
                                         pool_op=pool_op)
            for i in range(1, len(layers))
        ])
        if norm:
            self.norms = nn.ModuleList()
            for channels in layers:
                norm_dict = nn.ModuleDict()
                for rel_id in self.schema.relations:
                    norm_dict[str(rel_id)] = nn.BatchNorm1d(
                        channels,
                        affine=norm_affine,
                        track_running_stats=False)
                norm_activation = Activation(schema,
                                             norm_dict,
                                             is_dict=True,
                                             is_sparse=True)
                self.norms.append(norm_activation)
        else:
            self.norms = nn.ModuleList([
                Activation(schema, nn.Identity(), is_sparse=True)
                for _ in layers
            ])

        # Entity embeddings
        self.pooling = SparseMatrixEntityPoolingLayer(
            schema,
            layers[-1],
            embedding_dim,
            entities=embedding_entities,
            pool_op=pool_op)
        self.broadcasting = SparseMatrixEntityBroadcastingLayer(
            self.schema_out,
            embedding_dim,
            input_channels,
            entities=embedding_entities,
            pool_op=pool_op)

        self.final_activation = Activation(schema,
                                           final_activation,
                                           is_sparse=True)
    def __init__(self,
                 schema,
                 activation=F.relu,
                 layers=[64, 64, 64],
                 embedding_dim=50,
                 dropout=0,
                 pool_op='mean',
                 norm_affine=False,
                 embedding_entities=None,
                 out_fc_layer=True,
                 out_dim=1):
        super(EquivDecoder, self).__init__()
        self.schema = schema
        self.out_dim = out_dim
        self.activation = activation
        self.rel_activation = Activation(schema,
                                         self.activation,
                                         is_sparse=True)

        self.dropout = Dropout(p=dropout)
        self.rel_dropout = Activation(schema, self.dropout, is_sparse=True)

        self.use_out_fc_layer = out_fc_layer

        # Equivariant Layers
        self.broadcasting = SparseMatrixEntityBroadcastingLayer(
            self.schema,
            embedding_dim,
            layers[0],
            entities=embedding_entities,
            pool_op=pool_op)

        self.equiv_layers = nn.ModuleList([])
        self.equiv_layers.extend([
            SparseMatrixEquivariantLayer(schema,
                                         layers[i - 1],
                                         layers[i],
                                         pool_op=pool_op)
            for i in range(1, len(layers))
        ])
        self.n_layers = len(layers) - 1
        if self.use_out_fc_layer:
            # Add fully connected layer to output
            self.fc_out_layer = SparseMatrixRelationLinear(
                schema, layers[-1], self.out_dim)
        else:
            # Alternatively, use an equivariant layer
            self.equiv_layers.append(
                SparseMatrixEquivariantLayer(schema,
                                             layers[-1],
                                             self.out_dim,
                                             pool_op=pool_op))

        self.norms = nn.ModuleList()
        for channels in layers:
            norm_dict = nn.ModuleDict()
            for rel_id in self.schema.relations:
                norm_dict[str(rel_id)] = nn.BatchNorm1d(
                    channels, affine=norm_affine, track_running_stats=False)
            norm_activation = Activation(schema,
                                         norm_dict,
                                         is_dict=True,
                                         is_sparse=True)
            self.norms.append(norm_activation)
    def __init__(self,
                 schema,
                 input_channels,
                 width,
                 depth,
                 embedding_dim,
                 activation=F.relu,
                 final_activation=nn.Identity(),
                 output_dim=1,
                 dropout=0,
                 norm=True,
                 pool_op='mean',
                 in_fc_layer=True,
                 out_fc_layer=True,
                 norm_affine=False,
                 residual=False):
        super(EquivAlternatingLinkPredictor, self).__init__()

        self.schema = schema
        self.input_channels = input_channels

        self.width = width
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.activation = activation
        self.rel_activation = Activation(schema,
                                         self.activation,
                                         is_sparse=True)

        self.dropout = Dropout(p=dropout)
        self.rel_dropout = Activation(schema, self.dropout, is_sparse=True)

        self.use_in_fc_layer = in_fc_layer
        if self.use_in_fc_layer:
            self.in_fc_layer = SparseMatrixRelationLinear(
                schema, self.input_channels, width)
        self.use_out_fc_layer = out_fc_layer
        if self.use_out_fc_layer:
            self.out_fc_layer = SparseMatrixRelationLinear(
                schema, width, output_dim)

        # Equivariant Layers
        self.pool_layers = nn.ModuleList([])
        self.bcast_layers = nn.ModuleList([])

        for i in range(depth):
            if i == 0 and not self.use_in_fc_layer:
                in_dim = input_channels
            else:
                in_dim = width
            if i == depth - 1 and not self.use_out_fc_layer:
                out_dim = output_dim
            else:
                out_dim = width

            pool_i = SparseMatrixEntityPoolingLayer(schema,
                                                    in_dim,
                                                    embedding_dim,
                                                    entities=schema.entities,
                                                    pool_op=pool_op)
            self.pool_layers.append(pool_i)

            bcast_i = SparseMatrixEntityBroadcastingLayer(
                schema,
                embedding_dim,
                out_dim,
                entities=schema.entities,
                pool_op=pool_op)
            self.bcast_layers.append(bcast_i)

        if norm:
            self.norms = nn.ModuleList()
            for i in range(depth):
                norm_dict = nn.ModuleDict()
                for rel_id in self.schema.relations:
                    norm_dict[str(rel_id)] = nn.BatchNorm1d(
                        embedding_dim,
                        affine=norm_affine,
                        track_running_stats=False)
                norm_activation = Activation(schema,
                                             norm_dict,
                                             is_dict=True,
                                             is_sparse=True)
                self.norms.append(norm_activation)
        else:
            self.norms = nn.ModuleList([
                Activation(schema, nn.Identity(), is_sparse=True)
                for _ in range(depth)
            ])

        self.final_activation = final_activation
        self.residual = residual
    def __init__(self, schema, input_channels, source_layers=[32,64,32],
                 target_layers=[32], output_dim=1, schema_out=None,
                 activation=F.relu, dropout=0, pool_op='mean',
                 norm=True, norm_affine=True, final_activation=nn.Identity()):
        super(EquivariantNetwork, self).__init__()
        self.schema = schema
        if schema_out == None:
            self.schema_out = schema
        else:
            self.schema_out = schema_out
        self.input_channels = input_channels

        self.activation = activation
        self.source_activation = Activation(schema, self.activation)
        self.target_activation = Activation(self.schema_out, self.activation)

        self.dropout = Dropout(p=dropout)
        self.source_dropout  = Activation(self.schema, self.dropout)
        self.target_dropout = Activation(self.schema_out, self.dropout)
        # Equivariant layers with source schema
        self.n_source_layers = len(source_layers)
        self.source_layers = nn.ModuleList([])
        self.source_layers.append(EquivariantLayer(
                self.schema, input_channels, source_layers[0], pool_op=pool_op))
        self.source_layers.extend([
                EquivariantLayer(self.schema, source_layers[i-1], source_layers[i], pool_op=pool_op)
                for i in range(1, len(source_layers))])
        if norm:
            self.source_norms = nn.ModuleList()
            for channels in source_layers:
                norm_dict = nn.ModuleDict()
                for rel_id in self.schema.relations:
                    norm_dict[str(rel_id)] = nn.GroupNorm(channels, channels, affine=norm_affine)
                norm_activation = Activation(self.schema, norm_dict, is_dict=True)
                self.source_norms.append(norm_activation)
        else:
            self.source_norms = nn.ModuleList([Activation(schema, nn.Identity())
                                        for _ in source_layers])

        # Equivariant layers with target schema
        target_layers = target_layers + [output_dim]
        self.n_target_layers = len(target_layers)
        self.target_layers = nn.ModuleList([])
        self.target_layers.append(EquivariantLayer(self.schema, source_layers[-1],
                                                   target_layers[0],
                                                   schema_out=self.schema_out,
                                                   pool_op=pool_op))
        self.target_layers.extend([
                EquivariantLayer(self.schema_out, target_layers[i-1],
                                 target_layers[i], pool_op=pool_op)
                for i in range(1, len(target_layers))])
        if norm:
            self.target_norms = nn.ModuleList()
            for channels in target_layers:
                norm_dict = nn.ModuleDict()
                for rel_id in self.schema_out.relations:
                    norm_dict[str(rel_id)] = nn.GroupNorm(channels, channels, affine=norm_affine)
                norm_activation = Activation(self.schema_out, norm_dict, is_dict=True)
                self.target_norms.append(norm_activation)
        else:
            self.target_norms = nn.ModuleList([Activation(self.schema_out, nn.Identity())
                                        for _ in target_layers])

        self.final_activation = final_activation
        self.final_rel_activation = Activation(self.schema_out, self.final_activation)