Пример #1
0
 def savemodel(self, epoch=999, epoch_end_callback=None):
     # sync aux params across devices
     arg_params, aux_params = self.get_params()
     self.set_params(arg_params, aux_params)
     if epoch_end_callback is not None and self._kvstore.rank == 0:
         for callback in _as_list(epoch_end_callback):
             callback(epoch, self.symbol, arg_params, aux_params)
Пример #2
0
 def decode_seq(self, inputs, states, valid_length=None):
     length = inputs.shape[1]
     output = []
     additional_outputs = []
     inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True))
     rnn_states_l = []
     attention_output_l = []
     fixed_states = states[2:]
     for i in range(length):
         ele_output, states, ele_additional_outputs = self.forward(inputs[i], states)
         rnn_states_l.append(states[0])
         attention_output_l.append(states[1])
         output.append(ele_output)
         additional_outputs.extend(ele_additional_outputs)
     output = mx.nd.stack(*output, axis=1)
     if valid_length is not None:
         states = [_nested_sequence_last(rnn_states_l, valid_length),
                   _nested_sequence_last(attention_output_l, valid_length)] + fixed_states
         output = mx.nd.SequenceMask(output,
                                     sequence_length=valid_length,
                                     use_sequence_length=True,
                                     axis=1)
     if self._output_attention:
         additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
     return output, states, additional_outputs
Пример #3
0
 def step5(state, state2):
     states = _as_list(state)
     states.append(state2)
     if isinstance(state, list):
         return state, states
     else:
         return [state], states
Пример #4
0
 def decode_seq(self, inputs, states, valid_length=None):
     length = inputs.shape[1]
     output = []
     additional_outputs = []
     inputs = _as_list(
         mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True))
     rnn_states_l = []
     attention_output_l = []
     fixed_states = states[2:]
     for i in range(length):
         ele_output, states, ele_additional_outputs = self.forward(
             inputs[i], states)
         rnn_states_l.append(states[0])
         attention_output_l.append(states[1])
         output.append(ele_output)
         additional_outputs.extend(ele_additional_outputs)
     output = mx.nd.stack(*output, axis=1)
     if valid_length is not None:
         states = [
             _nested_sequence_last(rnn_states_l, valid_length),
             _nested_sequence_last(attention_output_l, valid_length)
         ] + fixed_states
         output = mx.nd.SequenceMask(output,
                                     sequence_length=valid_length,
                                     use_sequence_length=True,
                                     axis=1)
     if self._output_attention:
         additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
     return output, states, additional_outputs
Пример #5
0
def callback(metric_list, callback_list):
    for metric in metric_list:
        batch_end_params = BatchEndParam(epoch=epoch,
                                         nbatch=i,
                                         eval_metric=metric,
                                         locals=locals())
        for callback in _as_list(batch_end_callback):
            callback(batch_end_params)
Пример #6
0
def test_output_format_cond():
    class TestLayer1(gluon.HybridBlock):
        def __init__(self, func):
            super(TestLayer1, self).__init__()
            self.func = func

        def forward(self, data):
            def then_func(data):
                return self.func(data)

            def else_func(data):
                return self.func(data)

            return mx.npx.cond(lambda data: mx.npx.slice(data, begin=0, end=1),
                               then_func, else_func, data)

    def func1(data):
        return data

    def func2(data):
        return [data]

    def func3(data):
        return [data, data]

    funcs = [func1, func2, func3]
    data = mx.np.random.normal(loc=0, scale=1, size=(2))
    for func in funcs:
        layer1 = TestLayer1(func)
        layer1.initialize(ctx=default_context())
        layer2 = TestLayer1(func)
        layer2.initialize(ctx=default_context())
        layer2.hybridize()
        out1 = layer1(data)
        out2 = layer2(data)
        func_out = func(data)
        assert type(out1) == type(func_out)
        assert type(out2) == type(func_out)
        out1 = _as_list(out1)
        out2 = _as_list(out2)
        for i in range(len(out1)):
            assert_almost_equal(out1[i].asnumpy(),
                                out2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)
Пример #7
0
    def decode_seq(self, inputs, states, valid_length=None):
        """Decode the decoder inputs. This function is only used for training.

        Parameters
        ----------
        inputs : NDArray, Shape (batch_size, length, C_in)
        states : list of NDArrays or None
            Initial states. The list of initial decoder states
        valid_length : NDArray or None
            Valid lengths of each sequence. This is usually used when part of sequence has
            been padded. Shape (batch_size,)

        Returns
        -------
        output : NDArray, Shape (batch_size, length, C_out)
        states : list
            The decoder states, includes:

            - rnn_states : NDArray
            - attention_vec : NDArray
            - mem_value : NDArray
            - mem_masks : NDArray, optional
        additional_outputs : list
            Either be an empty list or contains the attention weights in this step.
            The attention weights will have shape (batch_size, length, mem_length) or
            (batch_size, num_heads, length, mem_length)
        """
        length = inputs.shape[1]
        output = []
        additional_outputs = []
        inputs = _as_list(
            mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True))
        rnn_states_l = []
        attention_output_l = []
        fixed_states = states[2:]
        for i in range(length):
            ele_output, states, ele_additional_outputs = self.forward(
                inputs[i], states)
            rnn_states_l.append(states[0])
            attention_output_l.append(states[1])
            output.append(ele_output)
            additional_outputs.extend(ele_additional_outputs)
        output = mx.nd.stack(*output, axis=1)
        if valid_length is not None:
            states = [
                _nested_sequence_last(rnn_states_l, valid_length),
                _nested_sequence_last(attention_output_l, valid_length)
            ] + fixed_states
            output = mx.nd.SequenceMask(output,
                                        sequence_length=valid_length,
                                        use_sequence_length=True,
                                        axis=1)
        if self._output_attention:
            additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
        return output, states, additional_outputs
Пример #8
0
 def step6(state, state2):
     states = _as_list(state)
     states.append(state2)
     return [], states
