Example #1
0
 def train(self):
     epoch_time = AverageMeter()
     end = time.time()
     self.model.train()
     for epoch in range(self.start_epoch, self.run_config.epochs):
         self.logger.log('\n' + '-' * 30 +
                         'Train epoch: {}'.format(epoch + 1) + '-' * 30 +
                         '\n',
                         mode='retrain')
         epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1,
                                                   self.run_config.epochs)
         self.scheduler.step(epoch)
         train_lr = self.scheduler.get_lr()
         time_left = epoch_time.average * (self.run_config.epochs - epoch)
         common_log = '[Train the {:}] Left={:} LR={:}'.format(
             epoch_str,
             str(timedelta(seconds=time_left)) if epoch != 0 else None,
             train_lr)
         self.logger.log(common_log, 'retrain')
         # train log have been wrote in train_one_epoch
         loss, acc, miou, fscore = self.train_one_epoch(epoch)
         epoch_time.update(time.time() - end)
         end = time.time()
         # perform validation at the end of each epoch.
         val_loss, val_acc, val_miou, val_fscore = self.validate(
             epoch, is_test=False, use_train_mode=False)
         val_monitor_metric = get_monitor_metric(self.monitor_metric,
                                                 val_loss, val_acc,
                                                 val_miou, val_fscore)
         is_best = val_monitor_metric > self.best_monitor
         self.best_monitor = max(self.best_monitor, val_monitor_metric)
         # update visdom
         if self.vis is not None:
             self.vis.visdom_update(epoch, 'loss', [loss, val_loss])
             self.vis.visdom_update(epoch, 'accuracy', [acc, val_acc])
             self.vis.visdom_update(epoch, 'miou', [miou, val_miou])
             self.vis.visdom_update(epoch, 'f1score', [fscore, val_fscore])
         # save checkpoint
         if (epoch + 1) % self.run_config.save_ckpt_freq == 0 or (
                 epoch + 1) == self.run_config.epochs or is_best:
             checkpoint = {
                 'state_dict': self.model.state_dict(),
                 'weight_optimizer': self.optimizer.state_dict(),
                 'weight_scheduler': self.scheduler.state_dict(),
                 'best_monitor': (self.monitor_metric, self.best_monitor),
                 'start_epoch': epoch + 1,
                 'actual_path': self.run_config.actual_path,
                 'cell_genotypes': self.run_config.cell_genotypes
             }
             filename = self.logger.path(mode='retrain', is_best=is_best)
             save_checkpoint(checkpoint,
                             filename,
                             self.logger,
                             mode='retrain')
