def init_optimizer(self, kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), force_init=False): """Install and initialize optimizers. Parameters ---------- kvstore : str or KVStore Default `'local'`. optimizer : str or Optimizer Default `'sgd'` optimizer_params : dict Default `(('learning_rate', 0.01),)`. The default value is not a dictionary, just to avoid pylint warning of dangerous default values. force_init : bool Default `False`, indicating whether we should force re-initializing the optimizer in the case an optimizer is already installed. """ assert self.binded and self.params_initialized if self.optimizer_initialized and not force_init: self.logger.warning('optimizer already initialized, ignoring...') return (kvstore, update_on_kvstore) = \ _create_kvstore(kvstore, len(self._context), self._arg_params) batch_size = self._exec_group.batch_size if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type: batch_size *= kvstore.num_workers rescale_grad = 1.0/batch_size if isinstance(optimizer, str): idx2name = {} if update_on_kvstore: idx2name.update(enumerate(self._exec_group.param_names)) else: for k in range(len(self._context)): idx2name.update({i*len(self._context)+k: n for i, n in enumerate(self._exec_group.param_names)}) optimizer_params = dict(optimizer_params) if 'rescale_grad' not in optimizer_params: optimizer_params['rescale_grad'] = rescale_grad optimizer = opt.create(optimizer, sym=self.symbol, param_idx2name=idx2name, **optimizer_params) else: assert isinstance(optimizer, opt.Optimizer) if optimizer.rescale_grad != rescale_grad: #pylint: disable=no-member warnings.warn( "Optimizer created manually outside Module but rescale_grad " + "is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "%( optimizer.rescale_grad, rescale_grad) + "Is this intended?", stacklevel=2) self._optimizer = optimizer self._kvstore = kvstore self._update_on_kvstore = update_on_kvstore self._updater = None if kvstore: # copy initialized local parameters to kvstore _initialize_kvstore(kvstore=kvstore, param_arrays=self._exec_group.param_arrays, arg_params=self._arg_params, param_names=self._param_names, update_on_kvstore=update_on_kvstore) if update_on_kvstore: kvstore.set_optimizer(self._optimizer) else: self._updater = opt.get_updater(optimizer) self.optimizer_initialized = True if self._preload_opt_states is not None: self.load_optimizer_states(self._preload_opt_states) self._preload_opt_states = None
def _train_rnn( symbol, ctx, marks, arg_names, param_names, aux_names, arg_params, aux_params, begin_epoch, end_epoch, epoch_size, optimizer, kvstore, update_on_kvstore, train_data, e_marks=None, eval_data=None, eval_metric=None, epoch_end_callback=None, batch_end_callback=None, time_step_callback=None, logger=None, work_load_list=None, monitor=None, eval_batch_end_callback=None, sym_gen=None, mutable_data_shape=False, max_data_shape=None): """Mark should be a list of #SeriesLength, annotating if image has label by 1 , 0""" # TODO marks not working if label of SAX is different in one batch if logger is None: logger = logging executor_manager = DataParallelExecutorManager(symbol=symbol, sym_gen=sym_gen, ctx=ctx, train_data=train_data, param_names=param_names, arg_names=arg_names, aux_names=aux_names, work_load_list=work_load_list, logger=logger, mutable_data_shape=mutable_data_shape, max_data_shape=max_data_shape) if monitor: executor_manager.install_monitor(monitor) executor_manager.set_params(arg_params, aux_params) #if not update_on_kvstore: updater = get_updater(optimizer) if kvstore: _initialize_kvstore(kvstore=kvstore, param_arrays=executor_manager.param_arrays, arg_params=arg_params, param_names=executor_manager.param_names, update_on_kvstore=update_on_kvstore) if update_on_kvstore: kvstore.set_optimizer(optimizer) # Now start training train_data.reset() for epoch in range(begin_epoch, end_epoch): # Training phase tic = time.time() eval_metric.reset() nbatch = 0 # Iterate over training data. # Into Epoch ######################### # record acc acc_hist = [] logger.info('Starting New Epoch...') while True: do_reset = True # iter on batch_size for data_batch_zoo in train_data: assert isinstance(data_batch_zoo, list), "Iter Error" if monitor is not None: monitor.tic() # Start to iter on Time steps if isinstance(marks[nbatch], list): M = marks[nbatch] else: M = marks executor_manager, eval_metric, acc_hist = _run_sax( data_batch_zoo, M, executor_manager, eval_metric, updater, ctx, kvstore, acc_hist, monitor=monitor, logger=logger, update_on_kvstore=update_on_kvstore, is_train=True, callback= time_step_callback ) nbatch += 1 # batch callback (for print purpose) if batch_end_callback != None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch, eval_metric=eval_metric, locals=locals()) if isinstance(batch_end_callback, list): for call in batch_end_callback: call(batch_end_params) else: batch_end_callback(batch_end_params) # this epoch is done possibly earlier if epoch_size is not None and nbatch >= epoch_size: do_reset = False break # end on batch_size if do_reset is True: logger.debug('Epoch[%d] Resetting Data Iterator', epoch) train_data.reset() logger.debug('Epoch[%d] Resetting Eval Metric', epoch) eval_metric.reset() # this epoch is done if epoch_size is None or nbatch >= epoch_size: break toc = time.time() logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) if epoch_end_callback or epoch + 1 == end_epoch: executor_manager.copy_to(arg_params, aux_params) if epoch_end_callback != None: if isinstance(epoch_end_callback, list): for call in epoch_end_callback: call(epoch, symbol, arg_params, aux_params, acc_hist) else: epoch_end_callback(epoch, symbol, arg_params, aux_params, acc_hist) # evaluation # print 'enter evaluation' if eval_data: assert e_marks is not None, 'e marks cannot be None' eval_metric.reset() eval_data.reset() for b, eval_zoo in enumerate(eval_data): if isinstance(e_marks[b], list): M = e_marks[b] else: M = e_marks executor_manager, eval_metric, acc_hist = _run_sax( eval_zoo, M, executor_manager, eval_metric, updater, ctx, kvstore, acc_hist, update_on_kvstore=update_on_kvstore, is_train=False) # executor_manager.load_data_batch(eval_batch) # executor_manager.forward(is_train=False) # executor_manager.update_metric(eval_metric, eval_batch.label) if eval_batch_end_callback != None: batch_end_params = BatchEndParam(epoch=epoch, nbatch=i, eval_metric=eval_metric, locals=locals()) if isinstance(eval_batch_end_callback, list): for call in eval_batch_end_callback: call(batch_end_params) else: eval_batch_end_callback(batch_end_params) name_value = eval_metric.get_name_value() for name, value in name_value: logger.info('Epoch[%d] Validation-%s=%f', epoch, name, value) # end of all epochs return
def run(self): data = self.model._init_iter(self.data, None, is_train=True) arg_names, param_names, aux_names = \ self.model._init_params(dict(data.provide_data+data.provide_label)) # create kvstore (kvstore, update_on_kvstore) = _create_kvstore( self.kv, len(self.ctxs), self.model.arg_params) self.executor_manager = DataParallelExecutorManager(symbol=self.sym, ctx=self.ctxs, train_data=self.data, param_names=param_names, arg_names=arg_names, aux_names=aux_names, logger=logger) self.executor_manager.set_params(self.model.arg_params, self.model.aux_params) if not update_on_kvstore: updater = get_updater(optimizer) if kvstore: _initialize_kvstore(kvstore=kvstore, param_arrays=self.executor_manager.param_arrays, arg_params=self.model.arg_params, param_names=self.executor_manager.param_names, update_on_kvstore=update_on_kvstore) if update_on_kvstore: kvstore.set_optimizer(self.optimizer) for e in self.before_training_extensions: e(self) while True: self.metric.reset() nbatch = 0 self.data.reset() for data_batch in self.data: self.executor_manager.load_data_batch(data_batch) self.executor_manager.forward(is_train=True) self.executor_manager.backward() if update_on_kvstore: _update_params_on_kvstore(self.executor_manager.param_arrays, self.executor_manager.grad_arrays, kvstore) else: _update_params(self.executor_manager.param_arrays, self.executor_manager.grad_arrays, updater=updater, num_device=len(self.model.ctx), kvstore=kvstore) # evaluate at end, so out_cpu_array can lazy copy self.metric.update(data_batch.label, self.executor_manager.cpu_output_arrays) self.status['iterations'] += 1 self.status['epoch_iterations'] += 1 self.log[self.status['iterations']] = dict(iterations=self.status['iterations']) self.current_log = self.log[self.status['iterations']] for e in self.batch_extensions: e(self) nbatch += 1 self.status['epochs'] += 1 self.status['epoch_iterations'] = 0 for e in self.epoch_extensions: e(self)