示例#1
0
 def train(self,
           data_rdd,
           model_rdd,
           batch_size,
           epochs,
           model_dir,
           go_on=False):
     n_samples = data_rdd.count()
     # steps_per_epoch = n_samples // batch_size // self.num_workers
     steps_per_epoch = math.ceil(n_samples / batch_size / self.num_workers)
     assert steps_per_epoch > 0
     md = ModelDir(model_dir, 'train*')
     if go_on:
         md.create_model_dir()
     else:
         md = md.rebuild_model_dir()
     worker = TFTrainWorker(model_rdd,
                            go_on=go_on,
                            batch_size=batch_size,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            **md.to_dict())
     cluster = TFCluster.run(self.sc,
                             worker,
                             self.tf_args,
                             self.cluster_size,
                             self.num_ps,
                             input_mode=self.input_mode)
     cluster.train(data_rdd.rdd, num_epochs=epochs, feed_timeout=60000)
     cluster.shutdown()
     results = md.read_result()
     return self.sqlc.createDataFrame(results)
示例#2
0
文件: tfos.py 项目: linxigal/tfos
 def evaluate(self, data_rdd, steps, model_dir):
     md = ModelDir(model_dir, 'evaluate*')
     steps_per_epoch = data_rdd.count() if steps <= 0 else steps
     steps_per_epoch = math.ceil(steps_per_epoch / self.num_workers)
     worker = EvaluateWorker(steps_per_epoch=steps_per_epoch, **md.to_dict())
     md.delete_result_file()
     cluster = TFCluster.run(self.sc, worker, self.tf_args, self.cluster_size, self.num_ps,
                             input_mode=self.input_mode)
     cluster.train(data_rdd.rdd, num_epochs=1)
     cluster.shutdown()
     results = md.read_result()
     return self.sqlc.createDataFrame(results)
示例#3
0
文件: tfos.py 项目: linxigal/tfos
 def recurrent_predict(self, data_rdd, units, steps, feature_type, model_dir):
     md = ModelDir(model_dir, 'recurrent_predict*')
     worker = RecurrentPredictWorker(units=units,
                                     steps=steps,
                                     feature_type=feature_type,
                                     **md.to_dict())
     md.delete_result_file()
     cluster = TFCluster.run(self.sc, worker, self.tf_args, self.cluster_size, self.num_ps,
                             input_mode=self.input_mode)
     cluster.train(data_rdd.rdd, num_epochs=1, feed_timeout=6000)
     cluster.shutdown()
     results = md.read_result(True)
     return self.sqlc.createDataFrame([{"result": result} for result in results])
示例#4
0
文件: tfos.py 项目: linxigal/tfos
 def predict(self, data_rdd, steps, model_dir, output_prob=False):
     md = ModelDir(model_dir, 'predict*')
     steps_per_epoch = data_rdd.count() if steps <= 0 else steps
     steps_per_epoch = math.ceil(steps_per_epoch / self.num_workers)
     worker = PredictWorker(steps_per_epoch=steps_per_epoch,
                            output_prob=output_prob,
                            **md.to_dict())
     md.delete_result_file()
     cluster = TFCluster.run(self.sc, worker, self.tf_args, self.cluster_size, self.num_ps,
                             input_mode=self.input_mode)
     cluster.train(data_rdd.rdd, num_epochs=1, feed_timeout=6000)
     cluster.shutdown()
     results = md.read_result()
     return self.sqlc.createDataFrame(results)
示例#5
0
文件: tfos.py 项目: linxigal/tfos
 def yolov3_tiny_train(self,
                       model_rdd,
                       batch_size,
                       epochs,
                       classes_path,
                       anchors_path,
                       train_path,
                       val_path,
                       image_size,
                       model_dir,
                       weights_path=None,
                       freeze_body=2,
                       go_on=False):
     columns = model_rdd.columns
     assert "model_config" in columns, "not exists model layer config!"
     assert tf.io.gfile.exists(train_path), "train dataset path not exists!"
     data_rdd = self.sc.textFile(train_path)
     n_samples = data_rdd.count()
     steps_per_epoch = math.ceil(n_samples / batch_size / self.num_workers)
     md = ModelDir(model_dir, 'train*')
     if go_on:
         md.create_model_dir()
     else:
         md = md.rebuild_model_dir()
     worker = YOLOV3TinyModelTrainWorker(model_rdd,
                                         go_on=go_on,
                                         batch_size=batch_size,
                                         epochs=epochs,
                                         classes_path=classes_path,
                                         anchors_path=anchors_path,
                                         weights_path=weights_path,
                                         val_path=val_path,
                                         image_size=image_size,
                                         steps_per_epoch=steps_per_epoch,
                                         freeze_body=freeze_body,
                                         **md.to_dict())
     cluster = TFCluster.run(self.sc,
                             worker,
                             self.tf_args,
                             self.cluster_size,
                             self.num_ps,
                             input_mode=self.input_mode)
     cluster.train(data_rdd, num_epochs=epochs, feed_timeout=60000)
     cluster.shutdown()
     results = md.read_result()
     return self.sqlc.createDataFrame(results)
示例#6
0
文件: tfos.py 项目: linxigal/tfos
 def yolov3_train(self,
                  model_rdd,
                  data_dir,
                  batch_size,
                  epochs,
                  image_size,
                  model_dir,
                  weights_path=None,
                  freeze_body=2,
                  go_on=False):
     train_path = os.path.join(data_dir, 'train.txt')
     assert tf.io.gfile.exists(train_path), "train dataset path not exists!"
     data_rdd = self.sc.textFile(train_path)
     n_samples = data_rdd.count()
     steps_per_epoch = math.ceil(n_samples / batch_size / self.num_workers)
     md = ModelDir(model_dir, 'train*')
     if go_on:
         md.create_model_dir()
     else:
         md = md.rebuild_model_dir()
     worker = YOLOV3ModelTrainWorker(model_rdd,
                                     data_dir,
                                     go_on=go_on,
                                     batch_size=batch_size,
                                     epochs=epochs,
                                     image_size=image_size,
                                     steps_per_epoch=steps_per_epoch,
                                     freeze_body=freeze_body,
                                     **md.to_dict())
     cluster = TFCluster.run(self.sc,
                             worker,
                             self.tf_args,
                             self.cluster_size,
                             self.num_ps,
                             input_mode=self.input_mode)
     cluster.train(data_rdd, num_epochs=epochs, feed_timeout=60000)
     cluster.shutdown()
     results = md.read_result()
     if results:
         return self.sqlc.createDataFrame(results)