Exemplo n.º 1
0
class siamese:
    def __init__(self,
                 hidden_units,
                 embedding_size=None,
                 vocab_size=None,
                 lr=1e-1,
                 clipping='norm',
                 clip_val=1,
                 cell=LSTMCell,
                 bidirectional=False,
                 trainable_embed=False,
                 embedding_matrix=None):

        assert clipping in ['none', 'value', 'norm']

        if clipping == 'value':
            assert isinstance(clip_val, list) and len(clip_val) == 2
        if clipping == 'norm':
            assert (isinstance(clip_val, int)
                    or isinstance(clip_val, float)) and clip_val > 0

        tf.reset_default_graph()
        self.sess = tf.InteractiveSession()

        self.vocab_size = vocab_size

        if trainable_embed:
            self.embedding = tf.Variable(tf.truncated_normal(
                [self.vocab_size, embedding_size]),
                                         name='embedding')
        else:
            self.embedding = tf.Variable(embedding_matrix,
                                         trainable=False,
                                         dtype=tf.float32)

        self.PAD = self.embedding.get_shape(
        )[0].value - 1  #last value is prepared just for padding

        self.question_ph = tf.placeholder(tf.int32, [None, None])
        self.answer_ph = tf.placeholder(tf.int32, [None, None])

        self.question = tf.nn.embedding_lookup(self.embedding,
                                               self.question_ph)
        self.answer = tf.nn.embedding_lookup(self.embedding, self.answer_ph)

        self.targets_ph = tf.placeholder(tf.int32, [None])

        self.lengths_q = tf.placeholder(tf.int32, [None])
        self.lengths_a = tf.placeholder(tf.int32, [None])

        if bidirectional:
            self.cell_fw = cell(num_units=hidden_units)
            self.cell_bw = cell(num_units=hidden_units)

            with tf.variable_scope('twins') as scope:

                self.outs_q, self.states_q = tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=self.cell_fw,
                    cell_bw=self.cell_bw,
                    sequence_length=self.lengths_q,
                    inputs=self.question,
                    dtype=tf.float32)

                scope.reuse_variables()

                self.outs_a, self.states_a = tf.nn.bidirectional_dynamic_rnn(
                    cell_fw=self.cell_fw,
                    cell_bw=self.cell_bw,
                    sequence_length=self.lengths_q,
                    inputs=self.question,
                    dtype=tf.float32)

                if isinstance(self.states_q[0], LSTMStateTuple):

                    sq_fw, sq_bw = self.states_q
                    sa_fw, sa_bw = self.states_a

                    sqc = tf.concat([sq_fw.c, sq_bw.c], axis=1)
                    sac = tf.concat([sa_fw.c, sa_bw.c], axis=1)
                    sqh = tf.concat([sq_fw.h, sq_bw.h], axis=1)
                    sah = tf.concat([sa_fw.h, sa_bw.h], axis=1)

                    self.states_q = LSTMStateTuple(c=sqc, h=sqh)
                    self.states_a = LSTMStateTuple(c=sac, h=sqh)

                else:

                    self.states_q = tf.concat(self.states_q, axis=1)
                    self.states_a = tf.concat(self.states_a, axis=1)

        else:

            self.siamese_cell = cell(num_units=hidden_units)

            with tf.variable_scope('twins') as scope:
                self.outs_q, self.states_q = tf.nn.dynamic_rnn(
                    cell=self.siamese_cell,
                    sequence_length=self.lengths_q,
                    inputs=self.question,
                    dtype=tf.float32)

                scope.reuse_variables()

                self.outs_a, self.states_a = tf.nn.dynamic_rnn(
                    cell=self.siamese_cell,
                    sequence_length=self.lengths_a,
                    inputs=self.answer,
                    dtype=tf.float32)

        if isinstance(self.states_q, LSTMStateTuple):

            self.states_q = tf.concat([self.states_q.c, self.states_q.h],
                                      axis=1)
            self.states_a = tf.concat([self.states_a.c, self.states_a.h],
                                      axis=1)

        state_size = self.states_q.get_shape()[1].value

        self.distance_weights = tf.Variable(tf.truncated_normal(
            [state_size, 2], stddev=1),
                                            name='dist_W')
        self.distance_bias = tf.Variable(tf.ones([2]), name='dist_B')

        self.intermediate = tf.matmul(
            tf.abs(self.states_q - self.states_a),
            self.distance_weights) + self.distance_bias

        #         self.distance = tf.nn.sigmoid( self.intermediate )

        self.loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=self.intermediate, labels=self.targets_ph)

        self.opt = tf.train.AdamOptimizer(lr)

        self.grads_vars = self.opt.compute_gradients(tf.reduce_mean(self.loss))
        if clipping == 'none':
            self.optimize = self.opt.minimize(tf.reduce_mean(self.loss))

        elif clipping in ['value', 'norm']:

            if clipping == 'value':
                self.clipped_grads = [
                    (tf.clip_by_value(grad, clip_val[0], clip_val[1]), var)
                    for grad, var in self.grads_vars
                ]

            else:
                self.clipped_grads = [(tf.clip_by_norm(grad, clip_by_val), var)
                                      for grad, var in self.grads_vars]

            self.optimize = self.opt.apply_gradients(self.clipped_grads)


