def run(self): self.logger.info('args = %s', self.cfg) # Setup Metrics self.metric_train = SegmentationMetric(self.n_classes) self.metric_val = SegmentationMetric(self.n_classes) self.metric_test = SegmentationMetric(self.n_classes) self.val_loss_meter = average_meter() self.test_loss_meter = average_meter() self.train_loss_meter = average_meter() self.train_dice_coeff_meter = average_meter() self.val_dice_coeff_meter = average_meter() self.patience = 0 self.save_best = True run_start = time.time() # Set up results folder if not os.path.exists(self.save_image_path): os.makedirs(self.save_image_path) for epoch in range(self.start_epoch, self.cfg['training']['epoch']): self.epoch = epoch self.scheduler.step() self.logger.info('=> Epoch {}, lr {}'.format( self.epoch, self.scheduler.get_lr()[-1])) # train and search the model self.train() # valid the model self.val() self.logger.info('current best loss {}, pixAcc {}, mIoU {}'.format( self.best_loss, self.best_pixAcc, self.best_mIoU, )) if self.show_dice_coeff: self.logger.info('current best DSC {}'.format( self.best_dice_coeff)) if self.save_best: save_checkpoint( { 'epoch': epoch + 1, 'dur_time': self.dur_time + time.time() - run_start, 'model_state': self.model.state_dict(), 'model_optimizer': self.model_optimizer.state_dict(), 'best_pixAcc': self.best_pixAcc, 'best_mIoU': self.best_mIoU, 'best_dice_coeff': self.best_dice_coeff, 'best_loss': self.best_loss, }, True, self.save_path) self.logger.info( 'save checkpoint (epoch %d) in %s dur_time: %s', epoch, self.save_path, calc_time(self.dur_time + time.time() - run_start)) self.save_best = False # if self.patience == self.cfg['training']['max_patience'] or epoch == self.cfg['training']['epoch']-1: if epoch == self.cfg['training']['epoch'] - 1: # load best model weights # self._check_resume(os.path.join(self.save_path, 'checkpint.pth.tar')) # # Test # if len(self.test_queue) > 0: # self.logger.info('Training ends \n Test') # self.test() # else: # self.logger.info('Training ends!') print('Early stopping') break else: self.logger.info('current patience :{}'.format(self.patience)) self.val_loss_meter.reset() self.train_loss_meter.reset() self.train_dice_coeff_meter.reset() self.val_dice_coeff_meter.reset() self.metric_train.reset() self.metric_val.reset() self.logger.info('cost time: {}'.format( calc_time(self.dur_time + time.time() - run_start))) # export scalar data to JSON for external processing self.writer.export_scalars_to_json(self.save_tbx_log + "/all_scalars.json") self.writer.close() self.logger.info('cost time: {}'.format( calc_time(self.dur_time + time.time() - run_start))) self.logger.info('log dir in : {}'.format(self.save_path))
def run(self): self.logger.info('args = {}'.format(self.cfg)) # Setup Metrics self.metric_train = SegmentationMetric(self.n_classes) self.metric_val = SegmentationMetric(self.n_classes) self.train_loss_meter = AverageMeter() self.val_loss_meter = AverageMeter() run_start = time.time() for epoch in range(self.start_epoch, self.cfg['searching']['epoch']): self.epoch = epoch # update scheduler self.scheduler.step() self.logger.info('Epoch %d / %d lr %e', self.epoch, self.cfg['searching']['epoch'], self.scheduler.get_lr()[-1]) # get genotype genotype = self.model.genotype() self.logger.info('genotype = %s', genotype) print('alpha normal_down:', F.softmax(self.model.alphas_normal_down, dim=-1)) print('alpha down:', F.softmax(self.model.alphas_down, dim=-1)) print('alpha normal_up:', F.softmax(self.model.alphas_normal_up, dim=-1)) print('alpha up:', F.softmax(self.model.alphas_up, dim=-1)) # the performance may be unstable, before train in a degree if self.epoch >= self.cfg['searching']['alpha_begin']: # check whether the genotype has changed if self.geno_type == genotype: self.patience += 1 else: self.patience = 0 self.geno_type = genotype self.logger.info('Current patience :{}'.format(self.patience)) if self.patience >= self.cfg['searching']['max_patience']: self.logger.info( 'Reach the max patience! \n best genotype {}'.format( genotype)) break # train and search the model self.train() # valid the model self.infer() if self.epoch % self.cfg['searching']['report_freq'] == 0: self.logger.info( 'GPU memory total:{}, reserved:{}, allocated:{}, waiting:{}' .format(*gpu_memory())) save_checkpoint( { 'epoch': epoch + 1, 'dur_time': self.dur_time + time.time() - run_start, 'cur_patience': self.patience, 'geno_type': self.geno_type, 'model_state': self.model.state_dict(), 'arch_optimizer': self.arch_optimizer.state_dict(), 'model_optimizer': self.model_optimizer.state_dict(), 'alphas_dict': self.model.alphas_dict(), 'scheduler': self.scheduler.state_dict() }, False, self.save_path) self.logger.info( 'save checkpoint (epoch %d) in %s dur_time: %s', epoch, self.save_path, calc_time(self.dur_time + time.time() - run_start)) self.metric_train.reset() self.metric_val.reset() self.val_loss_meter.reset() self.train_loss_meter.reset() # export scalar data to JSON for external processing self.writer.export_scalars_to_json(self.save_tbx_log + "/all_scalars.json") self.writer.close()