Пример #9
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None,
            sparse_row_id_fn=None,
            profile=False):
        """Trains the module parameters.
        Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
        a end-to-end use-case.
        Parameters
        ----------
        train_data : DataIter
            Train DataIter.
        eval_data : DataIter
            If not ``None``, will be used as validation set and the performance
            after each epoch will be evaluated.
        eval_metric : str or EvalMetric
            Defaults to 'accuracy'. The performance measure used to display during training.
            Other possible predefined metrics are:
            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Defaults to 'local'.
        optimizer : str or Optimizer
            Defaults to 'sgd'.
        optimizer_params : dict
            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
            the optimizer constructor.
            The default value is not a dict, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each mini-batch during evaluation.
        initializer : Initializer
            The initializer is called to initialize the module parameters when they are
            not already initialized.
        arg_params : dict
            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has a higher priority than `initializer`.
        aux_params : dict
            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Defaults to ``False``. Whether to force rebinding the executors if already bound.
        force_init : bool
            Defaults to ``False``. Indicates whether to force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
            checkpoint saved at a previous training phase at epoch N, then this value should be
            N+1.
        num_epoch : int
            Number of epochs for training.
        sparse_row_id_fn : A callback function
            The function  takes `data_batch` as an input and returns a dict of
            str -> NDArray. The resulting dict is used for pulling row_sparse
            parameters from the kvstore, where the str key is the name of the param,
            and the value is the row id of the param to pull.
        """
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)

        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward_backward(data_batch)
                self.update()

                if isinstance(data_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in data_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, data_batch.label)

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch,
                                 sparse_row_id_fn=sparse_row_id_fn)
                except StopIteration:
                    end_of_batch = True

                if monitor is not None:
                    monitor.toc_print()

                if end_of_batch:
                    eval_name_vals = eval_metric.get_name_value()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

                if profile is True and nbatch == 10:
                    self.logger.info("Profiling ends")
                    import mxnet as mx
                    mx.profiler.dump()

            # one epoch of training is finished
            for name, val in eval_name_vals:
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None and self._kvstore.rank == 0:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Пример #10
0
    def fit(
            self,
            train_data,
            ogdb,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(
                ('learning_rate',
                 0.01), ),  #,('rescale_grad', 1.0/8.0),), #8 gpu attempt
            eval_end_callback=None,
            iter_size=1,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):
        """Ke's revision: add iter_size. Trains the module parameters.

        Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
        a end-to-end use-case.

        Parameters
        ----------
        train_data : DataIter
            Train DataIter.
        eval_data : DataIter
            If not ``None``, will be used as validation set and the performance
            after each epoch will be evaluated.
        eval_metric : str or EvalMetric
            Defaults to 'accuracy'. The performance measure used to display during training.
            Other possible predefined metrics are:
            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Defaults to 'local'.
        optimizer : str or Optimizer
            Defaults to 'sgd'.
        optimizer_params : dict
            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
            the optimizer constructor.
            The default value is not a dict, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each mini-batch during evaluation.
        initializer : Initializer
            The initializer is called to initialize the module parameters when they are
            not already initialized.
        arg_params : dict
            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has a higher priority than `initializer`.
        aux_params : dict
            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Defaults to ``False``. Whether to force rebinding the executors if already bound.
        force_init : bool
            Defaults to ``False``. Indicates whether to force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
            checkpoint saved at a previous training phase at epoch N, then this value should be
            N+1.
        num_epoch : int
            Number of epochs for training.

        Examples
        --------
        >>> # An example of using fit for training.
        >>> # Assume training dataIter and validation dataIter are ready
        >>> # Assume loading a previously checkpointed model
        >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
        >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
        ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
        ...     arg_params=arg_params, aux_params=aux_params,
        ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
        """
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind,
                  grad_req='add')
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        annealing_steps = 0  # number of current annealing steps in current epoch
        redo_training = 0  # Flag to redo training / resample
        val_list = []  # list of validation results per annealing step
        cur_val = 0
        target_prec = 50
        #Note: we want to identify the best cluster of images / training sets with a low percentage
        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            if redo_training:
                annealing_steps = annealing_steps + 1
                self.logger.info('Redoing training to meet criteria = %d',
                                 annealing_steps)
                #sroidb = train_data.roidb #passthrough test

                atick = time.time()

                iterdiff = 1.0
                # Check if we've stagnated
                if len(val_list) > 2:
                    itermean = (val_list[-1] + val_list[-2] + val_list[-3]) / 3
                    iterdiff = abs(itermean - val_list[-1])
                    self.logger.info('Last 3 samples have diff of: %f',
                                     iterdiff)

                if iterdiff < 0.01:
                    self.logger.info(
                        'Reached a stagnated annealing criteria, dumping current samples'
                    )
                    # Do something drastic
                    # Lets try to instantly use the original db
                    sroidb = ogdb

                    # Try to read in another random subset
                    #sroidb = sample_roidb(ogdb, 25) # Sample with removal
                else:
                    # Continue as usual
                    # Select a new random subset
                    newroidb = sample_roidb(ogdb,
                                            15)  # Without removal, this is 10%

                    # Append old with new
                    sroidb = append_roidb(train_data.roidb, newroidb)

                # Create new training data instance by passing most of previous arguments and new random db
                train_data2 = AnchorLoader(
                    train_data.feat_sym,
                    sroidb,
                    train_data.batch_size,
                    train_data.shuffle,
                    train_data.ctx,
                    train_data.work_load_list,
                    train_data.feat_stride,
                    train_data.anchor_scales,
                    train_data.anchor_ratios,
                    train_data.aspect_grouping,
                    nThreads=default.prefetch_thread_num)

                # Overwrite old train_data with the new one
                train_data = train_data2
                data_iter = iter(train_data)

                atock = time.time()
                self.logger.info('Annealing[%d] Time cost=%.3f',
                                 annealing_steps, (atock - atick))
            else:
                data_iter = iter(train_data)
                annealing_steps = 0
                val_list = []
                #target_prec=cur_val+5
                target_prec = target_prec + 5
            end_of_batch = False
            next_data_batch = next(data_iter)

            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                # self.forward_backward(data_batch)
                self.forward(data_batch, is_train=True, grad_req='add')
                self.backward()
                if nbatch % iter_size == 0:  # update every iter_size batches
                    self.update()
                    for g in self._curr_module._exec_group.grad_arrays:
                        for g1 in g:
                            if g1 is not None:
                                g1[:] = 0.

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
            #print('Epoch[%d] Time cost=%.3f', epoch, (toc-tic))
            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    cur_val = callback(epoch, self.symbol, arg_params,
                                       aux_params)

            self.logger.info('Returned Validation=%f', val)
            val_list.append(val)
            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                self.logger.info('Evaluating data')
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            #----------
            # Check epoch if it falls within the validation threshold
            if cur_val < target_prec:
                # Evaluate list of precision/validation results first
                #val_list
                print(eval_data)

                #else
                redo_training = 1
                self.logger.info('Retraining data=%f', val)
            else:
                redo_training = 0

            self.logger.info('Annealing steps=%f', annealing_steps)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Пример #11
