Ejemplo n.º 1
0
def dis_decoder(hparmas, parser, encoder_outputs, states, sequence, n_output):
    def cell_lstm(input_l, state_l):
        decoder_lstm1 = LSTM(hparmas.gen_rnn_size,
                             return_sequences=True,
                             return_state=True,
                             dropout=0.5,
                             recurrent_dropout=0.5)(input_l,
                                                    initial_state=state_l)
        decoder_outputs, state_h, state_c = LSTM(
            hparmas.gen_rnn_size,
            return_state=True,
            dropout=0.5,
            recurrent_dropout=0.5)(decoder_lstm1)
        state = [state_h, state_c]
        return decoder_outputs, state

        # print(initial_state[0])

    input = Lambda(spl, arguments={'index': 0})(sequence)
    rnn_out, state_gen = cell_lstm(input, states)
    # initial_state[0]=Lambda(reshap)(initial_state[0])
    (attention_keys, attention_values) = Lambda(pre_attein)(encoder_outputs)
    attention_score_fn = attention_utils._create_attention_score_fn(
        "attention_keys", hparmas.gen_rnn_size, "luong")
    # Attention construction function
    attention_construct_fn = attention_utils._create_attention_construct_fn(
        "attention_score", hparmas.gen_rnn_size, attention_score_fn)
    attention_option = parser.attention_option

    def atten(outputs):
        return attention_construct_fn(outputs, attention_keys,
                                      attention_values)

    rnn_out = Lambda(atten)(rnn_out)
    # out=rnn_out
    decoder_dense = Dense(n_output, activation='softmax')
    decoder_outputs = decoder_dense(rnn_out)
    # decoder_outputs = Reshape((-1,1,n_output))(decoder_outputs)
    decoder_outputs = Lambda(ex_pen)(decoder_outputs)
    out = decoder_outputs
    # out = Reshape((-1,n_units))(rnn_out)
    for i in range(1, n_steps_out):
        input_s = Lambda(spl, arguments={'index': i})(sequence)
        rnn_out, state_g = cell_lstm(input_s, state_gen)
        rnn_out = Lambda(atten)(rnn_out)
        state_gen = state_g
        # rnn_out = Reshape((-1,n_units))(rnn_out)
        # print(rnn_out)
        decoder_outputs = decoder_dense(rnn_out)
        # decoder_outputs = Reshape((-1,1,n_output))(decoder_outputs)
        decoder_outputs = Lambda(ex_pen)(decoder_outputs)
        # decoder_outputs=K.expand_dims(decoder_outputs,1)
        out = Concatenate(1)([out, decoder_outputs])

        # pri:nt(decoder_outputs)
        # decoder_outputs=AttentionDecoder(n_units,n_features,decoder_outputs.shape)(decoder_outputs,initial_state[0])
    return out, state_gen
