Beispiel #1
0
    def eval_with_image_folder(self, eval_params):
        image_list = eval_params.image_list
        data_process_func = eval_params.data_process_func
        label_list = eval_params.lable_list
        log_interval = eval_params.log_interval
        metrics = eval_params.metrics
        show = eval_params.show
        show_func = eval_params.show_func
        save = eval_params.save
        save_func = eval_params.save_func

        callback = mx.callback.Speedometer(1, log_interval)
        num_samples = len(image_list)
        for idx, image_file in enumerate(image_list):
            if label_list is not None:
                label_file = label_list[idx]
            else:
                label_file = None
            datas, labels = data_process_func(image=image_file,
                                              label=label_file)
            preds = self.net(datas)

            self.update_metrics(preds, labels, metrics)

            callback_params = BatchEndParam(epoch=1,
                                            nbatch=1,
                                            eval_metric=metrics,
                                            locals=locals())
            callback(callback_params)

            if show:
                show_func(image_file, datas, preds, labels)
            if save:
                save_func(image_file, datas, preds, labels)
    def _eval_process(self, mod, eval_iter, metric, callback):
        outputs = []
        metric.reset()
        eval_iter.reset()
        for nbatch, batch in enumerate(eval_iter):
            # forward
            mod.prepare(batch)
            mod.forward(batch, is_train=False)
            # remove padded samples
            num_pad = batch.pad
            batch_label = batch.label
            batch_out = mod.get_outputs()
            if num_pad > 0:
                batch_out = [
                    out[:-num_pad] if len(out) > num_pad else out
                    for out in batch_out
                ]  # shorter is loss
                batch_label = [lab[:-num_pad] for lab in batch_label]
            # metric
            metric.update(batch_label, batch_out)
            # callback
            callback(BatchEndParam(0, nbatch, metric, locals()))
            # save outputs
            if nbatch == 0:
                outputs = [bout.copy() for bout in batch_out]
            else:
                outputs = [
                    mx.nd.concat(out, bout, dim=0)
                    for out, bout in zip(outputs, batch_out)
                ]

        criteria = metric.get_name_value()
        return criteria, outputs
Beispiel #3
0
def callback(metric_list, callback_list):
    for metric in metric_list:
        batch_end_params = BatchEndParam(epoch=epoch,
                                         nbatch=i,
                                         eval_metric=metric,
                                         locals=locals())
        for callback in _as_list(batch_end_callback):
            callback(batch_end_params)
Beispiel #4
0
    def eval_with_dataloader(self, eval_params):
        data_process_func = eval_params.data_process_func
        dataloader = eval_params.dataloader
        metrics = eval_params.metrics
        log_interval = eval_params.log_interval
        batch_size = eval_params.batch_size
        batch_end_callback_list = eval_params.batch_end_callback_list
        show = eval_params.show
        show_func = eval_params.show_func
        save = eval_params.save
        save_func = eval_params.save_func
        save_dir = eval_params.save_dir
        ctxs = eval_params.ctxs
        dtype = eval_params.dtype

        dataloader.reset()

        if batch_end_callback_list is None:
            batch_end_callback_list = [
                mx.callback.Speedometer(batch_size, log_interval)
            ]

        # for nbatch, (datas, labels) in enumerate(dataloader):
        for nbatch, databatch in enumerate(dataloader):
            datas, labels = databatch.data[0], databatch.label[0]

            self.net.cast(dtype)
            gpu_datas = mx.gluon.utils.split_and_load(datas, ctxs)
            gpu_labels = mx.gluon.utils.split_and_load(labels, ctxs)

            for datai, labeli in zip(gpu_datas, gpu_labels):
                datai, labeli = data_process_func(datai, labeli)
                datai = datai.astype(dtype)
                predi = self.forward(datai)

                self.update_metrics(predi, labeli, metrics)

                if show:
                    show_func(datas, predi, labeli)
                if save:
                    save_func(datas,
                              predi,
                              labeli,
                              os.path.join(save_dir, 'eval'),
                              num_count=nbatch * datas.shape[0])

            for batch_end_callback in batch_end_callback_list:
                batch_end_params = BatchEndParam(epoch=1,
                                                 nbatch=nbatch,
                                                 eval_metric=metrics,
                                                 locals=locals())
                batch_end_callback(batch_end_params)
Beispiel #5
0
    def train_epoch(self, epoch):
        # for nbatch, (datas, labels) in enumerate(self.train_dataloader):
        for nbatch, databatch in enumerate(self.train_dataloader):

            datas, labels = databatch.data[0], databatch.label[0]
            self.train_batch(datas, labels)

            # batch-end-callback
            if self.batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=self.metrics,
                                                 locals=locals())
                for callback in self.batch_end_callback:
                    callback(batch_end_params)
    def _train_process(self,
                       mod,
                       epoch,
                       train_iter,
                       valid_iter,
                       metric,
                       optimizer,
                       batch_end_callback,
                       epoch_end_callback,
                       early_stop=30):
        # params extracting
        checkpoint, adapted_lr = epoch_end_callback
        arg_params = self.net_params[
            'args'] if self.net_params is not None else None
        aux_params = self.net_params[
            'auxs'] if self.net_params is not None else None
        # model initialization
        mod.bind(train_iter.provide_data,
                 train_iter.provide_label,
                 for_training=True)
        mod.init_params(initializer=mx.init.Xavier(),
                        arg_params=arg_params,
                        aux_params=aux_params,
                        allow_missing=False)
        mod.init_optimizer(optimizer=optimizer)

        # training loops
        if checkpoint.rule == "greater":
            best = {'epoch': 0, 'name': checkpoint.criteria_name, 'value': 0}
        else:
            best = {
                'epoch': 0,
                'name': checkpoint.criteria_name,
                'value': 1e10
            }
        for nepoch in range(epoch):
            metric.reset()
            train_iter = iter(train_iter)
            for nbatch, batch in enumerate(train_iter):
                mod.forward(batch, is_train=True)
                mod.update_metric(metric, batch.label)
                mod.backward()
                mod.update()
                batch_end_callback(
                    BatchEndParam(nepoch, nbatch, metric, locals()))
            # training result present
            result = metric.get_name_value()
            logging_line = 'Epoch[%d]\t' % nepoch
            logging_line += '\t'.join(
                ['Train-%s=%f' % (n, v) for n, v in result])
            logging.info(logging_line)
            # sync aux params across devices
            arg_params, aux_params = mod.get_params()
            mod.set_params(arg_params, aux_params)
            # evaluation
            res, _ = self._eval_process(mod, valid_iter, metric,
                                        batch_end_callback)
            # epoch end callback
            best = checkpoint(best, res, nepoch,
                              (mod.symbol, arg_params, aux_params))
            adapted_lr(optimizer,
                       nepoch - best['epoch'])  # adapted learning rate
            # early stop
            logging.info('Validation Best performance: Epoch[%d] %s=%f' %
                         (best['epoch'], best['name'], best['value']))
            if nepoch - best['epoch'] > early_stop:
                break
            # reset train iter
            train_iter.reset()
