def evaluate(self): """evaluate function will always be called on a single process even during distributed training""" split = self.args.evaluate_split # fix seed to guarantee the same evaluation protocol across steps random.seed(self.args.seed) np.random.seed(self.args.seed) torch.manual_seed(self.args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.args.seed) with torch.cuda.device(self.args.device): torch.cuda.empty_cache() # set all models to eval self.downstream.eval() self.upstream.eval() # prepare data dataloader = self.downstream.get_dataloader(split) records = defaultdict(list) for batch_id, (wavs, *others) in enumerate( tqdm(dataloader, dynamic_ncols=True, desc=split)): wavs = [ torch.FloatTensor(wav).to(self.args.device) for wav in wavs ] with torch.no_grad(): features = self.upstream(wavs) self.downstream( split, features, *others, records=records, ) return records
def evaluate(self, split=None, logger=None, global_step=0): """evaluate function will always be called on a single process even during distributed training""" # When this member function is called directly by command line not_during_training = split is None and logger is None and global_step == 0 if not_during_training: split = self.args.evaluate_split tempdir = tempfile.mkdtemp() logger = SummaryWriter(tempdir) # fix seed to guarantee the same evaluation protocol across steps random.seed(self.args.seed) np.random.seed(self.args.seed) torch.manual_seed(self.args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.args.seed) with torch.cuda.device(self.args.device): torch.cuda.empty_cache() # record original train/eval states and set all models to eval trainings = [] for entry in self.all_entries: trainings.append(entry.model.training) entry.model.eval() # prepare data dataloader = self.downstream.model.get_dataloader(split) batch_ids = [] records = defaultdict(list) for batch_id, (wavs, *others) in enumerate(tqdm(dataloader, dynamic_ncols=True, desc=split)): wavs = [torch.FloatTensor(wav).to(self.args.device) for wav in wavs] with torch.no_grad(): features = self.upstream.model(wavs) features = self.featurizer.model(wavs, features) self.downstream.model( split, features, *others, records = records, batch_id = batch_id, ) batch_ids.append(batch_id) save_names = self.downstream.model.log_records( split, records = records, logger = logger, global_step = global_step, batch_ids = batch_ids, total_batch_num = len(dataloader), ) batch_ids = [] records = defaultdict(list) # prepare back to training with torch.cuda.device(self.args.device): torch.cuda.empty_cache() for entry, training in zip(self.all_entries, trainings): if training: entry.model.train() if not_during_training: logger.close() shutil.rmtree(tempdir) return [] if type(save_names) is not list else save_names
def train(self): # trainable parameters and train/eval mode trainable_models = [] trainable_paras = [] for entry in self.all_entries: if entry.trainable: entry.model.train() trainable_models.append(entry.model) trainable_paras += list(entry.model.parameters()) else: entry.model.eval() # optimizer optimizer = self._get_optimizer(trainable_models) # scheduler scheduler = None if self.config.get('scheduler'): scheduler = self._get_scheduler(optimizer) # specaug specaug = None if self.config.get('specaug'): from .specaug import SpecAug specaug = SpecAug(**self.config["specaug"]) # progress bar tqdm_file = sys.stderr if is_leader_process() else open(os.devnull, 'w') pbar = tqdm(total=self.config['runner']['total_steps'], dynamic_ncols=True, desc='overall', file=tqdm_file) init_step = self.init_ckpt.get('Step') if init_step: pbar.n = init_step # Tensorboard logging if is_leader_process(): logger = SummaryWriter(self.args.expdir) # prepare data dataloader = self.downstream.model.get_dataloader('train') batch_ids = [] backward_steps = 0 records = defaultdict(list) epoch = self.init_ckpt.get('Epoch', 0) while pbar.n < pbar.total: if is_initialized(): dataloader.sampler.set_epoch(epoch) for batch_id, (wavs, *others) in enumerate(tqdm(dataloader, dynamic_ncols=True, desc='train', file=tqdm_file)): # try/except block for forward/backward try: if pbar.n >= pbar.total: break global_step = pbar.n + 1 wavs = [torch.FloatTensor(wav).to(self.args.device) for wav in wavs] if self.upstream.trainable: features = self.upstream.model(wavs) else: with torch.no_grad(): features = self.upstream.model(wavs) features = self.featurizer.model(wavs, features) if specaug: features, _ = specaug(features) loss = self.downstream.model( 'train', features, *others, records = records, ) batch_ids.append(batch_id) gradient_accumulate_steps = self.config['runner'].get('gradient_accumulate_steps') (loss / gradient_accumulate_steps).backward() del loss except RuntimeError as e: if 'CUDA out of memory' in str(e): print(f'[Runner] - CUDA out of memory at step {global_step}') if is_initialized(): raise with torch.cuda.device(self.args.device): torch.cuda.empty_cache() optimizer.zero_grad() continue else: raise # whether to accumulate gradient backward_steps += 1 if backward_steps % gradient_accumulate_steps > 0: continue # gradient clipping grad_norm = torch.nn.utils.clip_grad_norm_( trainable_paras, self.config['runner']['gradient_clipping']) # optimize if math.isnan(grad_norm): print(f'[Runner] - grad norm is NaN at step {global_step}') else: optimizer.step() optimizer.zero_grad() # adjust learning rate if scheduler: scheduler.step() if not is_leader_process(): batch_ids = [] records = defaultdict(list) continue # logging if global_step % self.config['runner']['log_step'] == 0: self.downstream.model.log_records( 'train', records = records, logger = logger, global_step = global_step, batch_ids = batch_ids, total_batch_num = len(dataloader), ) batch_ids = [] records = defaultdict(list) # evaluation and save checkpoint save_names = [] if global_step % self.config['runner']['eval_step'] == 0: for split in self.config['runner']['eval_dataloaders']: save_names += self.evaluate(split, logger, global_step) if global_step % self.config['runner']['save_step'] == 0: def check_ckpt_num(directory): max_keep = self.config['runner']['max_keep'] ckpt_pths = glob.glob(f'{directory}/states-*.ckpt') if len(ckpt_pths) >= max_keep: ckpt_pths = sorted(ckpt_pths, key=lambda pth: int(pth.split('-')[-1].split('.')[0])) for ckpt_pth in ckpt_pths[:len(ckpt_pths) - max_keep + 1]: os.remove(ckpt_pth) check_ckpt_num(self.args.expdir) save_names.append(f'states-{global_step}.ckpt') if len(save_names) > 0: all_states = { 'Optimizer': optimizer.state_dict(), 'Step': global_step, 'Epoch': epoch, 'Args': self.args, 'Config': self.config, } for entry in self.all_entries: if entry.trainable: all_states[entry.name] = get_model_state(entry.model) if scheduler: all_states['Scheduler'] = scheduler.state_dict() if is_initialized(): all_states['WorldSize'] = get_world_size() save_paths = [os.path.join(self.args.expdir, name) for name in save_names] tqdm.write(f'[Runner] - Save the checkpoint to:') for i, path in enumerate(save_paths): tqdm.write(f'{i + 1}. {path}') torch.save(all_states, path) pbar.update(1) epoch += 1 pbar.close() if is_leader_process(): logger.close()