0
def fit(self,
        train_data,
        eval_data=None,
        eval_metric='acc',
        epoch_end_callback=None,
        batch_end_callback=None,
        kvstore='local',
        optimizer='sgd',
        optimizer_params=(('learning_rate', 0.01), ),
        eval_end_callback=None,
        eval_batch_end_callback=None,
        initializer=Uniform(0.01),
        arg_params=None,
        aux_params=None,
        allow_missing=False,
        force_rebind=False,
        force_init=False,
        begin_epoch=0,
        num_epoch=None,
        validation_metric=None,
        monitor=None,
        sparse_row_id_fn=None,
        accuracy_target=1.0,
        eval_frequency=1,
        eval_offset=0,
        logger=None):
    assert num_epoch is not None, 'please specify number of epochs'

    if 'horovod' in kvstore:
        rank = hvd.rank()
        local_rank = hvd.local_rank()
    else:
        rank = 0
        local_rank = 0

    profiler_on = os.getenv('RESNET50_PROFILING', False) and (rank == 0)
    if profiler_on:
        self.logger.info("Profiling is enabled")

    stop_iter = int(os.getenv('RESNET50_STOP_ITERATION', '0'))
    if stop_iter > 0:
        self.logger.info(
            "Training will stop at iteration {} of the first epoch".format(
                stop_iter))

    self.bind(data_shapes=train_data.provide_data,
              label_shapes=train_data.provide_label,
              for_training=True,
              force_rebind=force_rebind)
    if monitor is not None:
        self.install_monitor(monitor)
    self.init_params(initializer=initializer,
                     arg_params=arg_params,
                     aux_params=aux_params,
                     allow_missing=allow_missing,
                     force_init=force_init)
    self.init_optimizer(kvstore=kvstore,
                        optimizer=optimizer,
                        optimizer_params=optimizer_params)

    if validation_metric is None:
        validation_metric = eval_metric
    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)

    block_epoch_start = begin_epoch
    block_epoch_count = eval_offset + 1 - (begin_epoch % eval_frequency)
    if block_epoch_count < 0:
        block_epoch_count += eval_frequency
    mll.block_start(block_epoch_start + 1, count=block_epoch_count)

    if profiler_on:
        mx.profiler.set_config(profile_symbolic=True,
                               profile_imperative=True,
                               profile_memory=False,
                               profile_api=True,
                               filename='resnet50_profile.json',
                               aggregate_stats=True)
        mx.profiler.set_state('run')

    ################################################################################
    # training loop
    ################################################################################
    for epoch in range(begin_epoch, num_epoch):
        mll.epoch_start(epoch + 1)
        tic = time.time()
        eval_metric.reset()
        nbatch = 0
        early_stop = False
        data_iter = iter(train_data)
        end_of_batch = False
        next_data_batch = next(data_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if monitor is not None:
                monitor.tic()
            self.forward_backward(data_batch)
            self.update()

            if isinstance(data_batch, list):
                self.update_metric(eval_metric,
                                   [db.label for db in data_batch],
                                   pre_sliced=True)
            else:
                self.update_metric(eval_metric, data_batch.label)

            try:
                # pre fetch next batch
                next_data_batch = next(data_iter)
                self.prepare(next_data_batch,
                             sparse_row_id_fn=sparse_row_id_fn)
            except StopIteration:
                end_of_batch = True

            if monitor is not None:
                monitor.toc_print()

            if end_of_batch:
                eval_name_vals = eval_metric.get_global_name_value()

            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1
            if stop_iter > 0 and nbatch >= stop_iter:
                early_stop = True
                self.logger.info(
                    "Training stopped at {} iteration. Clear RESNET50_STOP_ITERATION if it's not itended."
                    .format(stop_iter))
                break

        if early_stop:
            break

        mll.epoch_stop(epoch + 1)
        # one epoch of training is finished
        if rank == 0:
            for name, val in eval_name_vals:
                self.logger.info('Rank[%d] Epoch[%d] Train-%s=%f', rank, epoch,
                                 name, val)
            toc = time.time()
            self.logger.info('Rank[%d] Epoch[%d] Time cost=%.3f', rank, epoch,
                             (toc - tic))

        # sync aux params across devices
        arg_params, aux_params = self.get_params()
        self.set_params(arg_params, aux_params)

        #----------------------------------------
        # evaluation on validation set
        if eval_data is not None and ((epoch % eval_frequency == eval_offset)
                                      or (epoch + 1 == num_epoch)):
            mll.eval_start(epoch + 1, sync=True)
            res = self.score(eval_data,
                             [validation_metric,
                              CorrectCount(),
                              TotalCount()],
                             score_end_callback=eval_end_callback,
                             batch_end_callback=eval_batch_end_callback,
                             epoch=epoch)
            #TODO: pull this into default
            if rank == 0:
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # temporarily add these two metrics for debugging, can be removed before submission
            res = dict(res)
            correct_count = res['correct-count']
            total_count = res['total-count']
            if 'horovod' in kvstore:
                correct_count = allreduce(correct_count)
                total_count = allreduce(total_count)

            acc = correct_count / total_count
            mll.eval_stop(epoch + 1)
            mll.eval_accuracy(epoch + 1, acc)

            mll.block_stop(block_epoch_start + 1)
            if acc > accuracy_target:
                mll.run_stop(status='success')
                return

            if epoch < num_epoch - 1:
                block_epoch_start = epoch + 1
                block_epoch_count = num_epoch - epoch - 1
                if block_epoch_count > eval_frequency:
                    block_epoch_count = eval_frequency
                mll.block_start(block_epoch_start + 1, count=block_epoch_count)

        # end of 1 epoch, reset the data-iter for another epoch
        train_data.reset()

    if profiler_on:
        mx.profiler.set_state('stop')
        print(mx.profiler.dumps())

    mll.run_stop(status='aborted')
Пример #12
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            best_model_callbacks=None,
            eval_interval=None,
            validation_metric=None,
            monitor=None):

        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)

        if monitor is not None:
            self.install_monitor(monitor)

        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)

        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        if validation_metric is None:
            validation_metric = copy.deepcopy(eval_metric)

        epoch_metric = copy.deepcopy(eval_metric)

        swa_arg_params = None
        swa_aux_params = None
        swa_cnt = 0

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic_epoch = time.time()
            eval_metric.reset()

            nbatch = 0
            end_of_batch = False
            data_iter = iter(train_data)
            next_data_batch = next(data_iter)
            name_values = []

            while not end_of_batch:
                data_batch = next_data_batch

                if monitor is not None:
                    monitor.tic()

                self.forward_backward(data_batch)
                self.update()

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)
                if end_of_batch:
                    name_values = eval_metric.get_name_value()

                if monitor is not None:
                    monitor.toc_print()

                nbatch += 1

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)

                    eval_metric.reset()

                # ----------------------------------------
                # evaluation on validation set
                to_go = eval_interval is not None and nbatch % eval_interval == 0
                if to_go and eval_data:
                    res = self.score(
                        eval_data,
                        validation_metric,
                        score_end_callback=eval_end_callback,
                        batch_end_callback=eval_batch_end_callback,
                        epoch=epoch)
                    for name, val in res:
                        self.logger.info(
                            'Epoch[%d] Batch[%d] Validation-%s=%f', epoch,
                            nbatch, name, val)

                    if best_model_callbacks is not None:
                        for callback in _as_list(best_model_callbacks):
                            if callback.is_best(validation_metric):
                                # sync aux params across devices
                                arg_params, aux_params = self.get_params()
                                sync_made = True
                                callback.checkpoint_if_only_best(
                                    validation_metric, self.symbol, arg_params,
                                    aux_params)
                                break

            # one epoch of training is finished
            for name, val in name_values:
                self.logger.info('Epoch[%d] Train-%s=%f', epoch + 1, name, val)
            toc_epoch = time.time()
            elapsed = (toc_epoch - tic_epoch)
            avg_speed = float(len(train_data)) / (toc_epoch - tic_epoch)
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch + 1, elapsed)
            self.logger.info('Epoch[%d] Average speed=%.3f samples/sec',
                             epoch + 1, avg_speed)

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch + 1)
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch + 1,
                                     name, val)

                if best_model_callbacks is not None:
                    for callback in _as_list(best_model_callbacks):
                        callback.checkpoint_if_only_best(
                            validation_metric, self.symbol, arg_params,
                            aux_params)

            # end of epoch, reset the data-iter for another epoch
            train_data.reset()