Beispiel #7
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            best_model_callbacks=None,
            eval_interval=None,
            validation_metric=None,
            monitor=None):

        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)

        if monitor is not None:
            self.install_monitor(monitor)

        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)

        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        if validation_metric is None:
            validation_metric = copy.deepcopy(eval_metric)

        epoch_metric = copy.deepcopy(eval_metric)

        swa_arg_params = None
        swa_aux_params = None
        swa_cnt = 0

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic_epoch = time.time()
            eval_metric.reset()

            nbatch = 0
            end_of_batch = False
            data_iter = iter(train_data)
            next_data_batch = next(data_iter)
            name_values = []

            while not end_of_batch:
                data_batch = next_data_batch

                if monitor is not None:
                    monitor.tic()

                self.forward_backward(data_batch)
                self.update()

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)
                if end_of_batch:
                    name_values = eval_metric.get_name_value()

                if monitor is not None:
                    monitor.toc_print()

                nbatch += 1

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)

                    eval_metric.reset()

                # ----------------------------------------
                # evaluation on validation set
                to_go = eval_interval is not None and nbatch % eval_interval == 0
                if to_go and eval_data:
                    res = self.score(
                        eval_data,
                        validation_metric,
                        score_end_callback=eval_end_callback,
                        batch_end_callback=eval_batch_end_callback,
                        epoch=epoch)
                    for name, val in res:
                        self.logger.info(
                            'Epoch[%d] Batch[%d] Validation-%s=%f', epoch,
                            nbatch, name, val)

                    if best_model_callbacks is not None:
                        for callback in _as_list(best_model_callbacks):
                            if callback.is_best(validation_metric):
                                # sync aux params across devices
                                arg_params, aux_params = self.get_params()
                                sync_made = True
                                callback.checkpoint_if_only_best(
                                    validation_metric, self.symbol, arg_params,
                                    aux_params)
                                break

            # one epoch of training is finished
            for name, val in name_values:
                self.logger.info('Epoch[%d] Train-%s=%f', epoch + 1, name, val)
            toc_epoch = time.time()
            elapsed = (toc_epoch - tic_epoch)
            avg_speed = float(len(train_data)) / (toc_epoch - tic_epoch)
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch + 1, elapsed)
            self.logger.info('Epoch[%d] Average speed=%.3f samples/sec',
                             epoch + 1, avg_speed)

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch + 1)
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch + 1,
                                     name, val)

                if best_model_callbacks is not None:
                    for callback in _as_list(best_model_callbacks):
                        callback.checkpoint_if_only_best(
                            validation_metric, self.symbol, arg_params,
                            aux_params)

            # end of epoch, reset the data-iter for another epoch
            train_data.reset()
    def fit(self, train_data,
            plateau_lr, plateau_metric=None, fn_curr_model=None, plateau_backtrace=True,
            eval_data=None, eval_metric='acc',
            epoch_end_callback=None, batch_end_callback=None, kvstore='local',
            optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
            eval_end_callback=None,
            eval_batch_end_callback=None, initializer=mx.init.Uniform(0.01),
            arg_params=None, aux_params=None, allow_missing=False,
            force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
            validation_metric=None, validation_period=1, monitor=None):
        '''
        overrides fit() in base_module.
        '''
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                  for_training=True, force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                         allow_missing=allow_missing, force_init=force_init)
        self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, mx.metric.EvalMetric):
            eval_metric = mx.metric.create(eval_metric)

        # we will use plateau lr scheduler.
        self.logger.info('Initial base lr = %.5f', self._optimizer.lr)
        plateau_lr.reset(self._optimizer.lr)
        self._optimizer.lr_scheduler = None
        if plateau_metric is None:
            plateau_metric = eval_metric
        elif not isinstance(plateau_metric, mx.metric.EvalMetric):
            plateau_metric = mx.metric.create(plateau_metric)

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            if plateau_metric is not eval_metric:
                plateau_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward_backward(data_batch)
                self.update()
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)
                if plateau_metric is not eval_metric:
                    self.update_metric(plateau_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data and epoch % validation_period == 0:
                res = self.score(eval_data, validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback, epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)

            # update lr using plateau algorithm
            new_lr, is_good = plateau_lr.update_lr(plateau_metric)
            if is_good and fn_curr_model is not None:
                super(PlateauModule, self).save_params(fn_curr_model)
            if new_lr == 0.0:
                self.logger.info('Minimum LR reached. Terminate training.')
                train_data.reset()
                break
            if new_lr < self._optimizer.lr:
                self.logger.info('Update lr, from %.6f to %.6f', self._optimizer.lr, new_lr)
                self._optimizer.lr = new_lr
                if fn_curr_model is not None and plateau_backtrace:
                    self.logger.info('Reset network parameters to the previous best result.')
                    super(PlateauModule, self).load_params(fn_curr_model)
            else:
                self.logger.info('Current lr = %.6f', self._optimizer.lr)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Beispiel #9
0
    def fit(
        self,
        train_data,
        eval_data=None,
        eval_metric="acc",
        epoch_end_callback=None,
        batch_end_callback=None,
        kvstore="local",
        optimizer="sgd",
        optimizer_params=(("learning_rate", 0.01),),
        eval_end_callback=None,
        eval_batch_end_callback=None,
        initializer=Uniform(0.01),
        arg_params=None,
        aux_params=None,
        allow_missing=False,
        force_rebind=False,
        force_init=False,
        begin_epoch=0,
        num_epoch=None,
        validation_metric=None,
        monitor=None,
        prefix=None,
    ):
        """Train the module parameters.

        Parameters
        ----------
        train_data : DataIter
        eval_data : DataIter
            If not `None`, will be used as validation set and evaluate the performance
            after each epoch.
        eval_metric : str or EvalMetric
            Default `'acc'`. The performance measure used to display during training.
        epoch_end_callback : function or list of function
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Default `'local'`.
        optimizer : str or Optimizer
            Default `'sgd'`
        optimizer_params : dict
            Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor.
            The default value is not a `dict`, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each minibatch during evaluation
        initializer : Initializer
            Will be called to initialize the module parameters if not already initialized.
        arg_params : dict
            Default `None`, if not `None`, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has higher priority to `initializer`.
        aux_params : dict
            Default `None`. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Default `False`. Indicate whether we allow missing parameters when `arg_params`
            and `aux_params` are not `None`. If this is `True`, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Default `False`. Whether to force rebinding the executors if already binded.
        force_init : bool
            Default `False`. Indicate whether we should force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
            checkpoint saved at a previous training phase at epoch N, then we should specify
            this value as N+1.
        num_epoch : int
            Number of epochs to run training.

        Examples
        --------
        An example of using fit for training::
            >>> #Assume training dataIter and validation dataIter are ready
            >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter,
                        optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
                        num_epoch=10)
        """
        assert num_epoch is not None, "please specify number of epochs"

        self.bind(
            data_shapes=train_data.provide_data,
            label_shapes=train_data.provide_label,
            for_training=True,
            force_rebind=force_rebind,
        )
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(
            initializer=initializer,
            arg_params=arg_params,
            aux_params=aux_params,
            allow_missing=allow_missing,
            force_init=force_init,
        )
        self.init_optimizer(
            kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params
        )

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        # epoch 0
        if epoch_end_callback is not None:
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)
            for callback in _as_list(epoch_end_callback):
                callback(-1, self.symbol, arg_params, aux_params)

        from lib.pair_matching.batch_updater_py_multi import batchUpdaterPyMulti

        config = self.config
        if config.TRAIN.TENSORBOARD_LOG:
            from mxboard import SummaryWriter

            tf_log_dir = os.path.join(
                os.path.dirname(prefix),
                "logs/{}".format(time.strftime("%Y-%m-%d-%H-%M")),
            )
            summ_writer = SummaryWriter(logdir=tf_log_dir)

        interBatchUpdater = batchUpdaterPyMulti(config, 480, 640)
        last_lr = 0
        cur_step = 0
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            for nbatch, data_batch in enumerate(train_data):
                if monitor is not None:
                    monitor.tic()
                # disp weights L2 norm
                cur_lr = self._curr_module._optimizer._get_lr(0)
                if nbatch % (4000 / train_data.batch_size) == 0:
                    all_params = self._curr_module.get_params()[0]
                    all_param_names = all_params.keys()
                    all_param_names = sorted(all_param_names)
                    print_and_log(prefix, self.logger)
                    weight_str = ""
                    for view_name in all_param_names:
                        weight_str += "{}: {} ".format(
                            view_name, nd.norm(all_params[view_name]).asnumpy()
                        )
                    print_and_log(weight_str, self.logger)
                    print_and_log(
                        "batch {}: lr: {}".format(nbatch, cur_lr), self.logger
                    )
                    if config.TRAIN.TENSORBOARD_LOG:
                        summ_writer.add_scalar(
                            tag="learning_rate", value=cur_lr, global_step=cur_step
                        )
                if cur_lr != last_lr:
                    print_and_log(
                        "batch {}: lr: {}".format(nbatch, cur_lr), self.logger
                    )
                    last_lr = cur_lr
                    if config.TRAIN.TENSORBOARD_LOG:
                        summ_writer.add_scalar(
                            tag="learning_rate", value=cur_lr, global_step=cur_step
                        )

                train_iter_size = config.network.TRAIN_ITER_SIZE
                for iter_idx in range(train_iter_size):
                    self.forward_backward(data_batch)
                    preds = self._curr_module.get_outputs(False)
                    self.update()
                    if iter_idx != train_iter_size - 1:
                        data_batch = interBatchUpdater.forward(
                            data_batch, preds, config
                        )
                cur_step += 1
                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(
                        epoch=epoch,
                        nbatch=nbatch,
                        eval_metric=eval_metric,
                        locals=locals(),
                    )
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                if config.TRAIN.TENSORBOARD_LOG:
                    for name, val in eval_metric.get_name_value():
                        summ_writer.add_scalar(
                            tag="BatchTrain-{}".format(name),
                            value=val,
                            global_step=cur_step,
                        )

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info("Epoch[%d] Train-%s=%f", epoch, name, val)
                if config.TRAIN.TENSORBOARD_LOG:
                    summ_writer.add_scalar(
                        tag="EpochTrain-{}".format(name), value=val, global_step=epoch
                    )

            toc = time.time()
            self.logger.info("Epoch[%d] Time cost=%.3f", epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            # ----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(
                    eval_data,
                    validation_metric,
                    score_end_callback=eval_end_callback,
                    batch_end_callback=eval_batch_end_callback,
                    epoch=epoch,
                )
                # TODO: pull this into default
                for name, val in res:
                    self.logger.info("Epoch[%d] Validation-%s=%f", epoch, name, val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Beispiel #10
0
    def fit(self, train_data, eval_data=None, eval_metric='acc',
            epoch_end_callback=None, batch_end_callback=None, kvstore='local',
            optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
            eval_end_callback=None,
            eval_batch_end_callback=None, initializer=Uniform(0.01),
            arg_params=None, aux_params=None, allow_missing=False,
            force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
            validation_metric=None, monitor=None, summary_writer = None):

        assert num_epoch is not None, 'please specify number of epochs'
        self.num_batch = 0
        self.writer = summary_writer

        self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                  for_training=True, force_rebind=force_rebind)

        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                         allow_missing=allow_missing, force_init=force_init)
        self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
                            optimizer_params=optimizer_params)

        acc_metric = IgnoreAccuracy(output_names=['softmax_output'], label_names=['softmax_label'])
        # acc_metric = metric.Accuracy(output_names=['softmax_output'], label_names=['softmax_label'])
        lmnn_metric = metric.Loss(output_names=['lmnn_output'], label_names=['softmax_label'])

        if validation_metric is None:
            validation_metric = lmnn_metric

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            acc_metric.reset()
            lmnn_metric.reset()
            # eval_metric.reset()
            for nbatch, data_batch in enumerate(train_data):
                if monitor is not None:
                    monitor.tic()

                self.forward_backward(data_batch)

                self.update()

                self.update_metric(acc_metric, data_batch.label)
                self.update_metric(lmnn_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)

                # one epoch of training is finished
                for name, val in acc_metric.get_name_value():
                    self.logger.info('Epoch[%d] Accuracy Train-%s=%f', epoch, name, val)
                for name, val in lmnn_metric.get_name_value():
                    self.logger.info('Epoch[%d] Lmnn Train-%s=%f', epoch, name, val)

                if self.num_batch % 10 == 0:
                    # print acc_metric.sum_metric, acc_metric.num_inst
                    self.writer.add_scalar('{}/cls_acc'.format('Train'), acc_metric.sum_metric / acc_metric.num_inst,
                                                       self.num_batch)
                    self.writer.add_scalar('{}/lmnn_loss'.format('Train'), lmnn_metric.sum_metric / lmnn_metric.num_inst,
                                                       self.num_batch)

                self.num_batch += 1

            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data, validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback, epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Beispiel #11
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            validate_metric=None,
            work_load_list=None,
            epoch_end_callback=None,
            batch_end_callback=None,
            fixed_param_prefix=None,
            initializer=None,
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            optimizer=None,
            optimizer_params=None,
            begin_epoch=0,
            num_epoch=None,
            kvstore='device',
            teacher_modules=None):
        if type(teacher_modules) is not list:
            teacher_modules = [teacher_modules]
        self.module.bind(data_shapes=self.data_shapes,
                         label_shapes=self.label_shapes,
                         for_training=True)
        self.module.init_params(initializer=initializer,
                                arg_params=arg_params,
                                aux_params=aux_params,
                                allow_missing=allow_missing)
        self.module.init_optimizer(kvstore=kvstore,
                                   optimizer=optimizer,
                                   optimizer_params=optimizer_params)

        if validate_metric is None:
            validate_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        # training loop
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch

                if teacher_modules[0] is not None:
                    for teacher_module in teacher_modules:
                        teacher_module.forward(data_batch=data_batch,
                                               is_train=True)
                        transfer_label = teacher_module.get_outputs()
                        data_batch.label = data_batch.label + transfer_label
                self.module.forward(data_batch, is_train=True)
                self.module.backward()
                self.module.update()

                try:
                    next_data_batch = next(data_iter)
                except StopIteration:
                    end_of_batch = True

                self.module.update_metric(eval_metric, data_batch.label)

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            arg_params, aux_params = self.module.get_params()
            self.module.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)
            if eval_data:
                res = self.module.score(eval_data,
                                        validate_metric,
                                        score_end_callback=None,
                                        batch_end_callback=None,
                                        reset=True,
                                        epoch=epoch)
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            train_data.reset()
Beispiel #12
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):

        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################

        last_grad_debug = None

        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward_backward(data_batch)

                # grad_array = [[grad.copyto(grad.context) if grad is not None else None for grad in grads] for grads in
                #             self._curr_module._exec_group.grad_arrays]
                #
                # for exec_ in self._curr_module._exec_group.execs:
                #     grad_dict = exec_.grad_dict
                #
                # grad_debug = dict()
                # for k, v in grad_dict.items():
                #     if v is not None:
                #         v_np = v.asnumpy()
                #         grad_debug[k] = (np.min(v_np), np.max(v_np))
                # print 'rpn_conv_cls_weight:', grad_debug['rpn_conv_cls_weight']
                # print 'rcnn_fc_cls_weight:', grad_debug['rcnn_fc_cls_weight']

                self.update()
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Beispiel #13
0
    # train
    logging.info("Training started ... ")
    for epoch in range(args.epochs):
        # train
        total_loss = 0.0
        nbatch = 0
        for batch in train_data:
            module.forward(batch)
            module.backward()
            module.update(max_norm=args.clip * bptt * batch_size)
            # update metric
            outputs = module.get_loss()
            total_loss += mx.nd.sum(outputs[0]).asscalar()
            speedometer_param = BatchEndParam(epoch=epoch,
                                              nbatch=nbatch,
                                              eval_metric=None,
                                              locals=locals())
            speedometer(speedometer_param)
            if nbatch % args.log_interval == 0 and nbatch > 0:
                cur_loss = total_loss / bptt / batch_size / args.log_interval
                logging.info('Iter[%d] Batch [%d]\tLoss:  %.7f,\tPerplexity:\t%.7f' % \
                             (epoch, nbatch, cur_loss, math.exp(cur_loss)))
                total_loss = 0.0
            nbatch += 1
        # validation
        valid_loss = evaluate(module, valid_data, epoch, 'Valid', bptt,
                              batch_size)
        if valid_loss < best_loss:
            best_loss = valid_loss
            # test
            test_loss = evaluate(module, test_data, epoch, 'Test', bptt,
Beispiel #14
0
    def fit(
            self,
            train_data,
            ogdb,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(
                ('learning_rate',
                 0.01), ),  #,('rescale_grad', 1.0/8.0),), #8 gpu attempt
            eval_end_callback=None,
            iter_size=1,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):
        """Ke's revision: add iter_size. Trains the module parameters.

        Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
        a end-to-end use-case.

        Parameters
        ----------
        train_data : DataIter
            Train DataIter.
        eval_data : DataIter
            If not ``None``, will be used as validation set and the performance
            after each epoch will be evaluated.
        eval_metric : str or EvalMetric
            Defaults to 'accuracy'. The performance measure used to display during training.
            Other possible predefined metrics are:
            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Defaults to 'local'.
        optimizer : str or Optimizer
            Defaults to 'sgd'.
        optimizer_params : dict
            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
            the optimizer constructor.
            The default value is not a dict, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each mini-batch during evaluation.
        initializer : Initializer
            The initializer is called to initialize the module parameters when they are
            not already initialized.
        arg_params : dict
            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has a higher priority than `initializer`.
        aux_params : dict
            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Defaults to ``False``. Whether to force rebinding the executors if already bound.
        force_init : bool
            Defaults to ``False``. Indicates whether to force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
            checkpoint saved at a previous training phase at epoch N, then this value should be
            N+1.
        num_epoch : int
            Number of epochs for training.

        Examples
        --------
        >>> # An example of using fit for training.
        >>> # Assume training dataIter and validation dataIter are ready
        >>> # Assume loading a previously checkpointed model
        >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
        >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
        ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
        ...     arg_params=arg_params, aux_params=aux_params,
        ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
        """
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind,
                  grad_req='add')
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        annealing_steps = 0  # number of current annealing steps in current epoch
        redo_training = 0  # Flag to redo training / resample
        val_list = []  # list of validation results per annealing step
        cur_val = 0
        target_prec = 50
        #Note: we want to identify the best cluster of images / training sets with a low percentage
        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            if redo_training:
                annealing_steps = annealing_steps + 1
                self.logger.info('Redoing training to meet criteria = %d',
                                 annealing_steps)
                #sroidb = train_data.roidb #passthrough test

                atick = time.time()

                iterdiff = 1.0
                # Check if we've stagnated
                if len(val_list) > 2:
                    itermean = (val_list[-1] + val_list[-2] + val_list[-3]) / 3
                    iterdiff = abs(itermean - val_list[-1])
                    self.logger.info('Last 3 samples have diff of: %f',
                                     iterdiff)

                if iterdiff < 0.01:
                    self.logger.info(
                        'Reached a stagnated annealing criteria, dumping current samples'
                    )
                    # Do something drastic
                    # Lets try to instantly use the original db
                    sroidb = ogdb

                    # Try to read in another random subset
                    #sroidb = sample_roidb(ogdb, 25) # Sample with removal
                else:
                    # Continue as usual
                    # Select a new random subset
                    newroidb = sample_roidb(ogdb,
                                            15)  # Without removal, this is 10%

                    # Append old with new
                    sroidb = append_roidb(train_data.roidb, newroidb)

                # Create new training data instance by passing most of previous arguments and new random db
                train_data2 = AnchorLoader(
                    train_data.feat_sym,
                    sroidb,
                    train_data.batch_size,
                    train_data.shuffle,
                    train_data.ctx,
                    train_data.work_load_list,
                    train_data.feat_stride,
                    train_data.anchor_scales,
                    train_data.anchor_ratios,
                    train_data.aspect_grouping,
                    nThreads=default.prefetch_thread_num)

                # Overwrite old train_data with the new one
                train_data = train_data2
                data_iter = iter(train_data)

                atock = time.time()
                self.logger.info('Annealing[%d] Time cost=%.3f',
                                 annealing_steps, (atock - atick))
            else:
                data_iter = iter(train_data)
                annealing_steps = 0
                val_list = []
                #target_prec=cur_val+5
                target_prec = target_prec + 5
            end_of_batch = False
            next_data_batch = next(data_iter)

            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                # self.forward_backward(data_batch)
                self.forward(data_batch, is_train=True, grad_req='add')
                self.backward()
                if nbatch % iter_size == 0:  # update every iter_size batches
                    self.update()
                    for g in self._curr_module._exec_group.grad_arrays:
                        for g1 in g:
                            if g1 is not None:
                                g1[:] = 0.

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
            #print('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    cur_val = callback(epoch, self.symbol, arg_params,
                                       aux_params)

            self.logger.info('Returned Validation=%f', val)
            val_list.append(val)
            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                self.logger.info('Evaluating data')
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            #----------
            # Check epoch if it falls within the validation threshold
            if cur_val < target_prec:
                # Evaluate list of precision/validation results first
                #val_list
                print(eval_data)

                #else
                redo_training = 1
                self.logger.info('Retraining data=%f', val)
            else:
                redo_training = 0

            self.logger.info('Annealing steps=%f', annealing_steps)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
