def evaluate(self): """evaluate function will always be called on a single process even during distributed training""" split = self.args.evaluate_split # fix seed to guarantee the same evaluation protocol across steps random.seed(self.args.seed) np.random.seed(self.args.seed) torch.manual_seed(self.args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.args.seed) with torch.cuda.device(self.args.device): torch.cuda.empty_cache() # set all models to eval self.downstream.eval() self.upstream.eval() # prepare data dataloader = self.downstream.get_dataloader(split) records = defaultdict(list) for batch_id, (wavs, *others) in enumerate( tqdm(dataloader, dynamic_ncols=True, desc=split)): wavs = [ torch.FloatTensor(wav).to(self.args.device) for wav in wavs ] with torch.no_grad(): features = self.upstream(wavs) self.downstream( split, features, *others, records=records, ) return records
def evaluate(self, split=None, logger=None, global_step=0): """evaluate function will always be called on a single process even during distributed training""" # When this member function is called directly by command line not_during_training = split is None and logger is None and global_step == 0 if not_during_training: split = self.args.evaluate_split tempdir = tempfile.mkdtemp() logger = SummaryWriter(tempdir) # fix seed to guarantee the same evaluation protocol across steps random.seed(self.args.seed) np.random.seed(self.args.seed) torch.manual_seed(self.args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(self.args.seed) with torch.cuda.device(self.args.device): torch.cuda.empty_cache() # record original train/eval states and set all models to eval downstream_training = self.downstream.training upstream_training = self.upstream.training self.downstream.eval() self.upstream.eval() # prepare data dataloader = self.downstream.get_dataloader(split) batch_ids = [] records = defaultdict(list) for batch_id, (wavs, *others) in enumerate( tqdm(dataloader, dynamic_ncols=True, desc=split)): wavs = [ torch.FloatTensor(wav).to(self.args.device) for wav in wavs ] with torch.no_grad(): features = self.upstream(wavs) self.downstream( split, features, *others, records=records, ) batch_ids.append(batch_id) save_names = self.downstream.log_records( split, records=records, logger=logger, global_step=global_step, batch_ids=batch_ids, total_batch_num=len(dataloader), ) batch_ids = [] records = defaultdict(list) # prepare back to training with torch.cuda.device(self.args.device): torch.cuda.empty_cache() if downstream_training: self.downstream.train() if upstream_training: self.upstream.train() if not_during_training: logger.close() shutil.rmtree(tempdir) return [] if type(save_names) is not list else save_names
def train(self): # set model train/eval modes self.downstream.train() self.upstream.eval() if self.args.upstream_trainable: self.upstream.train() # set optimizer model_params = [self.downstream] if self.args.upstream_trainable: model_params.append(self.upstream) optimizer = self._get_optimizer(model_params) # set scheduler scheduler = None if self.config.get('scheduler'): scheduler = self._get_scheduler(optimizer) # set specaug specaug = None if self.config.get('specaug'): from .specaug import SpecAug specaug = SpecAug(**self.config["specaug"]) # set progress bar tqdm_file = sys.stderr if is_leader_process() else open( os.devnull, 'w') pbar = tqdm(total=self.config['runner']['total_steps'], dynamic_ncols=True, desc='overall', file=tqdm_file) init_step = self.init_ckpt.get('Step') if init_step: pbar.n = init_step # set Tensorboard logging if is_leader_process(): logger = SummaryWriter(self.args.expdir) # prepare data dataloader = self.downstream.get_dataloader('train') batch_ids = [] backward_steps = 0 records = defaultdict(list) epoch = self.init_ckpt.get('Epoch', 0) while pbar.n < pbar.total: if is_initialized(): dataloader.sampler.set_epoch(epoch) for batch_id, (wavs, *others) in enumerate( tqdm(dataloader, dynamic_ncols=True, desc='train', file=tqdm_file)): # try/except block for forward/backward try: if pbar.n >= pbar.total: break global_step = pbar.n + 1 wavs = [ torch.FloatTensor(wav).to(self.args.device) for wav in wavs ] if self.upstream.training: features = self.upstream(wavs) else: with torch.no_grad(): features = self.upstream(wavs) if specaug: features, _ = specaug(features) loss = self.downstream( 'train', features, *others, records=records, ) batch_ids.append(batch_id) gradient_accumulate_steps = self.config['runner'].get( 'gradient_accumulate_steps') (loss / gradient_accumulate_steps).backward() del loss except RuntimeError as e: if 'CUDA out of memory' in str(e): print( f'[Runner] - CUDA out of memory at step {global_step}' ) if is_initialized(): raise with torch.cuda.device(self.args.device): torch.cuda.empty_cache() optimizer.zero_grad() continue else: raise # whether to accumulate gradient backward_steps += 1 if backward_steps % gradient_accumulate_steps > 0: continue # gradient clipping paras = list(self.downstream.parameters()) if self.args.upstream_trainable: paras += list(self.upstream.parameters()) grad_norm = torch.nn.utils.clip_grad_norm_( paras, self.config['runner']['gradient_clipping']) # optimize if math.isnan(grad_norm): print(f'[Runner] - grad norm is NaN at step {global_step}') else: optimizer.step() optimizer.zero_grad() # adjust learning rate if scheduler: scheduler.step() if not is_leader_process(): batch_ids = [] records = defaultdict(list) continue # logging if global_step % self.config['runner']['log_step'] == 0: self.downstream.log_records( 'train', records=records, logger=logger, global_step=global_step, batch_ids=batch_ids, total_batch_num=len(dataloader), ) batch_ids = [] records = defaultdict(list) # evaluation and save checkpoint save_names = [] if global_step % self.config['runner']['eval_step'] == 0: for split in self.config['runner']['eval_dataloaders']: save_names += self.evaluate(split, logger, global_step) if global_step % self.config['runner']['save_step'] == 0: def check_ckpt_num(directory): max_keep = self.config['runner']['max_keep'] ckpt_pths = glob.glob(f'{directory}/states-*.ckpt') if len(ckpt_pths) >= max_keep: ckpt_pths = sorted( ckpt_pths, key=lambda pth: int( pth.split('-')[-1].split('.')[0])) for ckpt_pth in ckpt_pths[:len(ckpt_pths) - max_keep + 1]: os.remove(ckpt_pth) check_ckpt_num(self.args.expdir) save_names.append(f'states-{global_step}.ckpt') if len(save_names) > 0: all_states = { 'Downstream': get_model_state(self.downstream), 'Optimizer': optimizer.state_dict(), 'Step': global_step, 'Epoch': epoch, 'Args': self.args, 'Config': self.config, } if scheduler: all_states['Scheduler'] = scheduler.state_dict() if self.args.upstream_trainable: all_states['Upstream'] = get_model_state(self.upstream) if is_initialized(): all_states['WorldSize'] = get_world_size() save_paths = [ os.path.join(self.args.expdir, name) for name in save_names ] tqdm.write(f'[Runner] - Save the checkpoint to:') for i, path in enumerate(save_paths): tqdm.write(f'{i + 1}. {path}') torch.save(all_states, path) pbar.update(1) epoch += 1 pbar.close() if is_leader_process(): logger.close()