Exemple #1
0
	def build(self, inputs, for_deploy):
		scope = ""
		conf = self.conf
		name = self.name
		job_type = self.job_type
		dtype = self.dtype
		self.beam_splits = conf.beam_splits
		self.beam_size = 1 if not for_deploy else sum(self.beam_splits)

		self.enc_str_inps = inputs["enc_inps:0"]
		self.dec_str_inps = inputs["dec_inps:0"]
		self.enc_lens = inputs["enc_lens:0"] 
		self.dec_lens = inputs["dec_lens:0"]
		self.down_wgts = inputs["down_wgts:0"]

		with tf.name_scope("TableLookup"):
			# Input maps
			self.in_table = lookup.MutableHashTable(key_dtype=tf.string,
														 value_dtype=tf.int64,
														 default_value=UNK_ID,
														 shared_name="in_table",
														 name="in_table",
														 checkpoint=True)

			self.out_table = lookup.MutableHashTable(key_dtype=tf.int64,
													 value_dtype=tf.string,
													 default_value="_UNK",
													 shared_name="out_table",
													 name="out_table",
													 checkpoint=True)
			# lookup
			self.enc_inps = self.in_table.lookup(self.enc_str_inps)
			self.dec_inps = self.in_table.lookup(self.dec_str_inps)

		graphlg.info("Preparing decoder inps...")
		dec_inps = tf.slice(self.dec_inps, [0, 0], [-1, conf.output_max_len + 1])

		# Create encode graph and get attn states
		graphlg.info("Creating embeddings and embedding enc_inps.")
		with ops.device("/cpu:0"):
			self.embedding = variable_scope.get_variable("embedding", [conf.output_vocab_size, conf.embedding_size])
		with tf.name_scope("Embed") as scope:
			dec_inps = tf.slice(self.dec_inps, [0, 0], [-1, conf.output_max_len + 1])
			with ops.device("/cpu:0"):
				self.emb_inps = embedding_lookup_unique(self.embedding, self.enc_inps)
				emb_dec_inps = embedding_lookup_unique(self.embedding, dec_inps)

		graphlg.info("Creating dynamic x rnn...")
		self.enc_outs, self.enc_states, mem_size, enc_state_size = DynRNN(conf.cell_model, conf.num_units, conf.num_layers,
																		self.emb_inps, self.enc_lens, keep_prob=1.0,
																		bidi=conf.bidirectional, name_scope="DynRNNEncoder")
		batch_size = tf.shape(self.enc_outs)[0]

		if self.conf.attention:
			init_h = self.enc_states[-1].h
		else:
			mechanism = dynamic_attention_wrapper.LuongAttention(num_units=conf.num_units, memory=self.enc_outs, 
																	max_mem_size=self.conf.input_max_len,
																	memory_sequence_length=self.enc_lens)
			init_h = mechanism(self.enc_states[-1].h)

		if isinstance(self.enc_states[-1], LSTMStateTuple):
			enc_state = LSTMStateTuple(self.enc_states[-1].c, init_h) 
		
		hidden_units = int(math.sqrt(mem_size * self.conf.enc_latent_dim))
		z, mu_prior, logvar_prior = PriorNet([enc_state], hidden_units, self.conf.enc_latent_dim, stddev=1.0, prior_type=conf.prior_type)

		KLD = 0.0
		# Different graph for training and inference time 
		if not for_deploy:
			# Y inputs for posterior z 
			with tf.name_scope("YEncode"):
				y_emb_inps = tf.slice(emb_dec_inps, [0, 1, 0], [-1, -1, -1])
				y_enc_outs, y_enc_states, y_mem_size, y_enc_state_size = DynRNN(conf.cell_model, conf.num_units, conf.num_layers,
																			y_emb_inps, self.dec_lens, keep_prob=1.0, bidi=False, name_scope="y_enc")
				y_enc_state = y_enc_states[-1]

				z, KLD, l2 = CreateVAE([enc_state, y_enc_state], self.conf.enc_latent_dim, mu_prior, logvar_prior)

		# project z + x_thinking_state to decoder state
		raw_dec_states = [z, enc_state]
		# add BOW loss
		#num_hidden_units = int(math.sqrt(conf.output_vocab_size * int(decision_state.shape[1])))
		#bow_l1 = layers_core.Dense(num_hidden_units, use_bias=True, name="bow_hidden", activation=tf.tanh)
		#bow_l2 = layers_core.Dense(conf.output_vocab_size, use_bias=True, name="bow_out", activation=None)
		#bow = bow_l2(bow_l1(decision_state)) 

		#y_dec_inps = tf.slice(self.dec_inps, [0, 1], [-1, -1])
		#bow_y = tf.reduce_sum(tf.one_hot(y_dec_inps, on_value=1.0, off_value=0.0, axis=-1, depth=conf.output_vocab_size), axis=1)
		#batch_bow_losses = tf.reduce_sum(bow_y * (-1.0) * tf.nn.log_softmax(bow), axis=1)

		max_mem_size = self.conf.input_max_len + self.conf.output_max_len + 2

		with tf.name_scope("ShapeToBeam") as scope: 
			def _to_beam(t):
				beam_t = tf.reshape(tf.tile(t, [1, self.beam_size]), [-1, int(t.get_shape()[1])])
				return beam_t 
			beam_raw_dec_states = tf.contrib.framework.nest.map_structure(_to_beam, raw_dec_states) 
	
			beam_memory = tf.reshape(tf.tile(self.enc_outs, [1, 1, self.beam_size]), [-1, conf.input_max_len, mem_size])
			beam_memory_lens = tf.squeeze(tf.reshape(tf.tile(tf.expand_dims(self.enc_lens, 1), [1, self.beam_size]), [-1, 1]), 1)
			
		cell = AttnCell(cell_model=conf.cell_model, num_units=mem_size, num_layers=conf.num_layers,
						attn_type=self.conf.attention, memory=beam_memory, mem_lens=beam_memory_lens,
						max_mem_size=max_mem_size, addmem=self.conf.addmem, keep_prob=1.0,
						dtype=tf.float32, name_scope="AttnCell")
		# Fit decision states to shape of attention decoder cell states 
		zero_attn_states = DecStateInit(beam_raw_dec_states, cell, batch_size * self.beam_size)
		
		# Output projection
		with tf.variable_scope("OutProj"):
			graphlg.info("Creating out_proj...") 
			if conf.out_layer_size:
				w = tf.get_variable("proj_w", [conf.out_layer_size, conf.output_vocab_size], dtype=dtype)
			else:
				w = tf.get_variable("proj_w", [mem_size, conf.output_vocab_size], dtype=dtype)
			b = tf.get_variable("proj_b", [conf.output_vocab_size], dtype=dtype)
			self.out_proj = (w, b)
		
		if not for_deploy: 
			inputs = {}
			dec_init_state = zero_attn_states
			hp_train = helper.ScheduledEmbeddingTrainingHelper(inputs=emb_dec_inps, sequence_length=self.dec_lens, 
															   embedding=self.embedding, sampling_probability=0.0,
															   out_proj=self.out_proj)
			output_layer = layers_core.Dense(self.conf.out_layer_size, use_bias=True) if self.conf.out_layer_size else None
			my_decoder = basic_decoder.BasicDecoder(cell=cell, helper=hp_train, initial_state=dec_init_state, output_layer=output_layer)
			cell_outs, final_state = decoder.dynamic_decode(decoder=my_decoder, impute_finished=False,
															maximum_iterations=conf.output_max_len + 1, scope=scope)
			outputs = cell_outs.rnn_output

			L = tf.shape(outputs)[1]
			outputs = tf.reshape(outputs, [-1, int(self.out_proj[0].shape[0])])
			outputs = tf.matmul(outputs, self.out_proj[0]) + self.out_proj[1] 
			logits = tf.reshape(outputs, [-1, L, int(self.out_proj[0].shape[1])])

			# branch 1 for debugging, doesn't have to be called
			#m = tf.shape(self.outputs)[0]
			#self.mask = tf.zeros([m, int(w.shape[1])])
			#for i in [3]:
			#	self.mask = self.mask + tf.one_hot(indices=tf.ones([m], dtype=tf.int32) * i, on_value=100.0, depth=int(w.shape[1]))
			#self.outputs = self.outputs - self.mask

			with tf.name_scope("DebugOutputs") as scope:
				self.outputs = tf.argmax(logits, axis=2)
				self.outputs = tf.reshape(self.outputs, [-1, L])
				self.outputs = self.out_table.lookup(tf.cast(self.outputs, tf.int64))

			# branch 2 for loss
			with tf.name_scope("Loss") as scope:
				tars = tf.slice(self.dec_inps, [0, 1], [-1, L])
				wgts = tf.cumsum(tf.one_hot(self.dec_lens, L), axis=1, reverse=True)

				#wgts = wgts * tf.expand_dims(self.down_wgts, 1)
				self.loss = loss.sequence_loss(logits=logits, targets=tars, weights=wgts, average_across_timesteps=False, average_across_batch=False)
				batch_wgt = tf.reduce_sum(self.down_wgts) + 1e-12 
				#bow_loss = tf.reduce_sum(batch_bow_losses * self.down_wgts) / batch_wgt

				example_losses = tf.reduce_sum(self.loss, 1)
				see_loss = tf.reduce_sum(example_losses / tf.cast(self.dec_lens, tf.float32) * self.down_wgts) / batch_wgt
				KLD = tf.reduce_sum(KLD * self.down_wgts) / batch_wgt
				self.loss = tf.reduce_sum((example_losses + self.conf.kld_ratio * KLD) / tf.cast(self.dec_lens, tf.float32) * self.down_wgts) / batch_wgt

			with tf.name_scope(self.model_kind):
				tf.summary.scalar("loss", see_loss)
				tf.summary.scalar("kld", KLD) 
				#tf.summary.scalar("bow", bow_loss)

			graph_nodes = {
				"loss":self.loss,
				"inputs":inputs,
				"debug_outputs":self.outputs,
				"outputs":{},
				"visualize":None
			}
			return graph_nodes
		else:
			hp_infer = helper.GreedyEmbeddingHelper(embedding=self.embedding,
													start_tokens=tf.ones(shape=[batch_size * self.beam_size], dtype=tf.int32),
													end_token=EOS_ID, out_proj=self.out_proj)
			output_layer = layers_core.Dense(self.conf.out_layer_size, use_bias=True) if self.conf.out_layer_size else None
			dec_init_state = beam_decoder.BeamState(tf.zeros([batch_size * self.beam_size]), zero_attn_states, tf.zeros([batch_size * self.beam_size], tf.int32))

			my_decoder = beam_decoder.BeamDecoder(cell=cell, helper=hp_infer, out_proj=self.out_proj, initial_state=dec_init_state,
													beam_splits=self.beam_splits, max_res_num=self.conf.max_res_num, output_layer=output_layer)
			cell_outs, final_state = decoder.dynamic_decode(decoder=my_decoder, scope=scope, maximum_iterations=self.conf.output_max_len)

			L = tf.shape(cell_outs.beam_ends)[1]
			beam_symbols = cell_outs.beam_symbols
			beam_parents = cell_outs.beam_parents

			beam_ends = cell_outs.beam_ends
			beam_end_parents = cell_outs.beam_end_parents
			beam_end_probs = cell_outs.beam_end_probs
			alignments = cell_outs.alignments

			beam_ends = tf.reshape(tf.transpose(beam_ends, [0, 2, 1]), [-1, L])
			beam_end_parents = tf.reshape(tf.transpose(beam_end_parents, [0, 2, 1]), [-1, L])
			beam_end_probs = tf.reshape(tf.transpose(beam_end_probs, [0, 2, 1]), [-1, L])


			# Creating tail_ids 
			batch_size = tf.Print(batch_size, [batch_size], message="CVAERNN batch")

			#beam_symbols = tf.Print(cell_outs.beam_symbols, [tf.shape(cell_outs.beam_symbols)], message="beam_symbols")
			#beam_parents = tf.Print(cell_outs.beam_parents, [tf.shape(cell_outs.beam_parents)], message="beam_parents")
			#beam_ends = tf.Print(cell_outs.beam_ends, [tf.shape(cell_outs.beam_ends)], message="beam_ends") 
			#beam_end_parents = tf.Print(cell_outs.beam_end_parents, [tf.shape(cell_outs.beam_end_parents)], message="beam_end_parents") 
			#beam_end_probs = tf.Print(cell_outs.beam_end_probs, [tf.shape(cell_outs.beam_end_probs)], message="beam_end_probs") 
			#alignments = tf.Print(cell_outs.alignments, [tf.shape(cell_outs.alignments)], message="beam_attns")

			batch_offset = tf.expand_dims(tf.cumsum(tf.ones([batch_size, self.beam_size], dtype=tf.int32) * self.beam_size, axis=0, exclusive=True), 2)
			offset2 = tf.expand_dims(tf.cumsum(tf.ones([batch_size, self.beam_size * 2], dtype=tf.int32) * self.beam_size, axis=0, exclusive=True), 2)

			out_len = tf.shape(beam_symbols)[1]
			self.beam_symbol_strs = tf.reshape(self.out_table.lookup(tf.cast(beam_symbols, tf.int64)), [batch_size, self.beam_size, -1])
			self.beam_parents = tf.reshape(beam_parents, [batch_size, self.beam_size, -1]) - batch_offset

			self.beam_ends = tf.reshape(beam_ends, [batch_size, self.beam_size * 2, -1])
			self.beam_end_parents = tf.reshape(beam_end_parents, [batch_size, self.beam_size * 2, -1]) - offset2
			self.beam_end_probs = tf.reshape(beam_end_probs, [batch_size, self.beam_size * 2, -1])
			self.beam_attns = tf.reshape(alignments, [batch_size, self.beam_size, out_len, -1])

			#cell_outs.alignments
			#self.outputs = tf.concat([outputs_str, tf.cast(cell_outs.beam_parents, tf.string)], 1)

			#ones = tf.ones([batch_size, self.beam_size], dtype=tf.int32)
			#aux_matrix = tf.cumsum(ones * self.beam_size, axis=0, exclusive=True)

			#tm_beam_parents_reverse = tf.reverse(tf.transpose(cell_outs.beam_parents), axis=[0])
			#beam_probs = final_state[1] 

			#def traceback(prev_out, curr_input):
			#	return tf.gather(curr_input, prev_out) 
			#	
			#tail_ids = tf.reshape(tf.cumsum(ones, axis=1, exclusive=True) + aux_matrix, [-1])
			#tm_symbol_index_reverse = tf.scan(traceback, tm_beam_parents_reverse, initializer=tail_ids)
			## Create beam index for symbols, and other info  
			#tm_symbol_index = tf.concat([tf.expand_dims(tail_ids, 0), tm_symbol_index_reverse], axis=0)
			#tm_symbol_index = tf.reverse(tm_symbol_index, axis=[0])
			#tm_symbol_index = tf.slice(tm_symbol_index, [1, 0], [-1, -1])
			#symbol_index = tf.expand_dims(tf.transpose(tm_symbol_index), axis=2)
			#symbol_index = tf.concat([symbol_index, tf.cumsum(tf.ones_like(symbol_index), exclusive=True, axis=1)], axis=2)

			## index alignments and output symbols
			#alignments = tf.gather_nd(cell_outs.alignments, symbol_index)
			#symbol_ids = tf.gather_nd(cell_outs.beam_symbols, symbol_index)

			## outputs and other info
			#self.others = [alignments, beam_probs]
			#self.outputs = self.out_table.lookup(tf.cast(symbol_ids, tf.int64))

			inputs = { 
				"enc_inps:0":self.enc_str_inps,
				"enc_lens:0":self.enc_lens
			}
			outputs = {
				"beam_symbols":self.beam_symbol_strs,
				"beam_parents":self.beam_parents,
				"beam_ends":self.beam_ends,
				"beam_end_parents":self.beam_end_parents,
				"beam_end_probs":self.beam_end_probs,
				"beam_attns":self.beam_attns
			}

			graph_nodes = {
				"loss":None,
				"inputs":inputs,
				"outputs":outputs,
				"visualize":{"z":z}
			}

			return graph_nodes
    def build(self, inputs, for_deploy):
        conf = self.conf
        name = self.name
        job_type = self.job_type
        dtype = self.dtype
        self.beam_size = 1 if (not for_deploy or self.conf.variants
                               == "score") else sum(self.conf.beam_splits)
        conf.keep_prob = conf.keep_prob if not for_deploy else 1.0

        self.enc_str_inps = inputs["enc_inps:0"]
        self.dec_str_inps = inputs["dec_inps:0"]
        self.enc_lens = inputs["enc_lens:0"]
        self.dec_lens = inputs["dec_lens:0"]
        #self.down_wgts = inputs["down_wgts:0"]

        with tf.name_scope("TableLookup"):
            # lookup tables
            self.in_table = lookup.MutableHashTable(key_dtype=tf.string,
                                                    value_dtype=tf.int64,
                                                    default_value=UNK_ID,
                                                    shared_name="in_table",
                                                    name="in_table",
                                                    checkpoint=True)

            self.out_table = lookup.MutableHashTable(key_dtype=tf.int64,
                                                     value_dtype=tf.string,
                                                     default_value="_UNK",
                                                     shared_name="out_table",
                                                     name="out_table",
                                                     checkpoint=True)
            self.enc_inps = self.in_table.lookup(self.enc_str_inps)
            self.dec_inps = self.in_table.lookup(self.dec_str_inps)

        # Create encode graph and get attn states
        graphlg.info("Creating embeddings and embedding enc_inps.")
        with ops.device("/cpu:0"):
            self.embedding = variable_scope.get_variable(
                "embedding", [conf.output_vocab_size, conf.embedding_size])

        with tf.name_scope("Embed") as scope:
            dec_inps = tf.slice(self.dec_inps, [0, 0],
                                [-1, conf.output_max_len + 1])
            with ops.device("/cpu:0"):
                self.emb_inps = embedding_lookup_unique(
                    self.embedding, self.enc_inps)
                emb_dec_inps = embedding_lookup_unique(self.embedding,
                                                       dec_inps)
        # output projector (w, b)
        with tf.variable_scope("OutProj"):
            if conf.out_layer_size:
                w = tf.get_variable(
                    "proj_w", [conf.out_layer_size, conf.output_vocab_size],
                    dtype=dtype)
            elif conf.bidirectional:
                w = tf.get_variable(
                    "proj_w", [conf.num_units * 2, conf.output_vocab_size],
                    dtype=dtype)
            else:
                w = tf.get_variable("proj_w",
                                    [conf.num_units, conf.output_vocab_size],
                                    dtype=dtype)
            b = tf.get_variable("proj_b", [conf.output_vocab_size],
                                dtype=dtype)

        graphlg.info("Creating dynamic rnn...")
        self.enc_outs, self.enc_states, mem_size, enc_state_size = DynRNN(
            conf.cell_model,
            conf.num_units,
            conf.num_layers,
            self.emb_inps,
            self.enc_lens,
            keep_prob=conf.keep_prob,
            bidi=conf.bidirectional,
            name_scope="DynRNNEncoder")
        batch_size = tf.shape(self.enc_outs)[0]
        # to modify the output states of all encoder layers for dec init
        final_enc_states = self.enc_states

        with tf.name_scope("DynRNNDecode") as scope:
            with tf.name_scope("ShapeToBeam") as scope:
                beam_memory = tf.reshape(
                    tf.tile(self.enc_outs, [1, 1, self.beam_size]),
                    [-1, conf.input_max_len, mem_size])
                beam_memory_lens = tf.squeeze(
                    tf.reshape(
                        tf.tile(tf.expand_dims(self.enc_lens, 1),
                                [1, self.beam_size]), [-1, 1]), 1)

                def _to_beam(t):
                    return tf.reshape(tf.tile(t, [1, self.beam_size]),
                                      [-1, int(t.get_shape()[1])])

                beam_init_states = tf.contrib.framework.nest.map_structure(
                    _to_beam, final_enc_states)

            max_mem_size = self.conf.input_max_len + self.conf.output_max_len + 2
            cell = AttnCell(cell_model=conf.cell_model,
                            num_units=mem_size,
                            num_layers=conf.num_layers,
                            attn_type=self.conf.attention,
                            memory=beam_memory,
                            mem_lens=beam_memory_lens,
                            max_mem_size=max_mem_size,
                            addmem=self.conf.addmem,
                            keep_prob=conf.keep_prob,
                            dtype=tf.float32,
                            name_scope="AttnCell")

            dec_init_state = DecStateInit(all_enc_states=beam_init_states,
                                          decoder_cell=cell,
                                          batch_size=batch_size *
                                          self.beam_size,
                                          init_type=conf.dec_init_type,
                                          use_proj=conf.use_init_proj)

            if not for_deploy:
                hp_train = helper.ScheduledEmbeddingTrainingHelper(
                    inputs=emb_dec_inps,
                    sequence_length=self.dec_lens,
                    embedding=self.embedding,
                    sampling_probability=self.conf.sample_prob,
                    out_proj=(w, b))
                output_layer = layers_core.Dense(
                    self.conf.out_layer_size,
                    use_bias=True) if self.conf.out_layer_size else None
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=hp_train,
                    initial_state=dec_init_state,
                    output_layer=output_layer)
                cell_outs, final_state = decoder.dynamic_decode(
                    decoder=my_decoder,
                    impute_finished=True,
                    maximum_iterations=conf.output_max_len + 1,
                    scope=scope)
            elif self.conf.variants == "score":
                hp_train = helper.ScheduledEmbeddingTrainingHelper(
                    inputs=emb_dec_inps,
                    sequence_length=self.dec_lens,
                    embedding=self.embedding,
                    sampling_probability=0.0,
                    out_proj=(w, b))
                output_layer = layers_core.Dense(
                    self.conf.out_layer_size,
                    use_bias=True) if self.conf.out_layer_size else None
                my_decoder = score_decoder.ScoreDecoder(
                    cell=cell,
                    helper=hp_train,
                    out_proj=(w, b),
                    initial_state=dec_init_state,
                    output_layer=output_layer)
                cell_outs, final_state = decoder.dynamic_decode(
                    decoder=my_decoder,
                    scope=scope,
                    maximum_iterations=self.conf.output_max_len,
                    impute_finished=True)
            else:
                hp_infer = helper.GreedyEmbeddingHelper(
                    embedding=self.embedding,
                    start_tokens=tf.ones(shape=[batch_size * self.beam_size],
                                         dtype=tf.int32),
                    end_token=EOS_ID,
                    out_proj=(w, b))

                output_layer = layers_core.Dense(
                    self.conf.out_layer_size,
                    use_bias=True) if self.conf.out_layer_size else None
                dec_init_state = beam_decoder.BeamState(
                    tf.zeros([batch_size * self.beam_size]), dec_init_state,
                    tf.zeros([batch_size * self.beam_size], tf.int32))
                my_decoder = beam_decoder.BeamDecoder(
                    cell=cell,
                    helper=hp_infer,
                    out_proj=(w, b),
                    initial_state=dec_init_state,
                    beam_splits=self.conf.beam_splits,
                    max_res_num=self.conf.max_res_num,
                    output_layer=output_layer)
                cell_outs, final_state = decoder.dynamic_decode(
                    decoder=my_decoder,
                    scope=scope,
                    maximum_iterations=self.conf.output_max_len,
                    impute_finished=True)

        if not for_deploy:
            outputs = cell_outs.rnn_output
            # Output ouputprojected to logits
            L = tf.shape(outputs)[1]
            outputs = tf.reshape(outputs, [-1, int(w.shape[0])])
            outputs = tf.matmul(outputs, w) + b
            logits = tf.reshape(outputs, [-1, L, int(w.shape[1])])

            # branch 1 for debugging, doesn't have to be called
            with tf.name_scope("DebugOutputs") as scope:
                self.outputs = tf.argmax(logits, axis=2)
                self.outputs = tf.reshape(self.outputs, [-1, L])
                self.outputs = self.out_table.lookup(
                    tf.cast(self.outputs, tf.int64))

            with tf.name_scope("Loss") as scope:
                tars = tf.slice(self.dec_inps, [0, 1], [-1, L])

                # wgts may be a more complicated form, for example a partial down-weighting of a sequence
                # but here i just use  1.0 weights for all no-padding label
                wgts = tf.cumsum(tf.one_hot(self.dec_lens, L),
                                 axis=1,
                                 reverse=True)

                #wgts = wgts * tf.expand_dims(self.down_wgts, 1)

                loss_matrix = loss.sequence_loss(
                    logits=logits,
                    targets=tars,
                    weights=wgts,
                    average_across_timesteps=False,
                    average_across_batch=False)

                self.loss = see_loss = tf.reduce_sum(
                    loss_matrix) / tf.reduce_sum(wgts)

            with tf.name_scope(self.model_kind):
                tf.summary.scalar("loss", see_loss)

            graph_nodes = {
                "loss": self.loss,
                "inputs": {},
                "outputs": {},
                "debug_outputs": self.outputs
            }

        elif self.conf.variants == "score":
            L = tf.shape(cell_outs.logprobs)[1]
            one_hot = tf.one_hot(tf.slice(self.dec_inps, [0, 1], [-1, L]),
                                 depth=self.conf.output_vocab_size,
                                 axis=-1,
                                 on_value=1.0,
                                 off_value=0.0)
            outputs = tf.reduce_sum(cell_outs.logprobs * one_hot, 2)
            outputs = tf.reduce_sum(outputs, axis=1)

            graph_nodes = {
                "loss": None,
                "inputs": {
                    "enc_inps:0": self.enc_str_inps,
                    "enc_lens:0": self.enc_lens,
                    "dec_inps:0": self.dec_str_inps,
                    "dec_lens:0": self.dec_lens
                },
                "outputs": {
                    "logprobs": outputs
                },
                "visualize": None
            }

        else:
            L = tf.shape(cell_outs.beam_ends)[1]
            beam_symbols = cell_outs.beam_symbols
            beam_parents = cell_outs.beam_parents

            beam_ends = cell_outs.beam_ends
            beam_end_parents = cell_outs.beam_end_parents
            beam_end_probs = cell_outs.beam_end_probs
            alignments = cell_outs.alignments

            beam_ends = tf.reshape(tf.transpose(beam_ends, [0, 2, 1]), [-1, L])
            beam_end_parents = tf.reshape(
                tf.transpose(beam_end_parents, [0, 2, 1]), [-1, L])
            beam_end_probs = tf.reshape(
                tf.transpose(beam_end_probs, [0, 2, 1]), [-1, L])

            ## Creating tail_ids
            batch_size = tf.Print(batch_size, [batch_size], message="BATCH")
            batch_offset = tf.expand_dims(
                tf.cumsum(
                    tf.ones([batch_size, self.beam_size], dtype=tf.int32) *
                    self.beam_size,
                    axis=0,
                    exclusive=True), 2)
            offset2 = tf.expand_dims(
                tf.cumsum(
                    tf.ones([batch_size, self.beam_size * 2], dtype=tf.int32) *
                    self.beam_size,
                    axis=0,
                    exclusive=True), 2)

            out_len = tf.shape(beam_symbols)[1]
            self.beam_symbol_strs = tf.reshape(
                self.out_table.lookup(tf.cast(beam_symbols, tf.int64)),
                [batch_size, self.beam_size, -1])
            self.beam_parents = tf.reshape(
                beam_parents, [batch_size, self.beam_size, -1]) - batch_offset

            self.beam_ends = tf.reshape(beam_ends,
                                        [batch_size, self.beam_size * 2, -1])
            self.beam_end_parents = tf.reshape(
                beam_end_parents,
                [batch_size, self.beam_size * 2, -1]) - offset2
            self.beam_end_probs = tf.reshape(
                beam_end_probs, [batch_size, self.beam_size * 2, -1])
            self.beam_attns = tf.reshape(
                alignments, [batch_size, self.beam_size, out_len, -1])

            graph_nodes = {
                "loss": None,
                "inputs": {
                    "enc_inps:0": self.enc_str_inps,
                    "enc_lens:0": self.enc_lens
                },
                "outputs": {
                    "beam_symbols": self.beam_symbol_strs,
                    "beam_parents": self.beam_parents,
                    "beam_ends": self.beam_ends,
                    "beam_end_parents": self.beam_end_parents,
                    "beam_end_probs": self.beam_end_probs,
                    "beam_attns": self.beam_attns
                },
                "visualize": {}
            }
        return graph_nodes
