示例#1
0
 def grad_with_element_cap(pv):
     g = grad_with_element_cap_orig_grad(pv)
     elems = xp.where(xp.abs(g) > self.cfg.gradient_element_cap)
     g[elems] = xp.sign(g[elems]) * self.cfg.gradient_element_cap
     return g
示例#2
0
 def grad_with_element_cap(pv):
     g = grad_with_element_cap_orig_grad(pv)
     elems = xp.where(xp.abs(g) > self.cfg.gradient_element_cap)
     g[elems] = xp.sign(g[elems]) * self.cfg.gradient_element_cap
     return g
示例#3
0
    def generic_training(self,
                         cfg_dir,
                         checkpoint=None,
                         checkpoint_handler=None,
                         loss_record_interval=10,
                         max_missed_val_improvements=200,
                         iteration_gain=1.25,
                         min_improvement=1e-7,
                         reset_termination_criteria=False,
                         desired_loss=None,
                         initialize=True,
                         large_gradient_threshold=0.0,
                         print_gradient_info=False,
                         print_gradient=False,
                         print_parameters=False,
                         log_parameters=[],
                         plot_logged_parameters=True,
                         print_logged_parameters=False,
                         check_gradient_finite=False,
                         log_modules=None):
        """
        Generic training procedure.
        :param cfg_dir: configuration directory
        :param checkpoint: checkpoint
        :param checkpoint_handler: checkpoint handler
        :param loss_record_interval: number of iterations between calculating and recording losses
        :param max_missed_val_improvements: maximum iterations without improvement of loss before training ist stopped
        :param iteration_gain: If not set to 0, then training is performed up to iteration
                               (iteration_gain * iteration_of_last_improvement).
        :param min_improvement: minimum loss change to count as improvement
        :param reset_termination_criteria: resets the termination criteria after loading a checkpoint
        :param desired_loss: if specified, training is terminated with this loss is reached
        :param initialize: if True, the model parameters are initialized using the init_parameters method.
        :param large_gradient_threshold: if specified, a check for large gradient elements that exceed the
                                         given threshold is performed every iteration and they are printed.
        :param print_gradient_info: if True, this function prints diagnostic gradient information
        :param print_gradient: if True, this function prints the full gradient every minibatch
        :param log_parameters: list of parameter names (from self.ps) that should be logged every iteration
                               to file params.out
        :param plot_logged_parameters: if True, logged parameters are plotted to params.png
        :param print_logged_parameters: if True, logged parameters are also printed to standard output
        :param check_gradient_finite: if True, gradient is checked for infs and nans in every iteration
        :param log_modules: list of python modules for which version or latest git commit should be logged
        :return: ParameterHistory object of training
        """

        max_iters = self.cfg.max_iters if self.cfg.has('max_iters') else None

        # build gradient preprocessing chain
        grad_func = self.mb_loss_grad

        # gradient magnitude cap
        if self.cfg.has('gradient_cap'):
            grad_with_cap_orig_grad = grad_func

            def grad_with_cap(pv):
                g = grad_with_cap_orig_grad(pv)
                g_mag = xp.sqrt(xp.sum(g**2))
                if g_mag > self.cfg.gradient_cap:
                    print "gradient magnitude %f is being rescaled" % g_mag
                    g *= self.cfg.gradient_cap / g_mag
                return g

            grad_func = grad_with_cap

        # gradient element cap
        if self.cfg.has('gradient_element_cap'):
            grad_with_element_cap_orig_grad = grad_func

            def grad_with_element_cap(pv):
                g = grad_with_element_cap_orig_grad(pv)
                elems = xp.where(xp.abs(g) > self.cfg.gradient_element_cap)
                g[elems] = xp.sign(g[elems]) * self.cfg.gradient_element_cap
                return g

            grad_func = grad_with_element_cap

        # gradient of constants set to zero
        grad_without_const_orig_grad = grad_func

        def grad_without_const(pv):
            g = grad_without_const_orig_grad(pv)
            return self.ps.nullify_gradient_of_constants(g)

        grad_func = grad_without_const

        # initialize or restore checkpoint, if available
        if not checkpoint:
            itr = 0
            if initialize:
                self.init_parameters()

            his = ParameterHistory(
                cfg=self.cfg,
                state_dir=cfg_dir,
                max_iters=max_iters,
                max_missed_val_improvements=max_missed_val_improvements,
                min_improvement=min_improvement,
                desired_loss=desired_loss,
                iteration_gain=iteration_gain)
            logger = ParameterLogger(out_dir=cfg_dir,
                                     parameters=log_parameters,
                                     plot=plot_logged_parameters,
                                     print_stdout=print_logged_parameters)
            git_log(modules=log_modules, log_dir=cfg_dir)

            # Record initial loss and parameters
            self.record_loss(his, itr)
            logger.log(itr, self.ps)
        else:
            itr = checkpoint['iter']
            self.ps.data[:] = post(checkpoint['data'])
            if 'optimizer_step_rate' in checkpoint:
                self.cfg.optimizer_step_rate = checkpoint[
                    'optimizer_step_rate']

            his = checkpoint['his']
            his.state_dir = cfg_dir
            his.max_missed_val_improvements = max_missed_val_improvements
            his.desired_loss = desired_loss
            his.iteration_gain = iteration_gain

            # start and endtimes in his should have the same length, this is
            # not the case in explicit auto-save checkpoints, therefore set the
            # end to the saved
            if len(his.start_time) != len(his.end_time):
                his.end_time.append(checkpoint['save_time'])
            his.start()

            logger = checkpoint['logger']
            git_log(modules=log_modules, log_dir=cfg_dir, check=True)

        # reset termination criteria if requested
        second_chance_file = join(cfg_dir, "2nd_chance")
        if exists(second_chance_file):
            print "Resetting termination criteria because %s is present" % second_chance_file
            reset_termination_criteria = True
            unlink(second_chance_file)
        if self.cfg.continue_training:
            print "Resetting termination criteria because --continue flag was specified"
            reset_termination_criteria = True
        if reset_termination_criteria:
            his.reset_best()

        if 'step_element_cap' in dir(self.cfg):
            step_element_cap_orig = self.cfg.step_element_cap
        step_element_cap_decrease_iteration = None

        restart = True
        while restart and (max_iters is None or max_iters > 0):
            # create optimizer
            if isinstance(self.cfg.optimizer, dict):

                def wrt_fprime_for_part(partition):
                    wrt_for_part = self.ps.num_partition(partition)

                    def fprime_for_part(pv_part):
                        # we assume that the optimizer updates the ParameterSet inplace and
                        # evaluates the gradient at the current values of the parameters
                        start, stop = self.ps.extents_of_partition(partition)
                        return grad_func(self.ps.num_data)[start:stop]

                    return wrt_for_part, fprime_for_part

                opts_obj = optimizers_from_cfg(self.cfg, wrt_fprime_for_part,
                                               self.mb_loss)
                opts = {
                    part: iter(opt_obj)
                    for part, opt_obj in opts_obj.iteritems()
                }
                partioned_opt = True

                opt_parts = set(opts.keys())
                ps_parts = set(self.ps.partitions)
                if opt_parts != ps_parts:
                    raise ValueError(
                        "optimizer config does not cover all ParameterSet partitions or vice versa: %s"
                        % repr(opt_parts ^ ps_parts))
            else:
                opt = iter(
                    optimizer_from_cfg(self.cfg, self.ps.data, self.mb_loss,
                                       grad_func))
                partioned_opt = False

            # do training
            self.ps.restore_constants()
            last_pars = xp.copy(self.ps.data)
            while not his.should_terminate:
                # call optimizer(s)
                if partioned_opt:
                    for part, opt in opts.iteritems():
                        # print "optimizing %s" % part
                        opt.next()
                else:
                    opt.next()

                # element change cap
                if self.cfg.has('step_element_cap'):
                    d = self.ps.data - last_pars
                    if isinstance(self.cfg.step_element_cap, dict):
                        for par, lim in self.cfg.step_element_cap.iteritems():
                            start, stop = self.ps.extents_of_var(par)
                            dpar = d[start:stop]  # dpar is a subview of d
                            # print "parameter diff for %s is %s (limit is %.4f)" % (par, repr(dpar), lim)
                            elems = xp.where(xp.abs(dpar) > lim)
                            dpar[elems] = xp.sign(dpar[elems]) * lim
                    elif isinstance(self.cfg.step_element_cap, (float, int)):
                        lim = float(self.cfg.step_element_cap)
                        elems = xp.where(xp.abs(d) > lim)
                        d[elems] = xp.sign(d[elems]) * lim
                    else:
                        raise TypeError(
                            "cfg.step_element_cap must either be a dict or a float"
                        )
                    self.ps.data[:] = last_pars + d
                    last_pars = xp.copy(self.ps.data)

                # parameter printout
                if print_parameters:
                    pars = gather(self.ps.data)
                    pars_var = self.ps.split(pars)
                    print "parameters at iteration %d:" % itr
                    for name, value in pars_var.iteritems():
                        print "%10s: %s" % (name, repr(list(value)))

                # obtain gradient if required for debugging operations
                if large_gradient_threshold > 0 or print_gradient_info or print_gradient:
                    gradient = gather(grad_func(self.ps.num_data))
                else:
                    gradient = None

                # check gradient for large elements
                if large_gradient_threshold > 0:
                    lgv = self.ps.find_large_elements(
                        gradient, threshold=large_gradient_threshold)
                    if len(lgv) > 0:
                        print "parameters with large gradient: "
                        for (var, idx), value in lgv.itervalues():
                            print "                                %s[%d] = %.3f" % (
                                var, idx, value)

                # gradient magnitude printout
                if print_gradient_info:
                    gradient_magnitude = np.sqrt(np.sum(gradient**2))
                    print "|gradient| = %.3f" % gradient_magnitude

                # gradient printout
                if print_gradient:
                    gradient_var = self.ps.split(gradient)
                    print "gradient at iteration %d:" % itr
                    for name, value in gradient_var.iteritems():
                        print "%10s: %s" % (name, repr(list(value)))

                # check gradient for NaNs and Infs
                if check_gradient_finite or gradient is not None:
                    if not np.all(np.isfinite(gradient)):
                        his.should_terminate = True
                        his.termination_reason = 'inf_or_nan_gradient'
                        break

                if self.next_minibatch():
                    # iteration finished
                    self.after_iteration(his, itr)

                    itr += 1

                    # log parameters
                    logger.log(itr, self.ps)

                    # calculate losses
                    if itr % loss_record_interval == 0:
                        self.record_loss(his, itr)

                    if step_element_cap_decrease_iteration is not None:
                        if 'step_element_cap_restore_iterations' in dir(
                                self.cfg):
                            restore_itrs = self.cfg.step_element_cap_restore_iterations
                        else:
                            restore_itrs = 100
                        if itr >= step_element_cap_decrease_iteration + restore_itrs:
                            self.cfg.step_element_cap = step_element_cap_orig
                            print "Restored step element cap to %g" % self.cfg.step_element_cap
                            step_element_cap_decrease_iteration = None

                # save checkpoint if necessary
                if checkpoint_handler is not None:
                    if checkpoint_handler.requested:
                        his.stop()
                        checkpoint_handler.save(
                            data=gather(self.ps.data),
                            his=his,
                            iter=itr,
                            logger=logger,
                            optimizer_step_rate=self.cfg.optimizer_step_rate)
                    if his.should_save_checkpoint:
                        checkpoint_handler.save(
                            data=gather(self.ps.data),
                            his=his,
                            iter=itr,
                            logger=logger,
                            optimizer_step_rate=self.cfg.optimizer_step_rate,
                            explicit=True)
                        his.checkpoint_saved()

            # restore best parametes
            self.ps.data[:] = his.best_pars

            # check for retry conditions
            restart = False

            # temporarily reduce step element cap to move over regions with very large gradient
            if (his.should_terminate
                    and his.termination_reason == 'nan_or_inf_loss'
                    and 'step_element_cap' in dir(self.cfg)
                    and 'step_element_cap_min' in dir(self.cfg)
                    and self.cfg.step_element_cap >=
                    self.cfg.step_element_cap_min):
                self.cfg.step_element_cap /= 10.
                step_element_cap_decrease_iteration = itr
                print "Reduced step element cap to %g" % self.cfg.step_element_cap
                his.should_terminate = False
                restart = True

            # advance learning rate schedule
            if (his.should_terminate and his.termination_reason in [
                    'no_improvement', 'nan_or_inf_loss',
                    'user_learning_rate_decrease'
            ] and 'optimizer_step_rate_min' in dir(self.cfg)
                    and self.cfg.optimizer_step_rate / 10. >=
                    self.cfg.optimizer_step_rate_min):
                self.cfg.optimizer_step_rate /= 10.
                print "Decaying optimizer step rate to %g" % self.cfg.optimizer_step_rate
                his.should_terminate = False
                his.last_val_improvement = itr
                restart = True

        # training finished
        self.after_training(his)

        # save results and plot loss
        if checkpoint_handler:
            his.stop()
            checkpoint_handler.save(
                data=gather(self.ps.data),
                his=his,
                iter=itr,
                logger=logger,
                optimizer_step_rate=self.cfg.optimizer_step_rate,
                explicit=True)
        his.finish()
        logger.plot()

        return his
