def __init__(self,
                 xlnet_config,
                 input_ids,
                 seg_ids,
                 input_mask,
                 args,
                 mems=None,
                 perm_mask=None,
                 target_mapping=None,
                 inp_q=None):
        self._tie_weight = True

        self._d_head = xlnet_config['d_head']
        self._d_inner = xlnet_config['d_inner']
        self._d_model = xlnet_config['d_model']
        self._ff_activation = xlnet_config['ff_activation']
        self._n_head = xlnet_config['n_head']
        self._n_layer = xlnet_config['n_layer']
        self._n_token = xlnet_config['n_token']
        self._untie_r = xlnet_config['untie_r']

        self._mem_len = None if 'mem_len' not in args else args.mem_len
        self._reuse_len = None if 'reuse_len' not in args else args.reuse_len
        self._bi_data = False if 'bi_data' not in args else args.bi_data
        self._clamp_len = args.clamp_len
        self._same_length = False if 'same_length' not in args else args.same_length
        # Initialize all weigths by the specified initializer, and all biases
        # will be initialized by constant zero by default.
        self._param_initializer = _get_initiliaizer(args)

        tfm_args = dict(n_token=self._n_token,
                        initializer=self._param_initializer,
                        attn_type="bi",
                        n_layer=self._n_layer,
                        d_model=self._d_model,
                        n_head=self._n_head,
                        d_head=self._d_head,
                        d_inner=self._d_inner,
                        ff_activation=self._ff_activation,
                        untie_r=self._untie_r,
                        use_bfloat16=args.use_fp16,
                        dropout=args.dropout,
                        dropatt=args.dropatt,
                        mem_len=self._mem_len,
                        reuse_len=self._reuse_len,
                        bi_data=self._bi_data,
                        clamp_len=args.clamp_len,
                        same_length=self._same_length,
                        name='model_transformer')
        input_args = dict(inp_k=input_ids,
                          seg_id=seg_ids,
                          input_mask=input_mask,
                          mems=mems,
                          perm_mask=perm_mask,
                          target_mapping=target_mapping,
                          inp_q=inp_q)
        tfm_args.update(input_args)
        self.output, self.new_mems, self.lookup_table = modeling.transformer_xl(
            **tfm_args)
Exemplo n.º 2
0
    def __init__(
        self,
        xlnet_config,
        run_config,
        input_ids,
        seg_ids,
        input_mask,
        mems=None,
        perm_mask=None,
        target_mapping=None,
        inp_q=None,
        **kw,
    ):

        initializer = _get_initializer(run_config)

        tfm_args = dict(
            n_token=xlnet_config.n_token,
            initializer=initializer,
            attn_type="bi",
            n_lays=xlnet_config.n_lays,
            d_model=xlnet_config.d_model,
            n_heads=xlnet_config.n_heads,
            d_head=xlnet_config.d_head,
            d_inner=xlnet_config.d_inner,
            ff_activation=xlnet_config.ff_activation,
            untie_r=xlnet_config.untie_r,
            is_training=run_config.is_training,
            use_bfloat16=run_config.use_bfloat16,
            use_tpu=run_config.use_tpu,
            drop=run_config.drop,
            dropatt=run_config.dropatt,
            mem_len=run_config.mem_len,
            reuse_len=run_config.reuse_len,
            bi_data=run_config.bi_data,
            clamp_len=run_config.clamp_len,
            same_length=run_config.same_length,
        )

        input_args = dict(
            inp_k=input_ids,
            seg_id=seg_ids,
            input_mask=input_mask,
            mems=mems,
            perm_mask=perm_mask,
            target_mapping=target_mapping,
            inp_q=inp_q,
        )
        tfm_args.update(input_args)

        with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
            (self.output, self.new_mems, self.lookup_table) = modeling.transformer_xl(**tfm_args)

        self.input_mask = input_mask
        self.initializer = initializer
        self.xlnet_config = xlnet_config
        self.run_config = run_config
Exemplo n.º 3
0
    def __init__(self,
                 xlnet_config,
                 run_config,
                 input_ids,
                 seg_ids,
                 input_mask,
                 mems=None,
                 perm_mask=None,
                 target_mapping=None,
                 inp_q=None,
                 **kwargs):
        """
		Args:
		  xlnet_config: XLNetConfig,
		  run_config: RunConfig,
		  input_ids: int32 Tensor in shape [len, bsz], the input token IDs.
		  seg_ids: int32 Tensor in shape [len, bsz], the input segment IDs.
		  input_mask: float32 Tensor in shape [len, bsz], the input mask.
			0 for real tokens and 1 for padding.
		  mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
			from previous batches. The length of the list equals n_layer.
			If None, no memory is used.
		  perm_mask: float32 Tensor in shape [len, len, bsz].
			If perm_mask[i, j, k] = 0, i attend to j in batch k;
			if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
			If None, each position attends to all the others.
		  target_mapping: float32 Tensor in shape [num_predict, len, bsz].
			If target_mapping[i, j, k] = 1, the i-th predict in batch k is
			on the j-th token.
			Only used during pretraining for partial prediction.
			Set to None during finetuning.
		  inp_q: float32 Tensor in shape [len, bsz].
			1 for tokens with losses and 0 for tokens without losses.
			Only used during pretraining for two-stream attention.
			Set to None during finetuning.
		"""

        initializer = _get_initializer(run_config)

        tfm_args = dict(n_token=xlnet_config.n_token,
                        initializer=initializer,
                        attn_type="bi",
                        n_layer=xlnet_config.n_layer,
                        d_model=xlnet_config.d_model,
                        n_head=xlnet_config.n_head,
                        d_head=xlnet_config.d_head,
                        d_inner=xlnet_config.d_inner,
                        ff_activation=xlnet_config.ff_activation,
                        untie_r=xlnet_config.untie_r,
                        is_training=run_config.is_training,
                        use_bfloat16=run_config.use_bfloat16,
                        use_tpu=run_config.use_tpu,
                        dropout=run_config.dropout,
                        dropatt=run_config.dropatt,
                        mem_len=run_config.mem_len,
                        reuse_len=run_config.reuse_len,
                        bi_data=run_config.bi_data,
                        clamp_len=run_config.clamp_len,
                        same_length=run_config.same_length)

        input_args = dict(inp_k=input_ids,
                          seg_id=seg_ids,
                          input_mask=input_mask,
                          mems=mems,
                          perm_mask=perm_mask,
                          target_mapping=target_mapping,
                          inp_q=inp_q)
        tfm_args.update(input_args)

        with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
            (self.output, self.new_mems,
             self.lookup_table) = modeling.transformer_xl(**tfm_args)

        self.input_mask = input_mask
        self.initializer = initializer
        self.xlnet_config = xlnet_config
        self.run_config = run_config