Пример #13
0
def callback_for_metric(name_value, callback_list):
    for callback in _as_list(callback_list):
        callback(name_value)
Пример #14
0
def test_while_loop_rnn():
    def _array(shape):
        return mx.nd.random.uniform(-1.0, 1.0, shape=shape)

    cell_types = [mx.rnn.LSTMCell]
    num_params = [2]

    batch_size = 2
    hidden_dim = 3
    input_dim = 4
    seq_len = 3

    for cell, n_param in zip(cell_types, num_params):
        # using while_loop
        params = mx.rnn.RNNParams()
        data = mx.sym.var("data")
        iter_i = mx.sym.var("i")

        def _cond(*states):
            i = states[0]
            return i < seq_len

        def _func(*states):
            i = states[0]
            states = states[1:]
            in_ = data.take(i).squeeze(axis=0)
            rnn = cell(hidden_dim, prefix='', params=params)
            next_hidden, next_states = rnn(in_, states)
            return [next_hidden], [i + 1] + list(next_states)

        states = [mx.sym.var("s_" + str(i)) for i in range(n_param)]
        result = mx.sym.contrib.while_loop(cond=_cond,
                                           func=_func,
                                           loop_vars=[iter_i] + states,
                                           max_iterations=seq_len)
        result = mx.sym.Group(result[0] + result[1][1:])
        arg_shapes, _, _ = result.infer_shape(
            data=(seq_len, batch_size, input_dim),
            s_0=(batch_size, hidden_dim),
        )
        rnn_inputs = result.list_inputs()
        args = {
            name: _array(arg_shapes[i])
            for i, name in enumerate(rnn_inputs) if name != "i"
        }
        args["i"] = mx.nd.zeros([1])
        args_grad = {
            name: _array(arg_shapes[i])
            for i, name in enumerate(rnn_inputs)
        }
        e_1 = result.bind(
            ctx=default_context(),
            args={name: array.copy()
                  for name, array in args.items()},
            args_grad={
                name: array.copy()
                for name, array in args_grad.items() if name != "i"
            },
        )
        # using unrolled rnn
        rnn = cell(hidden_dim, prefix='')
        unroll_outs = []
        for inputs in mx.sym.split(data,
                                   num_outputs=seq_len,
                                   axis=0,
                                   squeeze_axis=True):
            h, states = rnn(inputs, states)
            unroll_outs.append(mx.sym.expand_dims(h, axis=0))
        unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0))
        unroll_outs.extend(states)
        result = mx.sym.Group(unroll_outs)
        e_2 = result.bind(
            ctx=default_context(),
            args={
                name: array.copy()
                for name, array in args.items() if name != "i"
            },
            args_grad={
                name: array.copy()
                for name, array in args_grad.items() if name != "i"
            },
        )
        for case_id in range(100):
            out_grads = [_array(arr.shape) for arr in e_1.outputs]
            args = {name: array.copy() for name, array in args.items()}
            e_1.forward(is_train=True, **args)
            e_1.backward(out_grads)
            args = {
                name: array.copy()
                for name, array in args.items() if name != "i"
            }
            e_2.forward(is_train=True, **args)
            e_2.backward(out_grads)
            assert len(e_1.outputs) == len(e_2.outputs)
            for x, y in zip(e_1.outputs, e_2.outputs):
                x = x.asnumpy()
                y = y.asnumpy()
                assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
            grad_keys = list(e_2.grad_dict.keys())
            e_1_grad = [e_1.grad_dict[x] for x in grad_keys]
            e_2_grad = [e_2.grad_dict[x] for x in grad_keys]
            for x, y in zip(e_1_grad, e_2_grad):
                x = x.asnumpy()
                y = y.asnumpy()
                assert_almost_equal(x, y, rtol=1e-4, atol=1e-4)
