def __init__(self, node_dim, cond_dim, edge_dim, method: str, params=None, ): super().__init__() self.method = method self.node_dim = node_dim self.cond_dim = cond_dim self.edge_dim = edge_dim self.params = params self.use_n_geo = params.e_f_use_nGeo if self.method in ('share', 'none'): return self.n2n, self.n2c = method.split('_') if params.n_geo_method in ('sum', 'none') or not self.use_n_geo: n_dim = node_dim else: n_dim = node_dim + params.n_geo_out_dim orders = params.e_f_orders if self.n2n == 'mul' else params.e_f_orders self.n_proj = pt.PtLinear(n_dim, edge_dim, drop=params.e_f_drop, norm=params.e_f_norm, orders=orders) if params.e_geo_method == 'linear': orders = params.e_geo_orders if self.n2n == 'mul' else params.e_geo_orders self.e_geo_linear = pt.PtLinear(params.e_geo_dim, params.e_geo_out_dim, orders=orders, norm=params.e_geo_norm) else: self.e_geo_linear = None join_dim = edge_dim + params.e_geo_out_dim if self.n2n in ('sum', 'mul', 'max') else edge_dim * 2 + params.e_geo_out_dim self.n2c_fusion = FilmFusion(join_dim, cond_dim, edge_dim, params=params)
def __init__( self, node_dim: int, cond_dim: int, out_dim: int, method: str = 'linear', params=None, ): super().__init__() self.params = params if params.n_geo_method in ('linear_cat', 'sum'): self.n_geo_layer = pt.PtLinear(params.n_geo_dim, params.n_geo_out_dim, orders=params.n_geo_orders, norm=params.n_geo_norm) else: self.n_geo_layer = None if params.n_geo_method == 'sum': n_dim = node_dim else: n_dim = node_dim + params.n_geo_out_dim norm, orders, drop = params.stem_norm, params.stem_orders, params.stem_drop if method == 'linear': self.linear_l = pt.PtLinear(n_dim, out_dim, norm=norm, orders=orders, drop=drop) elif method == 'double_linear': if params.stem_use_act: new_orders = orders else: new_orders = orders[:-1] if orders.endswith('a') else orders self.linear_l = nn.Sequential( pt.PtLinear(n_dim, out_dim, norm=norm, orders=new_orders, drop=drop), pt.PtLinear(out_dim, out_dim, norm=norm, orders=orders)) elif method == 'film': self.linear_l = FilmFusion(node_dim, cond_dim, out_dim, dropout=drop) else: raise NotImplementedError() self.method = method self.node_dim = node_dim
def __init__(self, in_dim: int, cond_dim: int, out_dim: int, norm_type='layer', dropout: float = 0., act_type: str = 'relu', params=None, ): super().__init__() self.cond_proj = pt.PtLinear(cond_dim, out_dim*2, orders=params.f_c_orders, norm=params.f_c_norm, drop=params.f_c_drop) self.x_linear = pt.PtLinear(in_dim, out_dim, orders=params.f_x_orders, drop=dropout, norm=params.f_x_norm, norm_affine=False, bias=False) self.cond_l = pt.Cond() self.act_l = pt.Act(params.f_act)
def __init__(self, edge_dim: int, out_dim: int, method: str, params=None, ): super().__init__() self.edge_dim = edge_dim self.out_dim = out_dim self.method = method self.params = params if method in ('share', 'none'): return else: self.params_l = nn.Sequential( pt.PtLinear(edge_dim, out_dim//2, norm=params.e_p_norm, drop=params.e_p_drop, orders=params.e_p_orders, act=params.e_p_act), pt.PtLinear(out_dim//2, out_dim, norm=params.e_p_norm, orders='lna', act='tanh') )
def __init__(self, edge_dim: int, method: str, params=None, ): super().__init__() self.edge_dim = edge_dim self.params = params self.score_method, self.norm_method, self.reduce_size = pt.str_split(method, '_') if self.score_method == 'share': self.score_l = None elif self.score_method == 'linear': self.score_l = nn.Sequential( pt.PtLinear(edge_dim, edge_dim//2, norm=params.e_w_norm, drop=params.e_w_drop, orders=params.e_w_orders), pt.PtLinear(edge_dim//2, 1, norm='weight', orders='ln') # for the ln orders, layer norm is terrible ) else: raise NotImplementedError()
def __init__(self, v_dim: int, q_dim: int, out_dim: int, method: str, params=None): super().__init__() self.method = method self.params = params if method == 'linear': self.cls_l = nn.Sequential( pt.PtLinear(v_dim, out_dim // 2, norm=params.cls_norm, orders=params.cls_orders, drop=params.cls_drop, act=params.cls_act), # layer norm worse pt.PtLinear(out_dim // 2, out_dim, norm=params.cls_norm, orders='ln')) else: raise NotImplementedError()