def fit(self,
        train_data,
        eval_data=None,
        eval_metric='acc',
        epoch_end_callback=None,
        batch_end_callback=None,
        kvstore='local',
        optimizer='sgd',
        optimizer_params=(('learning_rate', 0.01), ),
        eval_end_callback=None,
        eval_batch_end_callback=None,
        initializer=Uniform(0.01),
        arg_params=None,
        aux_params=None,
        allow_missing=False,
        force_rebind=False,
        force_init=False,
        begin_epoch=0,
        num_epoch=None,
        validation_metric=None,
        monitor=None,
        sparse_row_id_fn=None,
        accuracy_target=1.0,
        eval_frequency=1,
        eval_offset=0,
        logger=None):
    assert num_epoch is not None, 'please specify number of epochs'

    if 'horovod' in kvstore:
        rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        rank = 0
        local_rank = 0

    profiler_on = os.getenv('RESNET50_PROFILING', False) and (rank == 0)
    if profiler_on:
        self.logger.info("Profiling is enabled")

    stop_iter = int(os.getenv('RESNET50_STOP_ITERATION', '0'))
    if stop_iter > 0:
        self.logger.info(
            "Training will stop at iteration {} of the first epoch".format(
                stop_iter))

    self.bind(data_shapes=train_data.provide_data,
              label_shapes=train_data.provide_label,
              for_training=True,
              force_rebind=force_rebind)
    if monitor is not None:
        self.install_monitor(monitor)
    self.init_params(initializer=initializer,
                     arg_params=arg_params,
                     aux_params=aux_params,
                     allow_missing=allow_missing,
                     force_init=force_init)
    self.init_optimizer(kvstore=kvstore,
                        optimizer=optimizer,
                        optimizer_params=optimizer_params)

    if validation_metric is None:
        validation_metric = eval_metric
    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)

    block_epoch_start = begin_epoch
    block_epoch_count = eval_offset + 1 - (begin_epoch % eval_frequency)
    if block_epoch_count < 0:
        block_epoch_count += eval_frequency
    mll.block_start(block_epoch_start + 1, count=block_epoch_count)

    if profiler_on:
        mx.profiler.set_config(profile_symbolic=True,
                               profile_imperative=True,
                               profile_memory=False,
                               profile_api=True,
                               filename='resnet50_profile.json',
                               aggregate_stats=True)
        mx.profiler.set_state('run')

    ################################################################################
    # training loop
    ################################################################################
    for epoch in range(begin_epoch, num_epoch):
        mll.epoch_start(epoch + 1)
        tic = time.time()
        eval_metric.reset()
        nbatch = 0
        early_stop = False
        data_iter = iter(train_data)
        end_of_batch = False
        next_data_batch = next(data_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if monitor is not None:
                monitor.tic()
            self.forward_backward(data_batch)
            self.update()

            if isinstance(data_batch, list):
                self.update_metric(eval_metric,
                                   [db.label for db in data_batch],
                                   pre_sliced=True)
            else:
                self.update_metric(eval_metric, data_batch.label)

            try:
                # pre fetch next batch
                next_data_batch = next(data_iter)
                self.prepare(next_data_batch,
                             sparse_row_id_fn=sparse_row_id_fn)
            except StopIteration:
                end_of_batch = True

            if monitor is not None:
                monitor.toc_print()

            if end_of_batch:
                eval_name_vals = eval_metric.get_global_name_value()

            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1
            if stop_iter > 0 and nbatch >= stop_iter:
                early_stop = True
                self.logger.info(
                    "Training stopped at {} iteration. Clear RESNET50_STOP_ITERATION if it's not itended."
                    .format(stop_iter))
                break

        if early_stop:
            break

        mll.epoch_stop(epoch + 1)
        # one epoch of training is finished
        if rank == 0:
            for name, val in eval_name_vals:
                self.logger.info('Rank[%d] Epoch[%d] Train-%s=%f', rank, epoch,
                                 name, val)
            toc = time.time()
            self.logger.info('Rank[%d] Epoch[%d] Time cost=%.3f', rank, epoch,
                             (toc - tic))

        # sync aux params across devices
        arg_params, aux_params = self.get_params()
        self.set_params(arg_params, aux_params)

        #----------------------------------------
        # evaluation on validation set
        if eval_data is not None and ((epoch % eval_frequency == eval_offset)
                                      or (epoch + 1 == num_epoch)):
            mll.eval_start(epoch + 1, sync=True)
            res = self.score(eval_data,
                             [validation_metric,
                              CorrectCount(),
                              TotalCount()],
                             score_end_callback=eval_end_callback,
                             batch_end_callback=eval_batch_end_callback,
                             epoch=epoch)
            #TODO: pull this into default
            if rank == 0:
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # temporarily add these two metrics for debugging, can be removed before submission
            res = dict(res)
            correct_count = res['correct-count']
            total_count = res['total-count']
            if 'horovod' in kvstore:
                correct_count = allreduce(correct_count)
                total_count = allreduce(total_count)

            acc = correct_count / total_count
            mll.eval_stop(epoch + 1)
            mll.eval_accuracy(epoch + 1, acc)

            mll.block_stop(block_epoch_start + 1)
            if acc > accuracy_target:
                mll.run_stop(status='success')
                return

            if epoch < num_epoch - 1:
                block_epoch_start = epoch + 1
                block_epoch_count = num_epoch - epoch - 1
                if block_epoch_count > eval_frequency:
                    block_epoch_count = eval_frequency
                mll.block_start(block_epoch_start + 1, count=block_epoch_count)

        # end of 1 epoch, reset the data-iter for another epoch
        train_data.reset()

    if profiler_on:
        mx.profiler.set_state('stop')
        print(mx.profiler.dumps())

    mll.run_stop(status='aborted')
Beispiel #16
0
    def __Train(self):
        begin_epoch = 0
        num_epoch = 1000

        eval_metric = mx.metric.np(Accuracy, allow_extra_outputs=True)
        validation_metric = None

        batch_end_callback = mx.callback.Speedometer(self.Model_Batch_Size,
                                                     frequent=10),
        epoch_end_callback = mx.callback.do_checkpoint(
            "Epoch_Saved/ocr_checkpoint", period=1)

        eval_end_callback = None
        eval_batch_end_callback = None

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, mx.metric.EvalMetric):
            eval_metric = mx.metric.create(eval_metric)

        ################################################################################
        # Training Loop                                                                #
        ################################################################################

        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()

            eval_metric.reset()
            nbatch = 0
            data_iter = iter(self.Data_Train_Iter)
            end_of_batch = False
            next_data_batch = next(data_iter)

            while not end_of_batch:
                data_batch = next_data_batch

                self.Model.forward_backward(data_batch)
                self.Model.update()

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.Model.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.Model.update_metric(eval_metric, data_batch.label)

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in mx.module.base_module._as_list(
                            batch_end_callback):
                        callback(batch_end_params)

                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.Model.logger.info('Epoch[%d] Train-%s=%f', epoch, name,
                                       val)

            toc = time.time()
            self.Model.logger.info('Epoch[%d] Time cost=%.3f', epoch,
                                   (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.Model.get_params()
            self.Model.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in mx.module.base_module._as_list(
                        epoch_end_callback):
                    callback(epoch, self.Model.symbol, arg_params, aux_params)

            # ----------------------------------------
            # evaluation on validation set
            if self.Data_Val_Iter:
                res = self.Model.score(
                    self.Data_Val_Iter,
                    validation_metric,
                    score_end_callback=eval_end_callback,
                    batch_end_callback=eval_batch_end_callback,
                    epoch=epoch)
                # TODO: pull this into default
                for name, val in res:
                    self.Model.logger.info('Epoch[%d] Validation-%s=%f', epoch,
                                           name, val)

            # end of 1 epoch, reset the data-iter for another epoch
            self.Data_Train_Iter.reset()
Beispiel #17
0
                        logging.info(msg, param.epoch, count, speed,
                                     *sum(name_value, ()))
                else:
                    logging.info(
                        "Iter[%d] Batch [%d]\tSpeed: %.2f samples/sec",
                        param.epoch, count, speed)
                self.tic = time.time()
        else:
            # print(f'[3]  init: {self.init}  {self.last_count}  {count}  {self.frequent}')
            self.init = True
            self.tic = time.time()


if __name__ == '__main__':

    import mxnet as mx
    from mxnet.model import BatchEndParam

    acc = mx.metric.Accuracy()
    metric = mx.metric.MSE()
    params = BatchEndParam(epoch=0, nbatch=1, eval_metric=acc, locals=locals())

    batch_size = 2
    acc.update(labels=mx.nd.uniform(shape=(batch_size, 512)),
               preds=mx.nd.uniform(shape=(batch_size, 512)))

    cb = Speedometer(batch_size=batch_size, frequent=1)
    cb(params)
    cb(params)
    cb(params)
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None,
            sparse_row_id_fn=None,
            profile=False):
        """Trains the module parameters.
        Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
        a end-to-end use-case.
        Parameters
        ----------
        train_data : DataIter
            Train DataIter.
        eval_data : DataIter
            If not ``None``, will be used as validation set and the performance
            after each epoch will be evaluated.
        eval_metric : str or EvalMetric
            Defaults to 'accuracy'. The performance measure used to display during training.
            Other possible predefined metrics are:
            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Defaults to 'local'.
        optimizer : str or Optimizer
            Defaults to 'sgd'.
        optimizer_params : dict
            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
            the optimizer constructor.
            The default value is not a dict, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each mini-batch during evaluation.
        initializer : Initializer
            The initializer is called to initialize the module parameters when they are
            not already initialized.
        arg_params : dict
            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has a higher priority than `initializer`.
        aux_params : dict
            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Defaults to ``False``. Whether to force rebinding the executors if already bound.
        force_init : bool
            Defaults to ``False``. Indicates whether to force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
            checkpoint saved at a previous training phase at epoch N, then this value should be
            N+1.
        num_epoch : int
            Number of epochs for training.
        sparse_row_id_fn : A callback function
            The function  takes `data_batch` as an input and returns a dict of
            str -> NDArray. The resulting dict is used for pulling row_sparse
            parameters from the kvstore, where the str key is the name of the param,
            and the value is the row id of the param to pull.
        """
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)

        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward_backward(data_batch)
                self.update()

                if isinstance(data_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in data_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, data_batch.label)

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch,
                                 sparse_row_id_fn=sparse_row_id_fn)
                except StopIteration:
                    end_of_batch = True

                if monitor is not None:
                    monitor.toc_print()

                if end_of_batch:
                    eval_name_vals = eval_metric.get_name_value()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

                if profile is True and nbatch == 10:
                    self.logger.info("Profiling ends")
                    import mxnet as mx
                    mx.profiler.dump()

            # one epoch of training is finished
            for name, val in eval_name_vals:
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None and self._kvstore.rank == 0:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Beispiel #19
0
def mlperf_fit(self,
               train_data,
               eval_data=None,
               eval_metric='acc',
               epoch_end_callback=None,
               batch_end_callback=None,
               kvstore='local',
               optimizer='sgd',
               optimizer_params=(('learning_rate', 0.01), ),
               eval_end_callback=None,
               eval_batch_end_callback=None,
               initializer=Uniform(0.01),
               arg_params=None,
               aux_params=None,
               allow_missing=False,
               force_rebind=False,
               force_init=False,
               begin_epoch=0,
               num_epoch=None,
               validation_metric=None,
               monitor=None,
               sparse_row_id_fn=None,
               eval_offset=0,
               eval_period=1,
               accuracy_threshold=1.0):

    assert num_epoch is not None, 'please specify number of epochs'

    self.bind(data_shapes=train_data.provide_data,
              label_shapes=train_data.provide_label,
              for_training=True,
              force_rebind=force_rebind)

    if monitor is not None:
        self.install_monitor(monitor)

    self.init_params(initializer=initializer,
                     arg_params=arg_params,
                     aux_params=aux_params,
                     allow_missing=allow_missing,
                     force_init=force_init)
    self.init_optimizer(kvstore=kvstore,
                        optimizer=optimizer,
                        optimizer_params=optimizer_params)

    if validation_metric is None:
        validation_metric = eval_metric
    ###########################################################################
    # Adding Correct and Total Count metrics
    ###########################################################################
    if not isinstance(validation_metric, list):
        validation_metric = [validation_metric]

    validation_metric = mx.metric.create(validation_metric)

    if not isinstance(validation_metric, mx.metric.CompositeEvalMetric):
        vm = mx.metric.CompositeEvalMetric()
        vm.append(validation_metric)
        validation_metric = vm

    for m in [CorrectCount(), TotalCount()]:
        validation_metric.metrics.append(m)
    ###########################################################################

    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)

    mx_resnet_print(key=mlperf_log.TRAIN_LOOP)
    ################################################################################
    # training loop
    ################################################################################
    for epoch in range(begin_epoch, num_epoch):
        mx_resnet_print(key=mlperf_log.TRAIN_EPOCH, val=epoch)
        tic = time.time()
        eval_metric.reset()
        nbatch = 0
        data_iter = iter(train_data)
        end_of_batch = False
        next_data_batch = next(data_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if monitor is not None:
                monitor.tic()
            self.forward_backward(data_batch)
            self.update()

            if isinstance(data_batch, list):
                self.update_metric(eval_metric,
                                   [db.label for db in data_batch],
                                   pre_sliced=True)
            else:
                self.update_metric(eval_metric, data_batch.label)

            try:
                # pre fetch next batch
                next_data_batch = next(data_iter)
                self.prepare(next_data_batch,
                             sparse_row_id_fn=sparse_row_id_fn)
            except StopIteration:
                end_of_batch = True

            if monitor is not None:
                monitor.toc_print()

            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1

        # one epoch of training is finished
        toc = time.time()
        if kvstore:
            if kvstore.rank == 0:
                self.logger.info('Epoch[%d] Time cost=%.3f', epoch,
                                 (toc - tic))
        else:
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        # sync aux params across devices
        arg_params, aux_params = self.get_params()
        self.set_params(arg_params, aux_params)

        if epoch_end_callback is not None:
            for callback in _as_list(epoch_end_callback):
                callback(epoch, self.symbol, arg_params, aux_params)

        #----------------------------------------
        # evaluation on validation set
        if eval_data and epoch % eval_period == eval_offset:
            mx_resnet_print(key=mlperf_log.EVAL_START)
            res = self.score(eval_data,
                             validation_metric,
                             score_end_callback=eval_end_callback,
                             batch_end_callback=eval_batch_end_callback,
                             epoch=epoch)
            #TODO: pull this into default
            if kvstore:
                if kvstore.rank == 0:
                    for name, val in res:
                        self.logger.info('Epoch[%d] Validation-%s=%f', epoch,
                                         name, val)
            else:
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)
            res = dict(res)

            acc = [res['correct-count'], res['total-count']]
            acc = all_reduce(acc)
            acc = acc[0] / acc[1]
            mx_resnet_print(key=mlperf_log.EVAL_ACCURACY,
                            val={
                                "epoch": epoch,
                                "value": acc
                            })
            mx_resnet_print(key=mlperf_log.EVAL_STOP)
            if acc > accuracy_threshold:
                return epoch

        # end of 1 epoch, reset the data-iter for another epoch
        train_data.reset()

    return num_epoch