示例#4
0
    def generic_training(self, cfg_dir, checkpoint=None, checkpoint_handler=None, loss_record_interval=10,
                         max_missed_val_improvements=200, iteration_gain=1.25, min_improvement=1e-7,
                         reset_termination_criteria=False, desired_loss=None, initialize=True,
                         large_gradient_threshold=0.0, print_gradient_info=False, print_gradient=False,
                         print_parameters=False, log_parameters=[], plot_logged_parameters=True,
                         print_logged_parameters=False, check_gradient_finite=False, log_modules=None):
        """
        Generic training procedure.
        :param cfg_dir: configuration directory
        :param checkpoint: checkpoint
        :param checkpoint_handler: checkpoint handler
        :param loss_record_interval: number of iterations between calculating and recording losses
        :param max_missed_val_improvements: maximum iterations without improvement of loss before training ist stopped
        :param iteration_gain: If not set to 0, then training is performed up to iteration
                               (iteration_gain * iteration_of_last_improvement).
        :param min_improvement: minimum loss change to count as improvement
        :param reset_termination_criteria: resets the termination criteria after loading a checkpoint
        :param desired_loss: if specified, training is terminated with this loss is reached
        :param initialize: if True, the model parameters are initialized using the init_parameters method.
        :param large_gradient_threshold: if specified, a check for large gradient elements that exceed the
                                         given threshold is performed every iteration and they are printed.
        :param print_gradient_info: if True, this function prints diagnostic gradient information
        :param print_gradient: if True, this function prints the full gradient every minibatch
        :param log_parameters: list of parameter names (from self.ps) that should be logged every iteration
                               to file params.out
        :param plot_logged_parameters: if True, logged parameters are plotted to params.png
        :param print_logged_parameters: if True, logged parameters are also printed to standard output
        :param check_gradient_finite: if True, gradient is checked for infs and nans in every iteration
        :param log_modules: list of python modules for which version or latest git commit should be logged
        :return: ParameterHistory object of training
        """

        max_iters = self.cfg.max_iters if self.cfg.has('max_iters') else None

        # build gradient preprocessing chain
        grad_func = self.mb_loss_grad

        # gradient magnitude cap
        if self.cfg.has('gradient_cap'):
            grad_with_cap_orig_grad = grad_func
            def grad_with_cap(pv):
                g = grad_with_cap_orig_grad(pv)
                g_mag = xp.sqrt(xp.sum(g**2))
                if g_mag > self.cfg.gradient_cap:
                    print "gradient magnitude %f is being rescaled" % g_mag
                    g *= self.cfg.gradient_cap / g_mag
                return g
            grad_func = grad_with_cap

        # gradient element cap
        if self.cfg.has('gradient_element_cap'):
            grad_with_element_cap_orig_grad = grad_func
            def grad_with_element_cap(pv):
                g = grad_with_element_cap_orig_grad(pv)
                elems = xp.where(xp.abs(g) > self.cfg.gradient_element_cap)
                g[elems] = xp.sign(g[elems]) * self.cfg.gradient_element_cap
                return g
            grad_func = grad_with_element_cap

        # gradient of constants set to zero
        grad_without_const_orig_grad = grad_func
        def grad_without_const(pv):
            g = grad_without_const_orig_grad(pv)
            return self.ps.nullify_gradient_of_constants(g)
        grad_func = grad_without_const

        # initialize or restore checkpoint, if available
        if not checkpoint:
            itr = 0
            if initialize:
                self.init_parameters()

            his = ParameterHistory(cfg=self.cfg, state_dir=cfg_dir, max_iters=max_iters,
                                   max_missed_val_improvements=max_missed_val_improvements,
                                   min_improvement=min_improvement,
                                   desired_loss=desired_loss, iteration_gain=iteration_gain)
            logger = ParameterLogger(out_dir=cfg_dir, parameters=log_parameters,
                                     plot=plot_logged_parameters, print_stdout=print_logged_parameters)
            git_log(modules=log_modules, log_dir=cfg_dir)

            # Record initial loss and parameters
            self.record_loss(his, itr)
            logger.log(itr, self.ps)
        else:
            itr = checkpoint['iter']
            self.ps.data[:] = post(checkpoint['data'])
            if 'optimizer_step_rate' in checkpoint:
                self.cfg.optimizer_step_rate = checkpoint['optimizer_step_rate']

            his = checkpoint['his']
            his.state_dir = cfg_dir
            his.max_missed_val_improvements = max_missed_val_improvements
            his.desired_loss = desired_loss
            his.iteration_gain = iteration_gain

            # start and endtimes in his should have the same length, this is
            # not the case in explicit auto-save checkpoints, therefore set the
            # end to the saved
            if len(his.start_time) != len(his.end_time):
                his.end_time.append(checkpoint['save_time'])
            his.start()

            logger = checkpoint['logger']
            git_log(modules=log_modules, log_dir=cfg_dir, check=True)

        # reset termination criteria if requested
        second_chance_file = join(cfg_dir, "2nd_chance")
        if exists(second_chance_file):
            print "Resetting termination criteria because %s is present" % second_chance_file
            reset_termination_criteria = True
            unlink(second_chance_file)
        if self.cfg.continue_training:
            print "Resetting termination criteria because --continue flag was specified"
            reset_termination_criteria = True
        if reset_termination_criteria:
            his.reset_best()

        if 'step_element_cap' in dir(self.cfg):
            step_element_cap_orig = self.cfg.step_element_cap
        step_element_cap_decrease_iteration = None

        restart = True
        while restart and (max_iters is None or max_iters > 0):
            # create optimizer
            if isinstance(self.cfg.optimizer, dict):
                def wrt_fprime_for_part(partition):
                    wrt_for_part = self.ps.num_partition(partition)
                    def fprime_for_part(pv_part):
                        # we assume that the optimizer updates the ParameterSet inplace and
                        # evaluates the gradient at the current values of the parameters
                        start, stop = self.ps.extents_of_partition(partition)
                        return grad_func(self.ps.num_data)[start : stop]
                    return wrt_for_part, fprime_for_part
                opts_obj = optimizers_from_cfg(self.cfg, wrt_fprime_for_part, self.mb_loss)
                opts = {part: iter(opt_obj) for part, opt_obj in opts_obj.iteritems()}
                partioned_opt = True

                opt_parts = set(opts.keys())
                ps_parts = set(self.ps.partitions)
                if opt_parts != ps_parts:
                    raise ValueError("optimizer config does not cover all ParameterSet partitions or vice versa: %s" %
                                     repr(opt_parts ^ ps_parts))
            else:
                opt = iter(optimizer_from_cfg(self.cfg, self.ps.data, self.mb_loss, grad_func))
                partioned_opt = False

            # do training
            self.ps.restore_constants()
            last_pars = xp.copy(self.ps.data)
            while not his.should_terminate:
                # call optimizer(s)
                if partioned_opt:
                    for part, opt in opts.iteritems():
                        # print "optimizing %s" % part
                        opt.next()
                else:
                    opt.next()

                # element change cap
                if self.cfg.has('step_element_cap'):
                    d = self.ps.data - last_pars
                    if isinstance(self.cfg.step_element_cap, dict):
                        for par, lim in self.cfg.step_element_cap.iteritems():
                            start, stop = self.ps.extents_of_var(par)
                            dpar = d[start:stop]    # dpar is a subview of d
                            # print "parameter diff for %s is %s (limit is %.4f)" % (par, repr(dpar), lim)
                            elems = xp.where(xp.abs(dpar) > lim)
                            dpar[elems] = xp.sign(dpar[elems]) * lim
                    elif isinstance(self.cfg.step_element_cap, (float, int)):
                        lim = float(self.cfg.step_element_cap)
                        elems = xp.where(xp.abs(d) > lim)
                        d[elems] = xp.sign(d[elems]) * lim
                    else:
                        raise TypeError("cfg.step_element_cap must either be a dict or a float")
                    self.ps.data[:] = last_pars + d
                    last_pars = xp.copy(self.ps.data)

                # parameter printout
                if print_parameters:
                    pars = gather(self.ps.data)
                    pars_var = self.ps.split(pars)
                    print "parameters at iteration %d:" % itr
                    for name, value in pars_var.iteritems():
                        print "%10s: %s" % (name, repr(list(value)))

                # obtain gradient if required for debugging operations
                if large_gradient_threshold > 0 or print_gradient_info or print_gradient:
                    gradient = gather(grad_func(self.ps.num_data))
                else:
                    gradient = None

                # check gradient for large elements
                if large_gradient_threshold > 0:
                    lgv = self.ps.find_large_elements(gradient, threshold=large_gradient_threshold)
                    if len(lgv) > 0:
                        print "parameters with large gradient: "
                        for (var, idx), value in lgv.itervalues():
                            print "                                %s[%d] = %.3f" % (var, idx, value)

                # gradient magnitude printout
                if print_gradient_info:
                    gradient_magnitude = np.sqrt(np.sum(gradient ** 2))
                    print "|gradient| = %.3f" % gradient_magnitude

                # gradient printout
                if print_gradient:
                    gradient_var = self.ps.split(gradient)
                    print "gradient at iteration %d:" % itr
                    for name, value in gradient_var.iteritems():
                        print "%10s: %s" % (name, repr(list(value)))

                # check gradient for NaNs and Infs
                if check_gradient_finite or gradient is not None:
                    if not np.all(np.isfinite(gradient)):
                        his.should_terminate = True
                        his.termination_reason = 'inf_or_nan_gradient'
                        break

                if self.next_minibatch():
                    # iteration finished
                    self.after_iteration(his, itr)

                    itr += 1

                    # log parameters
                    logger.log(itr, self.ps)

                    # calculate losses
                    if itr % loss_record_interval == 0:
                        self.record_loss(his, itr)

                    if step_element_cap_decrease_iteration is not None:
                        if 'step_element_cap_restore_iterations' in dir(self.cfg):
                            restore_itrs = self.cfg.step_element_cap_restore_iterations
                        else:
                            restore_itrs = 100
                        if itr >= step_element_cap_decrease_iteration + restore_itrs:
                            self.cfg.step_element_cap = step_element_cap_orig
                            print "Restored step element cap to %g" % self.cfg.step_element_cap
                            step_element_cap_decrease_iteration = None

                # save checkpoint if necessary
                if checkpoint_handler is not None:
                    if checkpoint_handler.requested:
                        his.stop()
                        checkpoint_handler.save(data=gather(self.ps.data), his=his, iter=itr, logger=logger,
                                                optimizer_step_rate=self.cfg.optimizer_step_rate)
                    if his.should_save_checkpoint:
                        checkpoint_handler.save(data=gather(self.ps.data), his=his, iter=itr, logger=logger,
                                                optimizer_step_rate=self.cfg.optimizer_step_rate,
                                                explicit=True)
                        his.checkpoint_saved()

            # restore best parametes
            self.ps.data[:] = his.best_pars

            # check for retry conditions
            restart = False

            # temporarily reduce step element cap to move over regions with very large gradient
            if (his.should_terminate and his.termination_reason == 'nan_or_inf_loss' and
                    'step_element_cap' in dir(self.cfg) and 'step_element_cap_min' in dir(self.cfg) and
                    self.cfg.step_element_cap >= self.cfg.step_element_cap_min):
                self.cfg.step_element_cap /= 10.
                step_element_cap_decrease_iteration = itr
                print "Reduced step element cap to %g" % self.cfg.step_element_cap
                his.should_terminate = False
                restart = True

            # advance learning rate schedule
            if (his.should_terminate and
                    his.termination_reason in ['no_improvement', 'nan_or_inf_loss', 'user_learning_rate_decrease'] and
                    'optimizer_step_rate_min' in dir(self.cfg) and
                    self.cfg.optimizer_step_rate / 10. >= self.cfg.optimizer_step_rate_min):
                self.cfg.optimizer_step_rate /= 10.
                print "Decaying optimizer step rate to %g" % self.cfg.optimizer_step_rate
                his.should_terminate = False
                his.last_val_improvement = itr
                restart = True

        # training finished
        self.after_training(his)

        # save results and plot loss
        if checkpoint_handler:
            his.stop()
            checkpoint_handler.save(data=gather(self.ps.data), his=his, iter=itr, logger=logger,
                                    optimizer_step_rate=self.cfg.optimizer_step_rate,
                                    explicit=True)
        his.finish()
        logger.plot()

        return his