def main(): torch.multiprocessing.set_sharing_strategy('file_system') torchaudio.set_audio_backend('sox_io') hack_isinstance() # get config and arguments args, config, backup_files = get_downstream_args() if args.cache_dir is not None: torch.hub.set_dir(args.cache_dir) # When torch.distributed.launch is used if args.local_rank is not None: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(args.backend) if args.mode == 'train' and args.past_exp: ckpt = torch.load(args.init_ckpt, map_location='cpu') now_use_ddp = is_initialized() original_use_ddp = ckpt['Args'].local_rank is not None assert now_use_ddp == original_use_ddp, f'{now_use_ddp} != {original_use_ddp}' if now_use_ddp: now_world = get_world_size() original_world = ckpt['WorldSize'] assert now_world == original_world, f'{now_world} != {original_world}' # Save command if is_leader_process(): with open(os.path.join(args.expdir, f'args_{get_time_tag()}.yaml'), 'w') as file: yaml.dump(vars(args), file) with open(os.path.join(args.expdir, f'config_{get_time_tag()}.yaml'), 'w') as file: yaml.dump(config, file) for file in backup_files: backup(file, args.expdir) # Fix seed and make backends deterministic random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False runner = Runner(args, config) eval(f'runner.{args.mode}')()
def train(self): # set model train/eval modes self.downstream.train() self.upstream.eval() if self.args.upstream_trainable: self.upstream.train() # set optimizer model_params = [self.downstream] if self.args.upstream_trainable: model_params.append(self.upstream) optimizer = self._get_optimizer(model_params) # set scheduler scheduler = None if self.config.get('scheduler'): scheduler = self._get_scheduler(optimizer) # set specaug specaug = None if self.config.get('specaug'): from .specaug import SpecAug specaug = SpecAug(**self.config["specaug"]) # set progress bar tqdm_file = sys.stderr if is_leader_process() else open( os.devnull, 'w') pbar = tqdm(total=self.config['runner']['total_steps'], dynamic_ncols=True, desc='overall', file=tqdm_file) init_step = self.init_ckpt.get('Step') if init_step: pbar.n = init_step # set Tensorboard logging if is_leader_process(): logger = SummaryWriter(self.args.expdir) # prepare data dataloader = self.downstream.get_dataloader('train') batch_ids = [] backward_steps = 0 records = defaultdict(list) epoch = self.init_ckpt.get('Epoch', 0) while pbar.n < pbar.total: if is_initialized(): dataloader.sampler.set_epoch(epoch) for batch_id, (wavs, *others) in enumerate( tqdm(dataloader, dynamic_ncols=True, desc='train', file=tqdm_file)): # try/except block for forward/backward try: if pbar.n >= pbar.total: break global_step = pbar.n + 1 wavs = [ torch.FloatTensor(wav).to(self.args.device) for wav in wavs ] if self.upstream.training: features = self.upstream(wavs) else: with torch.no_grad(): features = self.upstream(wavs) if specaug: features, _ = specaug(features) loss = self.downstream( 'train', features, *others, records=records, ) batch_ids.append(batch_id) gradient_accumulate_steps = self.config['runner'].get( 'gradient_accumulate_steps') (loss / gradient_accumulate_steps).backward() del loss except RuntimeError as e: if 'CUDA out of memory' in str(e): print( f'[Runner] - CUDA out of memory at step {global_step}' ) if is_initialized(): raise with torch.cuda.device(self.args.device): torch.cuda.empty_cache() optimizer.zero_grad() continue else: raise # whether to accumulate gradient backward_steps += 1 if backward_steps % gradient_accumulate_steps > 0: continue # gradient clipping paras = list(self.downstream.parameters()) if self.args.upstream_trainable: paras += list(self.upstream.parameters()) grad_norm = torch.nn.utils.clip_grad_norm_( paras, self.config['runner']['gradient_clipping']) # optimize if math.isnan(grad_norm): print(f'[Runner] - grad norm is NaN at step {global_step}') else: optimizer.step() optimizer.zero_grad() # adjust learning rate if scheduler: scheduler.step() if not is_leader_process(): batch_ids = [] records = defaultdict(list) continue # logging if global_step % self.config['runner']['log_step'] == 0: self.downstream.log_records( 'train', records=records, logger=logger, global_step=global_step, batch_ids=batch_ids, total_batch_num=len(dataloader), ) batch_ids = [] records = defaultdict(list) # evaluation and save checkpoint save_names = [] if global_step % self.config['runner']['eval_step'] == 0: for split in self.config['runner']['eval_dataloaders']: save_names += self.evaluate(split, logger, global_step) if global_step % self.config['runner']['save_step'] == 0: def check_ckpt_num(directory): max_keep = self.config['runner']['max_keep'] ckpt_pths = glob.glob(f'{directory}/states-*.ckpt') if len(ckpt_pths) >= max_keep: ckpt_pths = sorted( ckpt_pths, key=lambda pth: int( pth.split('-')[-1].split('.')[0])) for ckpt_pth in ckpt_pths[:len(ckpt_pths) - max_keep + 1]: os.remove(ckpt_pth) check_ckpt_num(self.args.expdir) save_names.append(f'states-{global_step}.ckpt') if len(save_names) > 0: all_states = { 'Downstream': get_model_state(self.downstream), 'Optimizer': optimizer.state_dict(), 'Step': global_step, 'Epoch': epoch, 'Args': self.args, 'Config': self.config, } if scheduler: all_states['Scheduler'] = scheduler.state_dict() if self.args.upstream_trainable: all_states['Upstream'] = get_model_state(self.upstream) if is_initialized(): all_states['WorldSize'] = get_world_size() save_paths = [ os.path.join(self.args.expdir, name) for name in save_names ] tqdm.write(f'[Runner] - Save the checkpoint to:') for i, path in enumerate(save_paths): tqdm.write(f'{i + 1}. {path}') torch.save(all_states, path) pbar.update(1) epoch += 1 pbar.close() if is_leader_process(): logger.close()
def __init__(self, upstream_dim, downstream_expert, evaluate_split, expdir, **kwargs): super(DownstreamExpert, self).__init__() # config self.upstream_dim = upstream_dim self.downstream = downstream_expert self.datarc = downstream_expert['datarc'] self.modelrc = downstream_expert['modelrc'] # dataset train_file_path = Path(self.datarc['file_path']) / "dev" / "wav" test_file_path = Path(self.datarc['file_path']) / "test" / "wav" train_config = { "vad_config": self.datarc['vad_config'], "file_path": [train_file_path], "key_list": ["Voxceleb1"], "meta_data": self.datarc['train_meta_data'], "max_timestep": self.datarc["max_timestep"], } self.train_dataset = SpeakerVerifi_train(**train_config) dev_config = { "vad_config": self.datarc['vad_config'], "file_path": train_file_path, "meta_data": self.datarc['dev_meta_data'] } self.dev_dataset = SpeakerVerifi_test(**dev_config) test_config = { "vad_config": self.datarc['vad_config'], "file_path": test_file_path, "meta_data": self.datarc['test_meta_data'] } self.test_dataset = SpeakerVerifi_test(**test_config) # module self.connector = nn.Linear(self.upstream_dim, self.modelrc['input_dim']) # downstream model agg_dim = self.modelrc["module_config"][self.modelrc['module']].get( "agg_dim", self.modelrc['input_dim'] ) ModelConfig = { "input_dim": self.modelrc['input_dim'], "agg_dim": agg_dim, "agg_module_name": self.modelrc['agg_module'], "module_name": self.modelrc['module'], "hparams": self.modelrc["module_config"][self.modelrc['module']], "utterance_module_name": self.modelrc["utter_module"] } # downstream model extractor include aggregation module self.model = Model(**ModelConfig) # SoftmaxLoss or AMSoftmaxLoss objective_config = { "speaker_num": self.train_dataset.speaker_num, "hidden_dim": self.modelrc['input_dim'], **self.modelrc['LossConfig'][self.modelrc['ObjectiveLoss']] } self.objective = eval(self.modelrc['ObjectiveLoss'])(**objective_config) # utils self.score_fn = nn.CosineSimilarity(dim=-1) self.eval_metric = EER self.register_buffer('best_score', torch.ones(1) * 100) if evaluate_split in ['train_plda', 'test_plda'] and is_leader_process(): self.ark = open(f'{expdir}/{evaluate_split}.rep.ark', 'wb')
def forward(self, mode, features, utter_idx, labels, records, **kwargs): """ Args: features: the features extracted by upstream put in the device assigned by command-line args labels: the speaker labels records: defaultdict(list), by appending scalars into records, these scalars will be averaged and logged on Tensorboard logger: Tensorboard SummaryWriter, given here for logging/debugging convenience, please use "self.downstream/your_content_name" as key name to log your customized contents global_step: global_step in runner, which is helpful for Tensorboard logging Return: loss: the loss to be optimized, should not be detached """ features_pad = pad_sequence(features, batch_first=True) if self.modelrc['module'] == "XVector": # TDNN layers in XVector will decrease the total sequence length by fixed 14 attention_mask = [torch.ones((feature.shape[0] - 14)) for feature in features] else: attention_mask = [torch.ones((feature.shape[0])) for feature in features] attention_mask_pad = pad_sequence(attention_mask,batch_first=True) attention_mask_pad = (1.0 - attention_mask_pad) * -100000.0 features_pad = self.connector(features_pad) if mode == 'train': agg_vec = self.model(features_pad, attention_mask_pad.cuda()) labels = torch.LongTensor(labels).to(features_pad.device) loss = self.objective(agg_vec, labels) records['loss'].append(loss.item()) return loss elif mode in ['dev', 'test']: agg_vec = self.model.inference(features_pad, attention_mask_pad.cuda()) agg_vec = agg_vec / (torch.norm(agg_vec, dim=-1).unsqueeze(-1)) # separate batched data to pair data. vec1, vec2 = self.separate_data(agg_vec) scores = self.score_fn(vec1, vec2).cpu().detach().tolist() records['scores'].extend(scores) records['labels'].extend(labels) return torch.tensor(0) elif mode in ['train_plda', 'test_plda'] and is_leader_process(): agg_vec = self.model.inference(features_pad, attention_mask_pad.cuda()) agg_vec = agg_vec / (torch.norm(agg_vec, dim=-1).unsqueeze(-1)) for key, vec in zip(utter_idx, agg_vec): vec = vec.view(-1).detach().cpu().numpy() kaldi_io.write_vec_flt(self.ark, vec, key=key)