예제 #1
0
 def test_xshards_symbol(self):
     # prepare data
     resource_path = os.path.join(
         os.path.split(__file__)[0], "../../../resources")
     self.ray_ctx = get_ray_ctx()
     train_file_path = os.path.join(resource_path,
                                    "orca/learn/single_input_json/train")
     train_data_shard = zoo.orca.data.pandas.read_json(train_file_path,
                                                       self.ray_ctx,
                                                       orient='records',
                                                       lines=False)
     train_data_shard.transform_shard(prepare_data_symbol)
     test_file_path = os.path.join(resource_path,
                                   "orca/learn/single_input_json/test")
     test_data_shard = zoo.orca.data.pandas.read_json(test_file_path,
                                                      self.ray_ctx,
                                                      orient='records',
                                                      lines=False)
     test_data_shard.transform_shard(prepare_data_symbol)
     config = create_config(batch_size=32, log_interval=1, seed=42)
     estimator = Estimator(config,
                           get_symbol_model,
                           validation_metrics_creator=get_metrics,
                           eval_metrics_creator=get_metrics,
                           num_workers=2)
     estimator.fit(train_data_shard, test_data_shard, nb_epoch=2)
     estimator.shutdown()
 def test_xshards_symbol_with_val(self):
     resource_path = os.path.join(
         os.path.split(__file__)[0], "../../../../resources")
     train_file_path = os.path.join(resource_path,
                                    "orca/learn/single_input_json/train")
     train_data_shard = zoo.orca.data.pandas.read_json(
         train_file_path, orient='records',
         lines=False).transform_shard(prepare_data_symbol)
     test_file_path = os.path.join(resource_path,
                                   "orca/learn/single_input_json/test")
     test_data_shard = zoo.orca.data.pandas.read_json(
         test_file_path, orient='records',
         lines=False).transform_shard(prepare_data_symbol)
     config = create_config(log_interval=1, seed=42)
     estimator = Estimator.from_mxnet(
         config=config,
         model_creator=get_symbol_model,
         validation_metrics_creator=get_metrics,
         eval_metrics_creator=get_metrics,
         num_workers=2)
     estimator.fit(train_data_shard, epochs=2)
     train_data_shard2 = zoo.orca.data.pandas.read_json(
         train_file_path, orient='records',
         lines=False).transform_shard(prepare_data_symbol)
     estimator.fit(train_data_shard2,
                   validation_data=test_data_shard,
                   epochs=1,
                   batch_size=32)
     estimator.shutdown()
예제 #3
0
 def test_symbol(self):
     config = create_config(batch_size=32, log_interval=2, seed=42)
     estimator = Estimator(config,
                           get_model,
                           validation_metrics_creator=get_metrics,
                           eval_metrics_creator=get_metrics)
     estimator.fit(get_train_data_iter, get_test_data_iter, nb_epoch=2)
     estimator.shutdown()
예제 #4
0
 def test_symbol(self):
     config = create_config(log_interval=2, seed=42)
     estimator = Estimator.from_mxnet(
         config=config,
         model_creator=get_model,
         validation_metrics_creator=get_metrics,
         eval_metrics_creator=get_metrics)
     estimator.fit(get_train_data_iter,
                   validation_data=get_test_data_iter,
                   epochs=2,
                   batch_size=16)
     estimator.shutdown()
예제 #5
0
 def test_gluon(self):
     current_ray_ctx = RayContext.get()
     address_info = current_ray_ctx.address_info
     assert "object_store_address" in address_info
     config = create_config(log_interval=2, optimizer="adam",
                            optimizer_params={'learning_rate': 0.02})
     estimator = Estimator(config, get_model, get_loss,
                           eval_metrics_creator=get_metrics,
                           validation_metrics_creator=get_metrics,
                           num_workers=2)
     estimator.fit(get_train_data_iter, validation_data=get_test_data_iter, epochs=2)
     estimator.shutdown()
 def test_gluon_multiple_input(self):
     config = create_config(log_interval=2,
                            optimizer="adagrad",
                            seed=1128,
                            optimizer_params={'learning_rate': 0.02})
     estimator = Estimator(config,
                           get_model,
                           get_loss,
                           eval_metrics_creator=get_metrics,
                           validation_metrics_creator=get_metrics,
                           num_workers=4)
     estimator.fit(get_train_data_iter,
                   validation_data=get_test_data_iter,
                   epochs=2)
     estimator.shutdown()
예제 #7
0
    parser.add_argument(
        '--log_interval',
        type=int,
        default=20,
        help='The number of batches to wait before logging throughput and '
        'metrics information during the training process.')
    opt = parser.parse_args()

    num_nodes = 1 if opt.cluster_mode == "local" else opt.num_workers
    init_orca_context(cluster_mode=opt.cluster_mode,
                      cores=opt.cores,
                      num_nodes=num_nodes)

    config = create_config(
        optimizer="sgd",
        optimizer_params={'learning_rate': opt.learning_rate},
        log_interval=opt.log_interval,
        seed=42)
    estimator = Estimator.from_mxnet(config=config,
                                     model_creator=get_model,
                                     loss_creator=get_loss,
                                     validation_metrics_creator=get_metrics,
                                     num_workers=opt.num_workers,
                                     num_servers=opt.num_servers,
                                     eval_metrics_creator=get_metrics)
    estimator.fit(data=get_train_data_iter,
                  validation_data=get_test_data_iter,
                  epochs=opt.epochs,
                  batch_size=opt.batch_size)
    estimator.shutdown()
    stop_orca_context()