def __init__(self, config: dict): super().__init__(config) #self.emb_dim = config['EMBEDDING_DIM'] self.entities = get_param((self.num_ent, self.emb_dim)) self.relations = get_param((2 * self.num_rel, self.emb_dim)) self.model_name = 'Transformer_Statement' self.hid_drop2 = config['STAREARGS']['HID_DROP2'] self.feat_drop = config['STAREARGS']['FEAT_DROP'] self.num_transformer_layers = config['STAREARGS']['T_LAYERS'] self.num_heads = config['STAREARGS']['T_N_HEADS'] self.num_hidden = config['STAREARGS']['T_HIDDEN'] self.d_model = config['EMBEDDING_DIM'] self.positional = config['STAREARGS']['POSITIONAL'] self.pooling = config['STAREARGS']['POOLING'] # min / avg / concat self.device = config['DEVICE'] self.hidden_drop = torch.nn.Dropout(self.hid_drop) self.hidden_drop2 = torch.nn.Dropout(self.hid_drop2) self.feature_drop = torch.nn.Dropout(self.feat_drop) encoder_layers = TransformerEncoderLayer(self.d_model, self.num_heads, self.num_hidden, config['STAREARGS']['HID_DROP2']) self.encoder = TransformerEncoder(encoder_layers, config['STAREARGS']['T_LAYERS']) self.position_embeddings = nn.Embedding(config['MAX_QPAIRS'] - 1, self.d_model) self.layer_norm = torch.nn.LayerNorm(self.emb_dim) if self.pooling == "concat": self.flat_sz = self.emb_dim * (config['MAX_QPAIRS'] - 1) self.fc = torch.nn.Linear(self.flat_sz, self.emb_dim) else: self.fc = torch.nn.Linear(self.emb_dim, self.emb_dim)
def __init__(self, in_channels, out_channels, num_rels, act=lambda x: x, config=None): super(self.__class__, self).__init__(flow='target_to_source', aggr='add') self.p = config self.in_channels = in_channels self.out_channels = out_channels self.num_rels = num_rels self.act = act self.device = None self.w_loop = get_param((in_channels, out_channels)) # (100,200) self.w_in = get_param((in_channels, out_channels)) # (100,200) self.w_out = get_param((in_channels, out_channels)) # (100,200) self.w_rel = get_param((in_channels, out_channels)) # (100,200) if self.p['STATEMENT_LEN'] != 3: if self.p['STAREARGS']['QUAL_AGGREGATE'] == 'sum' or self.p[ 'STAREARGS']['QUAL_AGGREGATE'] == 'mul': self.w_q = get_param( (in_channels, in_channels)) # new for quals setup elif self.p['STAREARGS']['QUAL_AGGREGATE'] == 'concat': self.w_q = get_param( (2 * in_channels, in_channels)) # need 2x size due to the concat operation self.loop_rel = get_param((1, in_channels)) # (1,100) self.loop_ent = get_param((1, in_channels)) # new self.drop = torch.nn.Dropout(self.p['STAREARGS']['GCN_DROP']) self.bn = torch.nn.BatchNorm1d(out_channels) if self.p['STAREARGS']['ATTENTION']: assert self.p['STAREARGS']['GCN_DIM'] == self.p[ 'EMBEDDING_DIM'], "Current attn implementation requires those tto be identical" assert self.p['EMBEDDING_DIM'] % self.p['STAREARGS'][ 'ATTENTION_HEADS'] == 0, "should be divisible" self.heads = self.p['STAREARGS']['ATTENTION_HEADS'] self.attn_dim = self.out_channels // self.heads self.negative_slope = self.p['STAREARGS']['ATTENTION_SLOPE'] self.attn_drop = self.p['STAREARGS']['ATTENTION_DROP'] self.att = get_param((1, self.heads, 2 * self.attn_dim)) if self.p['STAREARGS']['BIAS']: self.register_parameter('bias', Parameter(torch.zeros(out_channels)))
def __init__(self, in_channels, out_channels, num_rels, fact_encoder=None, act=lambda x: x, params=None): # target_to_source -> Edges that flow from x_i to x_j super(self.__class__, self).__init__(flow='target_to_source', aggr='add') self.p = params self.emb_dim = params['EMBEDDING_DIM'] self.in_channels = in_channels self.out_channels = out_channels self.num_rels = num_rels self.act = act self.opn = params['MODEL']['OPN'] self.qual_comb = params['MODEL']['QUAL_COMB'] self.device = None # Weight of both comp functions in 'both' aggregate self.alpha = params['ALPHA'] # Three weight matrices for CompGCN # In = Standard / Out = Inverse self.w_loop = get_param((in_channels, out_channels)) self.w_in = get_param((in_channels, out_channels)) self.w_out = get_param((in_channels, out_channels)) # qual pairs self.w_q = get_param((in_channels, in_channels)) # Weight matrix for relation update self.w_rel = get_param((in_channels, out_channels)) # TODO: Move out of here? # Rel embedding for loop triplets self.loop_rel = get_param((1, in_channels)) self.drop = torch.nn.Dropout(self.p['MODEL']['GCN_DROP']) self.bn = torch.nn.BatchNorm1d(out_channels) self.fact_encoder = fact_encoder
def __init__(self, graph_repr: Dict[str, np.ndarray], config: dict, timestamps: dict = None): super().__init__(config) self.device = config['DEVICE'] # Storing the KG self.edge_index = torch.tensor(graph_repr['edge_index'], dtype=torch.long, device=self.device) self.edge_type = torch.tensor(graph_repr['edge_type'], dtype=torch.long, device=self.device) if not self.triple_mode: if self.qual_mode == "full": self.qual_rel = torch.tensor(graph_repr['qual_rel'], dtype=torch.long, device=self.device) self.qual_ent = torch.tensor(graph_repr['qual_ent'], dtype=torch.long, device=self.device) elif self.qual_mode == "sparse": self.quals = torch.tensor(graph_repr['quals'], dtype=torch.long, device=self.device) self.gcn_dim = self.emb_dim if self.n_layer == 1 else self.gcn_dim if timestamps is None: self.init_embed = get_param((self.num_ent, self.emb_dim)) self.init_embed.data[0] = 0 # padding if self.model_nm.endswith('transe'): self.init_rel = get_param((self.num_rel, self.emb_dim)) elif config['STAREARGS']['OPN'] == 'rotate' or config['STAREARGS'][ 'QUAL_OPN'] == 'rotate': phases = 2 * np.pi * torch.rand(self.num_rel, self.emb_dim // 2) self.init_rel = nn.Parameter( torch.cat([ torch.cat([torch.cos(phases), torch.sin(phases)], dim=-1), torch.cat([torch.cos(phases), -torch.sin(phases)], dim=-1) ], dim=0)) else: self.init_rel = get_param((self.num_rel * 2, self.emb_dim)) self.init_rel.data[0] = 0 # padding self.conv1 = StarEConvLayer(self.emb_dim, self.gcn_dim, self.num_rel, act=self.act, config=config) self.conv2 = StarEConvLayer( self.gcn_dim, self.emb_dim, self.num_rel, act=self.act, config=config) if self.n_layer == 2 else None if self.conv1: self.conv1.to(self.device) if self.conv2: self.conv2.to(self.device) self.register_parameter('bias', Parameter(torch.zeros(self.num_ent)))