#         self.gradients_norms = [tf.nn.l2_loss(x[0]) for x in self.grads_vars]

        self.sess.run(tf.global_variables_initializer())

    def transform_inputs(self, x, y, targets):

        #         targets = np.expand_dims(targets, axis=1)

        import copy
        X = copy.deepcopy(x)
        Y = copy.deepcopy(y)

        lx = list(map(len, X))
        ly = list(map(len, Y))

        mx = max(lx)
        my = max(ly)

        N = len(X)

        temp_x = np.zeros((N, mx), dtype=np.int32)
        temp_y = np.zeros((N, my), dtype=np.int32)

        if isinstance(X[0], list):
            for i in range(N):

                X[i] = X[i] + [self.PAD] * (mx - lx[i])
                Y[i] = Y[i] + [self.PAD] * (my - ly[i])

                temp_x[i, :] = np.asarray(X[i], dtype=np.int32)
                temp_y[i, :] = np.asarray(Y[i], dtype=np.int32)

        else:
            for i in range(N):

                X[i] = X[i].tolist() + [self.PAD] * (mx - lx[i])
                Y[i] = Y[i].tolist() + [self.PAD] * (my - ly[i])

                temp_x[i, :] = np.asarray(X[i], dtype=np.int32)
                temp_y[i, :] = np.asarray(Y[i], dtype=np.int32)

        return {
            self.question_ph: temp_x,
            self.answer_ph: temp_y,
            self.lengths_q: lx,
            self.lengths_a: ly,
            self.targets_ph: targets
        }

    def train(self, questions, answers, targets):

        fd = self.transform_inputs(questions, answers, targets)

        self.sess.run(self.optimize, feed_dict=fd)

    def run_op(self, op, fd):

        return self.sess.run(op, feed_dict=fd)

    def reset(self):

        self.sess.run(tf.global_variables_initializer())

    def infer_class(self, questions, answers):

        fd = self.transform_inputs(questions, answers)

        return self.sess.run(tf.argmax(self.intermediate, 1), feed_dict=fd)

    def infer_loss(self, questions, answers, targets, **kwargs):

        fd = self.transform_inputs(questions, answers, targets)

        return self.sess.run(tf.reduce_mean(self.loss), feed_dict=fd)

    def infer_accuracy(self, questions, answers, targets):

        fd = self.transform_inputs(questions, answers, targets)

        return self.sess.run(tf.argmax(self.intermediate, 1),
                             feed_dict=fd) == targets

    def infer_metrics(self, questions, answers, targets):

        fd = self.transform_inputs(questions, answers, targets)

        return self.sess.run(
            [tf.argmax(self.intermediate, 1),
             tf.reduce_mean(self.loss)],
            feed_dict=fd)

    def infer_probs(self, questions, answers, targets):

        N = len(questions)
        fd = self.transform_inputs(questions, answers, targets)
        probs = self.sess.run(tf.nn.softmax(self.intermediate), feed_dict=fd)
        return probs[range(N), targets]

    def calc_gradients(self, questions, answers, targets):

        fd = self.transform_inputs(questions, answers, targets)

        pre_g = self.sess.run([x[0] for x in self.grads_vars], feed_dict=fd)
        N = len(pre_g)
        for i in range(N):
            if not isinstance(pre_g[i], np.ndarray):
                #counter-effort against IndexedSlicesValues, appearing alongside nn.embedding_lookup
                pre_g[i] = pre_g[i][0]
        return pre_g

    def predict(self, questions, answers, **kwargs):

        questions = [u[:self.max_qL] for u in questions]
        answers = [u[:self.max_aL] for u in answers]

        lengths_q = list(map(len, questions))
        lengths_a = list(map(len, answers))

        return self.sess.run(self.distance,
                             feed_dict={
                                 self.question_ph: questions,
                                 self.answer_ph: answers,
                                 self.lengths_q: lengths_q,
                                 self.lengths_a: lengths_a
                             },
                             **kwargs)
