def main():
    parser = get_parser()
    LibriSpeechAsrDataModule.add_arguments(parser)
    args = parser.parse_args()
    world_size = args.world_size
    assert world_size >= 1
    mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
示例#2
0
def run(rank, world_size, args):
    '''
    Args:
      rank:
        It is a value between 0 and `world_size-1`, which is
        passed automatically by `mp.spawn()` in :func:`main`.
        The node with rank 0 is responsible for saving checkpoint.
      world_size:
        Number of GPUs for DDP training.
      args:
        The return value of get_parser().parse_args()
    '''
    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    accum_grad = args.accum_grad
    den_scale = args.den_scale
    att_rate = args.att_rate
    use_pruned_intersect = args.use_pruned_intersect

    fix_random_seed(42)
    if world_size > 1:
        setup_dist(rank, world_size, args.master_port)

    exp_dir = Path('exp-' + model_type + '-mmi-att-sa-vgg-normlayer')
    setup_logger(f'{exp_dir}/log/log-train-{rank}')
    if args.tensorboard and rank == 0:
        tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
    else:
        tb_writer = None
    #  tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None

    logging.info("Loading lexicon and symbol tables")
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = rank
    device = torch.device('cuda', device_id)

    if not Path(lang_dir / 'P.pt').is_file():
        logging.debug(f'Loading P from {lang_dir}/P.fst.txt')
        with open(lang_dir / 'P.fst.txt') as f:
            # P is not an acceptor because there is
            # a back-off state, whose incoming arcs
            # have label #0 and aux_label eps.
            P = k2.Fsa.from_openfst(f.read(), acceptor=False)

        phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
        first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)

        # P.aux_labels is not needed in later computations, so
        # remove it here.
        del P.aux_labels
        # CAUTION(fangjun): The following line is crucial.
        # Arcs entering the back-off state have label equal to #0.
        # We have to change it to 0 here.
        P.labels[P.labels >= first_phone_disambig_id] = 0

        P = k2.remove_epsilon(P)
        P = k2.arc_sort(P)
        torch.save(P.as_dict(), lang_dir / 'P.pt')
    else:
        logging.debug('Loading pre-compiled P')
        d = torch.load(lang_dir / 'P.pt')
        P = k2.Fsa.from_dict(d)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        P=P,
        device=device,
    )
    phone_ids = lexicon.phone_symbols()

    librispeech = LibriSpeechAsrDataModule(args)
    train_dl = librispeech.train_dataloaders()
    valid_dl = librispeech.valid_dataloaders()

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    if use_pruned_intersect:
        logging.info('Use pruned intersect for den_lats')
    else:
        logging.info("Don't use pruned intersect for den_lats")

    logging.info("About to create model")

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True,
            is_espnet_structure=True)
    elif model_type == "contextnet":
        model = ContextNet(
            num_features=80,
            num_classes=len(phone_ids) + 1)  # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) + " is not implemented")

    if args.torchscript:
        logging.info('Applying TorchScript to model...')
        model = torch.jit.script(model)

    model.to(device)
    describe(model)

    if world_size > 1:
        model = DDP(model, device_ids=[rank])

    # Now for the alignment model, if any
    if args.use_ali_model:
        ali_model = TdnnLstm1b(
            num_features=80,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4)

        ali_model_fname = Path(f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt')
        assert ali_model_fname.is_file(), \
                f'ali model filename {ali_model_fname} does not exist!'
        ali_model.load_state_dict(torch.load(ali_model_fname, map_location='cpu')['state_dict'])
        ali_model.to(device)

        ali_model.eval()
        ali_model.requires_grad_(False)
        logging.info(f'Use ali_model: {ali_model_fname}')
    else:
        ali_model = None
        logging.info('No ali_model')

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=args.lr_factor,
                     warm_step=args.warm_step,
                     weight_decay=args.weight_decay)

    scaler = GradScaler(enabled=args.amp)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scaler=scaler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}")

    for epoch in range(start_epoch, num_epochs):
        train_dl.sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            ali_model=ali_model,
            device=device,
            graph_compiler=graph_compiler,
            use_pruned_intersect=use_pruned_intersect,
            optimizer=optimizer,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
            world_size=world_size,
            scaler=scaler
        )
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            scaler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train,
                            local_rank=rank,
                            torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch
                            )
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch,
                               local_rank=rank)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        scaler=scaler,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train,
                        local_rank=rank,
                        torchscript=args.torchscript_epoch != -1 and epoch >= args.torchscript_epoch
                        )
        epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch,
                           local_rank=rank)

    logging.warning('Done')
    if world_size > 1:
        torch.distributed.barrier()
        cleanup_dist()
