def test_xshards_symbol_with_val(self): resource_path = os.path.join( os.path.split(__file__)[0], "../../../../resources") train_file_path = os.path.join(resource_path, "orca/learn/single_input_json/train") train_data_shard = zoo.orca.data.pandas.read_json( train_file_path, orient='records', lines=False).transform_shard(prepare_data_symbol) test_file_path = os.path.join(resource_path, "orca/learn/single_input_json/test") test_data_shard = zoo.orca.data.pandas.read_json( test_file_path, orient='records', lines=False).transform_shard(prepare_data_symbol) config = create_config(log_interval=1, seed=42) estimator = Estimator(config, get_symbol_model, validation_metrics_creator=get_metrics, eval_metrics_creator=get_metrics, num_workers=2) estimator.fit(train_data_shard, epochs=2) train_data_shard2 = zoo.orca.data.pandas.read_json( train_file_path, orient='records', lines=False).transform_shard(prepare_data_symbol) estimator.fit(train_data_shard2, validation_data=test_data_shard, epochs=1, batch_size=32) estimator.shutdown()
def test_symbol(self): config = create_config(log_interval=2, seed=42) estimator = Estimator.from_mxnet( config=config, model_creator=get_model, validation_metrics_creator=get_metrics, eval_metrics_creator=get_metrics) estimator.fit(get_train_data_iter, validation_data=get_test_data_iter, epochs=2, batch_size=16) estimator.shutdown()
def test_xshards_gluon(self): # prepare data resource_path = os.path.join( os.path.split(__file__)[0], "../../../resources") self.ray_ctx = get_ray_ctx() train_file_path = os.path.join(resource_path, "orca/learn/single_input_json/train") train_data_shard = zoo.orca.data.pandas.read_json(train_file_path, self.ray_ctx, orient='records', lines=False) train_data_shard.transform_shard(prepare_data_gluon) test_file_path = os.path.join(resource_path, "orca/learn/single_input_json/test") test_data_shard = zoo.orca.data.pandas.read_json(test_file_path, self.ray_ctx, orient='records', lines=False) test_data_shard.transform_shard(prepare_data_gluon) config = create_config(batch_size=32, log_interval=1, seed=42) estimator = Estimator(config, get_gluon_model, get_loss, validation_metrics_creator=get_gluon_metrics, eval_metrics_creator=get_gluon_metrics, num_workers=2) estimator.fit(train_data_shard, test_data_shard, nb_epoch=2) estimator.shutdown()
def test_symbol(self): config = create_config(batch_size=32, log_interval=2, seed=42) estimator = Estimator(config, get_model, validation_metrics_creator=get_metrics, eval_metrics_creator=get_metrics) estimator.fit(get_train_data_iter, get_test_data_iter, nb_epoch=2) estimator.shutdown()
def test_gluon(self): current_ray_ctx = RayContext.get() address_info = current_ray_ctx.address_info assert "object_store_address" in address_info config = create_config(log_interval=2, optimizer="adam", optimizer_params={'learning_rate': 0.02}) estimator = Estimator.from_mxnet(config=config, model_creator=get_model, loss_creator=get_loss, eval_metrics_creator=get_metrics, validation_metrics_creator=get_metrics, num_workers=2) estimator.fit(get_train_data_iter, validation_data=get_test_data_iter, epochs=2) estimator.shutdown()
def test_gluon_multiple_input(self): config = create_config(log_interval=2, optimizer="adagrad", seed=1128, optimizer_params={'learning_rate': 0.02}) estimator = Estimator.from_mxnet( config=config, model_creator=get_model, loss_creator=get_loss, eval_metrics_creator=get_metrics, validation_metrics_creator=get_metrics, num_workers=4) estimator.fit(get_train_data_iter, validation_data=get_test_data_iter, epochs=2) estimator.shutdown()
'--log_interval', type=int, default=20, help='The number of batches to wait before logging throughput and ' 'metrics information during the training process.') opt = parser.parse_args() num_nodes = 1 if opt.cluster_mode == "local" else opt.num_workers init_orca_context(cluster_mode=opt.cluster_mode, cores=opt.cores, num_nodes=num_nodes) config = create_config( optimizer="sgd", optimizer_params={'learning_rate': opt.learning_rate}, log_interval=opt.log_interval, seed=42) estimator = Estimator.from_mxnet(config=config, model_creator=get_model, loss_creator=get_loss, validation_metrics_creator=get_metrics, num_workers=opt.num_workers, num_servers=opt.num_servers, eval_metrics_creator=get_metrics) estimator.fit(data=get_train_data_iter, validation_data=get_test_data_iter, epochs=opt.epochs, batch_size=opt.batch_size) estimator.shutdown() stop_orca_context()
if opt.hadoop_conf: assert opt.conda_name is not None, "conda_name must be specified for yarn mode" sc = init_spark_on_yarn(hadoop_conf=opt.hadoop_conf, conda_name=opt.conda_name, num_executors=opt.num_workers, executor_cores=opt.executor_cores) else: sc = init_spark_on_local(cores="*") ray_ctx = RayContext(sc=sc) ray_ctx.init() config = create_config( optimizer="sgd", optimizer_params={'learning_rate': opt.learning_rate}, log_interval=opt.log_interval, seed=42) estimator = Estimator(config, model_creator=get_model, loss_creator=get_loss, validation_metrics_creator=get_metrics, num_workers=opt.num_workers, num_servers=opt.num_servers, eval_metrics_creator=get_metrics) estimator.fit(data=get_train_data_iter, validation_data=get_test_data_iter, epochs=opt.epochs, batch_size=opt.batch_size) estimator.shutdown() ray_ctx.stop() sc.stop()