Exemplo n.º 2
0
    def build(self, inputs, for_deploy):
        scope = ""
        conf = self.conf
        name = self.name
        job_type = self.job_type
        dtype = self.dtype
        self.beam_splits = conf.beam_splits
        self.beam_size = 1 if not for_deploy else sum(self.beam_splits)

        self.enc_str_inps = inputs["enc_inps:0"]
        self.dec_str_inps = inputs["dec_inps:0"]
        self.enc_lens = inputs["enc_lens:0"]
        self.dec_lens = inputs["dec_lens:0"]
        self.down_wgts = inputs["down_wgts:0"]

        with tf.name_scope("TableLookup"):
            # Input maps
            self.in_table = lookup.MutableHashTable(key_dtype=tf.string,
                                                    value_dtype=tf.int64,
                                                    default_value=UNK_ID,
                                                    shared_name="in_table",
                                                    name="in_table",
                                                    checkpoint=True)

            self.out_table = lookup.MutableHashTable(key_dtype=tf.int64,
                                                     value_dtype=tf.string,
                                                     default_value="_UNK",
                                                     shared_name="out_table",
                                                     name="out_table",
                                                     checkpoint=True)
            # lookup
            self.enc_inps = self.in_table.lookup(self.enc_str_inps)
            self.dec_inps = self.in_table.lookup(self.dec_str_inps)

        graphlg.info("Preparing decoder inps...")
        dec_inps = tf.slice(self.dec_inps, [0, 0],
                            [-1, conf.output_max_len + 1])

        # Create encode graph and get attn states
        graphlg.info("Creating embeddings and embedding enc_inps.")
        with ops.device("/cpu:0"):
            self.embedding = variable_scope.get_variable(
                "embedding", [conf.output_vocab_size, conf.embedding_size])
        with tf.name_scope("Embed") as scope:
            dec_inps = tf.slice(self.dec_inps, [0, 0],
                                [-1, conf.output_max_len + 1])
            with ops.device("/cpu:0"):
                self.emb_inps = embedding_lookup_unique(
                    self.embedding, self.enc_inps)
                emb_dec_inps = embedding_lookup_unique(self.embedding,
                                                       dec_inps)

        graphlg.info("Creating dynamic x rnn...")
        self.enc_outs, self.enc_states, mem_size, enc_state_size = DynRNN(
            conf.cell_model,
            conf.num_units,
            conf.num_layers,
            self.emb_inps,
            self.enc_lens,
            keep_prob=1.0,
            bidi=conf.bidirectional,
            name_scope="DynRNNEncoder")

        batch_size = tf.shape(self.enc_outs)[0]

        with tf.variable_scope("OutProj"):
            graphlg.info("Creating out_proj...")
            if conf.out_layer_size:
                w = tf.get_variable(
                    "proj_w", [conf.out_layer_size, conf.output_vocab_size],
                    dtype=dtype)
            else:
                w = tf.get_variable("proj_w",
                                    [mem_size, conf.output_vocab_size],
                                    dtype=dtype)
            b = tf.get_variable("proj_b", [conf.output_vocab_size],
                                dtype=dtype)
            self.out_proj = (w, b)

        if self.conf.attention:
            init_h = self.enc_states[-1].h
        else:
            mechanism = dynamic_attention_wrapper.LuongAttention(
                num_units=conf.num_units,
                memory=self.enc_outs,
                max_mem_size=self.conf.input_max_len,
                memory_sequence_length=self.enc_lens)
            init_h = mechanism(self.enc_states[-1].h)

        if isinstance(self.enc_states[-1], LSTMStateTuple):
            enc_state = LSTMStateTuple(self.enc_states[-1].c, init_h)
        else:
            enc_state = self.enc_states[-1]

        hidden_units = int(math.sqrt(mem_size * self.conf.enc_latent_dim))
        z, mu_prior, logvar_prior = PriorNet([enc_state],
                                             hidden_units,
                                             self.conf.enc_latent_dim,
                                             stddev=1,
                                             prior_type=conf.prior_type)

        KLD = 0.0
        # Different graph for training and inference time
        if not for_deploy:
            # Y inputs for posterior z
            with tf.name_scope("YEncode"):
                y_emb_inps = tf.slice(emb_dec_inps, [0, 1, 0], [-1, -1, -1])
                y_enc_outs, y_enc_states, y_mem_size, y_enc_state_size = DynRNN(
                    conf.cell_model,
                    conf.num_units,
                    conf.num_layers,
                    y_emb_inps,
                    self.dec_lens,
                    keep_prob=1.0,
                    bidi=False,
                    name_scope="y_enc")
                y_enc_state = y_enc_states[-1]
                z, KLD, l2 = CreateVAE([enc_state, y_enc_state],
                                       self.conf.enc_latent_dim, mu_prior,
                                       logvar_prior)

        # project z + x_thinking_state to decoder state
        if isinstance(enc_state, LSTMStateTuple):
            h_gate = tf.layers.dense(z,
                                     int(enc_state.h.get_shape()[1]),
                                     use_bias=True,
                                     name="z_gate_h",
                                     activation=tf.sigmoid)
            c_gate = tf.layers.dense(z,
                                     int(enc_state.c.get_shape()[1]),
                                     use_bias=True,
                                     name="z_gate_c",
                                     activation=tf.sigmoid)
            raw_dec_states = [
                LSTMStateTuple(tf.concat([c_gate * enc_state.c, z], 1),
                               tf.concat([h_gate * enc_state.h, z], 1))
            ]
        else:
            gate = tf.layers.dense(z,
                                   int(enc_state.get_shape()[1]),
                                   use_bias=True,
                                   name="z_gate",
                                   activation=tf.sigmoid)
            raw_dec_states = tf.concat([gate * enc_state, z], 1)

        # add BOW loss
        #num_hidden_units = int(math.sqrt(conf.output_vocab_size * int(decision_state.shape[1])))
        #bow_l1 = layers_core.Dense(num_hidden_units, use_bias=True, name="bow_hidden", activation=tf.tanh)
        #bow_l2 = layers_core.Dense(conf.output_vocab_size, use_bias=True, name="bow_out", activation=None)
        #bow = bow_l2(bow_l1(decision_state))

        #y_dec_inps = tf.slice(self.dec_inps, [0, 1], [-1, -1])
        #bow_y = tf.reduce_sum(tf.one_hot(y_dec_inps, on_value=1.0, off_value=0.0, axis=-1, depth=conf.output_vocab_size), axis=1)
        #batch_bow_losses = tf.reduce_sum(bow_y * (-1.0) * tf.nn.log_softmax(bow), axis=1)

        max_mem_size = self.conf.input_max_len + self.conf.output_max_len + 2

        def _to_beam(t):
            beam_t = tf.reshape(tf.tile(t, [1, self.beam_size]),
                                [-1, int(t.get_shape()[1])])
            return beam_t

        with tf.name_scope("ShapeToBeam") as scope:
            beam_raw_dec_states = tf.contrib.framework.nest.map_structure(
                _to_beam, raw_dec_states)
            beam_memory = tf.reshape(
                tf.tile(self.enc_outs, [1, 1, self.beam_size]),
                [-1, conf.input_max_len, mem_size])
            beam_memory_lens = tf.squeeze(
                tf.reshape(
                    tf.tile(tf.expand_dims(self.enc_lens, 1),
                            [1, self.beam_size]), [-1, 1]), 1)
            beam_z = tf.contrib.framework.nest.map_structure(_to_beam, z)

        cell = AttnCell(cell_model=conf.cell_model,
                        num_units=mem_size,
                        num_layers=conf.num_layers,
                        attn_type=self.conf.attention,
                        memory=beam_memory,
                        mem_lens=beam_memory_lens,
                        max_mem_size=max_mem_size,
                        addmem=self.conf.addmem,
                        z=beam_z,
                        keep_prob=1.0,
                        dtype=tf.float32,
                        name_scope="AttnCell")
        # Fit decision states to shape of attention decoder cell states
        zero_attn_states = DecStateInit(beam_raw_dec_states, cell,
                                        batch_size * self.beam_size,
                                        conf.dec_init_type, conf.use_init_proj)

        if not for_deploy:
            inputs = {}
            dec_init_state = zero_attn_states
            hp_train = helper.ScheduledEmbeddingTrainingHelper(
                inputs=emb_dec_inps,
                sequence_length=self.dec_lens,
                embedding=self.embedding,
                sampling_probability=0.0,
                out_proj=self.out_proj)
            output_layer = layers_core.Dense(
                self.conf.out_layer_size,
                use_bias=True) if self.conf.out_layer_size else None
            my_decoder = basic_decoder.BasicDecoder(
                cell=cell,
                helper=hp_train,
                initial_state=dec_init_state,
                output_layer=output_layer)
            cell_outs, final_state = decoder.dynamic_decode(
                decoder=my_decoder,
                impute_finished=False,
                maximum_iterations=conf.output_max_len + 1,
                scope=scope)
            outputs = cell_outs.rnn_output

            L = tf.shape(outputs)[1]
            outputs = tf.reshape(outputs, [-1, int(self.out_proj[0].shape[0])])
            outputs = tf.matmul(outputs, self.out_proj[0]) + self.out_proj[1]
            logits = tf.reshape(outputs,
                                [-1, L, int(self.out_proj[0].shape[1])])

            # branch 1 for debugging, doesn't have to be called
            #m = tf.shape(self.outputs)[0]
            #self.mask = tf.zeros([m, int(w.shape[1])])
            #for i in [3]:
            #	self.mask = self.mask + tf.one_hot(indices=tf.ones([m], dtype=tf.int32) * i, on_value=100.0, depth=int(w.shape[1]))
            #self.outputs = self.outputs - self.mask

            with tf.name_scope("DebugOutputs") as scope:
                self.outputs = tf.argmax(logits, axis=2)
                self.outputs = tf.reshape(self.outputs, [-1, L])
                self.outputs = self.out_table.lookup(
                    tf.cast(self.outputs, tf.int64))

            # branch 2 for loss
            with tf.name_scope("Loss") as scope:
                tars = tf.slice(self.dec_inps, [0, 1], [-1, L])

                # wgts may be a more complicated form, for example a partial down-weighting of a sequence
                # but here i just use  1.0 weights for all no-padding label
                wgts = tf.cumsum(tf.one_hot(self.dec_lens, L),
                                 axis=1,
                                 reverse=True)

                #wgts = wgts * tf.expand_dims(self.down_wgts, 1)
                loss_matrix = loss.sequence_loss(
                    logits=logits,
                    targets=tars,
                    weights=wgts,
                    average_across_timesteps=False,
                    average_across_batch=False)
                #bow_loss = tf.reduce_sum(batch_bow_losses * self.down_wgts) / batch_wgt

                example_total_wgts = tf.reduce_sum(wgts, 1)
                total_wgts = tf.reduce_sum(example_total_wgts)

                example_losses = tf.reduce_sum(loss_matrix, 1)
                see_loss = tf.reduce_sum(example_losses) / total_wgts

                KLD = tf.reduce_sum(KLD * example_total_wgts) / total_wgts
                self.loss = tf.reduce_sum(
                    example_losses + self.conf.kld_ratio * KLD) / total_wgts

            with tf.name_scope(self.model_kind):
                tf.summary.scalar("loss", see_loss)
                tf.summary.scalar("kld", KLD)
                #tf.summary.scalar("bow", bow_loss)
                for each in tf.trainable_variables():
                    tf.summary.histogram(each.name, each)

            graph_nodes = {
                "loss": self.loss,
                "inputs": inputs,
                "debug_outputs": self.outputs,
                "outputs": {},
                "visualize": None
            }
            return graph_nodes
        else:
            hp_infer = helper.GreedyEmbeddingHelper(
                embedding=self.embedding,
                start_tokens=tf.ones(shape=[batch_size * self.beam_size],
                                     dtype=tf.int32),
                end_token=EOS_ID,
                out_proj=self.out_proj)
            output_layer = layers_core.Dense(
                self.conf.out_layer_size,
                use_bias=True) if self.conf.out_layer_size else None
            dec_init_state = beam_decoder.BeamState(
                tf.zeros([batch_size * self.beam_size]), zero_attn_states,
                tf.zeros([batch_size * self.beam_size], tf.int32))

            my_decoder = beam_decoder.BeamDecoder(
                cell=cell,
                helper=hp_infer,
                out_proj=self.out_proj,
                initial_state=dec_init_state,
                beam_splits=self.beam_splits,
                max_res_num=self.conf.max_res_num,
                output_layer=output_layer)
            cell_outs, final_state = decoder.dynamic_decode(
                decoder=my_decoder,
                scope=scope,
                maximum_iterations=self.conf.output_max_len)

            L = tf.shape(cell_outs.beam_ends)[1]
            beam_symbols = cell_outs.beam_symbols
            beam_parents = cell_outs.beam_parents

            beam_ends = cell_outs.beam_ends
            beam_end_parents = cell_outs.beam_end_parents
            beam_end_probs = cell_outs.beam_end_probs
            alignments = cell_outs.alignments

            beam_ends = tf.reshape(tf.transpose(beam_ends, [0, 2, 1]), [-1, L])
            beam_end_parents = tf.reshape(
                tf.transpose(beam_end_parents, [0, 2, 1]), [-1, L])
            beam_end_probs = tf.reshape(
                tf.transpose(beam_end_probs, [0, 2, 1]), [-1, L])

            # Creating tail_ids
            batch_size = tf.Print(batch_size, [batch_size],
                                  message="CVAERNN batch")

            #beam_symbols = tf.Print(cell_outs.beam_symbols, [tf.shape(cell_outs.beam_symbols)], message="beam_symbols")
            #beam_parents = tf.Print(cell_outs.beam_parents, [tf.shape(cell_outs.beam_parents)], message="beam_parents")
            #beam_ends = tf.Print(cell_outs.beam_ends, [tf.shape(cell_outs.beam_ends)], message="beam_ends")
            #beam_end_parents = tf.Print(cell_outs.beam_end_parents, [tf.shape(cell_outs.beam_end_parents)], message="beam_end_parents")
            #beam_end_probs = tf.Print(cell_outs.beam_end_probs, [tf.shape(cell_outs.beam_end_probs)], message="beam_end_probs")
            #alignments = tf.Print(cell_outs.alignments, [tf.shape(cell_outs.alignments)], message="beam_attns")

            batch_offset = tf.expand_dims(
                tf.cumsum(
                    tf.ones([batch_size, self.beam_size], dtype=tf.int32) *
                    self.beam_size,
                    axis=0,
                    exclusive=True), 2)
            offset2 = tf.expand_dims(
                tf.cumsum(
                    tf.ones([batch_size, self.beam_size * 2], dtype=tf.int32) *
                    self.beam_size,
                    axis=0,
                    exclusive=True), 2)

            out_len = tf.shape(beam_symbols)[1]
            self.beam_symbol_strs = tf.reshape(
                self.out_table.lookup(tf.cast(beam_symbols, tf.int64)),
                [batch_size, self.beam_size, -1])
            self.beam_parents = tf.reshape(
                beam_parents, [batch_size, self.beam_size, -1]) - batch_offset

            self.beam_ends = tf.reshape(beam_ends,
                                        [batch_size, self.beam_size * 2, -1])
            self.beam_end_parents = tf.reshape(
                beam_end_parents,
                [batch_size, self.beam_size * 2, -1]) - offset2
            self.beam_end_probs = tf.reshape(
                beam_end_probs, [batch_size, self.beam_size * 2, -1])
            self.beam_attns = tf.reshape(
                alignments, [batch_size, self.beam_size, out_len, -1])

            #cell_outs.alignments
            #self.outputs = tf.concat([outputs_str, tf.cast(cell_outs.beam_parents, tf.string)], 1)

            #ones = tf.ones([batch_size, self.beam_size], dtype=tf.int32)
            #aux_matrix = tf.cumsum(ones * self.beam_size, axis=0, exclusive=True)

            #tm_beam_parents_reverse = tf.reverse(tf.transpose(cell_outs.beam_parents), axis=[0])
            #beam_probs = final_state[1]

            #def traceback(prev_out, curr_input):
            #	return tf.gather(curr_input, prev_out)
            #
            #tail_ids = tf.reshape(tf.cumsum(ones, axis=1, exclusive=True) + aux_matrix, [-1])
            #tm_symbol_index_reverse = tf.scan(traceback, tm_beam_parents_reverse, initializer=tail_ids)
            ## Create beam index for symbols, and other info
            #tm_symbol_index = tf.concat([tf.expand_dims(tail_ids, 0), tm_symbol_index_reverse], axis=0)
            #tm_symbol_index = tf.reverse(tm_symbol_index, axis=[0])
            #tm_symbol_index = tf.slice(tm_symbol_index, [1, 0], [-1, -1])
            #symbol_index = tf.expand_dims(tf.transpose(tm_symbol_index), axis=2)
            #symbol_index = tf.concat([symbol_index, tf.cumsum(tf.ones_like(symbol_index), exclusive=True, axis=1)], axis=2)

            ## index alignments and output symbols
            #alignments = tf.gather_nd(cell_outs.alignments, symbol_index)
            #symbol_ids = tf.gather_nd(cell_outs.beam_symbols, symbol_index)

            ## outputs and other info
            #self.others = [alignments, beam_probs]
            #self.outputs = self.out_table.lookup(tf.cast(symbol_ids, tf.int64))

            inputs = {
                "enc_inps:0": self.enc_str_inps,
                "enc_lens:0": self.enc_lens
            }
            outputs = {
                "beam_symbols": self.beam_symbol_strs,
                "beam_parents": self.beam_parents,
                "beam_ends": self.beam_ends,
                "beam_end_parents": self.beam_end_parents,
                "beam_end_probs": self.beam_end_probs,
                "beam_attns": self.beam_attns
            }

            graph_nodes = {
                "loss": None,
                "inputs": inputs,
                "outputs": outputs,
                "visualize": {
                    "z": z
                }
            }

            return graph_nodes
