def define_updates(self, new_chars, emb_path, char2idx, new_grams=None, ng_emb_path=None, gram2idx=None): self.nums_chars += len(new_chars) if self.word_vec and emb_path is not None: old_emb_weights = self.emb_layer.embeddings emb_dim = old_emb_weights.get_shape().as_list()[1] emb_len = old_emb_weights.get_shape().as_list()[0] new_emb = toolbox.get_new_embeddings(new_chars, emb_dim, emb_path) n_emb_sh = new_emb.get_shape().as_list() if len(n_emb_sh) > 1: new_emb_weights = tf.concat(axis=0, values=[old_emb_weights[:len(char2idx) - len(new_chars)], new_emb, old_emb_weights[len(char2idx):]]) if new_emb_weights.get_shape().as_list()[0] > emb_len: new_emb_weights = new_emb_weights[:emb_len] assign_op = old_emb_weights.assign(new_emb_weights) self.updates.append(assign_op) if self.ngram is not None and ng_emb_path is not None: old_gram_weights = [ng_layer.embeddings for ng_layer in self.gram_layers] ng_emb_dim = old_gram_weights[0].get_shape().as_list()[1] new_ng_emb = toolbox.get_new_ng_embeddings(new_grams, ng_emb_dim, ng_emb_path) for i in range(len(old_gram_weights)): new_ng_weight = tf.concat(axis=0, values=[old_gram_weights[i][:len(gram2idx[i]) - len(new_grams[i])], new_ng_emb[i], old_gram_weights[i][len(gram2idx[i]):]]) assign_op = old_gram_weights[i].assign(new_ng_weight) self.updates.append(assign_op)
def define_updates(self, new_chars, emb_path, char2idx): self.nums_chars += len(new_chars) if emb_path is not None: old_emb_weights = self.emb_layer.embeddings emb_dim = old_emb_weights.get_shape().as_list()[1] emb_len = old_emb_weights.get_shape().as_list()[0] new_emb = tf.pack(toolbox.get_new_embeddings(new_chars, emb_dim, emb_path)) n_emb_sh = new_emb.get_shape().as_list() if len(n_emb_sh) > 1: new_emb_weights = tf.concat(0, [old_emb_weights[:len(char2idx) - len(new_chars)], new_emb, old_emb_weights[len(char2idx):]]) if new_emb_weights.get_shape().as_list()[0] > emb_len: new_emb_weights = new_emb_weights[:emb_len] assign_op = old_emb_weights.assign(new_emb_weights) self.updates.append(assign_op)