def run(rank, world_size, args):
    '''
    Args:
      rank:
        It is a value between 0 and `world_size-1`, which is
        passed automatically by `mp.spawn()` in :func:`main`.
        The node with rank 0 is responsible for saving checkpoint.
      world_size:
        Number of GPUs for DDP training.
      args:
        The return value of get_parser().parse_args()
    '''
    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    accum_grad = args.accum_grad
    den_scale = args.den_scale
    att_rate = args.att_rate

    fix_random_seed(42)
    setup_dist(rank, world_size, args.master_port)

    exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa-vgg')
    setup_logger(f'{exp_dir}/log/log-train-{rank}')
    if args.tensorboard and rank == 0:
        tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
    else:
        tb_writer = None
    #  tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None

    logging.info("Loading lexicon and symbol tables")
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = rank
    device = torch.device('cuda', device_id)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        device=device,
    )
    phone_ids = lexicon.phone_symbols()
    P = create_bigram_phone_lm(phone_ids)
    P.scores = torch.zeros_like(P.scores)
    P = P.to(device)

    librispeech = LibriSpeechAsrDataModule(args)
    train_dl = librispeech.train_dataloaders()
    valid_dl = librispeech.valid_dataloaders()

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "conformer":
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers,
            vgg_frontend=True)
    elif model_type == "contextnet":
        model = ContextNet(
        num_features=80,
        num_classes=len(phone_ids) + 1) # +1 for the blank symbol
    else:
        raise NotImplementedError("Model of type " + str(model_type) + " is not implemented")

    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    model = DDP(model, device_ids=[rank])

    # Now for the aligment model, if any
    if args.use_ali_model:
        ali_model = TdnnLstm1b(
            num_features=80,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4)

        ali_model_fname = Path(f'exp-lstm-adam-ctc-musan/epoch-{args.ali_model_epoch}.pt')
        assert ali_model_fname.is_file(), \
                f'ali model filename {ali_model_fname} does not exist!'
        ali_model.load_state_dict(torch.load(ali_model_fname, map_location='cpu')['state_dict'])
        ali_model.to(device)

        ali_model.eval()
        ali_model.requires_grad_(False)
        logging.info(f'Use ali_model: {ali_model_fname}')
    else:
        ali_model = None
        logging.info('No ali_model')

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=args.lr_factor,
                     warm_step=args.warm_step,
                     weight_decay=args.weight_decay)

    scaler = GradScaler(enabled=args.amp)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path, model=model, optimizer=optimizer, scaler=scaler)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}")

    for epoch in range(start_epoch, num_epochs):
        train_dl.sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate, global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            ali_model=ali_model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
            world_size=world_size,
            scaler=scaler
        )
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            scaler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train,
                            local_rank=rank)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch,
                               local_rank=rank)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        scaler=scaler,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train,
                        local_rank=rank)
        epoch_info_filename = os.path.join(exp_dir, 'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch,
                           local_rank=rank)

    logging.warning('Done')
    torch.distributed.barrier()
    cleanup_dist()
