def get_init_ops(self): init_ops = [] if self.conf.embedding_init: init_ops = [ tf.variables_initializer( set(self.optimizer_params + self.global_params + self.trainable_params) - set([self.embedding])) ] w2v = np.load(self.conf.embedding_init) init_ops.append(self.embedding.assign(w2v)) else: init_ops = [ tf.variables_initializer( set(self.optimizer_params + self.global_params + self.trainable_params)) ] if self.task_id == 0: vocab_file = filter(lambda x: re.match("vocab\d+\.all", x) != None, os.listdir(self.conf.data_dir))[0] f = codecs.open(os.path.join(self.conf.data_dir, vocab_file)) k = [line.strip() for line in f] k = k[0:self.conf.output_vocab_size] v = [i for i in range(len(k))] op_in = self.in_table.insert( constant_op.constant(k), constant_op.constant(v, dtype=tf.int64)) op_out = self.out_table.insert( constant_op.constant(v, dtype=tf.int64), constant_op.constant(k)) init_ops.extend([op_in, op_out]) return init_ops
def get_init_ops(self, job_type, task_id): init_ops = [ tf.variables_initializer( set(self.optimizer_params + self.global_params + self.trainable_params)) ] if self.conf.embedding_init: w2v = np.load(self.conf.embedding_init) init_ops.append(self.embedding.assign(w2v)) init_ops = list(set(init_ops) - set([self.embedding])) if self.conf.topic_embedding_init: t2v = np.load(self.conf.topic_embedding_init) init_ops.append(self.topic_embedding.assign(t2v)) init_ops = list(set(init_ops) - set([self.topic_embedding])) if not self.for_deploy and task_id == 0: vocab_file = filter( lambda x: re.match("vocab[0-9]+\.all", x) != None, os.listdir(self.conf.data_dir))[0] f = codecs.open(os.path.join(self.conf.data_dir, vocab_file)) k = [line.strip() for line in f] k = k[0:self.conf.input_vocab_size] v = [i for i in range(len(k))] op_in = self.in_table.insert( constant_op.constant(k), constant_op.constant(v, dtype=tf.int64)) op_out = self.out_table.insert( constant_op.constant(v, dtype=tf.int64), constant_op.constant(k)) # Topic topic_vocab_file = "vocab.topic" ft = codecs.open(os.path.join(self.conf.data_dir, topic_vocab_file)) k = [line.strip() for line in ft] v = [i for i in range(len(k))] op_topic_in = self.topic_in_table.insert( constant_op.constant(k), constant_op.constant(v, dtype=tf.int64)) init_ops.extend([op_in, op_out, op_topic_in]) return init_ops