Exemplo n.º 1
0
    def data_creator(config, batch_size):
        from zoo.orca.data.utils import ray_partition_get_data_label, index_data, get_size
        from torch.utils.data import Dataset, DataLoader

        class NDArrayDataset(Dataset):
            def __init__(self, x, y):
                self.x = x  # features
                self.y = y  # labels

            def __len__(self):
                return get_size(self.y)

            def __getitem__(self, i):
                return index_data(self.x, i), index_data(self.y, i)

        params = {"batch_size": batch_size, "shuffle": True}
        for arg in ["shuffle", "sampler", "batch_sampler", "num_workers", "collate_fn",
                    "pin_memory", "drop_last", "timeout", "worker_init_fn",
                    "multiprocessing_context"]:
            if arg in config:
                params[arg] = config[arg]
        data, label = ray_partition_get_data_label(ray.get(shards_ref),
                                                   allow_tuple=False,
                                                   allow_list=False)
        print("Data size on worker: ", len(label))
        dataset = NDArrayDataset(data, label)
        data_loader = DataLoader(dataset, **params)
        return data_loader
Exemplo n.º 2
0
 def _handle_xshards(self, dataset, steps, local_batch_size, shuffle):
     import tensorflow as tf
     data, label = ray_partition_get_data_label(ray.get(dataset),
                                                allow_tuple=True,
                                                allow_list=False)
     dataset = tf.data.Dataset.from_tensor_slices((data, label))
     dataset = dataset.repeat()
     dataset = dataset.take(steps * local_batch_size)
     if shuffle:
         dataset = dataset.shuffle(local_batch_size * min(steps, 10))
     dataset = dataset.batch(local_batch_size)
     return dataset
Exemplo n.º 3
0
    def data_creator(config, kv):
        import mxnet as mx
        assert "batch_size" in config, "batch_size must be set in config"
        data, label = ray_partition_get_data_label(ray.get(shards_ref),
                                                   allow_tuple=False,
                                                   allow_list=False)

        train_data_iter = mx.io.NDArrayIter(data=data, label=label,
                                            batch_size=config["batch_size"],
                                            shuffle=shuffle)
        if "train_resize_batch_num" in config:
            train_data_iter = mx.io.ResizeIter(train_data_iter,
                                               config["train_resize_batch_num"])
        return train_data_iter
Exemplo n.º 4
0
 def _handle_xshards(self, dataset, steps, local_batch_size, shuffle):
     import tensorflow as tf
     data, label = ray_partition_get_data_label(ray.get(dataset),
                                                allow_tuple=True,
                                                allow_list=False)
     dataset = tf.data.Dataset.from_tensor_slices((data, label))
     options = tf.data.Options()
     options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
     dataset = dataset.with_options(options)
     dataset = dataset.repeat()
     dataset = dataset.take(steps * local_batch_size)
     if shuffle:
         dataset = dataset.shuffle(local_batch_size * min(steps, 10))
     dataset = dataset.batch(local_batch_size)
     return dataset
Exemplo n.º 5
0
    def data_creator(config):
        from zoo.orca.data.utils import ray_partition_get_data_label, index_data, get_size
        from torch.utils.data import Dataset, DataLoader

        assert "batch_size" in config, "batch_size must be set in config"
        params = {"batch_size": config["batch_size"], "shuffle": True}
        for arg in ["shuffle", "sampler", "batch_sampler", "num_workers", "collate_fn",
                    "pin_memory", "drop_last", "timeout", "worker_init_fn",
                    "multiprocessing_context"]:
            if arg in config:
                params[arg] = config[arg]
        data, label = ray_partition_get_data_label(ray.get(shards_ref),
                                                   allow_tuple=False,
                                                   allow_list=False)
        print("Data size on worker: ", len(label))
        dataset = torch.utils.data.TensorDataset(torch.from_numpy(data), torch.from_numpy(label))
        data_loader = DataLoader(dataset, **params)
        return data_loader
