コード例 #1
0
 def train(self, nb_epoch=1):
     """Train the model and update the model parameters."""
     stats = dict()
     if self.is_worker:
         start_time = time.time()
         if self.trainer:  # Imperative API
             for epoch in range(nb_epoch):
                 self.train_data.reset()
                 if self.metrics:
                     self.metrics.reset()  # metrics will accumulate for one batch
                 batch_start_time = time.time()
                 epoch_start_time = time.time()
                 for i, batch in enumerate(self.train_data):
                     data = gluon.utils.split_and_load(
                         batch.data[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0)
                     label = gluon.utils.split_and_load(
                         batch.label[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0)
                     outputs = []
                     Ls = []
                     from mxnet import autograd as ag
                     with ag.record():
                         for x, y in zip(data, label):
                             z = self.model(x)  # forward
                             L = self.loss(z, y)
                             # store the loss and do backward on a batch for better speed
                             Ls.append(L)
                             outputs.append(z)
                         ag.backward(Ls)
                     self.trainer.step(batch.data[0].shape[0])
                     if self.metrics:
                         self.metrics.update(label, outputs)
                     if not (i + 1) % self.config["log_interval"]:
                         # This would be logged on driver for each worker process.
                         iteration_log = \
                             "Epoch[%d] Batch[%d]  Speed: %f samples/sec  %s=%f" \
                             % (epoch, i,
                                self.config["batch_size"] / (time.time() - batch_start_time),
                                "loss", Ls[0].asnumpy().mean())
                         if self.metrics:
                             names, accs = self.metrics.get()
                             names, accs = to_list(names), to_list(accs)
                             for name, acc in zip(names, accs):
                                 iteration_log += "  %s=%f" % (name, acc)
                         self.logger.info(iteration_log)
                     batch_start_time = time.time()
                 # Epoch time log
                 self.logger.info("[Epoch %d] time cost: %f" %
                                  (epoch, time.time() - epoch_start_time))
                 # Epoch metrics log on train data
                 if self.metrics:
                     epoch_train_log = "[Epoch %d] training: " % epoch
                     names, accs = self.metrics.get()
                     names, accs = to_list(names), to_list(accs)
                     for name, acc in zip(names, accs):
                         epoch_train_log += "%s=%f  " % (name, acc)
                     self.logger.info(epoch_train_log)
                 # Epoch metrics log on validation data if any:
                 if self.val_data:
                     self.metrics.reset()
                     self.val_data.reset()
                     for batch in self.val_data:
                         data = gluon.utils.split_and_load(
                             batch.data[0].astype("float32", copy=False),
                             ctx_list=[mx.cpu()], batch_axis=0)
                         label = gluon.utils.split_and_load(
                             batch.label[0].astype("float32", copy=False),
                             ctx_list=[mx.cpu()], batch_axis=0)
                         outputs = [self.model(X) for X in data]
                         self.metrics.update(label, outputs)
                     epoch_val_log = "[Epoch %d] validation: " % epoch
                     names, accs = self.metrics.get()
                     names, accs = to_list(names), to_list(accs)
                     for name, acc in zip(names, accs):
                         epoch_val_log += "%s=%f  " % (name, acc)
                     self.logger.info(epoch_val_log)
                 # TODO: save checkpoints
             if self.metrics:
                 names, accs = self.metrics.get()
                 names, accs = to_list(names), to_list(accs)
                 for name, acc in zip(names, accs):
                     stats[name] = acc
         else:  # Symbolic API
             # TODO: seems no history (i.e. validation accuracy) returned by fit?
             if "init" not in self.config:
                 from mxnet.initializer import Uniform
                 self.config["init"] = Uniform(0.01)  # This is the default value for MXNet
             self.model.fit(train_data=self.train_data,
                            num_epoch=nb_epoch,
                            initializer=self.config["init"],
                            kvstore=self.kv,
                            optimizer=self.config["optimizer"],
                            optimizer_params=self.config["optimizer_params"],
                            eval_data=self.val_data,
                            # TODO: eval and validation metrics could be different
                            eval_metric=self.metrics,
                            validation_metric=self.metrics,
                            batch_end_callback=mx.callback.Speedometer(
                                self.config["batch_size"], self.config["log_interval"]),
                            epoch_end_callback=None if "model" not in self.config
                            else mx.callback.do_checkpoint(self.config["model"]))
         epoch_time = time.time() - start_time
         stats["epoch_time"] = epoch_time
     return stats
コード例 #2
0
 def train(self,
           train_data,
           epochs=1,
           batch_size=32,
           validation_data=None,
           train_resize_batch_num=None):
     """Train the model and update the model parameters."""
     stats = dict()
     if self.is_worker:
         from zoo.orca.data.shard import RayPartition
         if isinstance(train_data, RayPartition):
             from zoo.orca.data.utils import ray_partition_get_data_label
             data, label = ray_partition_get_data_label(
                 train_data.get_data(), allow_tuple=False, allow_list=False)
             train_data_iter = mx.io.NDArrayIter(data=data,
                                                 label=label,
                                                 batch_size=batch_size,
                                                 shuffle=True)
             if train_resize_batch_num is not None:
                 train_data_iter = mx.io.ResizeIter(train_data_iter,
                                                    train_resize_batch_num)
             if validation_data:
                 data_val, label_val = ray_partition_get_data_label(
                     validation_data.get_data(),
                     allow_tuple=False,
                     allow_list=False)
                 val_data_iter = mx.io.NDArrayIter(data=data_val,
                                                   label=label_val,
                                                   batch_size=batch_size,
                                                   shuffle=True)
             else:
                 val_data_iter = None
         else:  # data_creator functions; should return Iter or DataLoader
             config = self.config
             if "batch_size" not in config:
                 config["batch_size"] = batch_size
             train_data_iter = train_data(config, self.kv)
             val_data_iter = validation_data(
                 config, self.kv) if validation_data else None
         start_time = time.time()
         if self.trainer:  # Imperative API
             for epoch in range(epochs):
                 train_data_iter.reset()
                 if self.eval_metrics:
                     self.eval_metrics.reset(
                     )  # metrics will accumulate for one batch
                 batch_start_time = time.time()
                 epoch_start_time = time.time()
                 for i, batch in enumerate(train_data_iter):
                     data = gluon.utils.split_and_load(
                         batch.data[0].astype("float32"),
                         ctx_list=[mx.cpu()],
                         batch_axis=0)
                     label = gluon.utils.split_and_load(
                         batch.label[0].astype("float32"),
                         ctx_list=[mx.cpu()],
                         batch_axis=0)
                     outputs = []
                     Ls = []
                     from mxnet import autograd as ag
                     with ag.record():
                         for x, y in zip(data, label):
                             z = self.model(x)  # forward
                             L = self.loss(z, y)
                             # store the loss and do backward on a batch for better speed
                             Ls.append(L)
                             outputs.append(z)
                         ag.backward(Ls)
                     self.trainer.step(batch.data[0].shape[0])
                     if self.eval_metrics:
                         self.eval_metrics.update(label, outputs)
                     if not (i + 1) % self.config["log_interval"]:
                         # This would be logged on driver for each worker process.
                         iteration_log = \
                             "Epoch[%d] Batch[%d]  Speed: %f samples/sec  %s=%f" \
                             % (epoch, i,
                                batch_size / (time.time() - batch_start_time),
                                "loss", Ls[0].asnumpy().mean())
                         if self.eval_metrics:
                             names, accs = self.eval_metrics.get()
                             names, accs = to_list(names), to_list(accs)
                             for name, acc in zip(names, accs):
                                 iteration_log += "  %s=%f" % (name, acc)
                         self.logger.info(iteration_log)
                     batch_start_time = time.time()
                 # Epoch time log
                 self.logger.info("[Epoch %d] time cost: %f" %
                                  (epoch, time.time() - epoch_start_time))
                 # Epoch metrics log on train data
                 if self.eval_metrics:
                     epoch_train_log = "[Epoch %d] training: " % epoch
                     names, accs = self.eval_metrics.get()
                     names, accs = to_list(names), to_list(accs)
                     for name, acc in zip(names, accs):
                         epoch_train_log += "%s=%f  " % (name, acc)
                     self.logger.info(epoch_train_log)
                 # Epoch metrics log on validation data if any:
                 if val_data_iter:
                     self.val_metrics.reset()
                     val_data_iter.reset()
                     for batch in val_data_iter:
                         data = gluon.utils.split_and_load(
                             batch.data[0].astype("float32", copy=False),
                             ctx_list=[mx.cpu()],
                             batch_axis=0)
                         label = gluon.utils.split_and_load(
                             batch.label[0].astype("float32", copy=False),
                             ctx_list=[mx.cpu()],
                             batch_axis=0)
                         outputs = [self.model(X) for X in data]
                         self.val_metrics.update(label, outputs)
                     epoch_val_log = "[Epoch %d] validation: " % epoch
                     names, accs = self.val_metrics.get()
                     names, accs = to_list(names), to_list(accs)
                     for name, acc in zip(names, accs):
                         epoch_val_log += "%s=%f  " % (name, acc)
                     self.logger.info(epoch_val_log)
                 # TODO: save checkpoints
             if self.eval_metrics:
                 names, accs = self.eval_metrics.get()
                 names, accs = to_list(names), to_list(accs)
                 for name, acc in zip(names, accs):
                     stats[name] = acc
         else:  # Symbolic API
             # TODO: seems no history (i.e. validation accuracy) returned by fit?
             if "init" not in self.config:
                 from mxnet.initializer import Uniform
                 self.config["init"] = Uniform(
                     0.01)  # This is the default value for MXNet
             if self.eval_metrics is None:
                 self.eval_metrics = 'acc'
             self.model.fit(
                 train_data=train_data_iter,
                 num_epoch=epochs,
                 initializer=self.config["init"],
                 kvstore=self.kv,
                 optimizer=self.config["optimizer"],
                 optimizer_params=self.config["optimizer_params"],
                 eval_data=val_data_iter,
                 eval_metric=self.eval_metrics,
                 validation_metric=self.val_metrics,
                 batch_end_callback=mx.callback.Speedometer(
                     batch_size, self.config["log_interval"]),
                 epoch_end_callback=None if "model" not in self.config else
                 mx.callback.do_checkpoint(self.config["model"]))
         epoch_time = time.time() - start_time
         stats["epoch_time"] = epoch_time
         if isinstance(train_data, RayPartition):
             del train_data
         if validation_data and isinstance(validation_data, RayPartition):
             del validation_data
     return stats
コード例 #3
0
    def train(self,
              train_data,
              epochs=1,
              batch_size=32,
              validation_data=None,
              train_resize_batch_num=None):
        """Train the model and update the model parameters."""
        stats = dict()
        if self.is_worker:
            config = self.config.copy()
            if "batch_size" not in config:
                config["batch_size"] = batch_size

            if train_resize_batch_num is not None:
                config["train_resize_batch_num"] = train_resize_batch_num
            train_data_iter = train_data(config, self.kv)
            val_data_iter = validation_data(
                config, self.kv) if validation_data else None

            start_time = time.time()
            if self.trainer:  # Imperative API

                def cpu_context(target_data):
                    if isinstance(target_data, list):
                        return [cpu_context(d) for d in target_data]
                    else:
                        return target_data.as_in_context(mx.cpu())

                for epoch in range(epochs):
                    # DataLoader doesn't need to be reset.
                    if isinstance(train_data_iter, mx.io.DataIter):
                        train_data_iter.reset()
                    if self.eval_metrics:
                        self.eval_metrics.reset(
                        )  # metrics will accumulate for one batch.
                    batch_start_time = time.time()
                    epoch_start_time = time.time()
                    for i, batch in enumerate(train_data_iter):
                        data = cpu_context(batch.data)
                        label = cpu_context(batch.label)
                        if not isinstance(data, list):
                            data = [data]
                        if not isinstance(label, list):
                            label = [label]
                        from mxnet import autograd as ag
                        with ag.record():
                            output = self.model(*data)  # forward
                            if not isinstance(output, list):
                                output = [output]
                            Ls = self.loss(*output, *label)
                            ag.backward(Ls)
                        self.trainer.step(batch_size)
                        if self.eval_metrics:
                            self.eval_metrics.update(label, output)
                        if not (i + 1) % self.config["log_interval"]:
                            # This would be logged on driver for each worker process.
                            iteration_log = \
                                "Epoch[%d] Batch[%d]  Speed: %f samples/sec  %s=%f" \
                                % (epoch, i,
                                   batch_size / (time.time() - batch_start_time),
                                   "loss", Ls.asnumpy().mean())
                            if self.eval_metrics:
                                names, accs = self.eval_metrics.get()
                                names, accs = to_list(names), to_list(accs)
                                for name, acc in zip(names, accs):
                                    iteration_log += "  %s=%f" % (name, acc)
                            self.logger.info(iteration_log)
                        batch_start_time = time.time()
                    # Epoch time log.
                    self.logger.info("[Epoch %d] time cost: %f" %
                                     (epoch, time.time() - epoch_start_time))
                    # Epoch metrics log on train data.
                    if self.eval_metrics:
                        epoch_train_log = "[Epoch %d] training: " % epoch
                        names, accs = self.eval_metrics.get()
                        names, accs = to_list(names), to_list(accs)
                        for name, acc in zip(names, accs):
                            epoch_train_log += "%s=%f  " % (name, acc)
                        self.logger.info(epoch_train_log)
                    # Epoch metrics log on validation data if any.
                    if val_data_iter:
                        if isinstance(val_data_iter, mx.io.DataIter):
                            val_data_iter.reset()
                        self.val_metrics.reset()
                        for batch in val_data_iter:
                            data = cpu_context(batch.data)
                            label = cpu_context(batch.label)
                            if not isinstance(data, list):
                                data = [data]
                            if not isinstance(label, list):
                                label = [label]
                            output = self.model(*data)
                            if not isinstance(output, list):
                                output = [output]
                            self.val_metrics.update(label, output)
                        epoch_val_log = "[Epoch %d] validation: " % epoch
                        names, accs = self.val_metrics.get()
                        names, accs = to_list(names), to_list(accs)
                        for name, acc in zip(names, accs):
                            epoch_val_log += "%s=%f  " % (name, acc)
                        self.logger.info(epoch_val_log)
                    # TODO: save checkpoints
                if self.eval_metrics:
                    names, accs = self.eval_metrics.get()
                    names, accs = to_list(names), to_list(accs)
                    for name, acc in zip(names, accs):
                        stats[name] = acc
            else:  # Symbolic API
                # TODO: seems no history (i.e. validation accuracy) returned by fit?
                if "init" not in self.config:
                    from mxnet.initializer import Uniform
                    self.config["init"] = Uniform(
                        0.01)  # This is the default value for MXNet.
                if self.eval_metrics is None:
                    self.eval_metrics = 'acc'  # This is the default value for MXNet.
                self.model.fit(
                    train_data=train_data_iter,
                    num_epoch=epochs,
                    initializer=self.config["init"],
                    kvstore=self.kv,
                    optimizer=self.config["optimizer"],
                    optimizer_params=self.config["optimizer_params"],
                    eval_data=val_data_iter,
                    eval_metric=self.eval_metrics,
                    validation_metric=self.val_metrics,
                    batch_end_callback=mx.callback.Speedometer(
                        batch_size, self.config["log_interval"]),
                    epoch_end_callback=None if "model" not in self.config else
                    mx.callback.do_checkpoint(self.config["model"]))
            epoch_time = time.time() - start_time
            stats["epoch_time"] = epoch_time
        return [stats]