Ejemplo n.º 1
0
def main():
    torch.multiprocessing.set_sharing_strategy('file_system')
    torchaudio.set_audio_backend('sox_io')
    hack_isinstance()

    # get config and arguments
    args, config, backup_files = get_downstream_args()
    if args.cache_dir is not None:
        torch.hub.set_dir(args.cache_dir)

    # When torch.distributed.launch is used
    if args.local_rank is not None:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(args.backend)

    if args.mode == 'train' and args.past_exp:
        ckpt = torch.load(args.init_ckpt, map_location='cpu')

        now_use_ddp = is_initialized()
        original_use_ddp = ckpt['Args'].local_rank is not None
        assert now_use_ddp == original_use_ddp, f'{now_use_ddp} != {original_use_ddp}'

        if now_use_ddp:
            now_world = get_world_size()
            original_world = ckpt['WorldSize']
            assert now_world == original_world, f'{now_world} != {original_world}'

    # Save command
    if is_leader_process():
        with open(os.path.join(args.expdir, f'args_{get_time_tag()}.yaml'),
                  'w') as file:
            yaml.dump(vars(args), file)

        with open(os.path.join(args.expdir, f'config_{get_time_tag()}.yaml'),
                  'w') as file:
            yaml.dump(config, file)

        for file in backup_files:
            backup(file, args.expdir)

    # Fix seed and make backends deterministic
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    runner = Runner(args, config)
    eval(f'runner.{args.mode}')()