Exemplo n.º 6
0
    def data_creator(config):
        assert "batch_size" in config, "batch_size must be set in config"
        import tensorflow as tf
        data, label = ray_partition_get_data_label(partition.get_data(),
                                                   allow_tuple=True,
                                                   allow_list=False)

        dataset = tf.data.Dataset.from_tensor_slices((data, label))
        if max_length is not None:
            # todo find a way to pad empty tensors?
            dataset = dataset.repeat()
            if shuffle:
                dataset = dataset.shuffle(max_length)
            dataset = dataset.take(max_length)
        options = tf.data.Options()
        options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
        dataset = dataset.with_options(options)

        dataset = dataset.batch(config["batch_size"])
        return dataset
Exemplo n.º 7
0
    def _handle_xshards(self, dataset, steps, local_batch_size, shuffle):
        import tensorflow as tf
        data, label = ray_partition_get_data_label(ray.get(dataset),
                                                   allow_tuple=True,
                                                   allow_list=False)

        def dataset_fn(input_context):
            dataset = tf.data.Dataset.from_tensor_slices((data, label))
            options = tf.data.Options()
            options.experimental_distribute.auto_shard_policy = \
                tf.data.experimental.AutoShardPolicy.OFF
            dataset = dataset.with_options(options)
            dataset = dataset.repeat()
            dataset = dataset.take(steps * local_batch_size)
            if shuffle:
                dataset = dataset.shuffle(local_batch_size * min(steps, 10))
            dataset = dataset.batch(local_batch_size)
            return dataset

        from tensorflow.python.distribute import distribution_strategy_context as ds_context
        strategy = ds_context.get_strategy()
        dataset = strategy.experimental_distribute_datasets_from_function(dataset_fn)
        return dataset