Exemplo n.º 3
0
def build_rnn(movies_cnt,
              cell_type='gru',
              user_aware=True,
              user_cnt=None,
              rating_aware=True,
              rnn_unit=300,
              user_embedding=300,
              movie_emb_dim=300,
              feed_previous=True,
              loss_weights=None,
              rating_with_user=False,
              batch_size=32):

    loss_weights = loss_weights or [10, 2]
    movie_idx_ph = tf.placeholder(tf.int32, [None, None])
    _, maxlen = tf.unstack(tf.shape(movie_idx_ph))

    if_training = tf.placeholder_with_default(True, [])
    cell = cells[cell_type](num_units=rnn_unit)

    movie_embeddings = build_embedding(movies_cnt, movie_emb_dim,
                                       'movie_embedding')

    if user_aware and user_cnt is None:
        raise ValueError

    if user_aware:

        user_idx_ph = tf.placeholder(tf.int32, [None])
        if cell_type == 'lstm':

            c_user_embedding = build_embedding(user_cnt,
                                               user_embedding,
                                               name='user_c_embedding')
            h_user_embedding = build_embedding(user_cnt,
                                               user_embedding,
                                               name='user_h_embedding')
            state = LSTMStateTuple(
                c=tf.nn.embedding_lookup(c_user_embedding, user_idx_ph),
                h=tf.nn.embedding_lookup(h_user_embedding, user_idx_ph))

        elif cell_type == 'gru':

            user_embedding = build_embedding(user_cnt,
                                             user_embedding,
                                             name='user_embedding')
            state = tf.nn.embedding_lookup(user_embedding, user_idx_ph)

    else:

        state = cell.zero_state(batch_size=batch_size, dtype=tf.float32)

    def _choose_best(vec, reuse=False):
        with tf.variable_scope(name_or_scope='chooser', reuse=reuse) as scope:
            w = tf.get_variable(name='weights',
                                shape=[movie_emb_dim, movies_cnt])
            b = tf.get_variable(name='bias', shape=[movies_cnt])
            return tf.matmul(vec, w) + b

    # not using dynamic_rnn since I want to feed previous output

    def walker(idx, input, outputs, state, fprev):

        output, state = cell(input, state)

        new_idx = tf.cond(
            fprev[idx],
            lambda: tf.cast(tf.argmax(_choose_best(output), 1), tf.int32),
            lambda: movie_idx_ph[:, idx + 1])

        input = tf.nn.embedding_lookup(movie_embeddings, new_idx)

        return idx + 1, input, tf.concat(
            (outputs, tf.expand_dims(output, axis=1)), axis=1), state, fprev

    def cond(idx, input, outputs, state, fprev):
        return idx < maxlen - 1

    idx = tf.Variable(0)
    input = tf.nn.embedding_lookup(movie_embeddings, movie_idx_ph[:, 0])

    feed_prev = tf.placeholder(tf.bool, [None], name='feed_prev_ph')

    loop_vars = [
        idx, input,
        tf.zeros((batch_size, 0, movie_emb_dim), dtype=tf.float32), state,
        feed_prev
    ]

    shape_invs = [
        idx.get_shape(),
        input.get_shape(),
        tf.TensorShape((batch_size, None, movie_emb_dim)),
        state.get_shape(),
        feed_prev.get_shape()
    ]

    print(len(loop_vars), len(shape_invs))
    idx, last_output, outputs, state, fp = tf.while_loop(
        cond, walker, loop_vars=loop_vars, shape_invariants=shape_invs)

    logits = tf.reshape(outputs, (-1, rnn_unit))
    logits = _choose_best(logits, reuse=True)
    logits = tf.reshape(logits, (batch_size, -1, movies_cnt))

    clf_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=movie_idx_ph[:, 1:])

    def training_mask():
        clf_mask = tf.greater(movie_idx_ph[:, 1:], tf.cast(0, dtype=tf.int32))
        clf_mask = tf.cast(clf_mask, tf.float32)
        return clf_mask

    def val_mask():
        clf_mask = tf.greater(movie_idx_ph[:, 1:], tf.cast(0, dtype=tf.int32))
        clf_mask = tf.cast(clf_mask, tf.float32)
        clf_mask = tf.multiply(clf_mask, tf.cast(feed_prev[1:], tf.float32))
        return clf_mask

    clf_mask = tf.cond(if_training, training_mask, val_mask)
    clf_loss *= clf_mask
    clf_loss = tf.reduce_sum(clf_loss) / tf.reduce_sum(clf_mask)
    total_loss = loss_weights[0] * clf_loss

    if rating_aware:
        true_ratings = tf.placeholder('float', [None, None])

        ratings = linear(outputs, 1)
        ratings = tf.squeeze(ratings, axis=2)

        rat_loss = tf.square(ratings - true_ratings)
        mask = tf.greater(true_ratings, tf.cast(0, dtype=tf.float32))
        mask = tf.cast(mask, tf.float32)
        rat_loss *= mask
        rat_loss = tf.reduce_sum(rat_loss) / tf.reduce_sum(mask)
        total_loss += loss_weights[1] * rat_loss

    bag = {
        'base': [movie_idx_ph, total_loss, clf_loss],
        'feed_prev': feed_prev,
        'if_training': if_training,
        'movie_embeddings': movie_embeddings
    }

    if user_aware:
        bag['user'] = user_idx_ph

    if rating_aware:
        bag['ratings'] = [true_ratings, rat_loss]

    return bag
