def train(self, nb_epoch=1): """Train the model and update the model parameters.""" stats = dict() if self.is_worker: start_time = time.time() if self.trainer: # Imperative API for epoch in range(nb_epoch): self.train_data.reset() if self.metrics: self.metrics.reset() # metrics will accumulate for one batch batch_start_time = time.time() epoch_start_time = time.time() for i, batch in enumerate(self.train_data): data = gluon.utils.split_and_load( batch.data[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0) label = gluon.utils.split_and_load( batch.label[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0) outputs = [] Ls = [] from mxnet import autograd as ag with ag.record(): for x, y in zip(data, label): z = self.model(x) # forward L = self.loss(z, y) # store the loss and do backward on a batch for better speed Ls.append(L) outputs.append(z) ag.backward(Ls) self.trainer.step(batch.data[0].shape[0]) if self.metrics: self.metrics.update(label, outputs) if not (i + 1) % self.config["log_interval"]: # This would be logged on driver for each worker process. iteration_log = \ "Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f" \ % (epoch, i, self.config["batch_size"] / (time.time() - batch_start_time), "loss", Ls[0].asnumpy().mean()) if self.metrics: names, accs = self.metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): iteration_log += " %s=%f" % (name, acc) self.logger.info(iteration_log) batch_start_time = time.time() # Epoch time log self.logger.info("[Epoch %d] time cost: %f" % (epoch, time.time() - epoch_start_time)) # Epoch metrics log on train data if self.metrics: epoch_train_log = "[Epoch %d] training: " % epoch names, accs = self.metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): epoch_train_log += "%s=%f " % (name, acc) self.logger.info(epoch_train_log) # Epoch metrics log on validation data if any: if self.val_data: self.metrics.reset() self.val_data.reset() for batch in self.val_data: data = gluon.utils.split_and_load( batch.data[0].astype("float32", copy=False), ctx_list=[mx.cpu()], batch_axis=0) label = gluon.utils.split_and_load( batch.label[0].astype("float32", copy=False), ctx_list=[mx.cpu()], batch_axis=0) outputs = [self.model(X) for X in data] self.metrics.update(label, outputs) epoch_val_log = "[Epoch %d] validation: " % epoch names, accs = self.metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): epoch_val_log += "%s=%f " % (name, acc) self.logger.info(epoch_val_log) # TODO: save checkpoints if self.metrics: names, accs = self.metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): stats[name] = acc else: # Symbolic API # TODO: seems no history (i.e. validation accuracy) returned by fit? if "init" not in self.config: from mxnet.initializer import Uniform self.config["init"] = Uniform(0.01) # This is the default value for MXNet self.model.fit(train_data=self.train_data, num_epoch=nb_epoch, initializer=self.config["init"], kvstore=self.kv, optimizer=self.config["optimizer"], optimizer_params=self.config["optimizer_params"], eval_data=self.val_data, # TODO: eval and validation metrics could be different eval_metric=self.metrics, validation_metric=self.metrics, batch_end_callback=mx.callback.Speedometer( self.config["batch_size"], self.config["log_interval"]), epoch_end_callback=None if "model" not in self.config else mx.callback.do_checkpoint(self.config["model"])) epoch_time = time.time() - start_time stats["epoch_time"] = epoch_time return stats
def train(self, train_data, epochs=1, batch_size=32, validation_data=None, train_resize_batch_num=None): """Train the model and update the model parameters.""" stats = dict() if self.is_worker: from zoo.orca.data.shard import RayPartition if isinstance(train_data, RayPartition): from zoo.orca.data.utils import ray_partition_get_data_label data, label = ray_partition_get_data_label( train_data.get_data(), allow_tuple=False, allow_list=False) train_data_iter = mx.io.NDArrayIter(data=data, label=label, batch_size=batch_size, shuffle=True) if train_resize_batch_num is not None: train_data_iter = mx.io.ResizeIter(train_data_iter, train_resize_batch_num) if validation_data: data_val, label_val = ray_partition_get_data_label( validation_data.get_data(), allow_tuple=False, allow_list=False) val_data_iter = mx.io.NDArrayIter(data=data_val, label=label_val, batch_size=batch_size, shuffle=True) else: val_data_iter = None else: # data_creator functions; should return Iter or DataLoader config = self.config if "batch_size" not in config: config["batch_size"] = batch_size train_data_iter = train_data(config, self.kv) val_data_iter = validation_data( config, self.kv) if validation_data else None start_time = time.time() if self.trainer: # Imperative API for epoch in range(epochs): train_data_iter.reset() if self.eval_metrics: self.eval_metrics.reset( ) # metrics will accumulate for one batch batch_start_time = time.time() epoch_start_time = time.time() for i, batch in enumerate(train_data_iter): data = gluon.utils.split_and_load( batch.data[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0) label = gluon.utils.split_and_load( batch.label[0].astype("float32"), ctx_list=[mx.cpu()], batch_axis=0) outputs = [] Ls = [] from mxnet import autograd as ag with ag.record(): for x, y in zip(data, label): z = self.model(x) # forward L = self.loss(z, y) # store the loss and do backward on a batch for better speed Ls.append(L) outputs.append(z) ag.backward(Ls) self.trainer.step(batch.data[0].shape[0]) if self.eval_metrics: self.eval_metrics.update(label, outputs) if not (i + 1) % self.config["log_interval"]: # This would be logged on driver for each worker process. iteration_log = \ "Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f" \ % (epoch, i, batch_size / (time.time() - batch_start_time), "loss", Ls[0].asnumpy().mean()) if self.eval_metrics: names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): iteration_log += " %s=%f" % (name, acc) self.logger.info(iteration_log) batch_start_time = time.time() # Epoch time log self.logger.info("[Epoch %d] time cost: %f" % (epoch, time.time() - epoch_start_time)) # Epoch metrics log on train data if self.eval_metrics: epoch_train_log = "[Epoch %d] training: " % epoch names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): epoch_train_log += "%s=%f " % (name, acc) self.logger.info(epoch_train_log) # Epoch metrics log on validation data if any: if val_data_iter: self.val_metrics.reset() val_data_iter.reset() for batch in val_data_iter: data = gluon.utils.split_and_load( batch.data[0].astype("float32", copy=False), ctx_list=[mx.cpu()], batch_axis=0) label = gluon.utils.split_and_load( batch.label[0].astype("float32", copy=False), ctx_list=[mx.cpu()], batch_axis=0) outputs = [self.model(X) for X in data] self.val_metrics.update(label, outputs) epoch_val_log = "[Epoch %d] validation: " % epoch names, accs = self.val_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): epoch_val_log += "%s=%f " % (name, acc) self.logger.info(epoch_val_log) # TODO: save checkpoints if self.eval_metrics: names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): stats[name] = acc else: # Symbolic API # TODO: seems no history (i.e. validation accuracy) returned by fit? if "init" not in self.config: from mxnet.initializer import Uniform self.config["init"] = Uniform( 0.01) # This is the default value for MXNet if self.eval_metrics is None: self.eval_metrics = 'acc' self.model.fit( train_data=train_data_iter, num_epoch=epochs, initializer=self.config["init"], kvstore=self.kv, optimizer=self.config["optimizer"], optimizer_params=self.config["optimizer_params"], eval_data=val_data_iter, eval_metric=self.eval_metrics, validation_metric=self.val_metrics, batch_end_callback=mx.callback.Speedometer( batch_size, self.config["log_interval"]), epoch_end_callback=None if "model" not in self.config else mx.callback.do_checkpoint(self.config["model"])) epoch_time = time.time() - start_time stats["epoch_time"] = epoch_time if isinstance(train_data, RayPartition): del train_data if validation_data and isinstance(validation_data, RayPartition): del validation_data return stats
def train(self, train_data, epochs=1, batch_size=32, validation_data=None, train_resize_batch_num=None): """Train the model and update the model parameters.""" stats = dict() if self.is_worker: config = self.config.copy() if "batch_size" not in config: config["batch_size"] = batch_size if train_resize_batch_num is not None: config["train_resize_batch_num"] = train_resize_batch_num train_data_iter = train_data(config, self.kv) val_data_iter = validation_data( config, self.kv) if validation_data else None start_time = time.time() if self.trainer: # Imperative API def cpu_context(target_data): if isinstance(target_data, list): return [cpu_context(d) for d in target_data] else: return target_data.as_in_context(mx.cpu()) for epoch in range(epochs): # DataLoader doesn't need to be reset. if isinstance(train_data_iter, mx.io.DataIter): train_data_iter.reset() if self.eval_metrics: self.eval_metrics.reset( ) # metrics will accumulate for one batch. batch_start_time = time.time() epoch_start_time = time.time() for i, batch in enumerate(train_data_iter): data = cpu_context(batch.data) label = cpu_context(batch.label) if not isinstance(data, list): data = [data] if not isinstance(label, list): label = [label] from mxnet import autograd as ag with ag.record(): output = self.model(*data) # forward if not isinstance(output, list): output = [output] Ls = self.loss(*output, *label) ag.backward(Ls) self.trainer.step(batch_size) if self.eval_metrics: self.eval_metrics.update(label, output) if not (i + 1) % self.config["log_interval"]: # This would be logged on driver for each worker process. iteration_log = \ "Epoch[%d] Batch[%d] Speed: %f samples/sec %s=%f" \ % (epoch, i, batch_size / (time.time() - batch_start_time), "loss", Ls.asnumpy().mean()) if self.eval_metrics: names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): iteration_log += " %s=%f" % (name, acc) self.logger.info(iteration_log) batch_start_time = time.time() # Epoch time log. self.logger.info("[Epoch %d] time cost: %f" % (epoch, time.time() - epoch_start_time)) # Epoch metrics log on train data. if self.eval_metrics: epoch_train_log = "[Epoch %d] training: " % epoch names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): epoch_train_log += "%s=%f " % (name, acc) self.logger.info(epoch_train_log) # Epoch metrics log on validation data if any. if val_data_iter: if isinstance(val_data_iter, mx.io.DataIter): val_data_iter.reset() self.val_metrics.reset() for batch in val_data_iter: data = cpu_context(batch.data) label = cpu_context(batch.label) if not isinstance(data, list): data = [data] if not isinstance(label, list): label = [label] output = self.model(*data) if not isinstance(output, list): output = [output] self.val_metrics.update(label, output) epoch_val_log = "[Epoch %d] validation: " % epoch names, accs = self.val_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): epoch_val_log += "%s=%f " % (name, acc) self.logger.info(epoch_val_log) # TODO: save checkpoints if self.eval_metrics: names, accs = self.eval_metrics.get() names, accs = to_list(names), to_list(accs) for name, acc in zip(names, accs): stats[name] = acc else: # Symbolic API # TODO: seems no history (i.e. validation accuracy) returned by fit? if "init" not in self.config: from mxnet.initializer import Uniform self.config["init"] = Uniform( 0.01) # This is the default value for MXNet. if self.eval_metrics is None: self.eval_metrics = 'acc' # This is the default value for MXNet. self.model.fit( train_data=train_data_iter, num_epoch=epochs, initializer=self.config["init"], kvstore=self.kv, optimizer=self.config["optimizer"], optimizer_params=self.config["optimizer_params"], eval_data=val_data_iter, eval_metric=self.eval_metrics, validation_metric=self.val_metrics, batch_end_callback=mx.callback.Speedometer( batch_size, self.config["log_interval"]), epoch_end_callback=None if "model" not in self.config else mx.callback.do_checkpoint(self.config["model"])) epoch_time = time.time() - start_time stats["epoch_time"] = epoch_time return [stats]