Exemplo n.º 8
0
 def train(self,
           train_data,
           epochs=1,
           batch_size=32,
           validation_data=None,
           train_resize_batch_num=None):
     """Train the model and update the model parameters."""
     stats = dict()
     if self.is_worker:
         from zoo.orca.data.shard import RayPartition
         if isinstance(train_data, RayPartition):
             from zoo.orca.data.utils import ray_partition_get_data_label
             data, label = ray_partition_get_data_label(
                 train_data.get_data(), allow_tuple=False, allow_list=False)
             train_data_iter = mx.io.NDArrayIter(data=data,
                                                 label=label,
                                                 batch_size=batch_size,
                                                 shuffle=True)
             if train_resize_batch_num is not None:
                 train_data_iter = mx.io.ResizeIter(train_data_iter,
                                                    train_resize_batch_num)
             if validation_data:
                 data_val, label_val = ray_partition_get_data_label(
                     validation_data.get_data(),
                     allow_tuple=False,
                     allow_list=False)
                 val_data_iter = mx.io.NDArrayIter(data=data_val,
                                                   label=label_val,
                                                   batch_size=batch_size,
                                                   shuffle=True)
             else:
                 val_data_iter = None
         else:  # data_creator functions; should return Iter or DataLoader
             config = self.config
             if "batch_size" not in config:
                 config["batch_size"] = batch_size
             train_data_iter = train_data(config, self.kv)
             val_data_iter = validation_data(
                 config, self.kv) if validation_data else None
         start_time = time.time()
         if self.trainer:  # Imperative API
             for epoch in range(epochs):
                 train_data_iter.reset()
                 if self.eval_metrics:
                     self.eval_metrics.reset(
                     )  # metrics will accumulate for one batch
                 batch_start_time = time.time()
                 epoch_start_time = time.time()
                 for i, batch in enumerate(train_data_iter):
                     data = gluon.utils.split_and_load(
                         batch.data[0].astype("float32"),
                         ctx_list=[mx.cpu()],
                         batch_axis=0)
                     label = gluon.utils.split_and_load(
                         batch.label[0].astype("float32"),
                         ctx_list=[mx.cpu()],
                         batch_axis=0)
                     outputs = []
                     Ls = []
                     from mxnet import autograd as ag
                     with ag.record():
                         for x, y in zip(data, label):
                             z = self.model(x)  # forward
                             L = self.loss(z, y)
                             # store the loss and do backward on a batch for better speed
                             Ls.append(L)
                             outputs.append(z)
                         ag.backward(Ls)
                     self.trainer.step(batch.data[0].shape[0])
                     if self.eval_metrics:
                         self.eval_metrics.update(label, outputs)
                     if not (i + 1) % self.config["log_interval"]:
                         # This would be logged on driver for each worker process.
                         iteration_log = \
                             "Epoch[%d] Batch[%d]  Speed: %f samples/sec  %s=%f" \
                             % (epoch, i,
                                batch_size / (time.time() - batch_start_time),
                                "loss", Ls[0].asnumpy().mean())
                         if self.eval_metrics:
                             names, accs = self.eval_metrics.get()
                             names, accs = to_list(names), to_list(accs)
                             for name, acc in zip(names, accs):
                                 iteration_log += "  %s=%f" % (name, acc)
                         self.logger.info(iteration_log)
                     batch_start_time = time.time()
                 # Epoch time log
                 self.logger.info("[Epoch %d] time cost: %f" %
                                  (epoch, time.time() - epoch_start_time))
                 # Epoch metrics log on train data
                 if self.eval_metrics:
                     epoch_train_log = "[Epoch %d] training: " % epoch
                     names, accs = self.eval_metrics.get()
                     names, accs = to_list(names), to_list(accs)
                     for name, acc in zip(names, accs):
                         epoch_train_log += "%s=%f  " % (name, acc)
                     self.logger.info(epoch_train_log)
                 # Epoch metrics log on validation data if any:
                 if val_data_iter:
                     self.val_metrics.reset()
                     val_data_iter.reset()
                     for batch in val_data_iter:
                         data = gluon.utils.split_and_load(
                             batch.data[0].astype("float32", copy=False),
                             ctx_list=[mx.cpu()],
                             batch_axis=0)
                         label = gluon.utils.split_and_load(
                             batch.label[0].astype("float32", copy=False),
                             ctx_list=[mx.cpu()],
                             batch_axis=0)
                         outputs = [self.model(X) for X in data]
                         self.val_metrics.update(label, outputs)
                     epoch_val_log = "[Epoch %d] validation: " % epoch
                     names, accs = self.val_metrics.get()
                     names, accs = to_list(names), to_list(accs)
                     for name, acc in zip(names, accs):
                         epoch_val_log += "%s=%f  " % (name, acc)
                     self.logger.info(epoch_val_log)
                 # TODO: save checkpoints
             if self.eval_metrics:
                 names, accs = self.eval_metrics.get()
                 names, accs = to_list(names), to_list(accs)
                 for name, acc in zip(names, accs):
                     stats[name] = acc
         else:  # Symbolic API
             # TODO: seems no history (i.e. validation accuracy) returned by fit?
             if "init" not in self.config:
                 from mxnet.initializer import Uniform
                 self.config["init"] = Uniform(
                     0.01)  # This is the default value for MXNet
             if self.eval_metrics is None:
                 self.eval_metrics = 'acc'
             self.model.fit(
                 train_data=train_data_iter,
                 num_epoch=epochs,
                 initializer=self.config["init"],
                 kvstore=self.kv,
                 optimizer=self.config["optimizer"],
                 optimizer_params=self.config["optimizer_params"],
                 eval_data=val_data_iter,
                 eval_metric=self.eval_metrics,
                 validation_metric=self.val_metrics,
                 batch_end_callback=mx.callback.Speedometer(
                     batch_size, self.config["log_interval"]),
                 epoch_end_callback=None if "model" not in self.config else
                 mx.callback.do_checkpoint(self.config["model"]))
         epoch_time = time.time() - start_time
         stats["epoch_time"] = epoch_time
         if isinstance(train_data, RayPartition):
             del train_data
         if validation_data and isinstance(validation_data, RayPartition):
             del validation_data
     return stats