def train(self): print(' ') print('meantime / trainers / base.py / AbstractTrainer.train') epoch = self.epoch_start best_epoch = self.best_epoch accum_iter = self.accum_iter_start # self.validate(epoch-1, accum_iter, self.val_loader) best_metric = self.best_metric_at_best_epoch stop_training = False for epoch in range(self.epoch_start, self.num_epochs): if self.pilot: print('epoch', epoch) fix_random_seed_as( epoch ) # fix random seed at every epoch to make it perfectly resumable accum_iter = self.train_one_epoch(epoch, accum_iter, self.train_loader) print('meantime / trainers / base.py self.train_loader is') #print(self.train_loader) print('meantime / trainers / base.py train.accum_iter is') #print(accum_iter) self.lr_scheduler.step( ) # step before val because state_dict is saved at val. it doesn't affect val result val_log_data = self.validate(epoch, accum_iter, mode='val') metric = val_log_data[self.best_metric] if metric > best_metric: best_metric = metric best_epoch = epoch elif (self.saturation_wait_epochs is not None) and\ (epoch - best_epoch >= self.saturation_wait_epochs): stop_training = True # stop training if val perf doesn't improve for saturation_wait_epochs if stop_training: # load best model best_model_logger = self.val_loggers[-1] assert isinstance(best_model_logger, BestModelLogger) weight_path = best_model_logger.filepath() if self.use_parallel: self.model.module.load(weight_path) else: self.model.load(weight_path) # self.validate(epoch, accum_iter, mode='test') # test result at best model self.validate(best_epoch, accum_iter, mode='test') # test result at best model break self.logger_service.complete({ 'state_dict': (self._create_state_dict(epoch, accum_iter)), })
def init_weights(self): fix_random_seed_as(self.args.model_init_seed) self.apply(self._init_weights)