Пример #1
0
	def _create_network(self):
		# Initialize autoencode network weights and biases
		network_weights = self._initialize_weights(**self.network_architecture)
		start_token_tensor=tf.constant((np.zeros([self.batch_size,binary_dim])).astype(np.float32),dtype=tf.float32)
		self.network_weights=network_weights
		seqlen=tf.cast(tf.reduce_sum(self.mask,reduction_indices=-1),tf.int32)
		
		embedded_input,embedded_input_KLD_loss=self._get_word_embedding([network_weights['variational_encoding'],network_weights['biases_variational_encoding']],network_weights['input_meaning'],tf.reshape(self.caption_placeholder,[-1,self.network_architecture['n_input']]),logit=True)
		embedded_input=tf.reshape(embedded_input,[-1,self.network_architecture['maxlen'],self.network_architecture['n_lstm_input']])
		if not vanilla:
			embedded_input_KLD_loss=tf.reshape(embedded_input_KLD_loss,[-1,self.network_architecture['maxlen']])[:,1:]
		encoder_input=embedded_input[:,1:,:]
		cell=tf.contrib.rnn.BasicLSTMCell(self.network_architecture['n_lstm_input'])
		if lstm_stack>1:
			cell=tf.contrib.rnn.MultiRNNCell([cell]*lstm_stack)
		if not use_bdlstm:
			encoder_outs,encoder_states=rnn.dynamic_rnn(cell,encoder_input,sequence_length=seqlen-1,dtype=tf.float32,time_major=False)
		else:
			backward_cell=tf.contrib.rnn.BasicLSTMCell(self.network_architecture['n_lstm_input'])
			if lstm_stack>1:
				backward_cell=tf.contrib.rnn.MultiRNNCell([backward_cell]*lstm_stack)
			encoder_outs,encoder_states=rnn.bidirectional_dynamic_rnn(cell,backward_cell,encoder_input,sequence_length=seqlen-1,dtype=tf.float32,time_major=False)
		ix_range=tf.range(0,self.batch_size,1)
		ixs=tf.expand_dims(ix_range,-1)
		to_cat=tf.expand_dims(seqlen-2,-1)
		gather_inds=tf.concat([ixs,to_cat],axis=-1)
		print encoder_outs
		outs=tf.gather_nd(encoder_outs,gather_inds)
		self.deb=tf.gather_nd(self.caption_placeholder[:,1:,:],gather_inds)
		print outs.shape
		outs=tf.nn.dropout(outs,.75)
		input_embedding,input_embedding_KLD_loss=self._get_middle_embedding([network_weights['middle_encoding'],network_weights['biases_middle_encoding']],network_weights['middle_encoding'],outs,logit=True)
		input_embedding=tf.nn.l2_normalize(input_embedding,dim=-1)
		self.other_loss=tf.constant(0,dtype=tf.float32)
		KLD_penalty=(tf.cast(self.timestep,tf.float32)/(800000/18.0))*1e-3
		cos_penalty=tf.maximum(-0.1,(tf.cast(self.timestep,tf.float32)/(18.0)))*1e-3

		input_KLD_loss=0
		if form3:
			_x,input_KLD_loss=self._get_input_embedding([network_weights['embmap'],network_weights['embmap_biases']],network_weights['embmap'])
			input_KLD_loss=tf.reduce_mean(input_KLD_loss)*KLD_penalty#*tf.constant(0.0,dtype=tf.float32)
			normed_embedding= tf.nn.l2_normalize(input_embedding, dim=-1)
			normed_target=tf.nn.l2_normalize(_x,dim=-1)
			cos_sim=(tf.reduce_sum(tf.multiply(normed_embedding,normed_target),axis=-1))
			# self.exp_loss=tf.reduce_mean((-cos_sim))
			# self.exp_loss=tf.reduce_sum(xentropy)/float(self.batch_size)
			self.other_loss += tf.reduce_mean(1-(cos_sim))*cos_penalty
			# self.other_loss+=tf.reduce_mean(tf.reduce_sum(tf.square(_x-input_embedding),axis=-1))*cos_penalty

		# Use recognition network to determine mean and 
		# (log) variance of Gaussian distribution in latent
		# space
		# if not same_embedding:
		# 	input_embedding,input_embedding_KLD_loss=self._get_input_embedding([network_weights['variational_encoding'],network_weights['biases_variational_encoding']],network_weights['input_meaning'])
		# else:
		# 	input_embedding,input_embedding_KLD_loss=self._get_input_embedding([network_weights['variational_encoding'],network_weights['biases_variational_encoding']],network_weights['LSTM'])
		if not embeddings_trainable:
			input_embedding=tf.stop_gradient(input_embedding)
		# embed2decoder=tf.Variable(xavier_init(self.network_architecture['n_z_m_2'],self.network_architecture['n_lstm_input']),name='decoder_embedding_weight')
		# embed2decoder_bias=tf.Variable(tf.zeros(self.network_architecture['n_lstm_input']),name='decoder_embedding_bias')
		state = self.lstm.zero_state(self.batch_size, dtype=tf.float32)
		# input_embedding=tf.matmul(input_embedding,embed2decoder)+embed2decoder_bias
		loss = 0
		# self.debug=0
		probs=[]
		
		with tf.variable_scope("RNN"):
			for i in range(self.network_architecture['maxlen']): 
				if i > 0:

					# current_embedding = tf.nn.embedding_lookup(self.word_embedding, caption_placeholder[:,i-1]) + self.embedding_bias
					if form4:
						current_embedding,KLD_loss=input_embedding,0
					elif form2:
						current_embedding,KLD_loss = self._get_word_embedding([network_weights['variational_encoding'],network_weights['biases_variational_encoding']],network_weights['LSTM'], self.caption_placeholder[:,i-1,:],logit=True)
					else:
						current_embedding,KLD_loss = self._get_word_embedding([network_weights['variational_encoding'],network_weights['biases_variational_encoding']],network_weights['LSTM'], self.caption_placeholder[:,i-1])
					loss+=tf.reduce_sum(KLD_loss*self.mask[:,i])*KLD_penalty
				else:
					 current_embedding = input_embedding
				if i > 0: 
					tf.get_variable_scope().reuse_variables()

				out, state = self.lstm(current_embedding, state)

				
				if i > 0: 
					if not form2:
						labels = tf.expand_dims(self.caption_placeholder[:, i], 1)
						ix_range=tf.range(0, self.batch_size, 1)
						ixs = tf.expand_dims(ix_range, 1)
						concat = tf.concat([ixs, labels],1)
						onehot = tf.sparse_to_dense(
								concat, tf.stack([self.batch_size, self.n_words]), 1.0, 0.0)
					else:
						onehot=self.caption_placeholder[:,i,:]

					logit = tf.matmul(out, network_weights['LSTM']['encoding_weight']) + network_weights['LSTM']['encoding_bias']
					if not use_ctc:
						if form2:
							# best_word=tf.nn.softmax(logit)
							
							# best_word=tf.round(best_word)
							# all_the_f_one_h.append(best_word)
							xentropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=logit, labels=onehot)
							xentropy=tf.reduce_sum(xentropy,reduction_indices=-1)
						else:
							xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=logit, labels=onehot)

						
						xentropy = xentropy * self.mask[:,i]
						xentropy=tf.reduce_sum(xentropy)
						# self.debug+=xentropy
						loss += xentropy

					else:
						probs.append(tf.expand_dims(tf.nn.sigmoid(logit),1))
			self.debug=[input_KLD_loss,tf.reduce_mean(input_embedding_KLD_loss)]
			# self.debug=[tf.reshape(self.debug[0],[self.batch_size,self.network_architecture['maxlen'],-1])[:,1,:],self.debug[2]]
			# self.debug+=[_x,input_embedding, outs]
			# self.debug+=[input_embedding, outs]
			self.debug=[self.other_loss,cos_penalty]
			if not use_ctc:
				loss_ctc=0
				# self.debug=self.other_loss
				# self.debug=[input_KLD_loss,embedded_input_KLD_loss,input_embedding_KLD_loss]
			else:
				probs=tf.concat(probs,axis=1)
				probs=ctc_loss.get_output_probabilities(probs,self.caption_placeholder[:,1:,:])
				loss_ctc=ctc_loss.loss(probs,self.caption_placeholder[:,1:,:],self.network_architecture['maxlen']-2,self.batch_size,seqlen-1)
				self.debug=loss_ctc
			# 
			loss = (loss / tf.reduce_sum(self.mask[:,1:]))+tf.reduce_sum(input_embedding_KLD_loss)/self.batch_size*KLD_penalty+tf.reduce_sum(embedded_input_KLD_loss*self.mask[:,1:])/tf.reduce_sum(self.mask[:,1:])*KLD_penalty+loss_ctc+input_KLD_loss+self.other_loss

			self.loss=loss
