def test_xshards_symbol(self): # prepare data resource_path = os.path.join( os.path.split(__file__)[0], "../../../resources") self.ray_ctx = get_ray_ctx() train_file_path = os.path.join(resource_path, "orca/learn/train_data.json") train_data_shard = zoo.orca.data.pandas.read_json(train_file_path, self.ray_ctx, orient='records', lines=False) train_data_shard.transform_shard(prepare_data_symbol) test_file_path = os.path.join(resource_path, "orca/learn/test_data.json") test_data_shard = zoo.orca.data.pandas.read_json(test_file_path, self.ray_ctx, orient='records', lines=False) test_data_shard.transform_shard(prepare_data_symbol) config = create_trainer_config(batch_size=32, log_interval=1, seed=42) trainer = MXNetTrainer(config, train_data_shard, get_symbol_model, validation_metrics_creator=get_metrics, test_data=test_data_shard, eval_metrics_creator=get_metrics) trainer.train(nb_epoch=2)
def test_gluon(self): config = create_trainer_config(batch_size=32, log_interval=2, optimizer="adam", optimizer_params={'learning_rate': 0.02}) trainer = MXNetTrainer(config, get_train_data_iter, get_model, get_loss, eval_metrics_creator=get_metrics, validation_metrics_creator=get_metrics, num_workers=2, test_data=get_test_data_iter) trainer.train(nb_epoch=2)
def test_xshards_symbol_without_val(self): resource_path = os.path.join( os.path.split(__file__)[0], "../../../resources") train_file_path = os.path.join(resource_path, "orca/learn/single_input_json/train") train_data_shard = zoo.orca.data.pandas.read_json( train_file_path, get_spark_ctx(), orient='records', lines=False).transform_shard(prepare_data_symbol) config = create_trainer_config(batch_size=32, log_interval=1, seed=42) trainer = MXNetTrainer(config, train_data_shard, get_symbol_model, eval_metrics_creator=get_metrics, num_workers=2) trainer.train(nb_epoch=2)
def test_gluon(self): current_ray_ctx = RayContext.get() address_info = current_ray_ctx.address_info assert "object_store_address" in address_info config = create_trainer_config( batch_size=32, log_interval=2, optimizer="adam", optimizer_params={'learning_rate': 0.02}) trainer = MXNetTrainer(config, get_train_data_iter, get_model, get_loss, eval_metrics_creator=get_metrics, validation_metrics_creator=get_metrics, num_workers=2, test_data=get_test_data_iter) trainer.train(nb_epoch=2)
def test_symbol(self): config = create_trainer_config(batch_size=32, log_interval=2, seed=42) trainer = MXNetTrainer(config, get_train_data_iter, get_model, validation_metrics_creator=get_metrics, test_data=get_test_data_iter, eval_metrics_creator=get_metrics) trainer.train(nb_epoch=2)
if opt.hadoop_conf: assert opt.conda_name is not None, "conda_name must be specified for yarn mode" sc = init_spark_on_yarn(hadoop_conf=opt.hadoop_conf, conda_name=opt.conda_name, num_executor=opt.num_workers, executor_cores=opt.executor_cores) else: sc = init_spark_on_local(cores="*") ray_ctx = RayContext(sc=sc) ray_ctx.init() config = create_trainer_config( opt.batch_size, optimizer="sgd", optimizer_params={'learning_rate': opt.learning_rate}, log_interval=opt.log_interval, seed=42) trainer = MXNetTrainer(config, train_data=get_train_data_iter, model_creator=get_model, loss_creator=get_loss, validation_metrics_creator=get_metrics, num_workers=opt.num_workers, num_servers=opt.num_servers, test_data=get_test_data_iter, eval_metrics_creator=get_metrics) trainer.train(nb_epoch=opt.epochs) ray_ctx.stop() sc.stop()