Exemplo n.º 4
0
	def build(self, inputs, for_deploy):
		scope = ""
		conf = self.conf
		dtype = self.dtype
		beam_size = 1 if not for_deploy else sum(conf.beam_splits)

		with tf.name_scope("WordEmbedding"):
			# Input maps
			self.in_table = lookup.MutableHashTable(key_dtype=tf.string,
													value_dtype=tf.int64,
													default_value=UNK_ID,
													shared_name="in_table",
													name="in_table",
													checkpoint=True)

			self.out_table = lookup.MutableHashTable(key_dtype=tf.int64,
													 value_dtype=tf.string,
													 default_value="_UNK",
													 shared_name="out_table",
													 name="out_table",
													 checkpoint=True)
			enc_inps = self.in_table.lookup(inputs["enc_inps:0"])
			dec_inps = self.in_table.lookup(inputs["dec_inps:0"])

			graphlg.info("Creating embeddings and embedding enc_inps.")
			with tf.device("/cpu:0"):
				self.embedding = variable_scope.get_variable("embedding", [conf.output_vocab_size, conf.embedding_size])
				emb_inps = embedding_lookup_unique(self.embedding, enc_inps)
				emb_dec_inps = embedding_lookup_unique(self.embedding, dec_inps)
			emb_dec_next_inps = tf.slice(emb_dec_inps, [0, 0, 0], [-1, conf.output_max_len + 1, -1])

		
		batch_size = tf.shape(enc_inps)[0]

		# Create encode graph and get attn states
		graphlg.info("Creating dynamic x rnn...")
		enc_outs, enc_states, mem_size, enc_state_size = DynEncode(conf.cell_model, conf.num_units, conf.num_layers,
																emb_inps, inputs["enc_lens:0"], keep_prob=1.0,
																bidi=conf.bidirectional, name_scope="DynEncodeX")
		
		with tf.variable_scope("AttnEncState") as scope2:
			mechanism = Luong1_2(num_units=conf.num_units, memory=enc_outs, max_mem_size=conf.input_max_len, memory_sequence_length=inputs["enc_lens:0"], name=scope2.original_name_scope)
			if isinstance(enc_states[-1], LSTMStateTuple):
				#score = tf.expand_dims(tf.nn.softmax(mechanism(enc_states[-1].h)), 1)
				score = tf.expand_dims(mechanism(enc_states[-1].h, ()), 1)
				attention_h = tf.squeeze(tf.matmul(score, enc_outs), 1)
				enc_state = LSTMStateTuple(enc_states[-1].c, attention_h) 
			else:
				#score = tf.expand_dims(tf.nn.softmax(mechanism(enc_states[-1])), 1)
				score = tf.expand_dims(mechanism(enc_states[-1], ()), 1)
				enc_state = tf.squeeze(tf.matmul(score, enc_outs), 1)

		hidden_units = int(math.sqrt(mem_size * conf.enc_latent_dim))
		z, mu_prior, logvar_prior = Ptheta([enc_state], hidden_units, conf.enc_latent_dim, stddev=1, prior_type=conf.prior_type, name_scope="EncToPtheta")

		KLD = 0.0
		# Y inputs for posterior z when training
		if not for_deploy:
			#with tf.name_scope("variational_distribution") as scope:
			y_emb_inps = tf.slice(emb_dec_inps, [0, 1, 0], [-1, -1, -1])
			y_enc_outs, y_enc_states, y_mem_size, y_enc_state_size = DynEncode(conf.cell_model, conf.num_units, conf.num_layers, y_emb_inps, inputs["dec_lens:0"],
																					keep_prob=conf.keep_prob, bidi=False, name_scope="DynEncodeY")
			z, KLD, l2 = VAE([enc_state, y_enc_states[-1]], conf.enc_latent_dim, mu_prior, logvar_prior, name_scope="VAE")

		# project z + x_thinking_state to decoder state
		with tf.name_scope("GatedZState"):
			if isinstance(enc_state, LSTMStateTuple):
				h_gate = tf.layers.dense(z, int(enc_state.h.get_shape()[1]), use_bias=True, name="z_gate_h", activation=tf.sigmoid)
				c_gate = tf.layers.dense(z, int(enc_state.c.get_shape()[1]), use_bias=True, name="z_gate_c", activation=tf.sigmoid)
				raw_dec_states = tf.concat([c_gate * enc_state.c, h_gate * enc_state.h, z], 1)
				#raw_dec_states = LSTMStateTuple(tf.concat([c_gate * enc_state.c, z], 1), tf.concat([h_gate * enc_state.h, z], 1))
			else:
				gate = tf.layers.dense(z, int(enc_state.get_shape()[1]), use_bias=True, name="z_gate", activation=tf.sigmoid)
				raw_dec_states = tf.concat([gate * enc_state, z], 1)

		# add BOW loss
		#num_hidden_units = int(math.sqrt(conf.output_vocab_size * int(decision_state.shape[1])))
		#bow_l1 = layers_core.Dense(num_hidden_units, use_bias=True, name="bow_hidden", activation=tf.tanh)
		#bow_l2 = layers_core.Dense(conf.output_vocab_size, use_bias=True, name="bow_out", activation=None)
		#bow = bow_l2(bow_l1(decision_state)) 

		#y_dec_inps = tf.slice(self.dec_inps, [0, 1], [-1, -1])
		#bow_y = tf.reduce_sum(tf.one_hot(y_dec_inps, on_value=1.0, off_value=0.0, axis=-1, depth=conf.output_vocab_size), axis=1)
		#batch_bow_losses = tf.reduce_sum(bow_y * (-1.0) * tf.nn.log_softmax(bow), axis=1)

		max_mem_size = conf.input_max_len + conf.output_max_len + 2
		with tf.name_scope("ShapeToBeam"):
			beam_raw_dec_states = nest.map_structure(lambda x:tile_batch(x, beam_size), raw_dec_states)
			beam_memory = nest.map_structure(lambda x:tile_batch(x, beam_size), enc_outs)
			beam_memory_lens = tf.squeeze(nest.map_structure(lambda x:tile_batch(x, beam_size), tf.expand_dims(inputs["enc_lens:0"], 1)), 1)
			beam_z = nest.map_structure(lambda x:tile_batch(x, beam_size), z)

		#def _to_beam(t):
		#	beam_t = tf.reshape(tf.tile(t, [1, beam_size]), [-1, int(t.get_shape()[1])])
		#	return beam_t 
		#with tf.name_scope("ShapeToBeam") as scope: 
		#	beam_raw_dec_states = tf.contrib.framework.nest.map_structure(_to_beam, raw_dec_states) 
		#	beam_memory = tf.reshape(tf.tile(self.enc_outs, [1, 1, beam_size]), [-1, conf.input_max_len, mem_size])
		#	beam_memory_lens = tf.squeeze(tf.reshape(tf.tile(tf.expand_dims(inputs["enc_lens:0"], 1), [1, beam_size]), [-1, 1]), 1)
		#	beam_z = tf.contrib.framework.nest.map_structure(_to_beam, z)
			
		#cell = AttnCell(cell_model=conf.cell_model, num_units=mem_size, num_layers=conf.num_layers,
		#				attn_type=conf.attention, memory=beam_memory, mem_lens=beam_memory_lens,
		#				max_mem_size=max_mem_size, addmem=conf.addmem, z=beam_z, keep_prob=conf.keep_prob,
		#				dtype=tf.float32)
		#with tf.variable_scope("DynDecode/AttnCell") as dyn_scope:
		decoder_multi_rnn_cells = CreateMultiRNNCell(conf.cell_model, num_units=mem_size, num_layers=conf.num_layers, output_keep_prob=conf.keep_prob)
		zero_cell_states = DecCellStateInit(beam_raw_dec_states, decoder_multi_rnn_cells, name="InitCell")

		attn_cell = AttnCellWrapper(cell=decoder_multi_rnn_cells, cell_init_states=zero_cell_states, attn_type=conf.attention,
									attn_size=mem_size, memory=beam_memory, mem_lens=beam_memory_lens, max_mem_size=max_mem_size,
									addmem=conf.addmem, z=beam_z, dtype=tf.float32, name="AttnWrapper")
			
		if self.conf.attention:
			dec_init_state = None 
		else:
			dec_init_state = beam_decoder.BeamState(tf.zeros_like(beam_memory_lens, tf.float32), zero_cell_states, tf.zeros_like(beam_memory_lens))
		with tf.variable_scope("OutProj"):
			graphlg.info("Creating out_proj...") 
			if conf.out_layer_size:
				w = tf.get_variable("proj_w", [conf.out_layer_size, conf.output_vocab_size], dtype=dtype)
			else:
				w = tf.get_variable("proj_w", [mem_size, conf.output_vocab_size], dtype=dtype)
			b = tf.get_variable("proj_b", [conf.output_vocab_size], dtype=dtype)
			out_proj = (w, b)

		if not for_deploy: 
			hp_train = helper1_2.ScheduledEmbeddingTrainingHelper(inputs=emb_dec_next_inps, sequence_length=inputs["dec_lens:0"], embedding=self.embedding,
																sampling_probability=0.0, out_proj=out_proj)
			output_layer = layers_core.Dense(conf.out_layer_size, use_bias=True) if conf.out_layer_size else None
			my_decoder = basic_decoder1_2.BasicDecoder(cell=attn_cell, helper=hp_train, initial_state=dec_init_state, output_layer=output_layer)
			cell_outs, final_state, seq_len = decoder1_2.dynamic_decode(decoder=my_decoder, impute_finished=True, maximum_iterations=conf.output_max_len + 1)

			#cell_outs = tf.Print(cell_outs, [tf.shape(cell_outs)], message="cell_outs_shape")
			with tf.name_scope("Logits"):
				L = tf.shape(cell_outs.rnn_output)[1]
				rnn_output = tf.reshape(cell_outs.rnn_output, [-1, int(out_proj[0].shape[0])])
				rnn_output = tf.matmul(rnn_output, out_proj[0]) + out_proj[1] 
				logits = tf.reshape(rnn_output, [-1, L, int(out_proj[0].shape[1])])

			with tf.name_scope("DebugOutputs") as scope:
				outputs = tf.argmax(logits, axis=2)
				outputs = tf.reshape(outputs, [-1, L])
				outputs = self.out_table.lookup(tf.cast(outputs, tf.int64))

			# branch 2 for loss
			with tf.name_scope("Loss") as scope:
				tars = tf.slice(dec_inps, [0, 1], [-1, L])
				# wgts may be a more complicated form, for example a partial down-weighting of a sequence
				# but here i just use  1.0 weights for all no-padding label
				wgts = tf.cumsum(tf.one_hot(inputs["dec_lens:0"], L), axis=1, reverse=True)
				#wgts = wgts * tf.expand_dims(self.down_wgts, 1)
				loss_matrix = loss.sequence_loss(logits=logits, targets=tars, weights=wgts, average_across_timesteps=False, average_across_batch=False)
				#bow_loss = tf.reduce_sum(batch_bow_losses * self.down_wgts) / batch_wgt
				example_total_wgts = tf.reduce_sum(wgts, 1)
				total_wgts = tf.reduce_sum(example_total_wgts) 

				example_losses = tf.reduce_sum(loss_matrix, 1)
				see_loss = tf.reduce_sum(example_losses) / total_wgts

				KLD = tf.reduce_sum(KLD * example_total_wgts) / total_wgts 
				self.loss = tf.reduce_sum(example_losses + conf.kld_ratio * KLD) / total_wgts 

			with tf.name_scope(self.model_kind):
				tf.summary.scalar("loss", see_loss)
				tf.summary.scalar("kld", KLD) 
				#tf.summary.scalar("bow", bow_loss)
				for each in tf.trainable_variables():
					tf.summary.histogram(each.name, each)
			graph_nodes = {
				"loss":self.loss,
				"inputs":inputs,
				"debug_outputs":outputs,
				"outputs":{},
				"visualize":None
			}
			return graph_nodes
		else:
			beam_batch_size = tf.shape(beam_memory_lens)[0]
			hp_infer = helper1_2.GreedyEmbeddingHelper(embedding=self.embedding, start_tokens=tf.ones([beam_batch_size], dtype=tf.int32),
														end_token=EOS_ID, out_proj=out_proj)
			output_layer = layers_core.Dense(conf.out_layer_size, use_bias=True) if conf.out_layer_size else None

				

			my_decoder = beam_decoder.BeamDecoder(cell=attn_cell, helper=hp_infer, out_proj=out_proj, initial_state=dec_init_state, beam_splits=conf.beam_splits,
													max_res_num=conf.max_res_num, output_layer=output_layer)
			#cell_outs, final_state = decoder.dynamic_decode(decoder=my_decoder, scope=scope, maximum_iterations=conf.output_max_len)
			cell_outs, final_state, seq_len = decoder1_2.dynamic_decode(decoder=my_decoder, impute_finished=True, maximum_iterations=conf.output_max_len + 1)

			L = tf.shape(cell_outs.beam_ends)[1]
			beam_symbols = cell_outs.beam_symbols
			beam_parents = cell_outs.beam_parents

			beam_ends = cell_outs.beam_ends
			beam_end_parents = cell_outs.beam_end_parents
			beam_end_probs = cell_outs.beam_end_probs
			alignments = cell_outs.alignments

			beam_ends = tf.reshape(tf.transpose(beam_ends, [0, 2, 1]), [-1, L])
			beam_end_parents = tf.reshape(tf.transpose(beam_end_parents, [0, 2, 1]), [-1, L])
			beam_end_probs = tf.reshape(tf.transpose(beam_end_probs, [0, 2, 1]), [-1, L])

			# Creating tail_ids 
			batch_size = beam_batch_size / beam_size
			batch_size = tf.Print(batch_size, [batch_size], message="BATCH")

			#beam_symbols = tf.Print(cell_outs.beam_symbols, [tf.shape(cell_outs.beam_symbols)], message="beam_symbols")
			#beam_parents = tf.Print(cell_outs.beam_parents, [tf.shape(cell_outs.beam_parents)], message="beam_parents")
			#beam_ends = tf.Print(cell_outs.beam_ends, [tf.shape(cell_outs.beam_ends)], message="beam_ends") 
			#beam_end_parents = tf.Print(cell_outs.beam_end_parents, [tf.shape(cell_outs.beam_end_parents)], message="beam_end_parents") 
			#beam_end_probs = tf.Print(cell_outs.beam_end_probs, [tf.shape(cell_outs.beam_end_probs)], message="beam_end_probs") 
			#alignments = tf.Print(cell_outs.alignments, [tf.shape(cell_outs.alignments)], message="beam_attns")

			batch_offset = tf.expand_dims(tf.cumsum(tf.ones([batch_size, beam_size], dtype=tf.int32) * beam_size, axis=0, exclusive=True), 2)
			offset2 = tf.expand_dims(tf.cumsum(tf.ones([batch_size, beam_size * 2], dtype=tf.int32) * beam_size, axis=0, exclusive=True), 2)

			out_len = tf.shape(beam_symbols)[1]
			self.beam_symbol_strs = tf.reshape(self.out_table.lookup(tf.cast(beam_symbols, tf.int64)), [batch_size, beam_size, -1])
			self.beam_parents = tf.reshape(beam_parents, [batch_size, beam_size, -1]) - batch_offset

			self.beam_ends = tf.reshape(beam_ends, [batch_size, beam_size * 2, -1])
			self.beam_end_parents = tf.reshape(beam_end_parents, [batch_size, beam_size * 2, -1]) - offset2
			self.beam_end_probs = tf.reshape(beam_end_probs, [batch_size, beam_size * 2, -1])
			self.beam_attns = tf.reshape(alignments, [batch_size, beam_size, out_len, -1])

			#cell_outs.alignments
			#self.outputs = tf.concat([outputs_str, tf.cast(cell_outs.beam_parents, tf.string)], 1)

			#ones = tf.ones([batch_size, self.beam_size], dtype=tf.int32)
			#aux_matrix = tf.cumsum(ones * self.beam_size, axis=0, exclusive=True)

			#tm_beam_parents_reverse = tf.reverse(tf.transpose(cell_outs.beam_parents), axis=[0])
			#beam_probs = final_state[1] 

			#def traceback(prev_out, curr_input):
			#	return tf.gather(curr_input, prev_out) 
			#	
			#tail_ids = tf.reshape(tf.cumsum(ones, axis=1, exclusive=True) + aux_matrix, [-1])
			#tm_symbol_index_reverse = tf.scan(traceback, tm_beam_parents_reverse, initializer=tail_ids)
			## Create beam index for symbols, and other info  
			#tm_symbol_index = tf.concat([tf.expand_dims(tail_ids, 0), tm_symbol_index_reverse], axis=0)
			#tm_symbol_index = tf.reverse(tm_symbol_index, axis=[0])
			#tm_symbol_index = tf.slice(tm_symbol_index, [1, 0], [-1, -1])
			#symbol_index = tf.expand_dims(tf.transpose(tm_symbol_index), axis=2)
			#symbol_index = tf.concat([symbol_index, tf.cumsum(tf.ones_like(symbol_index), exclusive=True, axis=1)], axis=2)

			## index alignments and output symbols
			#alignments = tf.gather_nd(cell_outs.alignments, symbol_index)
			#symbol_ids = tf.gather_nd(cell_outs.beam_symbols, symbol_index)

			## outputs and other info
			#self.others = [alignments, beam_probs]
			#self.outputs = self.out_table.lookup(tf.cast(symbol_ids, tf.int64))

			outputs = {
				"beam_symbols":self.beam_symbol_strs,
				"beam_parents":self.beam_parents,
				"beam_ends":self.beam_ends,
				"beam_end_parents":self.beam_end_parents,
				"beam_end_probs":self.beam_end_probs,
				"beam_attns":self.beam_attns
			}
			
			infer_inputs = {} 
			infer_inputs["enc_inps:0"] = inputs["enc_inps:0"]
			infer_inputs["enc_lens:0"] = inputs["enc_lens:0"]
			graph_nodes = {
				"loss":None,
				"inputs":infer_inputs,
				"outputs":outputs,
				"visualize":{"z":z}
			}

			return graph_nodes