Пример #15
0
def mlperf_fit(self,
               args,
               data_loader,
               epoch_size,
               eval_metric='acc',
               epoch_end_callback=None,
               batch_end_callback=None,
               kvstore='local',
               optimizer='sgd',
               optimizer_params=(('learning_rate', 0.01), ),
               explorer='linear',
               explorer_params=None,
               eval_end_callback=None,
               eval_batch_end_callback=None,
               initializer=Uniform(0.01),
               arg_params=None,
               aux_params=None,
               allow_missing=False,
               force_rebind=False,
               force_init=False,
               begin_epoch=0,
               num_epoch=None,
               validation_metric=None,
               monitor=None,
               sparse_row_id_fn=None,
               eval_offset=0,
               eval_period=1,
               accuracy_threshold=1.0):

    assert num_epoch is not None, 'please specify number of epochs'

    if monitor is not None:
        self.install_monitor(monitor)

    self.init_optimizer(kvstore=kvstore,
                        optimizer=optimizer,
                        optimizer_params=optimizer_params)

    explorer = Explorer.create_explorer(name=explorer,
                                        optimizer=self._optimizer,
                                        explorer_params=explorer_params)
    #This mxnet can not use several optimizers without sgd series
    explorer.set_best_coeff(0)
    explorer.set_best_wd_coeff(0)
    explorer.set_best_cg(0)
    exp_freq = explorer_params['explore_freq']
    exp_start_epoch = explorer_params['explore_start_epoch']

    if validation_metric is None:
        validation_metric = eval_metric
    ###########################################################################
    # Adding Correct and Total Count metrics
    ###########################################################################
    if not isinstance(validation_metric, list):
        validation_metric = [validation_metric]

    validation_metric = mx.metric.create(validation_metric)

    if not isinstance(validation_metric, mx.metric.CompositeEvalMetric):
        vm = mx.metric.CompositeEvalMetric()
        vm.append(validation_metric)
        validation_metric = vm

    for m in [CorrectCount(), TotalCount()]:
        validation_metric.metrics.append(m)
    ###########################################################################

    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)

    try:
        world_rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
        world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
    except:
        world_rank = 0
        world_size = 1

    use_cval_data =    explorer_params['add_one_fwd_epoch'] < num_epoch \
                    or explorer_params['no_augument_epoch'] < num_epoch

    best_rank = 0
    self.prepare_states()

    mx_resnet_print(key=mlperf_constants.INIT_STOP, sync=True)
    mx_resnet_print(key=mlperf_constants.RUN_START, sync=True)

    # data iterators
    (train_data, eval_data, cval_data) = data_loader(args, kvstore)
    if 'dist' in args.kv_store and not 'async' in args.kv_store:
        logging.info('Resizing training data to %d batches per machine',
                     epoch_size)
        # resize train iter to ensure each machine has same number of batches per epoch
        # if not, dist_sync can hang at the end with one machine waiting for other machines
        if not args.use_dali:
            train = mx.io.ResizeIter(train_data, epoch_size)

    block_epoch_start = begin_epoch
    block_epoch_count = eval_offset + 1 - (begin_epoch % eval_period)
    if block_epoch_count < 0:
        block_epoch_count += eval_period
    mx_resnet_print(key=mlperf_constants.BLOCK_START,
                    metadata={
                        'first_epoch_num': block_epoch_start + 1,
                        'epoch_count': block_epoch_count
                    })
    ################################################################################
    # training loop
    ################################################################################

    for epoch in range(begin_epoch, num_epoch):
        mx_resnet_print(key=mlperf_constants.EPOCH_START,
                        metadata={'epoch_num': epoch + 1})
        tic = time.time()
        eval_metric.reset()
        nbatch = 0

        use_normal_data_batch = epoch < explorer_params['no_augument_epoch']
        if not use_normal_data_batch:
            if world_rank == 0:
                self.logger.info('use non-augumented batch')

        end_of_batch = False

        if use_normal_data_batch:
            data_iter = iter(train_data)
            next_data_batch = next(data_iter)
        else:
            cval_iter = iter(cval_data)
            next_cval_batch = next(cval_iter)

        smooth_decay = explorer_params['smooth_decay']

        if not smooth_decay:
            explorer.apply_lr_decay_epoch(epoch)
            explorer.apply_wd_decay_epoch(epoch)
        explorer.set_mom(epoch)

        while not end_of_batch:
            if use_normal_data_batch:
                data_batch = next_data_batch
            else:
                cval_batch = next_cval_batch
            if monitor is not None:
                monitor.tic()

            if use_normal_data_batch:
                self.forward_backward(data_batch)
            else:
                self.forward_backward(cval_batch)

            if smooth_decay:
                explorer.apply_lr_decay_iter()
                explorer.apply_wd_decay_iter()
            explorer.apply_wd_warmup()
            explorer.apply_burn_in()

            use_explorer = (epoch == 0
                            and nbatch == 0) or (epoch >= exp_start_epoch
                                                 and nbatch % exp_freq == 0)
            if use_explorer:
                explorer.set_tmp_coeff(world_rank)
                explorer.set_tmp_wd_coeff(world_rank)
                explorer.set_tmp_cg(world_rank)

            explorer.set_best_coeff(0)
            explorer.set_best_wd_coeff(0)
            explorer.set_best_cg(world_rank)
            self.update()

            if use_normal_data_batch:
                if isinstance(data_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in data_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, data_batch.label)
            else:
                if isinstance(cval_batch, list):
                    self.update_metric(eval_metric,
                                       [db.label for db in cval_batch],
                                       pre_sliced=True)
                else:
                    self.update_metric(eval_metric, cval_batch.label)

            if use_normal_data_batch:
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                except StopIteration:
                    end_of_batch = True
            else:
                try:
                    # pre fetch next cval batch
                    next_cval_batch = next(cval_iter)
                except StopIteration:
                    end_of_batch = True

            if use_normal_data_batch:
                if not end_of_batch:
                    self.prepare(next_data_batch,
                                 sparse_row_id_fn=sparse_row_id_fn)
            else:
                if not end_of_batch:
                    self.prepare(next_cval_batch,
                                 sparse_row_id_fn=sparse_row_id_fn)

            if monitor is not None:
                monitor.toc_print()

            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1

        mx_resnet_print(key=mlperf_constants.EPOCH_STOP,
                        metadata={"epoch_num": epoch + 1})
        # one epoch of training is finished
        toc = time.time()
        if kvstore:
            if kvstore.rank == 0:
                self.logger.info('Epoch[%d] Time cost=%.3f', epoch,
                                 (toc - tic))
        else:
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        # sync aux params across devices
        #arg_params, aux_params = self.get_params()
        #self.set_params(arg_params, aux_params)

        if epoch_end_callback is not None:
            for callback in _as_list(epoch_end_callback):
                callback(epoch, self.symbol, arg_params, aux_params)

        #----------------------------------------
        # evaluation on validation set
        if eval_data and epoch >= eval_offset and (
                epoch - eval_offset) % eval_period == 0:
            mx_resnet_print(key=mlperf_constants.EVAL_START,
                            metadata={'epoch_num': epoch + 1})
            res = self.score(eval_data,
                             validation_metric,
                             score_end_callback=eval_end_callback,
                             batch_end_callback=eval_batch_end_callback,
                             epoch=epoch)
            #TODO: pull this into default
            if kvstore:
                if kvstore.rank == 0:
                    for name, val in res:
                        self.logger.info('Epoch[%d] Validation-%s=%f', epoch,
                                         name, val)
            else:
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)
            res = dict(res)

            acc = [res['correct-count'], res['total-count']]
            acc = all_reduce(acc)
            acc = acc[0] / acc[1]
            mx_resnet_print(key=mlperf_constants.EVAL_STOP,
                            metadata={'epoch_num': epoch + 1})

            mx_resnet_print(key=mlperf_constants.EVAL_ACCURACY,
                            val=acc,
                            metadata={'epoch_num': epoch + 1})

            mx_resnet_print(
                key=mlperf_constants.BLOCK_STOP,
                metadata={'first_epoch_num': block_epoch_start + 1})
            if acc > accuracy_threshold:
                mx_resnet_print(key=mlperf_constants.RUN_STOP,
                                metadata={'status': 'success'})

                return epoch

            if epoch < (num_epoch - 1):
                block_epoch_start = epoch + 1
                block_epoch_count = num_epoch - epoch - 1
                if block_epoch_count > eval_period:
                    block_epoch_count = eval_period
                mx_resnet_print(key=mlperf_constants.BLOCK_START,
                                metadata={
                                    'first_epoch_num': block_epoch_start + 1,
                                    'epoch_count': block_epoch_count
                                })

        # end of 1 epoch, reset the data-iter for another epoch
        if use_normal_data_batch:
            train_data.reset()
        else:
            cval_data.reset()

    mx_resnet_print(key=mlperf_constants.RUN_STOP,
                    metadata={'status': 'aborted'})
    return num_epoch
Пример #16
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=mx.initializer.Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):

        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, mx.metric.EvalMetric):
            eval_metric = mx.metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward_backward(data_batch)
                self.update()

                self.update_metric(eval_metric, data_batch.label)

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = mx.model.BatchEndParam(
                        epoch=epoch,
                        nbatch=nbatch,
                        eval_metric=eval_metric,
                        locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Пример #17
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):

        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################

        last_grad_debug = None

        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward_backward(data_batch)

                # grad_array = [[grad.copyto(grad.context) if grad is not None else None for grad in grads] for grads in
                #             self._curr_module._exec_group.grad_arrays]
                #
                # for exec_ in self._curr_module._exec_group.execs:
                #     grad_dict = exec_.grad_dict
                #
                # grad_debug = dict()
                # for k, v in grad_dict.items():
                #     if v is not None:
                #         v_np = v.asnumpy()
                #         grad_debug[k] = (np.min(v_np), np.max(v_np))
                # print 'rpn_conv_cls_weight:', grad_debug['rpn_conv_cls_weight']
                # print 'rcnn_fc_cls_weight:', grad_debug['rcnn_fc_cls_weight']

                self.update()
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Пример #18
0
def callback_for_checkpoint(params, epoch_end_callback):
    symbol, arg_params, aux_params = params
    for callback in _as_list(epoch_end_callback):
        callback(epoch, symbol, arg_params, aux_params)
Пример #19
0
 def step4(state, state2):
     states = _as_list(state)
     states.append(state2)
     return state, states