Exemple #3
0
    def build(self, for_deploy, variants=""):
        conf = self.conf
        name = self.name
        job_type = self.job_type
        dtype = self.dtype
        self.beam_size = 1 if (not for_deploy or variants == "score") else sum(
            self.conf.beam_splits)

        graphlg.info("Creating placeholders...")
        self.enc_str_inps = tf.placeholder(tf.string,
                                           shape=(None, conf.input_max_len),
                                           name="enc_inps")
        self.enc_lens = tf.placeholder(tf.int32, shape=[None], name="enc_lens")
        self.dec_str_inps = tf.placeholder(
            tf.string, shape=[None, conf.output_max_len + 2], name="dec_inps")
        self.dec_lens = tf.placeholder(tf.int32, shape=[None], name="dec_lens")
        self.down_wgts = tf.placeholder(tf.float32,
                                        shape=[None],
                                        name="down_wgts")

        with tf.name_scope("TableLookup"):
            # lookup tables
            self.in_table = lookup.MutableHashTable(key_dtype=tf.string,
                                                    value_dtype=tf.int64,
                                                    default_value=UNK_ID,
                                                    shared_name="in_table",
                                                    name="in_table",
                                                    checkpoint=True)

            self.out_table = lookup.MutableHashTable(key_dtype=tf.int64,
                                                     value_dtype=tf.string,
                                                     default_value="_UNK",
                                                     shared_name="out_table",
                                                     name="out_table",
                                                     checkpoint=True)
            self.enc_inps = self.in_table.lookup(self.enc_str_inps)
            self.dec_inps = self.in_table.lookup(self.dec_str_inps)

        # Create encode graph and get attn states
        graphlg.info("Creating embeddings and embedding enc_inps.")
        with ops.device("/cpu:0"):
            self.embedding = variable_scope.get_variable(
                "embedding", [conf.output_vocab_size, conf.embedding_size])

        with tf.name_scope("Embed") as scope:
            dec_inps = tf.slice(self.dec_inps, [0, 0],
                                [-1, conf.output_max_len + 1])
            with ops.device("/cpu:0"):
                self.emb_inps = embedding_lookup_unique(
                    self.embedding, self.enc_inps)
                emb_dec_inps = embedding_lookup_unique(self.embedding,
                                                       dec_inps)
        # output projector (w, b)
        with tf.variable_scope("OutProj"):
            if conf.out_layer_size:
                w = tf.get_variable(
                    "proj_w", [conf.out_layer_size, conf.output_vocab_size],
                    dtype=dtype)
            elif conf.bidirectional:
                w = tf.get_variable(
                    "proj_w", [conf.num_units * 2, conf.output_vocab_size],
                    dtype=dtype)
            else:
                w = tf.get_variable("proj_w",
                                    [conf.num_units, conf.output_vocab_size],
                                    dtype=dtype)
            b = tf.get_variable("proj_b", [conf.output_vocab_size],
                                dtype=dtype)

        graphlg.info("Creating dynamic rnn...")
        self.enc_outs, self.enc_states, mem_size, enc_state_size = DynRNN(
            conf.cell_model,
            conf.num_units,
            conf.num_layers,
            self.emb_inps,
            self.enc_lens,
            keep_prob=1.0,
            bidi=conf.bidirectional,
            name_scope="DynRNNEncoder")
        batch_size = tf.shape(self.enc_outs)[0]
        # Do vae on the state of the last layer of the encoder
        final_enc_states = []
        KLDs = 0.0
        for each in self.enc_states:
            z, KLD, l2 = CreateVAE([each],
                                   self.conf.enc_latent_dim,
                                   name_scope="VAE")
            if isinstance(each, LSTMStateTuple):
                final_enc_states.append(
                    LSTMStateTuple(each.c, tf.concat([each.h, z], 1)))
            else:
                final_enc_state.append(tf.concat([z, each], 1))
            KLDs += KLD / self.conf.num_layers

        with tf.name_scope("DynRNNDecode") as scope:
            with tf.name_scope("ShapeToBeam") as scope:
                beam_memory = tf.reshape(
                    tf.tile(self.enc_outs, [1, 1, self.beam_size]),
                    [-1, conf.input_max_len, mem_size])
                beam_memory_lens = tf.squeeze(
                    tf.reshape(
                        tf.tile(tf.expand_dims(self.enc_lens, 1),
                                [1, self.beam_size]), [-1, 1]), 1)

                def _to_beam(t):
                    return tf.reshape(tf.tile(t, [1, self.beam_size]),
                                      [-1, int(t.get_shape()[1])])

                beam_init_states = tf.contrib.framework.nest.map_structure(
                    _to_beam, final_enc_states)
            max_mem_size = self.conf.input_max_len + self.conf.output_max_len + 2
            cell = AttnCell(cell_model=conf.cell_model,
                            num_units=mem_size,
                            num_layers=conf.num_layers,
                            attn_type=self.conf.attention,
                            memory=beam_memory,
                            mem_lens=beam_memory_lens,
                            max_mem_size=max_mem_size,
                            addmem=self.conf.addmem,
                            keep_prob=conf.keep_prob,
                            dtype=tf.float32,
                            name_scope="AttnCell")

            dec_init_state = DecStateInit(all_enc_states=beam_init_states,
                                          decoder_cell=cell,
                                          batch_size=batch_size *
                                          self.beam_size,
                                          init_type="each2each")

            if not for_deploy:
                hp_train = helper.ScheduledEmbeddingTrainingHelper(
                    inputs=emb_dec_inps,
                    sequence_length=self.dec_lens,
                    embedding=self.embedding,
                    sampling_probability=self.conf.sample_prob,
                    out_proj=(w, b))
                output_layer = layers_core.Dense(
                    self.conf.out_layer_size,
                    use_bias=True) if self.conf.out_layer_size else None
                my_decoder = basic_decoder.BasicDecoder(
                    cell=cell,
                    helper=hp_train,
                    initial_state=dec_init_state,
                    output_layer=output_layer)
                cell_outs, final_state = decoder.dynamic_decode(
                    decoder=my_decoder,
                    impute_finished=False,
                    maximum_iterations=conf.output_max_len + 1,
                    scope=scope)
            elif variants == "score":
                dec_init_state = zero_attn_states
                hp_train = helper.ScheduledEmbeddingTrainingHelper(
                    inputs=emb_dec_inps,
                    sequence_length=self.dec_lens,
                    embedding=self.embedding,
                    sampling_probability=0.0,
                    out_proj=(w, b))
                output_layer = layers_core.Dense(
                    self.conf.out_layer_size,
                    use_bias=True) if self.conf.out_layer_size else None
                my_decoder = score_decoder.ScoreDecoder(
                    cell=cell,
                    helper=hp_train,
                    out_proj=(w, b),
                    initial_state=dec_init_state,
                    output_layer=output_layer)
                cell_outs, final_state = decoder.dynamic_decode(
                    decoder=my_decoder,
                    scope=scope,
                    maximum_iterations=self.conf.output_max_len,
                    impute_finished=False)
            else:
                hp_infer = helper.GreedyEmbeddingHelper(
                    embedding=self.embedding,
                    start_tokens=tf.ones(shape=[batch_size * self.beam_size],
                                         dtype=tf.int32),
                    end_token=EOS_ID,
                    out_proj=(w, b))

                output_layer = layers_core.Dense(
                    self.conf.out_layer_size,
                    use_bias=True) if self.conf.out_layer_size else None
                my_decoder = beam_decoder.BeamDecoder(
                    cell=cell,
                    helper=hp_infer,
                    out_proj=(w, b),
                    initial_state=dec_init_state,
                    beam_splits=self.conf.beam_splits,
                    max_res_num=self.conf.max_res_num,
                    output_layer=output_layer)
                cell_outs, final_state = decoder.dynamic_decode(
                    decoder=my_decoder,
                    scope=scope,
                    maximum_iterations=self.conf.output_max_len,
                    impute_finished=True)

        if not for_deploy:
            outputs = cell_outs.rnn_output
            # Output ouputprojected to logits
            L = tf.shape(outputs)[1]
            outputs = tf.reshape(outputs, [-1, int(w.shape[0])])
            outputs = tf.matmul(outputs, w) + b
            logits = tf.reshape(outputs, [-1, L, int(w.shape[1])])

            # branch 1 for debugging, doesn't have to be called
            with tf.name_scope("DebugOutputs") as scope:
                self.outputs = tf.argmax(logits, axis=2)
                self.outputs = tf.reshape(self.outputs, [-1, L])
                self.outputs = self.out_table.lookup(
                    tf.cast(self.outputs, tf.int64))

            with tf.name_scope("Loss") as scope:
                tars = tf.slice(self.dec_inps, [0, 1], [-1, L])
                wgts = tf.cumsum(tf.one_hot(self.dec_lens, L),
                                 axis=1,
                                 reverse=True)
                #wgts = wgts * tf.expand_dims(self.down_wgts, 1)
                self.loss = loss.sequence_loss(logits=logits,
                                               targets=tars,
                                               weights=wgts,
                                               average_across_timesteps=False,
                                               average_across_batch=False)
                example_losses = tf.reduce_sum(self.loss, 1)

                batch_wgt = tf.reduce_sum(self.down_wgts)
                see_KLD = tf.reduce_sum(KLDs * self.down_wgts) / batch_wgt
                see_loss = tf.reduce_sum(example_losses / tf.cast(
                    self.dec_lens, tf.float32) * self.down_wgts) / batch_wgt

                # not average over length
                self.loss = tf.reduce_sum(
                    (example_losses + self.conf.kld_ratio * KLDs) *
                    self.down_wgts) / batch_wgt

            with tf.name_scope(self.model_kind):
                tf.summary.scalar("loss", see_loss)
                tf.summary.scalar("kld", see_KLD)

            graph_nodes = {
                "loss": self.loss,
                "inputs": {},
                "outputs": {},
                "debug_outputs": self.outputs
            }

        elif variants == "score":
            L = tf.shape(cell_outs.logprobs)[1]
            one_hot = tf.one_hot(tf.slice(self.dec_inps, [0, 1], [-1, L]),
                                 depth=self.conf.output_vocab_size,
                                 axis=-1,
                                 on_value=1.0,
                                 off_value=0.0)
            outputs = tf.reduce_sum(cell_outs.logprobs * one_hot, 2)
            outputs = tf.reduce_sum(outputs, axis=1)
            inputs = {
                "enc_inps:0": self.enc_str_inps,
                "enc_lens:0": self.enc_lens,
                "dec_inps:0": self.dec_str_inps,
                "dec_lens:0": self.dec_lens
            }
            graph_nodes = {
                "loss": None,
                "inputs": inputs,
                "outputs": {
                    "logprobs": outputs
                },
                "visualize": None
            }
        else:
            L = tf.shape(cell_outs.beam_ends)[1]
            beam_symbols = cell_outs.beam_symbols
            beam_parents = cell_outs.beam_parents

            beam_ends = cell_outs.beam_ends
            beam_end_parents = cell_outs.beam_end_parents
            beam_end_probs = cell_outs.beam_end_probs
            alignments = cell_outs.alignments

            beam_ends = tf.reshape(tf.transpose(beam_ends, [0, 2, 1]), [-1, L])
            beam_end_parents = tf.reshape(
                tf.transpose(beam_end_parents, [0, 2, 1]), [-1, L])
            beam_end_probs = tf.reshape(
                tf.transpose(beam_end_probs, [0, 2, 1]), [-1, L])

            ## Creating tail_ids
            batch_size = tf.Print(batch_size, [batch_size],
                                  message="VAERNN2 batch")
            batch_offset = tf.expand_dims(
                tf.cumsum(
                    tf.ones([batch_size, self.beam_size], dtype=tf.int32) *
                    self.beam_size,
                    axis=0,
                    exclusive=True), 2)
            offset2 = tf.expand_dims(
                tf.cumsum(
                    tf.ones([batch_size, self.beam_size * 2], dtype=tf.int32) *
                    self.beam_size,
                    axis=0,
                    exclusive=True), 2)

            out_len = tf.shape(beam_symbols)[1]
            self.beam_symbol_strs = tf.reshape(
                self.out_table.lookup(tf.cast(beam_symbols, tf.int64)),
                [batch_size, self.beam_size, -1])
            self.beam_parents = tf.reshape(
                beam_parents, [batch_size, self.beam_size, -1]) - batch_offset

            self.beam_ends = tf.reshape(beam_ends,
                                        [batch_size, self.beam_size * 2, -1])
            self.beam_end_parents = tf.reshape(
                beam_end_parents,
                [batch_size, self.beam_size * 2, -1]) - offset2
            self.beam_end_probs = tf.reshape(
                beam_end_probs, [batch_size, self.beam_size * 2, -1])
            self.beam_attns = tf.reshape(
                alignments, [batch_size, self.beam_size, out_len, -1])

            inputs = {
                "enc_inps:0": self.enc_str_inps,
                "enc_lens:0": self.enc_lens
            }
            outputs = {
                "beam_symbols": self.beam_symbol_strs,
                "beam_parents": self.beam_parents,
                "beam_ends": self.beam_ends,
                "beam_end_parents": self.beam_end_parents,
                "beam_end_probs": self.beam_end_probs,
                "beam_attns": self.beam_attns
            }
            graph_nodes = {
                "loss": None,
                "inputs": inputs,
                "outputs": outputs,
                "visualize": {
                    "z": z
                }
            }
        return graph_nodes