Пример #2
0
    def _create_network(self):
        # Initialize autoencode network weights and biases
        network_weights = self._initialize_weights(**self.network_architecture)
        start_token_tensor = tf.constant(
            (np.zeros([self.batch_size, binary_dim])).astype(np.float32),
            dtype=tf.float32)
        self.network_weights = network_weights
        seqlen = tf.cast(tf.reduce_sum(self.mask, reduction_indices=-1),
                         tf.int32)

        KLD_penalty = tf.tanh(tf.cast(self.global_step, tf.float32) / 1600.0)

        # Use recognition network to determine mean and
        # (log) variance of Gaussian distribution in latent
        # space
        if not same_embedding:
            input_embedding, input_embedding_KLD_loss = self._get_input_embedding(
                [
                    network_weights['variational_encoding'],
                    network_weights['biases_variational_encoding']
                ], network_weights['input_meaning'])
        else:
            input_embedding, input_embedding_KLD_loss = self._get_input_embedding(
                [
                    network_weights['variational_encoding'],
                    network_weights['biases_variational_encoding']
                ], network_weights['LSTM'])

        state = self.lstm.zero_state(self.batch_size, dtype=tf.float32)

        loss = 0
        self.debug = 0
        probs = []
        with tf.variable_scope("RNN"):
            for i in range(self.network_architecture['maxlen']):
                if i > 0:

                    # current_embedding = tf.nn.embedding_lookup(self.word_embedding, caption_placeholder[:,i-1]) + self.embedding_bias
                    if form2:
                        current_embedding, KLD_loss = self._get_word_embedding(
                            [
                                network_weights['variational_encoding'],
                                network_weights['biases_variational_encoding']
                            ],
                            network_weights['LSTM'],
                            self.caption_placeholder[:, i - 1, :],
                            logit=True)
                    else:
                        current_embedding, KLD_loss = self._get_word_embedding(
                            [
                                network_weights['variational_encoding'],
                                network_weights['biases_variational_encoding']
                            ], network_weights['LSTM'],
                            self.caption_placeholder[:, i - 1])
                    if transfertype2:
                        current_embedding = tf.stop_gradient(current_embedding)
                    loss += tf.reduce_sum(
                        KLD_loss * self.mask[:, i]) * KLD_penalty
                else:
                    current_embedding = input_embedding
                if i > 0:
                    tf.get_variable_scope().reuse_variables()

                out, state = self.lstm(current_embedding, state)

                if i > 0:
                    if not form2:
                        labels = tf.expand_dims(self.caption_placeholder[:, i],
                                                1)
                        ix_range = tf.range(0, self.batch_size, 1)
                        ixs = tf.expand_dims(ix_range, 1)
                        concat = tf.concat([ixs, labels], 1)
                        onehot = tf.sparse_to_dense(
                            concat, tf.stack([self.batch_size, self.n_words]),
                            1.0, 0.0)
                    else:
                        onehot = self.caption_placeholder[:, i, :]

                    logit = tf.matmul(
                        out, network_weights['LSTM']['encoding_weight']
                    ) + network_weights['LSTM']['encoding_bias']
                    if not use_ctc:
                        if form2:
                            # best_word=tf.nn.softmax(logit)

                            # best_word=tf.round(best_word)
                            # all_the_f_one_h.append(best_word)
                            xentropy = tf.nn.sigmoid_cross_entropy_with_logits(
                                logits=logit, labels=onehot)
                            xentropy = tf.reduce_sum(xentropy,
                                                     reduction_indices=-1)
                        else:
                            xentropy = tf.nn.softmax_cross_entropy_with_logits(
                                logits=logit, labels=onehot)

                        xentropy = xentropy * self.mask[:, i]
                        xentropy = tf.reduce_sum(xentropy)
                        self.debug += xentropy
                        loss += xentropy

                    else:
                        probs.append(tf.expand_dims(tf.nn.sigmoid(logit), 1))
            if not use_ctc:
                loss_ctc = 0
            else:
                probs = tf.concat(probs, axis=1)
                probs = ctc_loss.get_output_probabilities(
                    probs, self.caption_placeholder[:, 1:, :])
                loss_ctc = ctc_loss.loss(
                    probs, self.caption_placeholder[:, 1:, :],
                    self.network_architecture['maxlen'] - 2, self.batch_size,
                    seqlen - 1)
                self.debug = loss_ctc
            # self.debug/=tf.reduce_sum(self.mask[:,1:])
            loss = (loss / tf.reduce_sum(self.mask[:, 1:])) + tf.reduce_sum(
                input_embedding_KLD_loss
            ) / self.batch_size * KLD_penalty + loss_ctc

            self.loss = loss