Пример #20
0
def test_output_format_while():
    class TestLayer1(gluon.HybridBlock):
        def __init__(self, step, use_list, nested_list=False):
            super(TestLayer1, self).__init__()
            self.step = step
            self.use_list = use_list
            self.nested_list = nested_list

        def forward(self, states):
            def cond(state1):
                scalar = mx.npx.slice(state1, begin=0, end=1)
                return scalar == scalar

            cond_func = cond
            if self.use_list:
                states = [states]
            elif self.nested_list:

                def cond2(state1, state2):
                    scalar = mx.npx.slice(state1, begin=0, end=1)
                    return scalar == scalar

                cond_func = cond2
                states = [states, [states + 1]]
            out, states = mx.npx.while_loop(cond_func,
                                            self.step,
                                            states,
                                            max_iterations=5)
            return out, states

    def step1(state):
        return state, state

    def step2(state):
        if isinstance(state, list):
            return state, state
        else:
            return [state], state

    def step3(state):
        return [], state

    steps = [step1, step2, step3]
    state = mx.np.random.normal(loc=0, scale=1, size=(2))
    for step in steps:
        layer1 = TestLayer1(step, False)
        layer1.initialize(ctx=default_context())
        layer2 = TestLayer1(step, False)
        layer2.initialize(ctx=default_context())
        layer2.hybridize()
        out1, state1 = layer1(state)
        out2, state2 = layer2(state)
        assert type(out1) == type(out2)
        assert type(state1) == type(state1)
        out1 = _as_list(out1)
        out2 = _as_list(out2)
        state1 = _as_list(state1)
        state2 = _as_list(state2)
        for i in range(len(out1)):
            assert_almost_equal(out1[i].asnumpy(),
                                out2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)
        for i in range(len(state1)):
            assert_almost_equal(state1[i].asnumpy(),
                                state2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)

        layer1 = TestLayer1(step, True)
        layer1.initialize(ctx=default_context())
        layer2 = TestLayer1(step, True)
        layer2.initialize(ctx=default_context())
        layer2.hybridize()
        out1, state1 = layer1(state)
        out2, state2 = layer2(state)
        assert type(out1) == type(out2)
        assert type(state1) == type(state2)
        out1 = _as_list(out1)
        out2 = _as_list(out2)
        state1 = _as_list(state1)
        state2 = _as_list(state2)
        for i in range(len(out1)):
            assert_almost_equal(out1[i].asnumpy(),
                                out2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)
        for i in range(len(state1)):
            assert_almost_equal(state1[i].asnumpy(),
                                state2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)

    def step4(state, state2):
        states = _as_list(state)
        states.append(state2)
        return state, states

    def step5(state, state2):
        states = _as_list(state)
        states.append(state2)
        if isinstance(state, list):
            return state, states
        else:
            return [state], states

    def step6(state, state2):
        states = _as_list(state)
        states.append(state2)
        return [], states

    steps = [step4, step5, step6]
    for step in steps:
        layer1 = TestLayer1(step, False, True)
        layer1.initialize(ctx=default_context())
        layer2 = TestLayer1(step, False, True)
        layer2.initialize(ctx=default_context())
        layer2.hybridize()
        out1, state1 = layer1(state)
        out2, state2 = layer2(state)
        assert type(out1) == type(out2)
        assert type(state1) == type(state2)
        out1 = _as_list(out1)
        out2 = _as_list(out2)
        state1 = _as_list(state1)
        state2 = _as_list(state2)
        for i in range(len(out1)):
            assert_almost_equal(out1[i].asnumpy(),
                                out2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)
        for i in range(len(state1)):
            if not isinstance(state1[i], list):
                assert_almost_equal(state1[i].asnumpy(),
                                    state2[i].asnumpy(),
                                    rtol=0.001,
                                    atol=0.0001)
Пример #21
0
def test_output_format_foreach():
    class TestLayer1(gluon.HybridBlock):
        def __init__(self, step):
            super(TestLayer1, self).__init__()
            self.step = step

        def forward(self, ins, states):
            out, states = mx.npx.foreach(self.step, ins, states)
            return out, states

    def step1(data, state):
        return data, state

    def step2(data, state):
        return [data], state

    def step3(data, state):
        if isinstance(state, list):
            return [], [state[0] + data]
        else:
            return [], state + data

    def step4(data, state):
        if isinstance(state, list):
            return [data, state[0]], state
        else:
            return [data, state], state

    steps = [step1, step2, step3, step4]
    data = mx.np.random.normal(loc=0, scale=1, size=(10, 2))
    state = mx.np.random.normal(loc=0, scale=1, size=(2))
    for step in steps:
        layer1 = TestLayer1(step)
        layer1.initialize(ctx=default_context())
        layer2 = TestLayer1(step)
        layer2.initialize(ctx=default_context())
        layer2.hybridize()
        out1, state1 = layer1(data, [state])
        out2, state2 = layer2(data, [state])
        step_out, step_state = step(data, [state])
        assert type(out1) == type(step_out)
        assert type(out2) == type(step_out)
        assert type(state1) == type(step_state)
        assert type(state2) == type(step_state)
        out1 = _as_list(out1)
        out2 = _as_list(out2)
        state1 = _as_list(state1)
        state2 = _as_list(state2)
        for i in range(len(out1)):
            assert_almost_equal(out1[i].asnumpy(),
                                out2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)
        for i in range(len(state1)):
            assert_almost_equal(state1[i].asnumpy(),
                                state2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)

        layer1 = TestLayer1(step)
        layer1.initialize(ctx=default_context())
        layer2 = TestLayer1(step)
        layer2.initialize(ctx=default_context())
        layer2.hybridize()
        out1, state1 = layer1(data, state)
        out2, state2 = layer2(data, state)
        step_out, step_state = step(data, state)
        assert type(out1) == type(step_out)
        assert type(out2) == type(step_out)
        assert type(state1) == type(step_state)
        assert type(state2) == type(step_state)
        out1 = _as_list(out1)
        out2 = _as_list(out2)
        state1 = _as_list(state1)
        state2 = _as_list(state2)
        for i in range(len(out1)):
            assert_almost_equal(out1[i].asnumpy(),
                                out2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)
        for i in range(len(state1)):
            assert_almost_equal(state1[i].asnumpy(),
                                state2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)

        if step == step3:
            continue
        layer1 = TestLayer1(step)
        layer1.initialize(ctx=default_context())
        layer2 = TestLayer1(step)
        layer2.initialize(ctx=default_context())
        layer2.hybridize()
        out1, state1 = layer1(data, [state, [state + 1]])
        out2, state2 = layer2(data, [state, [state + 1]])
        step_out, step_state = step(data, [state, [state + 1]])
        assert type(out1) == type(step_out)
        assert type(out2) == type(step_out)
        assert type(state1) == type(step_state)
        assert type(state2) == type(step_state)
        out1 = _as_list(out1)
        out2 = _as_list(out2)
        state1 = _as_list(state1)
        state2 = _as_list(state2)
        for i in range(len(out1)):
            assert_almost_equal(out1[i].asnumpy(),
                                out2[i].asnumpy(),
                                rtol=0.001,
                                atol=0.0001)
        for i in range(len(state1)):
            if isinstance(state1[i], list):
                assert_almost_equal(state1[i][0].asnumpy(),
                                    state2[i][0].asnumpy(),
                                    rtol=0.001,
                                    atol=0.0001)
            else:
                assert_almost_equal(state1[i].asnumpy(),
                                    state2[i].asnumpy(),
                                    rtol=0.001,
                                    atol=0.0001)
