Пример #1
0
 def test_xshards_symbol(self):
     # prepare data
     resource_path = os.path.join(
         os.path.split(__file__)[0], "../../../resources")
     self.ray_ctx = get_ray_ctx()
     train_file_path = os.path.join(resource_path,
                                    "orca/learn/train_data.json")
     train_data_shard = zoo.orca.data.pandas.read_json(train_file_path,
                                                       self.ray_ctx,
                                                       orient='records',
                                                       lines=False)
     train_data_shard.transform_shard(prepare_data_symbol)
     test_file_path = os.path.join(resource_path,
                                   "orca/learn/test_data.json")
     test_data_shard = zoo.orca.data.pandas.read_json(test_file_path,
                                                      self.ray_ctx,
                                                      orient='records',
                                                      lines=False)
     test_data_shard.transform_shard(prepare_data_symbol)
     config = create_trainer_config(batch_size=32, log_interval=1, seed=42)
     trainer = MXNetTrainer(config,
                            train_data_shard,
                            get_symbol_model,
                            validation_metrics_creator=get_metrics,
                            test_data=test_data_shard,
                            eval_metrics_creator=get_metrics)
     trainer.train(nb_epoch=2)
Пример #2
0
 def test_gluon(self):
     config = create_trainer_config(batch_size=32, log_interval=2, optimizer="adam",
                                    optimizer_params={'learning_rate': 0.02})
     trainer = MXNetTrainer(config, get_train_data_iter, get_model, get_loss,
                            eval_metrics_creator=get_metrics,
                            validation_metrics_creator=get_metrics,
                            num_workers=2, test_data=get_test_data_iter)
     trainer.train(nb_epoch=2)
Пример #3
0
 def test_xshards_symbol_without_val(self):
     resource_path = os.path.join(
         os.path.split(__file__)[0], "../../../resources")
     train_file_path = os.path.join(resource_path,
                                    "orca/learn/single_input_json/train")
     train_data_shard = zoo.orca.data.pandas.read_json(
         train_file_path, get_spark_ctx(), orient='records',
         lines=False).transform_shard(prepare_data_symbol)
     config = create_trainer_config(batch_size=32, log_interval=1, seed=42)
     trainer = MXNetTrainer(config,
                            train_data_shard,
                            get_symbol_model,
                            eval_metrics_creator=get_metrics,
                            num_workers=2)
     trainer.train(nb_epoch=2)
Пример #4
0
 def test_gluon(self):
     current_ray_ctx = RayContext.get()
     address_info = current_ray_ctx.address_info
     assert "object_store_address" in address_info
     config = create_trainer_config(
         batch_size=32,
         log_interval=2,
         optimizer="adam",
         optimizer_params={'learning_rate': 0.02})
     trainer = MXNetTrainer(config,
                            get_train_data_iter,
                            get_model,
                            get_loss,
                            eval_metrics_creator=get_metrics,
                            validation_metrics_creator=get_metrics,
                            num_workers=2,
                            test_data=get_test_data_iter)
     trainer.train(nb_epoch=2)
Пример #5
0
 def test_symbol(self):
     config = create_trainer_config(batch_size=32, log_interval=2, seed=42)
     trainer = MXNetTrainer(config, get_train_data_iter, get_model,
                            validation_metrics_creator=get_metrics,
                            test_data=get_test_data_iter, eval_metrics_creator=get_metrics)
     trainer.train(nb_epoch=2)
Пример #6
0
    opt = parser.parse_args()

    if opt.hadoop_conf:
        assert opt.conda_name is not None, "conda_name must be specified for yarn mode"
        sc = init_spark_on_yarn(hadoop_conf=opt.hadoop_conf,
                                conda_name=opt.conda_name,
                                num_executor=opt.num_workers,
                                executor_cores=opt.executor_cores)
    else:
        sc = init_spark_on_local(cores="*")
    ray_ctx = RayContext(sc=sc)
    ray_ctx.init()

    config = create_trainer_config(
        opt.batch_size,
        optimizer="sgd",
        optimizer_params={'learning_rate': opt.learning_rate},
        log_interval=opt.log_interval,
        seed=42)
    trainer = MXNetTrainer(config,
                           train_data=get_train_data_iter,
                           model_creator=get_model,
                           loss_creator=get_loss,
                           validation_metrics_creator=get_metrics,
                           num_workers=opt.num_workers,
                           num_servers=opt.num_servers,
                           test_data=get_test_data_iter,
                           eval_metrics_creator=get_metrics)
    trainer.train(nb_epoch=opt.epochs)
    ray_ctx.stop()
    sc.stop()