Beispiel #20
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)
        ####chris_arg
        if int(os.getenv("TASK_LIMIT",
                         0)) != 0:  #为0时不分task限制,为1时分task但是每轮更新,为2时分task并但固定
            get_task_cmd = "sh /home/ubuntu/tc.sh -l 1"
        else:
            self.logger.info("no_task_bandwidth_limit")
            get_task_cmd = "sh /home/ubuntu/tc.sh -l 0"
        os.system(get_task_cmd)
        delay_time = float(os.getenv("DELAY_TIME", 0.8))
        ps_upload_bandwidth_part1 = int(os.getenv("PS_UPLOAD_BANDWIDTH1",
                                                  2000))
        worker_upload_bandwidth_part1 = int(
            os.getenv("WORKER_UPLOAD_BANDWIDTH1", 2000))
        ps_upload_bandwidth_part2 = int(os.getenv("PS_UPLOAD_BANDWIDTH2",
                                                  2000))
        worker_upload_bandwidth_part2 = int(
            os.getenv("WORKER_UPLOAD_BANDWIDTH2", 2000))
        tc_command = "sudo tc class change dev {} parent 1: classid 1:3 htb rate {}mbit ceil {}mbit  && sudo tc class change dev {} parent 1: classid 1:4 htb rate {}mbit ceil {}mbit"
        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward(data_batch, is_train=True)
                if int(os.getenv("TASK_LIMIT", 0)) == 1:
                    ##first part bandwidth allocation
                    ndarray.waitall()
                    # self.logger.info("change bandwidth part1:, "+str(time.time()))
                    x = str(ps_upload_bandwidth_part1)
                    y = str(worker_upload_bandwidth_part1)
                    cmd_up = tc_command.format("ens3", x, x, "ens3", y, y)
                    cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x)
                    os.system(cmd_up)
                    # os.system(cmd_down)
                # self.logger.info("after forward, "+str(time.time()))
                self.backward()
                # self.logger.info("before update: "+str(time.time()))
                self.update()  #异步执行的
                if int(os.getenv("TASK_LIMIT", 0)) == 1:
                    x = str(ps_upload_bandwidth_part2)
                    y = str(worker_upload_bandwidth_part2)
                    cmd_up = tc_command.format("ens3", x, x, "ens3", y, y)
                    cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x)
                    time.sleep(delay_time)
                    ##second part bandwidth allocation
                    # self.logger.info("change bandwidth part2:, "+str(time.time()))
                    os.system(cmd_up)
                    # os.system(cmd_down)
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Beispiel #21
0
def mlperf_fit(self,
               args,
               data_loader,
               epoch_size,
               eval_metric='acc',
               epoch_end_callback=None,
               batch_end_callback=None,
               kvstore='local',
               optimizer='sgd',
               optimizer_params=(('learning_rate', 0.01), ),
               explorer='linear',
               explorer_params=None,
               eval_end_callback=None,
               eval_batch_end_callback=None,
               initializer=Uniform(0.01),
               arg_params=None,
               aux_params=None,
               allow_missing=False,
               force_rebind=False,
               force_init=False,
               begin_epoch=0,
               num_epoch=None,
               validation_metric=None,
               monitor=None,
               sparse_row_id_fn=None,
               eval_offset=0,
               eval_period=1,
               accuracy_threshold=1.0):

    assert num_epoch is not None, 'please specify number of epochs'

    if monitor is not None:
        self.install_monitor(monitor)

    self.init_optimizer(kvstore=kvstore,
                        optimizer=optimizer,
                        optimizer_params=optimizer_params)

    explorer = Explorer.create_explorer(name=explorer,
                                        optimizer=self._optimizer,
                                        explorer_params=explorer_params)
    #This mxnet can not use several optimizers without sgd series
    explorer.set_best_coeff(0)
    explorer.set_best_wd_coeff(0)
    explorer.set_best_cg(0)
    exp_freq = explorer_params['explore_freq']
    exp_start_epoch = explorer_params['explore_start_epoch']

    if validation_metric is None:
        validation_metric = eval_metric
    ###########################################################################
    # Adding Correct and Total Count metrics
    ###########################################################################
    if not isinstance(validation_metric, list):
        validation_metric = [validation_metric]

    validation_metric = mx.metric.create(validation_metric)

    if not isinstance(validation_metric, mx.metric.CompositeEvalMetric):
        vm = mx.metric.CompositeEvalMetric()
        vm.append(validation_metric)
        validation_metric = vm

    for m in [CorrectCount(), TotalCount()]:
        validation_metric.metrics.append(m)
    ###########################################################################

    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)

    try:
        world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    except:
        world_rank = 0
        world_size = 1

    use_cval_data =    explorer_params['add_one_fwd_epoch'] < num_epoch \
                    or explorer_params['no_augument_epoch'] < num_epoch

    best_rank = 0
    self.prepare_states()

    mx_resnet_print(key=mlperf_constants.INIT_STOP, sync=True)
    mx_resnet_print(key=mlperf_constants.RUN_START, sync=True)

    # data iterators
    (train_data, eval_data, cval_data) = data_loader(args, kvstore)
    if 'dist' in args.kv_store and not 'async' in args.kv_store:
        logging.info('Resizing training data to %d batches per machine',
                     epoch_size)
        # resize train iter to ensure each machine has same number of batches per epoch
        # if not, dist_sync can hang at the end with one machine waiting for other machines
        if not args.use_dali:
            train = mx.io.ResizeIter(train_data, epoch_size)

    block_epoch_start = begin_epoch
    block_epoch_count = eval_offset + 1 - (begin_epoch % eval_period)
    if block_epoch_count < 0:
        block_epoch_count += eval_period
    mx_resnet_print(key=mlperf_constants.BLOCK_START,
                    metadata={
                        'first_epoch_num': block_epoch_start + 1,
                        'epoch_count': block_epoch_count
                    })
    ################################################################################
    # training loop
    ################################################################################

    for epoch in range(begin_epoch, num_epoch):
        mx_resnet_print(key=mlperf_constants.EPOCH_START,
                        metadata={'epoch_num': epoch + 1})
        tic = time.time()
        eval_metric.reset()
        nbatch = 0

        use_normal_data_batch = epoch < explorer_params['no_augument_epoch']
        if not use_normal_data_batch:
            if world_rank == 0:
                self.logger.info('use non-augumented batch')

        end_of_batch = False

        if use_normal_data_batch:
            data_iter = iter(train_data)
            next_data_batch = next(data_iter)
        else:
            cval_iter = iter(cval_data)
            next_cval_batch = next(cval_iter)

        smooth_decay = explorer_params['smooth_decay']

        if not smooth_decay:
            explorer.apply_lr_decay_epoch(epoch)
            explorer.apply_wd_decay_epoch(epoch)
        explorer.set_mom(epoch)

        while not end_of_batch:
            if use_normal_data_batch:
                data_batch = next_data_batch
            else:
                cval_batch = next_cval_batch
            if monitor is not None:
                monitor.tic()

            if use_normal_data_batch:
                self.forward_backward(data_batch)
            else:
                self.forward_backward(cval_batch)

            if smooth_decay:
                explorer.apply_lr_decay_iter()
                explorer.apply_wd_decay_iter()
            explorer.apply_wd_warmup()
            explorer.apply_burn_in()

            use_explorer = (epoch == 0
                            and nbatch == 0) or (epoch >= exp_start_epoch
                                                 and nbatch % exp_freq == 0)
            if use_explorer:
                explorer.set_tmp_coeff(world_rank)
                explorer.set_tmp_wd_coeff(world_rank)
                explorer.set_tmp_cg(world_rank)

            explorer.set_best_coeff(0)
            explorer.set_best_wd_coeff(0)
            explorer.set_best_cg(world_rank)
            self.update()

            if use_normal_data_batch:
                if isinstance(data_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in data_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, data_batch.label)
            else:
                if isinstance(cval_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in cval_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, cval_batch.label)

            if use_normal_data_batch:
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                except StopIteration:
                    end_of_batch = True
            else:
                try:
                    # pre fetch next cval batch
                    next_cval_batch = next(cval_iter)
                except StopIteration:
                    end_of_batch = True

            if use_normal_data_batch:
                if not end_of_batch:
                    self.prepare(next_data_batch,
                                 sparse_row_id_fn=sparse_row_id_fn)
            else:
                if not end_of_batch:
                    self.prepare(next_cval_batch,
                                 sparse_row_id_fn=sparse_row_id_fn)

            if monitor is not None:
                monitor.toc_print()

            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1

        mx_resnet_print(key=mlperf_constants.EPOCH_STOP,
                        metadata={"epoch_num": epoch + 1})
        # one epoch of training is finished
        toc = time.time()
        if kvstore:
            if kvstore.rank == 0:
                self.logger.info('Epoch[%d] Time cost=%.3f', epoch,
                                 (toc - tic))
        else:
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        # sync aux params across devices
        #arg_params, aux_params = self.get_params()
        #self.set_params(arg_params, aux_params)

        if epoch_end_callback is not None:
            for callback in _as_list(epoch_end_callback):
                callback(epoch, self.symbol, arg_params, aux_params)

        #----------------------------------------
        # evaluation on validation set
        if eval_data and epoch >= eval_offset and (
                epoch - eval_offset) % eval_period == 0:
            mx_resnet_print(key=mlperf_constants.EVAL_START,
                            metadata={'epoch_num': epoch + 1})
            res = self.score(eval_data,
                             validation_metric,
                             score_end_callback=eval_end_callback,
                             batch_end_callback=eval_batch_end_callback,
                             epoch=epoch)
            #TODO: pull this into default
            if kvstore:
                if kvstore.rank == 0:
                    for name, val in res:
                        self.logger.info('Epoch[%d] Validation-%s=%f', epoch,
                                         name, val)
            else:
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)
            res = dict(res)

            acc = [res['correct-count'], res['total-count']]
            acc = all_reduce(acc)
            acc = acc[0] / acc[1]
            mx_resnet_print(key=mlperf_constants.EVAL_STOP,
                            metadata={'epoch_num': epoch + 1})

            mx_resnet_print(key=mlperf_constants.EVAL_ACCURACY,
                            val=acc,
                            metadata={'epoch_num': epoch + 1})

            mx_resnet_print(
                key=mlperf_constants.BLOCK_STOP,
                metadata={'first_epoch_num': block_epoch_start + 1})
            if acc > accuracy_threshold:
                mx_resnet_print(key=mlperf_constants.RUN_STOP,
                                metadata={'status': 'success'})

                return epoch

            if epoch < (num_epoch - 1):
                block_epoch_start = epoch + 1
                block_epoch_count = num_epoch - epoch - 1
                if block_epoch_count > eval_period:
                    block_epoch_count = eval_period
                mx_resnet_print(key=mlperf_constants.BLOCK_START,
                                metadata={
                                    'first_epoch_num': block_epoch_start + 1,
                                    'epoch_count': block_epoch_count
                                })

        # end of 1 epoch, reset the data-iter for another epoch
        if use_normal_data_batch:
            train_data.reset()
        else:
            cval_data.reset()

    mx_resnet_print(key=mlperf_constants.RUN_STOP,
                    metadata={'status': 'aborted'})
    return num_epoch
