def build(self, data, stop_batch=0): self.ntypes = self.model.get_ntypes() # Usually, the type number of the model should be equal to that of the data # However, nt_model > nt_data should be allowed, since users may only want to # train using a dataset that only have some of elements assert (self.ntypes >= data.get_ntypes()), "ntypes should match that found in data" self.stop_batch = stop_batch self.batch_size = data.get_batch_size() if self.numb_fparam > 0: self._message("training with %d frame parameter(s)" % self.numb_fparam) else: self._message("training without frame parameter") self.type_map = data.get_type_map() self.model.data_stat(data) worker_device = "/job:%s/task:%d/%s" % (self.run_opt.my_job_name, self.run_opt.my_task_index, self.run_opt.my_device) with tf.device( tf.train.replica_device_setter( worker_device=worker_device, cluster=self.run_opt.cluster_spec)): self._build_lr() self._build_network(data) self._build_training()
def build(self, data, stop_batch=0): self.ntypes = self.model.get_ntypes() assert (self.ntypes == data.get_ntypes() ), "ntypes should match that found in data" self.stop_batch = stop_batch self.batch_size = data.get_batch_size() if self.numb_fparam > 0: self._message("training with %d frame parameter(s)" % self.numb_fparam) else: self._message("training without frame parameter") self.type_map = data.get_type_map() self.model.data_stat(data) worker_device = "/job:%s/task:%d/%s" % (self.run_opt.my_job_name, self.run_opt.my_task_index, self.run_opt.my_device) with tf.device( tf.train.replica_device_setter( worker_device=worker_device, cluster=self.run_opt.cluster_spec)): self._build_lr() self._build_network(data) self._build_training()
def connect_done_queue(cluster_spec, task_index): done_ops = [] for i in range(cluster_spec.num_tasks("ps")): with tf.device("/job:ps/task:%d" % i): queue = tf.FIFOQueue(cluster_spec.num_tasks('worker'), tf.int32, shared_name='done_queue' + str(i)) done_ops.append(queue.enqueue(task_index)) return done_ops
def create_done_queue(cluster_spec, task_index): with tf.device("/job:ps/task:%d" % (task_index)): queue = tf.FIFOQueue(cluster_spec.num_tasks("worker"), tf.int32, shared_name="done_queue" + str(task_index)) return queue