Пример #22
0
def mlperf_fit(self,
               train_data,
               eval_data=None,
               eval_metric='acc',
               epoch_end_callback=None,
               batch_end_callback=None,
               kvstore='local',
               optimizer='sgd',
               optimizer_params=(('learning_rate', 0.01), ),
               eval_end_callback=None,
               eval_batch_end_callback=None,
               initializer=Uniform(0.01),
               arg_params=None,
               aux_params=None,
               allow_missing=False,
               force_rebind=False,
               force_init=False,
               begin_epoch=0,
               num_epoch=None,
               validation_metric=None,
               monitor=None,
               sparse_row_id_fn=None,
               eval_offset=0,
               eval_period=1,
               accuracy_threshold=1.0):

    assert num_epoch is not None, 'please specify number of epochs'

    self.bind(data_shapes=train_data.provide_data,
              label_shapes=train_data.provide_label,
              for_training=True,
              force_rebind=force_rebind)

    if monitor is not None:
        self.install_monitor(monitor)

    self.init_params(initializer=initializer,
                     arg_params=arg_params,
                     aux_params=aux_params,
                     allow_missing=allow_missing,
                     force_init=force_init)
    self.init_optimizer(kvstore=kvstore,
                        optimizer=optimizer,
                        optimizer_params=optimizer_params)

    if validation_metric is None:
        validation_metric = eval_metric
    ###########################################################################
    # Adding Correct and Total Count metrics
    ###########################################################################
    if not isinstance(validation_metric, list):
        validation_metric = [validation_metric]

    validation_metric = mx.metric.create(validation_metric)

    if not isinstance(validation_metric, mx.metric.CompositeEvalMetric):
        vm = mx.metric.CompositeEvalMetric()
        vm.append(validation_metric)
        validation_metric = vm

    for m in [CorrectCount(), TotalCount()]:
        validation_metric.metrics.append(m)
    ###########################################################################

    if not isinstance(eval_metric, mx.metric.EvalMetric):
        eval_metric = mx.metric.create(eval_metric)

    mx_resnet_print(key=mlperf_log.TRAIN_LOOP)
    ################################################################################
    # training loop
    ################################################################################
    for epoch in range(begin_epoch, num_epoch):
        mx_resnet_print(key=mlperf_log.TRAIN_EPOCH, val=epoch)
        tic = time.time()
        eval_metric.reset()
        nbatch = 0
        data_iter = iter(train_data)
        end_of_batch = False
        next_data_batch = next(data_iter)
        while not end_of_batch:
            data_batch = next_data_batch
            if monitor is not None:
                monitor.tic()
            self.forward_backward(data_batch)
            self.update()

            if isinstance(data_batch, list):
                self.update_metric(eval_metric,
                                   [db.label for db in data_batch],
                                   pre_sliced=True)
            else:
                self.update_metric(eval_metric, data_batch.label)

            try:
                # pre fetch next batch
                next_data_batch = next(data_iter)
                self.prepare(next_data_batch,
                             sparse_row_id_fn=sparse_row_id_fn)
            except StopIteration:
                end_of_batch = True

            if monitor is not None:
                monitor.toc_print()

            if batch_end_callback is not None:
                batch_end_params = BatchEndParam(epoch=epoch,
                                                 nbatch=nbatch,
                                                 eval_metric=eval_metric,
                                                 locals=locals())
                for callback in _as_list(batch_end_callback):
                    callback(batch_end_params)
            nbatch += 1

        # one epoch of training is finished
        toc = time.time()
        if kvstore:
            if kvstore.rank == 0:
                self.logger.info('Epoch[%d] Time cost=%.3f', epoch,
                                 (toc - tic))
        else:
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

        # sync aux params across devices
        arg_params, aux_params = self.get_params()
        self.set_params(arg_params, aux_params)

        if epoch_end_callback is not None:
            for callback in _as_list(epoch_end_callback):
                callback(epoch, self.symbol, arg_params, aux_params)

        #----------------------------------------
        # evaluation on validation set
        if eval_data and epoch % eval_period == eval_offset:
            mx_resnet_print(key=mlperf_log.EVAL_START)
            res = self.score(eval_data,
                             validation_metric,
                             score_end_callback=eval_end_callback,
                             batch_end_callback=eval_batch_end_callback,
                             epoch=epoch)
            #TODO: pull this into default
            if kvstore:
                if kvstore.rank == 0:
                    for name, val in res:
                        self.logger.info('Epoch[%d] Validation-%s=%f', epoch,
                                         name, val)
            else:
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)
            res = dict(res)

            acc = [res['correct-count'], res['total-count']]
            acc = all_reduce(acc)
            acc = acc[0] / acc[1]
            mx_resnet_print(key=mlperf_log.EVAL_ACCURACY,
                            val={
                                "epoch": epoch,
                                "value": acc
                            })
            mx_resnet_print(key=mlperf_log.EVAL_STOP)
            if acc > accuracy_threshold:
                return epoch

        # end of 1 epoch, reset the data-iter for another epoch
        train_data.reset()

    return num_epoch
Пример #23
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind)
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)
        ####chris_arg
        if int(os.getenv("TASK_LIMIT",
                         0)) != 0:  #为0时不分task限制,为1时分task但是每轮更新,为2时分task并但固定
            get_task_cmd = "sh /home/ubuntu/tc.sh -l 1"
        else:
            self.logger.info("no_task_bandwidth_limit")
            get_task_cmd = "sh /home/ubuntu/tc.sh -l 0"
        os.system(get_task_cmd)
        delay_time = float(os.getenv("DELAY_TIME", 0.8))
        ps_upload_bandwidth_part1 = int(os.getenv("PS_UPLOAD_BANDWIDTH1",
                                                  2000))
        worker_upload_bandwidth_part1 = int(
            os.getenv("WORKER_UPLOAD_BANDWIDTH1", 2000))
        ps_upload_bandwidth_part2 = int(os.getenv("PS_UPLOAD_BANDWIDTH2",
                                                  2000))
        worker_upload_bandwidth_part2 = int(
            os.getenv("WORKER_UPLOAD_BANDWIDTH2", 2000))
        tc_command = "sudo tc class change dev {} parent 1: classid 1:3 htb rate {}mbit ceil {}mbit  && sudo tc class change dev {} parent 1: classid 1:4 htb rate {}mbit ceil {}mbit"
        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)
            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                self.forward(data_batch, is_train=True)
                if int(os.getenv("TASK_LIMIT", 0)) == 1:
                    ##first part bandwidth allocation
                    ndarray.waitall()
                    # self.logger.info("change bandwidth part1:, "+str(time.time()))
                    x = str(ps_upload_bandwidth_part1)
                    y = str(worker_upload_bandwidth_part1)
                    cmd_up = tc_command.format("ens3", x, x, "ens3", y, y)
                    cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x)
                    os.system(cmd_up)
                    # os.system(cmd_down)
                # self.logger.info("after forward, "+str(time.time()))
                self.backward()
                # self.logger.info("before update: "+str(time.time()))
                self.update()  #异步执行的
                if int(os.getenv("TASK_LIMIT", 0)) == 1:
                    x = str(ps_upload_bandwidth_part2)
                    y = str(worker_upload_bandwidth_part2)
                    cmd_up = tc_command.format("ens3", x, x, "ens3", y, y)
                    cmd_down = tc_command.format("ifb0", y, y, "ifb0", x, x)
                    time.sleep(delay_time)
                    ##second part bandwidth allocation
                    # self.logger.info("change bandwidth part2:, "+str(time.time()))
                    os.system(cmd_up)
                    # os.system(cmd_down)
                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()
