def get_cfg(key=None): if key is None: cfg = CfgNode() cfg.agg_type = 'concat' # Attention Aggregator cfg.attention_net = CfgNode() cfg.attention_net.num_layers = 6 cfg.attention_net.units = 64 cfg.attention_net.num_heads = 4 cfg.attention_net.hidden_size = -1 # Size of the FFN network used in attention cfg.attention_net.activation = 'gelu' # Activation of the attention # Other parameters cfg.mid_units = 128 cfg.feature_proj_num_layers = -1 cfg.out_proj_num_layers = 1 cfg.data_dropout = False cfg.dropout = 0.1 cfg.activation = 'leaky' cfg.normalization = 'layer_norm' cfg.norm_eps = 1e-5 cfg.initializer = CfgNode() cfg.initializer.weight = ['xavier', 'uniform', 'avg', 3.0] cfg.initializer.bias = ['zeros'] else: raise NotImplementedError return cfg
def get_cfg(key=None): if key is None: cfg = CfgNode() cfg.input_centering = False cfg.mid_units = 128 cfg.num_layers = 1 cfg.data_dropout = False cfg.dropout = 0.1 cfg.activation = 'leaky' cfg.normalization = 'layer_norm' cfg.norm_eps = 1e-5 cfg.initializer = CfgNode() cfg.initializer.weight = ['xavier', 'uniform', 'avg', 3.0] cfg.initializer.bias = ['zeros'] else: raise NotImplementedError return cfg
def get_cfg(key=None): if key is None: cfg = CfgNode() cfg.agg_type = 'concat' cfg.mid_units = 256 cfg.feature_proj_num_layers = -1 cfg.out_proj_num_layers = 0 cfg.data_dropout = False cfg.dropout = 0.1 cfg.activation = 'tanh' cfg.normalization = 'layer_norm' cfg.norm_eps = 1e-5 cfg.initializer = CfgNode() cfg.initializer.weight = ['xavier', 'uniform', 'avg', 3.0] cfg.initializer.bias = ['zeros'] else: raise NotImplementedError return cfg
def get_cfg(key=None): if key is None: cfg = CfgNode() cfg.emb_units = 32 cfg.mid_units = 64 cfg.num_layers = 1 cfg.data_dropout = False cfg.dropout = 0.1 cfg.activation = 'leaky' cfg.normalization = 'layer_norm' cfg.norm_eps = 1e-5 cfg.initializer = CfgNode() cfg.initializer.embed = ['xavier', 'gaussian', 'in', 1.0] cfg.initializer.weight = ['xavier', 'uniform', 'avg', 3.0] cfg.initializer.bias = ['zeros'] return cfg else: raise NotImplementedError