def set_multi_head_attention_v2(spec, variables, scope, self_attention=False, relative=False): set_layer_norm(spec.layer_norm, variables, "%s/input_layer_norm" % scope) if self_attention: split_layers = [common_spec.LinearSpec() for _ in range(3)] set_linear(split_layers[0], variables, "%s/layer/linear_queries" % scope) set_linear(split_layers[1], variables, "%s/layer/linear_keys" % scope) set_linear(split_layers[2], variables, "%s/layer/linear_values" % scope) utils.fuse_linear(spec.linear[0], split_layers) if relative: spec.relative_position_keys = variables[ "%s/layer/relative_position_keys" % scope] spec.relative_position_values = variables[ "%s/layer/relative_position_values" % scope] else: set_linear(spec.linear[0], variables, "%s/layer/linear_queries" % scope) split_layers = [common_spec.LinearSpec() for _ in range(2)] set_linear(split_layers[0], variables, "%s/layer/linear_keys" % scope) set_linear(split_layers[1], variables, "%s/layer/linear_values" % scope) utils.fuse_linear(spec.linear[1], split_layers) set_linear(spec.linear[-1], variables, "%s/layer/linear_output" % scope)
def set_multi_head_attention(spec, variables, scope, self_attention=False): if self_attention: split_layers = [common_spec.LinearSpec() for _ in range(3)] set_linear(split_layers[0], variables, "%s.linear_query" % scope) set_linear(split_layers[1], variables, "%s.linear_keys" % scope) set_linear(split_layers[2], variables, "%s.linear_values" % scope) utils.fuse_linear(spec.linear[0], split_layers) else: set_linear(spec.linear[0], variables, "%s.linear_query" % scope) split_layers = [common_spec.LinearSpec() for _ in range(2)] set_linear(split_layers[0], variables, "%s.linear_keys" % scope) set_linear(split_layers[1], variables, "%s.linear_values" % scope) utils.fuse_linear(spec.linear[1], split_layers) set_linear(spec.linear[-1], variables, "%s.final_linear" % scope)
def set_multi_head_attention(spec, module, self_attention=False): if self_attention: split_layers = [common_spec.LinearSpec() for _ in range(3)] set_linear(split_layers[0], module.q_proj) set_linear(split_layers[1], module.k_proj) set_linear(split_layers[2], module.v_proj) utils.fuse_linear(spec.linear[0], split_layers) else: set_linear(spec.linear[0], module.q_proj) split_layers = [common_spec.LinearSpec() for _ in range(2)] set_linear(split_layers[0], module.k_proj) set_linear(split_layers[1], module.v_proj) utils.fuse_linear(spec.linear[1], split_layers) set_linear(spec.linear[-1], module.out_proj)
def __init__(self, self_attention=False): self.layer_norm = common_spec.LayerNormSpec() if self_attention: num_projections = 2 else: num_projections = 3 self.linear = [common_spec.LinearSpec() for _ in range(num_projections)]
def __init__(self, num_layers, pre_norm=True): self.embeddings = common_spec.EmbeddingsSpec() self.position_encodings = PositionEncoderSpec() self.layer_norm = (common_spec.LayerNormSpec() if pre_norm else model_spec.OPTIONAL) self.projection = common_spec.LinearSpec() self.layer = [TransformerDecoderLayerSpec() for _ in range(num_layers)]
def set_multi_head_attention(spec, variables, scope, self_attention=False, relative=False): if self_attention: split_layers = [common_spec.LinearSpec() for _ in range(3)] set_linear(split_layers[0], variables, "%s.linear_query" % scope) set_linear(split_layers[1], variables, "%s.linear_keys" % scope) set_linear(split_layers[2], variables, "%s.linear_values" % scope) utils.fuse_linear(spec.linear[0], split_layers) else: set_linear(spec.linear[0], variables, "%s.linear_query" % scope) split_layers = [common_spec.LinearSpec() for _ in range(2)] set_linear(split_layers[0], variables, "%s.linear_keys" % scope) set_linear(split_layers[1], variables, "%s.linear_values" % scope) utils.fuse_linear(spec.linear[1], split_layers) set_linear(spec.linear[-1], variables, "%s.final_linear" % scope) if relative: spec.relative_position_keys = _get_variable( variables, "%s.relative_positions_embeddings.weight" % scope) spec.relative_position_values = spec.relative_position_keys
def __init__(self, self_attention=False): self.layer_norm = common_spec.LayerNormSpec() if self_attention: num_projections = 2 else: num_projections = 3 self.linear = [ common_spec.LinearSpec() for _ in range(num_projections) ] self.relative_position_keys = model_spec.OPTIONAL self.relative_position_values = model_spec.OPTIONAL
def __init__(self): self.layer_norm = common_spec.LayerNormSpec() self.linear_0 = common_spec.LinearSpec() self.linear_1 = common_spec.LinearSpec()
def __init__(self, num_layers): self.embeddings = common_spec.EmbeddingsSpec() self.position_encodings = PositionEncoderSpec() self.layer_norm = common_spec.LayerNormSpec() self.projection = common_spec.LinearSpec() self.layer = [TransformerDecoderLayerSpec() for _ in range(num_layers)]