def data_creator(config, batch_size): from zoo.orca.data.utils import ray_partition_get_data_label, index_data, get_size from torch.utils.data import Dataset, DataLoader class NDArrayDataset(Dataset): def __init__(self, x, y): self.x = x # features self.y = y # labels def __len__(self): return get_size(self.y) def __getitem__(self, i): return index_data(self.x, i), index_data(self.y, i) params = {"batch_size": batch_size, "shuffle": True} for arg in ["shuffle", "sampler", "batch_sampler", "num_workers", "collate_fn", "pin_memory", "drop_last", "timeout", "worker_init_fn", "multiprocessing_context"]: if arg in config: params[arg] = config[arg] data, label = ray_partition_get_data_label(ray.get(shards_ref), allow_tuple=False, allow_list=False) print("Data size on worker: ", len(label)) dataset = NDArrayDataset(data, label) data_loader = DataLoader(dataset, **params) return data_loader
def _handle_xshards(self, dataset, steps, local_batch_size, shuffle): import tensorflow as tf data, label = ray_partition_get_data_label(ray.get(dataset), allow_tuple=True, allow_list=False) dataset = tf.data.Dataset.from_tensor_slices((data, label)) dataset = dataset.repeat() dataset = dataset.take(steps * local_batch_size) if shuffle: dataset = dataset.shuffle(local_batch_size * min(steps, 10)) dataset = dataset.batch(local_batch_size) return dataset
def data_creator(config, kv): import mxnet as mx assert "batch_size" in config, "batch_size must be set in config" data, label = ray_partition_get_data_label(ray.get(shards_ref), allow_tuple=False, allow_list=False) train_data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=config["batch_size"], shuffle=shuffle) if "train_resize_batch_num" in config: train_data_iter = mx.io.ResizeIter(train_data_iter, config["train_resize_batch_num"]) return train_data_iter
def _handle_xshards(self, dataset, steps, local_batch_size, shuffle): import tensorflow as tf data, label = ray_partition_get_data_label(ray.get(dataset), allow_tuple=True, allow_list=False) dataset = tf.data.Dataset.from_tensor_slices((data, label)) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF dataset = dataset.with_options(options) dataset = dataset.repeat() dataset = dataset.take(steps * local_batch_size) if shuffle: dataset = dataset.shuffle(local_batch_size * min(steps, 10)) dataset = dataset.batch(local_batch_size) return dataset
def data_creator(config): from zoo.orca.data.utils import ray_partition_get_data_label, index_data, get_size from torch.utils.data import Dataset, DataLoader assert "batch_size" in config, "batch_size must be set in config" params = {"batch_size": config["batch_size"], "shuffle": True} for arg in ["shuffle", "sampler", "batch_sampler", "num_workers", "collate_fn", "pin_memory", "drop_last", "timeout", "worker_init_fn", "multiprocessing_context"]: if arg in config: params[arg] = config[arg] data, label = ray_partition_get_data_label(ray.get(shards_ref), allow_tuple=False, allow_list=False) print("Data size on worker: ", len(label)) dataset = torch.utils.data.TensorDataset(torch.from_numpy(data), torch.from_numpy(label)) data_loader = DataLoader(dataset, **params) return data_loader
def data_creator(config): assert "batch_size" in config, "batch_size must be set in config" import tensorflow as tf data, label = ray_partition_get_data_label(partition.get_data(), allow_tuple=True, allow_list=False) dataset = tf.data.Dataset.from_tensor_slices((data, label)) if max_length is not None: # todo find a way to pad empty tensors? dataset = dataset.repeat() if shuffle: dataset = dataset.shuffle(max_length) dataset = dataset.take(max_length) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF dataset = dataset.with_options(options) dataset = dataset.batch(config["batch_size"]) return dataset
def _handle_xshards(self, dataset, steps, local_batch_size, shuffle): import tensorflow as tf data, label = ray_partition_get_data_label(ray.get(dataset), allow_tuple=True, allow_list=False) def dataset_fn(input_context): dataset = tf.data.Dataset.from_tensor_slices((data, label)) options = tf.data.Options() options.experimental_distribute.auto_shard_policy = \ tf.data.experimental.AutoShardPolicy.OFF dataset = dataset.with_options(options) dataset = dataset.repeat() dataset = dataset.take(steps * local_batch_size) if shuffle: dataset = dataset.shuffle(local_batch_size * min(steps, 10)) dataset = dataset.batch(local_batch_size) return dataset from tensorflow.python.distribute import distribution_strategy_context as ds_context strategy = ds_context.get_strategy() dataset = strategy.experimental_distribute_datasets_from_function(dataset_fn) return dataset
def train(self, train_data, epochs=1, batch_size=32, validation_data=None, train_resize_batch_num=None): """Train the model and update the model parameters.""" stats = dict() if self.is_worker: from zoo.orca.data.shard import RayPartition if isinstance(train_data, RayPartition): from zoo.orca.data.utils import ray_partition_get_data_label data, label = ray_partition_get_data_label( train_data.get_data(), allow_tuple=False, allow_list=False) train_data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=batch_size, shuffle=True) if train_resize_batch_num is not None: train_data_iter = mx.io.ResizeIter(train_data_iter, train_resize_batch_num) if validation_data: data_val, label_val = ray_partition_get_data_label( validation_data.get_data(), allow_tuple=False, allow_list=False) val_data_iter = mx.io.NDArrayIter(data=data_val, label=label_val, batch_size=batch_size, shuffle=True) else: val_data_iter = None else: # data_creator functions; should return Iter or DataLoader config = self.config if "batch_size" not in config: config["batch_size"] = batch_size train_data_iter = train_data(config, self.kv) val_data_iter = validation_data( config, self.kv) if validation_data else None start_time = time.time() if self.trainer: # Imperative API for epoch in range(epochs): train_data_iter.reset() if self.eval_metrics: self.eval_metrics.reset( ) # metrics will accumulate for one batch batch_start_time = time.time() epoch_start_time = time.time() for i, batch in enumerate(train_data_iter): data = gluon.utils.split_and_load( batch.data[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0) label = gluon.utils.split_and_load( batch.label[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0) outputs = [] Ls = [] from mxnet import autograd as ag with ag.record(): for x, y in zip(data, label): z = self.model(x) # forward L = self.loss(z, y) # store the loss and do backward on a batch for better speed Ls.append(L) outputs.append(z) ag.backward(Ls) self.trainer.step(batch.data[0].shape[0]) if self.eval_metrics: self.eval_metrics.update(label, outputs) if not (i + 1) % self.config["log_interval"]: # This would be logged on driver for each worker process. iteration_log = \ "Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f" \ % (epoch, i, batch_size / (time.time() - batch_start_time), "loss", Ls[0].asnumpy().mean()) if self.eval_metrics: names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): iteration_log += " %s=%f" % (name, acc) self.logger.info(iteration_log) batch_start_time = time.time() # Epoch time log self.logger.info("[Epoch %d] time cost: %f" % (epoch, time.time() - epoch_start_time)) # Epoch metrics log on train data if self.eval_metrics: epoch_train_log = "[Epoch %d] training: " % epoch names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): epoch_train_log += "%s=%f " % (name, acc) self.logger.info(epoch_train_log) # Epoch metrics log on validation data if any: if val_data_iter: self.val_metrics.reset() val_data_iter.reset() for batch in val_data_iter: data = gluon.utils.split_and_load( batch.data[0].astype("float32", copy=False), ctx_list=[mx.cpu()], batch_axis=0) label = gluon.utils.split_and_load( batch.label[0].astype("float32", copy=False), ctx_list=[mx.cpu()], batch_axis=0) outputs = [self.model(X) for X in data] self.val_metrics.update(label, outputs) epoch_val_log = "[Epoch %d] validation: " % epoch names, accs = self.val_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): epoch_val_log += "%s=%f " % (name, acc) self.logger.info(epoch_val_log) # TODO: save checkpoints if self.eval_metrics: names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): stats[name] = acc else: # Symbolic API # TODO: seems no history (i.e. validation accuracy) returned by fit? if "init" not in self.config: from mxnet.initializer import Uniform self.config["init"] = Uniform( 0.01) # This is the default value for MXNet if self.eval_metrics is None: self.eval_metrics = 'acc' self.model.fit( train_data=train_data_iter, num_epoch=epochs, initializer=self.config["init"], kvstore=self.kv, optimizer=self.config["optimizer"], optimizer_params=self.config["optimizer_params"], eval_data=val_data_iter, eval_metric=self.eval_metrics, validation_metric=self.val_metrics, batch_end_callback=mx.callback.Speedometer( batch_size, self.config["log_interval"]), epoch_end_callback=None if "model" not in self.config else mx.callback.do_checkpoint(self.config["model"])) epoch_time = time.time() - start_time stats["epoch_time"] = epoch_time if isinstance(train_data, RayPartition): del train_data if validation_data and isinstance(validation_data, RayPartition): del validation_data return stats