Example #2
0
    def train(self, fix_net_weights=False):

        # have config valid_batch_size, and ignored drop_last.
        data_loader = self.run_manager.run_config.train_loader
        iter_per_epoch = len(data_loader)
        total_iteration = iter_per_epoch * self.run_manager.run_config.epochs

        if fix_net_weights:  # used to debug
            data_loader = [(0, 0)] * iter_per_epoch
            print('Train Phase close for debug')

        # arch_parameter update frequency and times in each iteration.
        #update_schedule = self.arch_search_config.get_update_schedule(iter_per_epoch)

        # pay attention here, total_epochs include warmup epochs
        epoch_time = AverageMeter()
        end_epoch = time.time()
        # TODO : use start_epochs
        for epoch in range(self.start_epoch,
                           self.run_manager.run_config.epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Train Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='search')

            self.run_manager.scheduler.step(epoch)
            train_lr = self.run_manager.scheduler.get_lr()
            arch_lr = self.arch_optimizer.param_groups[0]['lr']
            self.net.set_tau(self.arch_search_config.tau_max -
                             (self.arch_search_config.tau_max -
                              self.arch_search_config.tau_min) * (epoch) /
                             (self.run_manager.run_config.epochs))
            tau = self.net.get_tau()
            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            #valid_data_time = AverageMeter()
            valid_losses = AverageMeter()
            valid_accs = AverageMeter()
            valid_mious = AverageMeter()
            valid_fscores = AverageMeter()

            self.net.train()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            time_left = epoch_time.average * (
                self.run_manager.run_config.epochs - epoch)
            common_log = '[*Train-Search* the {:}] Left={:} WLR={:} ALR={:} tau={:}'\
                .format(epoch_str, str(timedelta(seconds=time_left)) if epoch != 0 else None, train_lr, arch_lr, tau)
            self.logger.log(common_log, 'search')

            end = time.time()
            for i, (datas, targets) in enumerate(data_loader):
                #if i == 29: break
                if not fix_net_weights:
                    if torch.cuda.is_available():
                        datas = datas.to(self.run_manager.device,
                                         non_blocking=True)
                        targets = targets.to(self.run_manager.device,
                                             non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')
                    data_time.update(time.time() - end)

                    # get single_path in each iteration
                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(datas, single_path)

                    # loss
                    loss = self.run_manager.criterion(logits, targets)
                    # metrics and update
                    evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    evaluator.add_batch(targets, logits)
                    acc = evaluator.Pixel_Accuracy()
                    miou = evaluator.Mean_Intersection_over_Union()
                    fscore = evaluator.Fx_Score()
                    losses.update(loss.data.item(), datas.size(0))
                    accs.update(acc.item(), datas.size(0))
                    mious.update(miou.item(), datas.size(0))
                    fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()

                    self.run_manager.optimizer.step()

                    #end_valid = time.time()
                    valid_datas, valid_targets = self.run_manager.run_config.valid_next_batch
                    if torch.cuda.is_available():
                        valid_datas = valid_datas.to(self.run_manager.device,
                                                     non_blocking=True)
                        valid_targets = valid_targets.to(
                            self.run_manager.device, non_blocking=True)
                    else:
                        raise ValueError('do not support cpu version')

                    _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                    )
                    _, aspp_index = self.net.get_aspp_hardwts_index()
                    single_path = self.net.sample_single_path(
                        self.run_manager.run_config.nb_layers, aspp_index,
                        network_index)
                    logits = self.net.single_path_forward(
                        valid_datas, single_path)

                    loss = self.run_manager.criterion(logits, valid_targets)

                    # metrics and update
                    valid_evaluator = Evaluator(
                        self.run_manager.run_config.nb_classes)
                    valid_evaluator.add_batch(valid_targets, logits)
                    acc = valid_evaluator.Pixel_Accuracy()
                    miou = valid_evaluator.Mean_Intersection_over_Union()
                    fscore = valid_evaluator.Fx_Score()
                    valid_losses.update(loss.data.item(), datas.size(0))
                    valid_accs.update(acc.item(), datas.size(0))
                    valid_mious.update(miou.item(), datas.size(0))
                    valid_fscores.update(fscore.item(), datas.size(0))

                    self.net.zero_grad()
                    loss.backward()
                    #if (i+1) % 5 == 0:
                    #    print('network_arch_parameters.grad in search phase:\n',
                    #                         self.net.network_arch_parameters.grad)
                    self.arch_optimizer.step()

                    # batch_time of one iter of train and valid.
                    batch_time.update(time.time() - end)
                    end = time.time()
                    if (
                            i + 1
                    ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                        Wstr = '|*Search*|' + time_string(
                        ) + '[{:}][iter{:03d}/{:03d}]'.format(
                            epoch_str, i + 1, iter_per_epoch)
                        Tstr = '|Time    | {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                            batch_time=batch_time, data_time=data_time)
                        Bstr = '|Base    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=losses, acc=accs, miou=mious, fscore=fscores)
                        Astr = '|Arch    | [Loss {loss.val:.3f} ({loss.avg:.3f}) Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'.format(
                            loss=valid_losses,
                            acc=valid_accs,
                            miou=valid_mious,
                            fscore=valid_fscores)
                        self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr +
                                        '\n' + Astr,
                                        mode='search')

            _, network_index = self.net.get_network_arch_hardwts_with_constraint(
            )  # set self.hardwts again
            _, aspp_index = self.net.get_aspp_hardwts_index()
            single_path = self.net.sample_single_path(
                self.run_manager.run_config.nb_layers, aspp_index,
                network_index)
            cell_arch_entropy, network_arch_entropy, total_entropy = self.net.calculate_entropy(
                single_path)

            # update visdom
            if self.vis is not None:
                self.vis.visdom_update(epoch, 'loss',
                                       [losses.average, valid_losses.average])
                self.vis.visdom_update(epoch, 'accuracy',
                                       [accs.average, valid_accs.average])
                self.vis.visdom_update(epoch, 'miou',
                                       [mious.average, valid_mious.average])
                self.vis.visdom_update(
                    epoch, 'f1score', [fscores.average, valid_fscores.average])

                self.vis.visdom_update(epoch, 'cell_entropy',
                                       [cell_arch_entropy])
                self.vis.visdom_update(epoch, 'network_entropy',
                                       [network_arch_entropy])
                self.vis.visdom_update(epoch, 'entropy', [total_entropy])

            #torch.cuda.empty_cache()
            # update epoch_time
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()

            epoch_str = '{:03d}/{:03d}'.format(
                epoch + 1, self.run_manager.run_config.epochs)
            log = '[{:}] train :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n' \
                  '[{:}] valid :: loss={:.2f} accuracy={:.2f} miou={:.2f} f1score={:.2f}\n'.format(
                epoch_str, losses.average, accs.average, mious.average, fscores.average,
                epoch_str, valid_losses.average, valid_accs.average, valid_mious.average, valid_fscores.average
            )
            self.logger.log(log, mode='search')

            self.logger.log(
                '<<<---------->>> Super Network decoding <<<---------->>> ',
                mode='search')
            actual_path, cell_genotypes = self.net.network_cell_arch_decode()
            #print(cell_genotypes)
            new_genotypes = []
            for _index, genotype in cell_genotypes:
                xlist = []
                print(_index, genotype)
                for edge_genotype in genotype:
                    for (node_str, select_index) in edge_genotype:
                        xlist.append((node_str, self.run_manager.run_config.
                                      conv_candidates[select_index]))
                new_genotypes.append((_index, xlist))
            log_str = 'The {:} decode network:\n' \
                      'actual_path = {:}\n' \
                      'genotype:'.format(epoch_str, actual_path)
            for _index, genotype in new_genotypes:
                log_str += 'index: {:} arch: {:}\n'.format(_index, genotype)
            self.logger.log(log_str, mode='network_space', display=False)

            # TODOļ¼š perform save the best network ckpt
            # 1. save network_arch_parameters and cell_arch_parameters
            # 2. save weight_parameters
            # 3. weight_optimizer.state_dict
            # 4. arch_optimizer.state_dict
            # 5. training process
            # 6. monitor_metric and the best_value
            # get best_monitor in valid phase.
            val_monitor_metric = get_monitor_metric(
                self.run_manager.monitor_metric, valid_losses.average,
                valid_accs.average, valid_mious.average, valid_fscores.average)
            is_best = self.run_manager.best_monitor < val_monitor_metric
            self.run_manager.best_monitor = max(self.run_manager.best_monitor,
                                                val_monitor_metric)
            # 1. if is_best : save_current_ckpt
            # 2. if can be divided : save_current_ckpt

            #self.run_manager.save_model(epoch, {
            #    'arch_optimizer': self.arch_optimizer.state_dict(),
            #}, is_best=True, checkpoint_file_name=None)
            # TODO: have modification on checkpoint_save semantics
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch +
                        1) == self.run_manager.run_config.epochs or is_best:
                checkpoint = {
                    'state_dict':
                    self.net.state_dict(),
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.scheduler.state_dict(),
                    'arch_optimizer':
                    self.arch_optimizer.state_dict(),
                    'best_monitor': (self.run_manager.monitor_metric,
                                     self.run_manager.best_monitor),
                    'warmup':
                    False,
                    'start_epochs':
                    epoch + 1,
                }
                checkpoint_arch = {
                    'actual_path': actual_path,
                    'cell_genotypes': cell_genotypes,
                }
                filename = self.logger.path(mode='search', is_best=is_best)
                filename_arch = self.logger.path(mode='arch', is_best=is_best)
                save_checkpoint(checkpoint,
                                filename,
                                self.logger,
                                mode='search')
                save_checkpoint(checkpoint_arch,
                                filename_arch,
                                self.logger,
                                mode='arch')