Exemplo n.º 5
0
def build_rnn(input_size, output_size, hidden, cell='lstm', average=False, bidirectional=True, time_major=False):
    
    inputs = tf.placeholder(tf.float32, [None, None, input_size])
    targets = tf.placeholder(tf.int32, [None])
    training = tf.placeholder(tf.bool, [])
    
    cells = {
        'lstm': LSTMCell,
        'gru': GRUCell,
    }
    
    lengths = tf.placeholder(tf.int32, [None])
    outputs, state = [], []
    
    if bidirectional:
        
        cell_fw = cells[cell](hidden, initializer=xavier_initializer())
        cell_bw = cells[cell](hidden, initializer=xavier_initializer())
        
        outputs, states = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell_fw,
            cell_bw=cell_bw,
            sequence_length=lengths,
            inputs=inputs,
            dtype=tf.float32
        )
        
        if isinstance(states[0], LSTMStateTuple):
            state = LSTMStateTuple(c=tf.concat((states[0].c, states[1].c), axis=1), 
                                   h=tf.concat((states[0].h, states[1].h), axis=1))
        else:
            state = tf.concat((states[0], states[1]), axis=1)
            
        outputs = tf.concat((outputs[0], outputs[1]), axis=2)
        
    else:
        
        cell_fw = cells[cell](hidden, initializer=xavier_initializer())
        
        outputs, state = tf.nn.dynamic_rnn(
            cell=cell_fw,
            sequence_length=lengths,
            inputs=inputs,
            dtype=tf.float32
        )
        
    if isinstance(state, LSTMStateTuple):
        
        state = tf.concat((state.c, state.h), axis=1)

    if average:
        output = tf.reduce_sum(outputs, axis=1)/tf.expand_dims(tf.cast(lengths, tf.float32), axis=1)
        
        print(output)
        
        weights = init_xavier([output.get_shape()[1].value, output_size])
        bias = init_xavier([output_size])
        
        logits = tf.matmul(output, weights)+bias
        return {'inputs':inputs, 'outputs':targets, 'lengths':lengths, 'training':training}, logits, {'weights':[weights, bias], 'outputs':outputs}, output_size
        
    else:
        weights = init_normal_var([state.get_shape()[1].value, output_size])
        bias = init_normal_var([output_size])
        logits = tf.matmul(state, weights)+bias

        return {'inputs':inputs, 'outputs':targets, 'lengths':lengths, 'training':training}, logits, {'weights':[weights, bias], 'outputs':outputs}, output_size