Ejemplo n.º 2
0
    def gen_decoder(self, input, encoder_outputs, states, n_output):
        # def init_f(x):
        #     init_z = initializers.Zeros()
        #     return
        #
        # # input=Lambda(init_f)(self.latent_dim)
        # init_z = initializers.Zeros()
        # inpt=init_z(shape=(1, 1, self.latent_dim))
        # inputg=Input(tensor=inpt)
        def cell_lstm(input_l, state_l):
            decoder_lstm1 = LSTM(self.gen_size,
                                 return_sequences=True,
                                 return_state=True,
                                 dropout=0.5,
                                 recurrent_dropout=0.5)(input_l,
                                                        initial_state=state_l)
            decoder_output, state_h, state_c = LSTM(
                self.gen_size,
                return_state=True,
                dropout=0.5,
                recurrent_dropout=0.5)(decoder_lstm1)
            # decoder_dense = Dense(n_output, activation='softmax')
            # decoder_outputs = decoder_dense(decoder_output)
            state = [state_h, state_c]
            return decoder_output, state

            # print(initial_state[0])

        rnn_out, state_gen = cell_lstm(input, states)
        # initial_state[0]=Lambda(reshap)(initial_state[0])
        (attention_keys,
         attention_values) = Lambda(pre_attein)(encoder_outputs)
        attention_score_fn = attention_utils._create_attention_score_fn(
            "attention_keys", self.gen_size, "luong")
        # Attention construction function
        attention_construct_fn = attention_utils._create_attention_construct_fn(
            "attention_score", self.gen_size, attention_score_fn)

        def categori(x):
            logit = x
            categorical = tf.distributions.Categorical(logits=logit)
            print(logit.shape)
            fake = categorical.sample(8)
            log_prob = categorical.log_prob(fake)
            print("wosjofake", fake)
            return ([fake, log_prob])

        def atten(outputs):
            return attention_construct_fn(outputs, attention_keys,
                                          attention_values)

        rnn_out = Lambda(atten)(rnn_out)
        if args.gen_train_strategy != 'cross_entory':
            # output, log_prob = Lambda(categori)(rnn_out)
            # out=rnn_out
            decoder_dense = Dense(n_output, activation='sigmoid')
            decoder_outputs = decoder_dense(rnn_out)
            # decoder_outputs = Reshape((-1,1,n_output))(decoder_outputs)
            # decoder_outputs = Lambda(ex_pen)(decoder_outputs)
            # out = decoder_outputs
            log_prob = Lambda(get_prob)(decoder_outputs)
            decoder_outputs = Lambda(ex_pen)(decoder_outputs)
            print("woshideout", decoder_outputs)
            out = decoder_outputs
            log_pro = Lambda(ex_pen)(log_prob)
            # out = Reshape((-1,n_units))(rnn_out)
            for i in range(0, self.n_steps_out - 1):
                rnn_out, state_g = cell_lstm(decoder_outputs, state_gen)
                print("woshiout", rnn_out)
                # print(rnn_out)
                rnn_out = Lambda(atten)(rnn_out)
                # output, log_prob = Lambda(categori)(rnn_out)
                state_gen = state_g
                decoder_dense = Dense(n_output, activation='sigmoid')
                decoder_outputs = decoder_dense(rnn_out)
                # rnn_out = Reshape((-1,n_units))(rnn_out)
                # print(rnn_out)
                log_prob = Lambda(get_prob)(decoder_outputs)
                decoder_outputs = Lambda(ex_pen)(decoder_outputs)
                # decoder_outputs = decoder_dense(rnn_out)
                # # decoder_outputs = Reshape((-1,1,n_output))(decoder_outputs)
                # decoder_outputs = Lambda(ex_pen)(decoder_outputs)
                # decoder_outputs=K.expand_dims(decoder_outputs,1)

                log_prob = Lambda(ex_pen)(log_prob)
                # decoder_outputs=K.expand_dims(decoder_outputs,1)
                out = Concatenate(1)([out, decoder_outputs])
                log_pro = Concatenate(1)([log_pro, log_prob])
            return out, log_pro
        else:
            # ?
            # decoder_outputs = Reshape((-1,1,n_output))(decoder_outputs)
            decoder_outputs = Lambda(ex_pen)(rnn_out)
            out = decoder_outputs
            # out = Reshape((-1,n_units))(rnn_out)
            for i in range(0, self.n_steps_out - 1):
                rnn_out, state_g = cell_lstm(decoder_outputs, state_gen)
                print("woshiout", rnn_out)
                rnn_out = Lambda(atten)(rnn_out)
                state_gen = state_g
                # rnn_out = Reshape((-1,n_units))(rnn_out)
                # print(rnn_out)
                # decoder_outputs = decoder_dense(rnn_out)
                # decoder_outputs = Reshape((-1,1,n_output))(decoder_outputs)
                decoder_outputs = Lambda(ex_pen)(rnn_out)
                # decoder_outputs=K.expand_dims(decoder_outputs,1)
                out = Concatenate(1)([out, decoder_outputs])

            # def init_f(ne):
            #
            #     init_z = initializers.Zeros()
            #     y=init_z(self.batch_size,self.n_steps_out,ne)
            #     return y

        def get_zero():
            d = np.zeros((self.batch_size, self.n_steps_out, self.n_output),
                         dtype="float32")
            return d

        log_pro = Lambda(get_zero)
        # print("woshiout",out)
        return out, log_pro