Ejemplo n.º 1
0
    def model_graph(self, itemseq_input, train=True):
        model_para = self.model_para
        context_seq = itemseq_input[:, 0:-1]
        label_seq = itemseq_input[:, 1:]

        self.context_embedding = tf.nn.embedding_lookup(self.allitem_embeddings,
                                                   context_seq, name="context_embedding")
        dilate_input = self.context_embedding
        for layer_id, dilation in enumerate(model_para['dilations']):
            dilate_input = ops.nextitnet_residual_block(dilate_input, dilation,
                                                        layer_id, model_para['dilated_channels'],
                                                        model_para['kernel_size'], causal=True, train=train)
        return label_seq, dilate_input
    def model_graph(self, itemseq_input, train=True):
        model_para = self.model_para
        context_seq = itemseq_input[:, 0:-1]  #取所有行,除了最后一列
        label_seq = itemseq_input[:, 1:]  #取所有行,除了第一列

        self.context_embedding = tf.nn.embedding_lookup(
            self.allitem_embeddings, context_seq, name="context_embedding")
        #就是根据train_inputs中的id,寻找embeddings中的对应元素。比如,train_inputs=[1,3,5],则找出embeddings中下标为1,3,5的向量组成一个矩阵返回。
        dilate_input = self.context_embedding
        for layer_id, dilation in enumerate(model_para['dilations']):
            dilate_input = ops.nextitnet_residual_block(
                dilate_input,
                dilation,
                layer_id,
                model_para['dilated_channels'],
                model_para['kernel_size'],
                causal=True,
                train=train)
        return label_seq, dilate_input
    def model_graph(self, itemseq_input, train):
        model_para = self.model_para

        self.context_embedding = tf.nn.embedding_lookup(
            self.allitem_embeddings, itemseq_input, name="context_embedding")

        #positional embedding

        if self.model_para['has_positionalembedding']:
            pos_emb = self.embedding(tf.tile(
                tf.expand_dims(tf.range(tf.shape(itemseq_input)[1]), 0),
                [tf.shape(self.itemseq_input)[0], 1]),
                                     max_position=model_para['max_position'],
                                     num_units=self.embedding_width,
                                     zero_pad=False,
                                     scale=False,
                                     l2_reg=0.0,
                                     scope="dec_pos",
                                     with_t=False)
            dilate_input = tf.concat([self.context_embedding, pos_emb], -1)
        else:
            dilate_input = self.context_embedding

        residual_channels = dilate_input.get_shape().as_list()[-1]

        for layer_id, dilation in enumerate(model_para['dilations']):

            dilate_input = ops.nextitnet_residual_block(
                dilate_input,
                dilation,
                layer_id,
                residual_channels,
                model_para['kernel_size'],
                causal=False,
                train=train)

        return dilate_input