def __init__(self, neox_args, init_method, output_layer_init_method, parallel_output=False): super().__init__() self.activation_func = get_activation(neox_args) self.activation_type = neox_args.activation self.bias_gelu_fusion = neox_args.bias_gelu_fusion # auto scale so geglu has equal parameters ff_mult = 4 * 2 / 3 if self.activation_type == "geglu" else 4 ff_dim = (int(ff_mult * neox_args.hidden_size) * 2 if self.activation_type == "geglu" else ff_mult * neox_args.hidden_size) self.dense_h_to_4h = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ff_dim, gather_output=False, init_method=init_method, skip_bias_add=True, ) ff_dim_in = ff_dim // 2 if self.activation_type == "geglu" else ff_dim # Project back to h. self.dense_4h_to_h = mpu.RowParallelLinear( neox_args=neox_args, input_size=ff_dim_in, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, parallel_output=parallel_output, )
def __init__( self, neox_args, init_method, output_layer_init_method, layer_number, ff_mult=4, mask_fn=None, ): super().__init__() self.layer_number = layer_number ff_dim = neox_args.hidden_size * ff_mult norm, eps = get_norm(neox_args) self.norm = norm(neox_args.hidden_size, eps=eps) self.input_linear = mpu.ColumnParallelLinear( neox_args=neox_args, input_size=neox_args.hidden_size, output_size=ff_dim * 2, gather_output=False, init_method=init_method, skip_bias_add=True, ) self.activation_func = get_activation(neox_args) ff_dim_parallel = mpu.divide(ff_dim, mpu.get_model_parallel_world_size()) if neox_args.attention_config[layer_number] == "amlp": d_attn = neox_args.gmlp_attn_dim else: d_attn = None self.sgu = SpatialGatingUnit(neox_args, ff_dim_parallel, d_attn, causal=True, mask_fn=mask_fn) self.output_linear = mpu.RowParallelLinear( neox_args=neox_args, input_size=ff_dim, output_size=neox_args.hidden_size, input_is_parallel=True, init_method=output_layer_init_method, skip_bias_add=True, )