Ejemplo n.º 2
0
    def train(self):
        # set model train/eval modes
        self.downstream.train()
        self.upstream.eval()
        if self.args.upstream_trainable:
            self.upstream.train()

        # set optimizer
        model_params = [self.downstream]
        if self.args.upstream_trainable:
            model_params.append(self.upstream)
        optimizer = self._get_optimizer(model_params)

        # set scheduler
        scheduler = None
        if self.config.get('scheduler'):
            scheduler = self._get_scheduler(optimizer)

        # set specaug
        specaug = None
        if self.config.get('specaug'):
            from .specaug import SpecAug
            specaug = SpecAug(**self.config["specaug"])

        # set progress bar
        tqdm_file = sys.stderr if is_leader_process() else open(
            os.devnull, 'w')
        pbar = tqdm(total=self.config['runner']['total_steps'],
                    dynamic_ncols=True,
                    desc='overall',
                    file=tqdm_file)
        init_step = self.init_ckpt.get('Step')
        if init_step:
            pbar.n = init_step

        # set Tensorboard logging
        if is_leader_process():
            logger = SummaryWriter(self.args.expdir)

        # prepare data
        dataloader = self.downstream.get_dataloader('train')

        batch_ids = []
        backward_steps = 0
        records = defaultdict(list)
        epoch = self.init_ckpt.get('Epoch', 0)
        while pbar.n < pbar.total:
            if is_initialized():
                dataloader.sampler.set_epoch(epoch)

            for batch_id, (wavs, *others) in enumerate(
                    tqdm(dataloader,
                         dynamic_ncols=True,
                         desc='train',
                         file=tqdm_file)):
                # try/except block for forward/backward
                try:
                    if pbar.n >= pbar.total:
                        break
                    global_step = pbar.n + 1

                    wavs = [
                        torch.FloatTensor(wav).to(self.args.device)
                        for wav in wavs
                    ]
                    if self.upstream.training:
                        features = self.upstream(wavs)
                    else:
                        with torch.no_grad():
                            features = self.upstream(wavs)

                    if specaug:
                        features, _ = specaug(features)

                    loss = self.downstream(
                        'train',
                        features,
                        *others,
                        records=records,
                    )
                    batch_ids.append(batch_id)

                    gradient_accumulate_steps = self.config['runner'].get(
                        'gradient_accumulate_steps')
                    (loss / gradient_accumulate_steps).backward()
                    del loss

                except RuntimeError as e:
                    if 'CUDA out of memory' in str(e):
                        print(
                            f'[Runner] - CUDA out of memory at step {global_step}'
                        )
                        if is_initialized():
                            raise
                        with torch.cuda.device(self.args.device):
                            torch.cuda.empty_cache()
                        optimizer.zero_grad()
                        continue
                    else:
                        raise

                # whether to accumulate gradient
                backward_steps += 1
                if backward_steps % gradient_accumulate_steps > 0:
                    continue

                # gradient clipping
                paras = list(self.downstream.parameters())
                if self.args.upstream_trainable:
                    paras += list(self.upstream.parameters())
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    paras, self.config['runner']['gradient_clipping'])

                # optimize
                if math.isnan(grad_norm):
                    print(f'[Runner] - grad norm is NaN at step {global_step}')
                else:
                    optimizer.step()
                optimizer.zero_grad()

                # adjust learning rate
                if scheduler:
                    scheduler.step()

                if not is_leader_process():
                    batch_ids = []
                    records = defaultdict(list)
                    continue

                # logging
                if global_step % self.config['runner']['log_step'] == 0:
                    self.downstream.log_records(
                        'train',
                        records=records,
                        logger=logger,
                        global_step=global_step,
                        batch_ids=batch_ids,
                        total_batch_num=len(dataloader),
                    )
                    batch_ids = []
                    records = defaultdict(list)

                # evaluation and save checkpoint
                save_names = []

                if global_step % self.config['runner']['eval_step'] == 0:
                    for split in self.config['runner']['eval_dataloaders']:
                        save_names += self.evaluate(split, logger, global_step)

                if global_step % self.config['runner']['save_step'] == 0:

                    def check_ckpt_num(directory):
                        max_keep = self.config['runner']['max_keep']
                        ckpt_pths = glob.glob(f'{directory}/states-*.ckpt')
                        if len(ckpt_pths) >= max_keep:
                            ckpt_pths = sorted(
                                ckpt_pths,
                                key=lambda pth: int(
                                    pth.split('-')[-1].split('.')[0]))
                            for ckpt_pth in ckpt_pths[:len(ckpt_pths) -
                                                      max_keep + 1]:
                                os.remove(ckpt_pth)

                    check_ckpt_num(self.args.expdir)
                    save_names.append(f'states-{global_step}.ckpt')

                if len(save_names) > 0:
                    all_states = {
                        'Downstream': get_model_state(self.downstream),
                        'Optimizer': optimizer.state_dict(),
                        'Step': global_step,
                        'Epoch': epoch,
                        'Args': self.args,
                        'Config': self.config,
                    }

                    if scheduler:
                        all_states['Scheduler'] = scheduler.state_dict()

                    if self.args.upstream_trainable:
                        all_states['Upstream'] = get_model_state(self.upstream)

                    if is_initialized():
                        all_states['WorldSize'] = get_world_size()

                    save_paths = [
                        os.path.join(self.args.expdir, name)
                        for name in save_names
                    ]
                    tqdm.write(f'[Runner] - Save the checkpoint to:')
                    for i, path in enumerate(save_paths):
                        tqdm.write(f'{i + 1}. {path}')
                        torch.save(all_states, path)

                pbar.update(1)
            epoch += 1

        pbar.close()
        if is_leader_process():
            logger.close()
