def test_save_for_non_scheduler_host(): model = Mock() training_utils.save(MODEL_DIR, model, current_host=WORKER_HOST, hosts=[SCHEDULER_HOST, WORKER_HOST]) model.symbol.save.assert_not_called model.save_params.assert_not_called
def test_save_single_machine(json_dump): model = Mock() model.data_shapes = [] with patch('six.moves.builtins.open', mock_open()): training_utils.save(MODEL_DIR, model) model.symbol.save.assert_called_with( os.path.join(MODEL_DIR, 'model-symbol.json')) model.save_params.assert_called_with( os.path.join(MODEL_DIR, 'model-0000.params')) json_dump.assert_called_once
def test_save_distributed(json_dump): model = Mock() model.data_shapes = [] with patch('six.moves.builtins.open', mock_open()): training_utils.save(MODEL_DIR, model, current_host=SCHEDULER_HOST, hosts=[SCHEDULER_HOST, WORKER_HOST]) model.symbol.save.assert_called_with( os.path.join(MODEL_DIR, 'model-symbol.json')) model.save_params.assert_called_with( os.path.join(MODEL_DIR, 'model-0000.params')) json_dump.assert_called_once
def train( batch_size, epochs, learning_rate, num_gpus, training_channel, testing_channel, hosts, current_host, model_dir, ): (train_labels, train_images) = load_data(training_channel) (test_labels, test_images) = load_data(testing_channel) # Data parallel training - shard the data so each host # only trains on a subset of the total data. shard_size = len(train_images) // len(hosts) for i, host in enumerate(hosts): if host == current_host: start = shard_size * i end = start + shard_size break train_iter = mx.io.NDArrayIter(train_images[start:end], train_labels[start:end], batch_size, shuffle=True) val_iter = mx.io.NDArrayIter(test_images, test_labels, batch_size) logging.getLogger().setLevel(logging.DEBUG) kvstore = "local" if len(hosts) == 1 else "dist_sync" mlp_model = mx.mod.Module(symbol=build_graph(), context=get_train_context(num_gpus)) mlp_model.fit( train_iter, eval_data=val_iter, kvstore=kvstore, optimizer="sgd", optimizer_params={"learning_rate": learning_rate}, eval_metric="acc", batch_end_callback=mx.callback.Speedometer(batch_size, 100), num_epoch=epochs, ) if len(hosts) == 1 or current_host == scheduler_host(hosts): save(model_dir, mlp_model)
def train(args): logging.info(mx.__version__) # Get hyperparameters batch_size = args.batch_size epochs = args.epochs learning_rate = args.learning_rate model_dir = os.environ['SM_MODEL_DIR'] num_gpus = int(os.environ['SM_NUM_GPUS']) current_host = args.current_host hosts = args.hosts beta1 = 0.9 beta2 = 0.99 num_workers = args.num_workers num_classes = 1 # Set context for compute based on instance environment if num_gpus > 0: ctx = [mx.gpu(i) for i in range(num_gpus)] else: ctx = mx.cpu() # Locate compressed training/validation data root_data_dir = args.root_dir # Define custom iterators on extracted data locations. train_iter = DataLoaderIter( root_data_dir, num_classes, batch_size, True, num_workers, 'train') validation_iter = DataLoaderIter( root_data_dir, num_classes, batch_size, False, num_workers, 'validation') # Build network symbolic graph sym = build_unet(num_classes) logging.info("Sym loaded") # Load graph into Module net = mx.mod.Module(sym, context=ctx, data_names=('data',), label_names=('label',)) # Initialize Custom Metric dice_metric = mx.metric.CustomMetric(feval=avg_dice_coef_metric, allow_extra_outputs=True) logging.info("Starting model fit") # Start training the model net.fit( train_data=train_iter, eval_data=validation_iter, eval_metric=dice_metric, initializer=mx.initializer.Xavier(magnitude=6), optimizer='adam', optimizer_params={ 'learning_rate': learning_rate, 'beta1': beta1, 'beta2': beta2}, num_epoch=epochs) # Save Parameters net.save_params('params') # Build inference-only graphs, set parameters from training models sym = build_unet(num_classes, inference=True) net = mx.mod.Module( sym, context=ctx, data_names=( 'data',), label_names=None) # Re-binding model for a batch-size of one net.bind(data_shapes=[('data', (1,) + train_iter.provide_data[0][1][1:])]) net.load_params('params') # save model from sagemaker_mxnet_container.training_utils import save save(os.environ['SM_MODEL_DIR'], net)
def _get_context(cpus, gpus): if gpus > 0: ctx = [mx.gpu(x) for x in range(gpus)] else: ctx = mx.cpu() logging.info("mxnet context: %s" % str(ctx)) return ctx if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) parser.add_argument( '--input-channels', type=str, default=json.loads( os.environ['SM_TRAINING_ENV'])['channel_input_dirs']) args = parser.parse_args() num_cpus = int(os.environ['SM_NUM_CPUS']) num_gpus = int(os.environ['SM_NUM_GPUS']) model = train(num_cpus, num_gpus, args.input_channels) save(args.model_dir, model)