示例#4
0
def run(rank, world_size, args):
    '''
    Args:
      rank:
        It is a value between 0 and `world_size-1`, which is
        passed automatically by `mp.spawn()` in :func:`main`.
        The node with rank 0 is responsible for saving checkpoint.
      world_size:
        Number of GPUs for DDP training.
      args:
        The return value of get_parser().parse_args()
    '''
    model_type = args.model_type
    start_epoch = args.start_epoch
    num_epochs = args.num_epochs
    accum_grad = args.accum_grad
    den_scale = args.den_scale
    att_rate = args.att_rate

    fix_random_seed(42)
    setup_dist(rank, world_size, args.master_port)

    exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa')
    setup_logger(f'{exp_dir}/log/log-train-{rank}')
    if args.tensorboard and rank == 0:
        tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')
    else:
        tb_writer = None
    #  tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard') if args.tensorboard and rank == 0 else None

    logging.info("Loading lexicon and symbol tables")
    lang_dir = Path('data/lang_nosp')
    lexicon = Lexicon(lang_dir)

    device_id = rank
    device = torch.device('cuda', device_id)

    graph_compiler = MmiTrainingGraphCompiler(
        lexicon=lexicon,
        device=device,
    )
    phone_ids = lexicon.phone_symbols()
    P = create_bigram_phone_lm(phone_ids)
    P.scores = torch.zeros_like(P.scores)
    P = P.to(device)

    librispeech = LibriSpeechAsrDataModule(args)
    train_dl = librispeech.train_dataloaders()
    valid_dl = librispeech.valid_dataloaders()

    if not torch.cuda.is_available():
        logging.error('No GPU detected!')
        sys.exit(-1)

    logging.info("About to create model")

    if att_rate != 0.0:
        num_decoder_layers = 6
    else:
        num_decoder_layers = 0

    if model_type == "transformer":
        model = Transformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)
    else:
        model = Conformer(
            num_features=80,
            nhead=args.nhead,
            d_model=args.attention_dim,
            num_classes=len(phone_ids) + 1,  # +1 for the blank symbol
            subsampling_factor=4,
            num_decoder_layers=num_decoder_layers)

    model.P_scores = nn.Parameter(P.scores.clone(), requires_grad=True)

    model.to(device)
    describe(model)

    model = DDP(model, device_ids=[rank])

    optimizer = Noam(model.parameters(),
                     model_size=args.attention_dim,
                     factor=1.0,
                     warm_step=args.warm_step)

    best_objf = np.inf
    best_valid_objf = np.inf
    best_epoch = start_epoch
    best_model_path = os.path.join(exp_dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(exp_dir, 'best-epoch-info')
    global_batch_idx_train = 0  # for logging only

    if start_epoch > 0:
        model_path = os.path.join(exp_dir,
                                  'epoch-{}.pt'.format(start_epoch - 1))
        ckpt = load_checkpoint(filename=model_path,
                               model=model,
                               optimizer=optimizer)
        best_objf = ckpt['objf']
        best_valid_objf = ckpt['valid_objf']
        global_batch_idx_train = ckpt['global_batch_idx_train']
        logging.info(
            f"epoch = {ckpt['epoch']}, objf = {best_objf}, valid_objf = {best_valid_objf}"
        )

    for epoch in range(start_epoch, num_epochs):
        train_dl.sampler.set_epoch(epoch)
        curr_learning_rate = optimizer._rate
        if tb_writer is not None:
            tb_writer.add_scalar('train/learning_rate', curr_learning_rate,
                                 global_batch_idx_train)
            tb_writer.add_scalar('train/epoch', epoch, global_batch_idx_train)

        logging.info('epoch {}, learning rate {}'.format(
            epoch, curr_learning_rate))
        objf, valid_objf, global_batch_idx_train = train_one_epoch(
            dataloader=train_dl,
            valid_dataloader=valid_dl,
            model=model,
            P=P,
            device=device,
            graph_compiler=graph_compiler,
            optimizer=optimizer,
            accum_grad=accum_grad,
            den_scale=den_scale,
            att_rate=att_rate,
            current_epoch=epoch,
            tb_writer=tb_writer,
            num_epochs=num_epochs,
            global_batch_idx_train=global_batch_idx_train,
            world_size=world_size,
        )
        # the lower, the better
        if valid_objf < best_valid_objf:
            best_valid_objf = valid_objf
            best_objf = objf
            best_epoch = epoch
            save_checkpoint(filename=best_model_path,
                            optimizer=None,
                            scheduler=None,
                            model=model,
                            epoch=epoch,
                            learning_rate=curr_learning_rate,
                            objf=objf,
                            valid_objf=valid_objf,
                            global_batch_idx_train=global_batch_idx_train,
                            local_rank=rank)
            save_training_info(filename=best_epoch_info_filename,
                               model_path=best_model_path,
                               current_epoch=epoch,
                               learning_rate=curr_learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               valid_objf=valid_objf,
                               best_valid_objf=best_valid_objf,
                               best_epoch=best_epoch,
                               local_rank=rank)

        # we always save the model for every epoch
        model_path = os.path.join(exp_dir, 'epoch-{}.pt'.format(epoch))
        save_checkpoint(filename=model_path,
                        optimizer=optimizer,
                        scheduler=None,
                        model=model,
                        epoch=epoch,
                        learning_rate=curr_learning_rate,
                        objf=objf,
                        valid_objf=valid_objf,
                        global_batch_idx_train=global_batch_idx_train,
                        local_rank=rank)
        epoch_info_filename = os.path.join(exp_dir,
                                           'epoch-{}-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=curr_learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           valid_objf=valid_objf,
                           best_valid_objf=best_valid_objf,
                           best_epoch=best_epoch,
                           local_rank=rank)

    logging.warning('Done')
    torch.distributed.barrier()
    # NOTE: The training process is very likely to hang at this point.
    # If you press ctrl + c, your GPU memory will not be freed.
    # To free you GPU memory, you can run:
    #
    #  $ ps aux | grep multi
    #
    # And it will print something like below:
    #
    # kuangfa+  430518 98.9  0.6 57074236 3425732 pts/21 Rl Apr02 639:01 /root/fangjun/py38/bin/python3 -c from multiprocessing.spawn
    #
    # You can kill the process manually by:
    #
    # $ kill -9 430518
    #
    # And you will see that your GPU is now not occupied anymore.
    cleanup_dist()