Пример #24
0
    def fit(self,
            train_data,
            eval_data=None,
            eval_metric='acc',
            epoch_end_callback=None,
            batch_end_callback=None,
            kvstore='local',
            optimizer='sgd',
            optimizer_params=(('learning_rate', 0.01), ),
            eval_end_callback=None,
            iter_size=1,
            eval_batch_end_callback=None,
            initializer=Uniform(0.01),
            arg_params=None,
            aux_params=None,
            allow_missing=False,
            force_rebind=False,
            force_init=False,
            begin_epoch=0,
            num_epoch=None,
            validation_metric=None,
            monitor=None):
        """Ke's revision: add iter_size. Trains the module parameters.

        Checkout `Module Tutorial <http://mxnet.io/tutorials/basic/module.html>`_ to see
        a end-to-end use-case.

        Parameters
        ----------
        train_data : DataIter
            Train DataIter.
        eval_data : DataIter
            If not ``None``, will be used as validation set and the performance
            after each epoch will be evaluated.
        eval_metric : str or EvalMetric
            Defaults to 'accuracy'. The performance measure used to display during training.
            Other possible predefined metrics are:
            'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
        epoch_end_callback : function or list of functions
            Each callback will be called with the current `epoch`, `symbol`, `arg_params`
            and `aux_params`.
        batch_end_callback : function or list of function
            Each callback will be called with a `BatchEndParam`.
        kvstore : str or KVStore
            Defaults to 'local'.
        optimizer : str or Optimizer
            Defaults to 'sgd'.
        optimizer_params : dict
            Defaults to ``(('learning_rate', 0.01),)``. The parameters for
            the optimizer constructor.
            The default value is not a dict, just to avoid pylint warning on dangerous
            default values.
        eval_end_callback : function or list of function
            These will be called at the end of each full evaluation, with the metrics over
            the entire evaluation set.
        eval_batch_end_callback : function or list of function
            These will be called at the end of each mini-batch during evaluation.
        initializer : Initializer
            The initializer is called to initialize the module parameters when they are
            not already initialized.
        arg_params : dict
            Defaults to ``None``, if not ``None``, should be existing parameters from a trained
            model or loaded from a checkpoint (previously saved model). In this case,
            the value here will be used to initialize the module parameters, unless they
            are already initialized by the user via a call to `init_params` or `fit`.
            `arg_params` has a higher priority than `initializer`.
        aux_params : dict
            Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
        allow_missing : bool
            Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
            and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
            will be initialized via the `initializer`.
        force_rebind : bool
            Defaults to ``False``. Whether to force rebinding the executors if already bound.
        force_init : bool
            Defaults to ``False``. Indicates whether to force initialization even if the
            parameters are already initialized.
        begin_epoch : int
            Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
            checkpoint saved at a previous training phase at epoch N, then this value should be
            N+1.
        num_epoch : int
            Number of epochs for training.

        Examples
        --------
        >>> # An example of using fit for training.
        >>> # Assume training dataIter and validation dataIter are ready
        >>> # Assume loading a previously checkpointed model
        >>> sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, 3)
        >>> mod.fit(train_data=train_dataiter, eval_data=val_dataiter, optimizer='sgd',
        ...     optimizer_params={'learning_rate':0.01, 'momentum': 0.9},
        ...     arg_params=arg_params, aux_params=aux_params,
        ...     eval_metric='acc', num_epoch=10, begin_epoch=3)
        """
        assert num_epoch is not None, 'please specify number of epochs'

        self.bind(data_shapes=train_data.provide_data,
                  label_shapes=train_data.provide_label,
                  for_training=True,
                  force_rebind=force_rebind,
                  grad_req='add')
        if monitor is not None:
            self.install_monitor(monitor)
        self.init_params(initializer=initializer,
                         arg_params=arg_params,
                         aux_params=aux_params,
                         allow_missing=allow_missing,
                         force_init=force_init)
        self.init_optimizer(kvstore=kvstore,
                            optimizer=optimizer,
                            optimizer_params=optimizer_params)

        if validation_metric is None:
            validation_metric = eval_metric
        if not isinstance(eval_metric, metric.EvalMetric):
            eval_metric = metric.create(eval_metric)

        ################################################################################
        # training loop
        ################################################################################
        for epoch in range(begin_epoch, num_epoch):
            tic = time.time()
            eval_metric.reset()
            nbatch = 0
            data_iter = iter(train_data)
            end_of_batch = False
            next_data_batch = next(data_iter)

            while not end_of_batch:
                data_batch = next_data_batch
                if monitor is not None:
                    monitor.tic()
                # self.forward_backward(data_batch)
                self.forward(data_batch, is_train=True, grad_req='add')
                self.backward()
                if nbatch % iter_size == 0:  # update every iter_size batches
                    self.update()
                    for g in self._curr_module._exec_group.grad_arrays:
                        for g1 in g:
                            if g1 is not None:
                                g1[:] = 0.

                try:
                    # pre fetch next batch
                    next_data_batch = next(data_iter)
                    self.prepare(next_data_batch)
                except StopIteration:
                    end_of_batch = True

                self.update_metric(eval_metric, data_batch.label)

                if monitor is not None:
                    monitor.toc_print()

                if batch_end_callback is not None:
                    batch_end_params = BatchEndParam(epoch=epoch,
                                                     nbatch=nbatch,
                                                     eval_metric=eval_metric,
                                                     locals=locals())
                    for callback in _as_list(batch_end_callback):
                        callback(batch_end_params)
                nbatch += 1

            # one epoch of training is finished
            for name, val in eval_metric.get_name_value():
                self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
            toc = time.time()
            self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))

            # sync aux params across devices
            arg_params, aux_params = self.get_params()
            self.set_params(arg_params, aux_params)

            if epoch_end_callback is not None:
                for callback in _as_list(epoch_end_callback):
                    callback(epoch, self.symbol, arg_params, aux_params)

            #----------------------------------------
            # evaluation on validation set
            if eval_data:
                res = self.score(eval_data,
                                 validation_metric,
                                 score_end_callback=eval_end_callback,
                                 batch_end_callback=eval_batch_end_callback,
                                 epoch=epoch)
                #TODO: pull this into default
                for name, val in res:
                    self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name,
                                     val)

            # end of 1 epoch, reset the data-iter for another epoch
            train_data.reset()