Beispiel #22
0
    def fit(self, train_data, eval_data=None, eval_metric='acc',
            epoch_end_callback=None, batch_end_callback=None, kvstore='local',
            optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
            eval_end_callback=None,
            eval_batch_end_callback=None, initializer=Uniform(0.01),
            arg_params=None, aux_params=None, allow_missing=False,
            force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
            validation_metric=None, monitor=None, prefix=None,
            batches_checkpoint=None, num_batches_save_ckpt=2000):
        """Train the module parameters.

        Parameters
        ----------
        train_data : DataIter
        eval_data : DataIter
            If not `None`, will be used as validation set and evaluate the performance
            after each epoch.
        eval_metric : str or EvalMetric
            Default `'acc'`. The performance measure used to display during training.
        epoch_end_callback : function or list of function
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Default `'local'`.
        optimizer : str or Optimizer
            Default `'sgd'`
        optimizer_params : dict
            Default `(('learning_rate', 0.01),)`. The parameters for the optimizer constructor.
            The default value is not a `dict`, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each minibatch during evaluation
        initializer : Initializer
            Will be called to initialize the module parameters if not already initialized.
        arg_params : dict
            Default `None`, if not `None`, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has higher priority to `initializer`.
        aux_params : dict
            Default `None`. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Default `False`. Indicate whether we allow missing parameters when `arg_params`
            and `aux_params` are not `None`. If this is `True`, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Default `False`. Whether to force rebinding the executors if already binded.
        force_init : bool
            Default `False`. Indicate whether we should force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Default `0`. Indicate the starting epoch. Usually, if we are resuming from a
            checkpoint saved at a previous training phase at epoch N, then we should specify
            this value as N+1.
        num_epoch : int
            Number of epochs to run training.

        Examples
        --------
        An example of using fit for training::
            >>> #Assume training dataIter and validation dataIter are ready
            >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter,
                        optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
                        num_epoch=10)
        """
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
                  for_training=True, force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
                         allow_missing=allow_missing, force_init=force_init)
        self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            for nbatch, data_batch in enumerate(train_data):
                if monitor is not None:
                    monitor.tic()
                self.forward_backward(data_batch)
                self.update()
                self.update_metric(eval_metric, data_batch.label)
                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch, nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)

                if batches_checkpoint is not None and nbatch != 0 and nbatch % num_batches_save_ckpt == 0:
                    for callback in _as_list(epoch_end_callback):
                        callback(epoch, self.symbol, arg_params, aux_params)

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)
            if prefix is not None:
                self._curr_module.save_checkpoint(prefix, epoch + 1, save_optimizer_states=True)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data, validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback, epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Beispiel #23
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            iter_size=1,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):
        """Ke's revision: add iter_size. Trains the module parameters.

        Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
        a end-to-end use-case.

        Parameters
        ----------
        train_data : DataIter
            Train DataIter.
        eval_data : DataIter
            If not ``None``, will be used as validation set and the performance
            after each epoch will be evaluated.
        eval_metric : str or EvalMetric
            Defaults to 'accuracy'. The performance measure used to display during training.
            Other possible predefined metrics are:
            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Defaults to 'local'.
        optimizer : str or Optimizer
            Defaults to 'sgd'.
        optimizer_params : dict
            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
            the optimizer constructor.
            The default value is not a dict, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each mini-batch during evaluation.
        initializer : Initializer
            The initializer is called to initialize the module parameters when they are
            not already initialized.
        arg_params : dict
            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has a higher priority than `initializer`.
        aux_params : dict
            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Defaults to ``False``. Whether to force rebinding the executors if already bound.
        force_init : bool
            Defaults to ``False``. Indicates whether to force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
            checkpoint saved at a previous training phase at epoch N, then this value should be
            N+1.
        num_epoch : int
            Number of epochs for training.

        Examples
        --------
        >>> # An example of using fit for training.
        >>> # Assume training dataIter and validation dataIter are ready
        >>> # Assume loading a previously checkpointed model
        >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
        >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
        ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
        ...     arg_params=arg_params, aux_params=aux_params,
        ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
        """
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind,
                  grad_req='add')
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)

            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                # self.forward_backward(data_batch)
                self.forward(data_batch, is_train=True, grad_req='add')
                self.backward()
                if nbatch % iter_size == 0:  # update every iter_size batches
                    self.update()
                    for g in self._curr_module._exec_group.grad_arrays:
                        for g1 in g:
                            if g1 is not None:
                                g1[:] = 0.

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()