예제 #1
0
 def generate(self, init_states):
     init_states = [dy.nobackprop(s) for s in init_states]
     states = self.lstm.initial_state(init_states)
     result = [None]
     while result[-1] != vocabulary.STOP_I:
         unormalized_distribution = W * state.output() + b
         vec = unormalized_distribution.value()
         prediction = np.argmax(dis_vec)
         result.append(prediction)
         
         states.add_input(self.lu[prediction])
     
     return "".join(result)
예제 #2
0
 def get_w_repr(self, word, train=False, update=True):
     """
     Get representation of word (word embedding)
     """
     if train:
         if self.w_dropout_rate > 0.0:
             w_id = self.w2i[UNK] if drop(word, self.wcount, self.w_dropout_rate) else self.w2i.get(word, self.w2i[UNK])
     else:
         if self.mimickx_model_path: # if given use MIMICKX
             if word not in self.w2i: #
                 #print("predict with MIMICKX for: ", word)
                 return dynet.inputVector(self.mimickx_model.predict(word).npvalue())
         w_id = self.w2i.get(word, self.w2i[UNK])
     if not update:
         return dynet.nobackprop(self.wembeds[w_id])
     else:
         return self.wembeds[w_id] 
예제 #3
0
 def get_w_repr(self, word, train=False, update=True):
     """
     Get representation of word (word embedding)
     """
     if train:
         if self.w_dropout_rate > 0.0:
             w_id = self.w2i[UNK] if drop(word, self.wcount, self.w_dropout_rate) else self.w2i.get(word, self.w2i[UNK])
     else:
         if self.mimickx_model_path: # if given use MIMICKX
             if word not in self.w2i: #
                 #print("predict with MIMICKX for: ", word)
                 return dynet.inputVector(self.mimickx_model.predict(word).npvalue())
         w_id = self.w2i.get(word, self.w2i[UNK])
     if not update:
         return dynet.nobackprop(self.wembeds[w_id])
     else:
         return self.wembeds[w_id] 
예제 #4
0
    def generator_train(self, example):

        text = example.sentence
        coded_text = self.vocabulary.code_chars(text)

        input = self.get_input(example, training=True, do_not_renew=False, backprop=True)
        input_noback = dy.nobackprop(input)
        
        real_loss = self.generator.train_real(input_noback, coded_text)
        
        fake_loss = - self.generator.train_fake(input, coded_text)
        fake_loss.backward()
        
        self.generator.zero_gradient()
        
        self.trainer.update()
        
        return real_loss.value()
예제 #5
0
    def discriminator_train(self, example):

        real_labels = example.get_aux_labels()
        n_labels = self.adversary_classifier.output_size()
        fake_labels = set([i for i in range(n_labels) if i not in real_labels])
        
        input = self.get_input(example, training=True, do_not_renew=False, backprop=True)
        input_noback = dy.nobackprop(input)
        
        real_loss = self.discriminator.train_real(input_noback, real_labels)
        fake_loss = self.discriminator.train_fake(input, fake_labels)
        fake_loss.backward()
        
        self.discriminator.zero_gradient()
        
        self.trainer.update()
        
        return real_loss.value()
예제 #6
0
 def get_input(self, example, training, do_not_renew=False, backprop=True):
     encoding = self._get_input(example, training, do_not_renew, backprop)
     if backprop:
         return encoding
     else:
         return dy.nobackprop(encoding)