Example #3
0
    def warm_up(self, warmup_epochs):
        if warmup_epochs <= 0:
            self.logger.log('=> warmup close', mode='warm')
            #print('\twarmup close')
            return
        # set optimizer and scheduler in warm_up phase
        lr_max = self.arch_search_config.warmup_lr
        data_loader = self.run_manager.run_config.train_loader
        scheduler_params = self.run_manager.run_config.optimizer_config[
            'scheduler_params']
        optimizer_params = self.run_manager.run_config.optimizer_config[
            'optimizer_params']
        momentum, nesterov, weight_decay = optimizer_params[
            'momentum'], optimizer_params['nesterov'], optimizer_params[
                'weight_decay']
        eta_min = scheduler_params['eta_min']
        optimizer_warmup = torch.optim.SGD(self.net.weight_parameters(),
                                           lr_max,
                                           momentum,
                                           weight_decay=weight_decay,
                                           nesterov=nesterov)
        # set initial_learning_rate in weight_optimizer
        #for param_groups in self.run_manager.optimizer.param_groups:
        #    param_groups['lr'] = lr_max
        lr_scheduler_warmup = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer_warmup, warmup_epochs, eta_min)
        iter_per_epoch = len(data_loader)
        total_iteration = warmup_epochs * iter_per_epoch

        self.logger.log('=> warmup begin', mode='warm')

        epoch_time = AverageMeter()
        end_epoch = time.time()
        for epoch in range(self.warmup_epoch, warmup_epochs):
            self.logger.log('\n' + '-' * 30 +
                            'Warmup Epoch: {}'.format(epoch + 1) + '-' * 30 +
                            '\n',
                            mode='warm')

            lr_scheduler_warmup.step(epoch)
            warmup_lr = lr_scheduler_warmup.get_lr()
            self.net.train()

            batch_time = AverageMeter()
            data_time = AverageMeter()
            losses = AverageMeter()
            accs = AverageMeter()
            mious = AverageMeter()
            fscores = AverageMeter()

            epoch_str = 'epoch[{:03d}/{:03d}]'.format(epoch + 1, warmup_epochs)
            time_left = epoch_time.average * (warmup_epochs - epoch)
            common_log = '[Warmup the {:}] Left={:} LR={:}'.format(
                epoch_str,
                str(timedelta(seconds=time_left)) if epoch != 0 else None,
                warmup_lr)
            self.logger.log(common_log, mode='warm')
            end = time.time()

            for i, (datas, targets) in enumerate(data_loader):

                #if i == 29: break

                if torch.cuda.is_available():
                    datas = datas.to(self.run_manager.device,
                                     non_blocking=True)
                    targets = targets.to(self.run_manager.device,
                                         non_blocking=True)
                else:
                    raise ValueError('do not support cpu version')
                data_time.update(time.time() - end)

                # get single_path in each iteration
                _, network_index = self.net.get_network_arch_hardwts_with_constraint(
                )
                _, aspp_index = self.net.get_aspp_hardwts_index()
                single_path = self.net.sample_single_path(
                    self.run_manager.run_config.nb_layers, aspp_index,
                    network_index)
                logits = self.net.single_path_forward(datas, single_path)

                # update without entropy reg in warm_up phase
                loss = self.run_manager.criterion(logits, targets)

                # measure metrics and update
                evaluator = Evaluator(self.run_manager.run_config.nb_classes)
                evaluator.add_batch(targets, logits)
                acc = evaluator.Pixel_Accuracy()
                miou = evaluator.Mean_Intersection_over_Union()
                fscore = evaluator.Fx_Score()
                losses.update(loss.data.item(), datas.size(0))
                accs.update(acc, datas.size(0))
                mious.update(miou, datas.size(0))
                fscores.update(fscore, datas.size(0))

                self.net.zero_grad()
                loss.backward()
                self.run_manager.optimizer.step()

                batch_time.update(time.time() - end)
                end = time.time()
                if (
                        i + 1
                ) % self.run_manager.run_config.train_print_freq == 0 or i + 1 == iter_per_epoch:
                    Wstr = '|*WARM-UP*|' + time_string(
                    ) + '[{:}][iter{:03d}/{:03d}]'.format(
                        epoch_str, i + 1, iter_per_epoch)
                    Tstr = '|Time     | [{batch_time.val:.2f} ({batch_time.avg:.2f})  Data {data_time.val:.2f} ({data_time.avg:.2f})]'.format(
                        batch_time=batch_time, data_time=data_time)
                    Bstr = '|Base     | [Loss {loss.val:.3f} ({loss.avg:.3f})  Accuracy {acc.val:.2f} ({acc.avg:.2f}) MIoU {miou.val:.2f} ({miou.avg:.2f}) F {fscore.val:.2f} ({fscore.avg:.2f})]'\
                        .format(loss=losses, acc=accs, miou=mious, fscore=fscores)
                    self.logger.log(Wstr + '\n' + Tstr + '\n' + Bstr, 'warm')

            #torch.cuda.empty_cache()
            epoch_time.update(time.time() - end_epoch)
            end_epoch = time.time()
            '''
            # TODO: wheter perform validation after each epoch in warmup phase ?
            valid_loss, valid_acc, valid_miou, valid_fscore = self.validate()
            valid_log = 'Warmup Valid\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tF\t{5:6.4f}'\
                .format(epoch+1, warmup_epochs, valid_loss, valid_acc, valid_miou, valid_fscore)
                        #'\tflops\t{6:}M\tparams{7:}M'\
            valid_log += 'Train\t[{0}/{1}]\tLoss\t{2:.6f}\tAcc\t{3:6.4f}\tMIoU\t{4:6.4f}\tFscore\t{5:6.4f}'
            self.run_manager.write_log(valid_log, 'valid')
            '''

            # continue warmup phrase
            self.warmup = epoch + 1 < warmup_epochs
            self.warmup_epoch = self.warmup_epoch + 1
            #self.start_epoch = self.warmup_epoch
            # To save checkpoint in warmup phase at specific frequency.
            if (epoch +
                    1) % self.run_manager.run_config.save_ckpt_freq == 0 or (
                        epoch + 1) == warmup_epochs:
                state_dict = self.net.state_dict()
                # rm architecture parameters because, in warm_up phase, arch_parameters are not updated.
                #for key in list(state_dict.keys()):
                #    if 'cell_arch_parameters' in key or 'network_arch_parameters' in key or 'aspp_arch_parameters' in key:
                #        state_dict.pop(key)
                checkpoint = {
                    'state_dict': state_dict,
                    'weight_optimizer':
                    self.run_manager.optimizer.state_dict(),
                    'weight_scheduler':
                    self.run_manager.optimizer.state_dict(),
                    'warmup': self.warmup,
                    'warmup_epoch': epoch + 1,
                }
                filename = self.logger.path(mode='warm', is_best=False)
                save_path = save_checkpoint(checkpoint,
                                            filename,
                                            self.logger,
                                            mode='warm')