Ejemplo n.º 3
0
    def __init__(self, upstream_dim, downstream_expert, evaluate_split, expdir, **kwargs):
        super(DownstreamExpert, self).__init__()
        # config
        self.upstream_dim = upstream_dim
        self.downstream = downstream_expert
        self.datarc = downstream_expert['datarc']
        self.modelrc = downstream_expert['modelrc']

        # dataset
        train_file_path = Path(self.datarc['file_path']) / "dev" / "wav"
        test_file_path = Path(self.datarc['file_path']) / "test" / "wav"
        
        train_config = {
            "vad_config": self.datarc['vad_config'],
            "file_path": [train_file_path],
            "key_list": ["Voxceleb1"],
            "meta_data": self.datarc['train_meta_data'],
            "max_timestep": self.datarc["max_timestep"],
        }
        self.train_dataset = SpeakerVerifi_train(**train_config)

        dev_config = {
            "vad_config": self.datarc['vad_config'],
            "file_path": train_file_path, 
            "meta_data": self.datarc['dev_meta_data']
        }        
        self.dev_dataset = SpeakerVerifi_test(**dev_config)

        test_config = {
            "vad_config": self.datarc['vad_config'],
            "file_path": test_file_path, 
            "meta_data": self.datarc['test_meta_data']
        }
        self.test_dataset = SpeakerVerifi_test(**test_config)

        # module
        self.connector = nn.Linear(self.upstream_dim, self.modelrc['input_dim'])

        # downstream model
        agg_dim = self.modelrc["module_config"][self.modelrc['module']].get(
            "agg_dim",
            self.modelrc['input_dim']
        )
        
        ModelConfig = {
            "input_dim": self.modelrc['input_dim'],
            "agg_dim": agg_dim,
            "agg_module_name": self.modelrc['agg_module'],
            "module_name": self.modelrc['module'], 
            "hparams": self.modelrc["module_config"][self.modelrc['module']],
            "utterance_module_name": self.modelrc["utter_module"]
        }
        # downstream model extractor include aggregation module
        self.model = Model(**ModelConfig)


        # SoftmaxLoss or AMSoftmaxLoss
        objective_config = {
            "speaker_num": self.train_dataset.speaker_num, 
            "hidden_dim": self.modelrc['input_dim'], 
            **self.modelrc['LossConfig'][self.modelrc['ObjectiveLoss']]
        }

        self.objective = eval(self.modelrc['ObjectiveLoss'])(**objective_config)
        # utils
        self.score_fn  = nn.CosineSimilarity(dim=-1)
        self.eval_metric = EER
        self.register_buffer('best_score', torch.ones(1) * 100)

        if evaluate_split in ['train_plda', 'test_plda'] and is_leader_process():
            self.ark = open(f'{expdir}/{evaluate_split}.rep.ark', 'wb')
Ejemplo n.º 4
0
    def forward(self, mode, features, utter_idx, labels, records, **kwargs):
        """
        Args:
            features:
                the features extracted by upstream
                put in the device assigned by command-line args

            labels:
                the speaker labels

            records:
                defaultdict(list), by appending scalars into records,
                these scalars will be averaged and logged on Tensorboard

            logger:
                Tensorboard SummaryWriter, given here for logging/debugging
                convenience, please use "self.downstream/your_content_name" as key
                name to log your customized contents

            global_step:
                global_step in runner, which is helpful for Tensorboard logging

        Return:
            loss:
                the loss to be optimized, should not be detached
        """

        features_pad = pad_sequence(features, batch_first=True)
        
        if self.modelrc['module'] == "XVector":
            # TDNN layers in XVector will decrease the total sequence length by fixed 14
            attention_mask = [torch.ones((feature.shape[0] - 14)) for feature in features]
        else:
            attention_mask = [torch.ones((feature.shape[0])) for feature in features]

        attention_mask_pad = pad_sequence(attention_mask,batch_first=True)
        attention_mask_pad = (1.0 - attention_mask_pad) * -100000.0

        features_pad = self.connector(features_pad)

        if mode == 'train':
            agg_vec = self.model(features_pad, attention_mask_pad.cuda())
            labels = torch.LongTensor(labels).to(features_pad.device)
            loss = self.objective(agg_vec, labels)
            records['loss'].append(loss.item())
            return loss
        
        elif mode in ['dev', 'test']:
            agg_vec = self.model.inference(features_pad, attention_mask_pad.cuda())
            agg_vec = agg_vec / (torch.norm(agg_vec, dim=-1).unsqueeze(-1))

            # separate batched data to pair data.
            vec1, vec2 = self.separate_data(agg_vec)

            scores = self.score_fn(vec1, vec2).cpu().detach().tolist()
            records['scores'].extend(scores)
            records['labels'].extend(labels)

            return torch.tensor(0)
        
        elif mode in ['train_plda', 'test_plda'] and is_leader_process():
            agg_vec = self.model.inference(features_pad, attention_mask_pad.cuda())
            agg_vec = agg_vec / (torch.norm(agg_vec, dim=-1).unsqueeze(-1))

            for key, vec in zip(utter_idx, agg_vec):
                vec = vec.view(-1).detach().cpu().numpy()
                kaldi_io.write_vec_flt(self.ark, vec, key=key)