def run(self):
        self.logger.info('args = %s', self.cfg)
        # Setup Metrics
        self.metric_train = SegmentationMetric(self.n_classes)
        self.metric_val = SegmentationMetric(self.n_classes)
        self.metric_test = SegmentationMetric(self.n_classes)
        self.val_loss_meter = average_meter()
        self.test_loss_meter = average_meter()
        self.train_loss_meter = average_meter()
        self.train_dice_coeff_meter = average_meter()
        self.val_dice_coeff_meter = average_meter()
        self.patience = 0
        self.save_best = True
        run_start = time.time()

        # Set up results folder
        if not os.path.exists(self.save_image_path):
            os.makedirs(self.save_image_path)

        for epoch in range(self.start_epoch, self.cfg['training']['epoch']):
            self.epoch = epoch

            self.scheduler.step()

            self.logger.info('=> Epoch {}, lr {}'.format(
                self.epoch,
                self.scheduler.get_lr()[-1]))

            # train and search the model
            self.train()

            # valid the model
            self.val()

            self.logger.info('current best loss {}, pixAcc {}, mIoU {}'.format(
                self.best_loss,
                self.best_pixAcc,
                self.best_mIoU,
            ))

            if self.show_dice_coeff:
                self.logger.info('current best DSC {}'.format(
                    self.best_dice_coeff))

            if self.save_best:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'dur_time': self.dur_time + time.time() - run_start,
                        'model_state': self.model.state_dict(),
                        'model_optimizer': self.model_optimizer.state_dict(),
                        'best_pixAcc': self.best_pixAcc,
                        'best_mIoU': self.best_mIoU,
                        'best_dice_coeff': self.best_dice_coeff,
                        'best_loss': self.best_loss,
                    }, True, self.save_path)
                self.logger.info(
                    'save checkpoint (epoch %d) in %s  dur_time: %s', epoch,
                    self.save_path,
                    calc_time(self.dur_time + time.time() - run_start))
                self.save_best = False

            # if self.patience == self.cfg['training']['max_patience'] or epoch == self.cfg['training']['epoch']-1:
            if epoch == self.cfg['training']['epoch'] - 1:
                # load best model weights
                # self._check_resume(os.path.join(self.save_path, 'checkpint.pth.tar'))
                # # Test
                # if len(self.test_queue) > 0:
                #     self.logger.info('Training ends \n Test')
                #     self.test()
                # else:
                #     self.logger.info('Training ends!')
                print('Early stopping')
                break
            else:
                self.logger.info('current patience :{}'.format(self.patience))

            self.val_loss_meter.reset()
            self.train_loss_meter.reset()
            self.train_dice_coeff_meter.reset()
            self.val_dice_coeff_meter.reset()
            self.metric_train.reset()
            self.metric_val.reset()
            self.logger.info('cost time: {}'.format(
                calc_time(self.dur_time + time.time() - run_start)))

        # export scalar data to JSON for external processing
        self.writer.export_scalars_to_json(self.save_tbx_log +
                                           "/all_scalars.json")
        self.writer.close()
        self.logger.info('cost time: {}'.format(
            calc_time(self.dur_time + time.time() - run_start)))
        self.logger.info('log dir in : {}'.format(self.save_path))
예제 #2
0
    def run(self):
        self.logger.info('args = {}'.format(self.cfg))
        # Setup Metrics
        self.metric_train = SegmentationMetric(self.n_classes)
        self.metric_val = SegmentationMetric(self.n_classes)
        self.train_loss_meter = AverageMeter()
        self.val_loss_meter = AverageMeter()
        run_start = time.time()

        for epoch in range(self.start_epoch, self.cfg['searching']['epoch']):
            self.epoch = epoch

            # update scheduler
            self.scheduler.step()
            self.logger.info('Epoch %d / %d lr %e', self.epoch,
                             self.cfg['searching']['epoch'],
                             self.scheduler.get_lr()[-1])

            # get genotype
            genotype = self.model.genotype()
            self.logger.info('genotype = %s', genotype)
            print('alpha normal_down:',
                  F.softmax(self.model.alphas_normal_down, dim=-1))
            print('alpha down:', F.softmax(self.model.alphas_down, dim=-1))
            print('alpha normal_up:',
                  F.softmax(self.model.alphas_normal_up, dim=-1))
            print('alpha up:', F.softmax(self.model.alphas_up, dim=-1))

            # the performance may be unstable, before train in a degree
            if self.epoch >= self.cfg['searching']['alpha_begin']:
                # check whether the genotype has changed
                if self.geno_type == genotype:
                    self.patience += 1
                else:
                    self.patience = 0
                    self.geno_type = genotype

                self.logger.info('Current patience :{}'.format(self.patience))

                if self.patience >= self.cfg['searching']['max_patience']:
                    self.logger.info(
                        'Reach the max patience! \n best genotype {}'.format(
                            genotype))
                    break

            # train and search the model
            self.train()

            # valid the model
            self.infer()

            if self.epoch % self.cfg['searching']['report_freq'] == 0:
                self.logger.info(
                    'GPU memory total:{}, reserved:{}, allocated:{}, waiting:{}'
                    .format(*gpu_memory()))

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'dur_time': self.dur_time + time.time() - run_start,
                    'cur_patience': self.patience,
                    'geno_type': self.geno_type,
                    'model_state': self.model.state_dict(),
                    'arch_optimizer': self.arch_optimizer.state_dict(),
                    'model_optimizer': self.model_optimizer.state_dict(),
                    'alphas_dict': self.model.alphas_dict(),
                    'scheduler': self.scheduler.state_dict()
                }, False, self.save_path)
            self.logger.info(
                'save checkpoint (epoch %d) in %s  dur_time: %s', epoch,
                self.save_path,
                calc_time(self.dur_time + time.time() - run_start))

            self.metric_train.reset()
            self.metric_val.reset()
            self.val_loss_meter.reset()
            self.train_loss_meter.reset()

        # export scalar data to JSON for external processing
        self.writer.export_scalars_to_json(self.save_tbx_log +
                                           "/all_scalars.json")
        self.writer.close()