Example #4
0
    def train(self):

        self.data_train = self.data_loader.read_data_to_tensor(
            self.data_loader._train_path,
            self.device,
            max_size=None,
            normalize_digits=True,
            has_start_node=True)
        self.data_dev = self.data_loader.read_data_to_tensor(
            self.data_loader._eval_paths[0],
            self.device,
            max_size=None,
            normalize_digits=True,
            has_start_node=True)
        self.data_test = self.data_loader.read_data_to_tensor(
            self.data_loader._eval_paths[1],
            self.device,
            max_size=None,
            normalize_digits=True,
            has_start_node=True)

        num_data = sum(self.data_train[1])

        num_batches = num_data // hparams.batch_size + 1

        best_epoch = 0
        best_dev_f1 = 0.0
        early_epoch = 0

        for epoch in range(1, hparams.num_epochs + 1):
            train_err = 0.
            train_data_er = 0.
            train_sample_er = 0.
            train_total = 0.

            start_time = time.time()
            num_back = 0

            for batch in range(1, num_batches + 1):
                word, char, _, _, labels, masks, lengths = self.data_loader.get_batch_tensor(
                    self.data_train,
                    hparams.batch_size,
                    unk_replace=hparams.unk_replace)

                samples = self.sampling(word, char, masks, lengths)
                self.net.train()

                sos_mat = labels.new_zeros(
                    (labels.size(0), 1)).fill_(hparams.sos).long()
                input_label = torch.cat([sos_mat, labels], dim=1)
                input_label = input_label[:, :-1]

                sample_sos_mat = sos_mat.repeat(1, hparams.num_samples).view(
                    -1, hparams.num_samples, 1)
                input_sample = torch.cat([sample_sos_mat, samples], dim=2)
                input_sample = input_sample[:, :, :-1]

                self.optimizer.zero_grad()

                loss, data_loss, sample_loss = self.net.loss(
                    word, char, input_label, input_sample, labels, samples,
                    masks, lengths)
                loss.backward()

                if hparams.clip_grad:
                    nn.utils.clip_grad_norm_(self.net.parameters(),
                                             hparams.clip_grad_norm)
                self.optimizer.step()

                num_inst = masks.size(0)
                train_err += loss.item() * num_inst
                train_data_er += data_loss.item() * num_inst
                train_sample_er += sample_loss.item() * num_inst
                train_total += num_inst

                time_ave = (time.time() - start_time) / batch
                time_left = (num_batches - batch) * time_ave

                # update log
                if batch % 100 == 0:
                    sys.stdout.write("\b" * num_back)
                    sys.stdout.write(" " * num_back)
                    sys.stdout.write("\b" * num_back)
                    log_info = 'train: %d/%d loss: %.4f, time left (estimated): %.2fs' % (
                        batch, num_batches, train_err / train_total, time_left)
                    sys.stdout.write(log_info)
                    sys.stdout.flush()
                    num_back = len(log_info)

            sys.stdout.write("\b" * num_back)
            sys.stdout.write(" " * num_back)
            sys.stdout.write("\b" * num_back)
            self.logger.info(
                'epoch: %d, train: %d, loss: %.4f, data loss: %.4f, sample loss: %.4f, time: %.2fs, lr:%.4f'
                % (epoch, num_batches, train_err / train_total,
                   train_data_er / train_total, train_sample_er / train_total,
                   time.time() - start_time, self.lr))

            # evaluate performance on dev data
            tmp_filename = os.path.join(self.config.tmp_path,
                                        'dev%d' % (epoch))
            f1 = self.eval('dev', epoch, tmp_filename)

            if best_dev_f1 <= f1:

                best_dev_f1 = f1
                best_epoch = epoch
                early_epoch = 0

                save_checkpoint(
                    {
                        'epoch': best_epoch,
                        'state_dict': self.net.state_dict(),
                        'optimizer': self.optimizer.state_dict()
                    },
                    join(self.config.checkpoint_path,
                         "%s_%d" % (self.config.model_name, best_epoch)))

            else:
                if epoch > hparams.early_start:
                    early_epoch += 1

                if early_epoch > hparams.early_patience:
                    self.logger.info("early stopped, best epoch: %d" %
                                     best_epoch)
                    os._exit(0)

            # evaluate performance on test data
            tmp_filename = os.path.join(self.config.tmp_path,
                                        'test%d' % (epoch))
            _ = self.eval('test', epoch, tmp_filename)

            if epoch % hparams.schedule == 0:
                self.lr = hparams.learning_rate / (1.0 +
                                                   epoch * hparams.decay_rate)

                if hparams.optimizer == 'SGD':
                    self.optimizer = optim.SGD(self.net.parameters(),
                                               lr=self.lr,
                                               momentum=0.9,
                                               weight_decay=0.0,
                                               nesterov=True)

                elif hparams.optimizer == 'Adam':
                    self.optimizer = optim.Adam(self.net.parameters(),
                                                lr=self.lr)
                else:
                    NotImplementedError()