Exemplo n.º 1
0
import numpy as np
import timeit
import os
import sys
import argparse
from torch.autograd import Function
from torch.utils.data import Dataset, DataLoader
from model import BLSTM
from dataset import SpeechDataset, SpeechDatasetMem, PadCollate
sys.path.append('../../src/ctc_crf')
import ctc_crf_base

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
TARGET_GPUS = [0, 1, 2, 3]
gpus = torch.IntTensor(TARGET_GPUS)
ctc_crf_base.init_env('data/den_meta/den_lm.fst', gpus)

os.system("mkdir -p models")


class Model(nn.Module):
    def __init__(self, idim, hdim, K, n_layers, dropout, lamb):
        super(Model, self).__init__()
        self.net = BLSTM(idim, hdim, n_layers, dropout=dropout)
        self.linear = nn.Linear(hdim * 2, K)
        self.loss_fn = CTC_CRF_LOSS(lamb=lamb)

    def forward(self, logits, labels_padded, input_lengths, label_lengths):
        # rearrange by input_lengths
        input_lengths, indices = torch.sort(input_lengths, descending=True)
        assert indices.dim() == 1, "input_lengths should have only 1 dim"
Exemplo n.º 2
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    args.rank = args.rank * ngpus_per_node + gpu
    print(f"Use GPU: local[{args.gpu}] | global[{args.rank}]")

    dist.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)

    args.batch_size = args.batch_size // ngpus_per_node

    print("> Data prepare")
    if args.h5py:
        data_format = "hdf5"
        utils.highlight_msg("H5py reading might cause error with Multi-GPUs.")
        Dataset = DataSet.SpeechDataset
    else:
        data_format = "pickle"
        Dataset = DataSet.SpeechDatasetPickle

    tr_set = Dataset(
        f"{args.data}/{data_format}/tr.{data_format}")
    test_set = Dataset(
        f"{args.data}/{data_format}/cv.{data_format}")
    print("Data prepared.")

    train_sampler = DistributedSampler(tr_set)
    test_sampler = DistributedSampler(test_set)
    test_sampler.set_epoch(1)

    trainloader = DataLoader(
        tr_set, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True,
        sampler=train_sampler, collate_fn=DataSet.sortedPadCollate())

    testloader = DataLoader(
        test_set, batch_size=args.batch_size, shuffle=(test_sampler is None),
        num_workers=args.workers, pin_memory=True,
        sampler=test_sampler, collate_fn=DataSet.sortedPadCollate())

    logger = OrderedDict({
        'log_train': ['epoch,loss,loss_real,net_lr,time'],
        'log_eval': ['loss_real,time']
    })
    manager = utils.Manager(logger, build_model, args)

    # get GPU info
    gpu_info = utils.gather_all_gpu_info(args.gpu)

    if args.rank == 0:
        print("> Model built.")
        print("Model size:{:.2f}M".format(
            utils.count_parameters(manager.model)/1e6))

        utils.gen_readme(args.dir+'/readme.md',
                         model=manager.model, gpu_info=gpu_info)

    # init ctc-crf, args.iscrf is set in build_model
    if args.iscrf:
        gpus = torch.IntTensor([args.gpu])
        ctc_crf_base.init_env(f"{args.data}/den_meta/den_lm.fst", gpus)

    # training
    manager.run(train_sampler, trainloader, testloader, args)

    if args.iscrf:
        ctc_crf_base.release_env(gpus)
Exemplo n.º 3
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.start_rank + gpu
    TARGET_GPUS = [args.gpu]
    gpus = torch.IntTensor(TARGET_GPUS)
    logger = None
    ckpt_path = "models_chunk_twin_context"
    os.system("mkdir -p {}".format(ckpt_path))
    if args.rank == 0:
        logger = init_logging(
            "chunk_model", "{}/train.log".format("models_chunk_twin_context"))
        args_msg = [
            '  %s: %s' % (name, value) for (name, value) in vars(args).items()
        ]
        logger.info('args:\n' + '\n'.join(args_msg))

        csv_file = open(args.csv_file, 'w', newline='')
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(header)

    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)
    #print("rank {} init process grop".format(args.rank),
    #      datetime.datetime.now(), flush=True)
    dist.init_process_group(backend='nccl',
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)

    model = CAT_Chunk_Model(args.feature_size, args.hdim, args.output_unit,
                            args.dropout, args.lamb, args.reg_weight,
                            args.ctc_crf)
    if args.rank == 0:
        params_msg = params_num(model)
        logger.info('\n'.join(params_msg))

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])

    model.cuda(args.gpu)
    model = nn.parallel.DistributedDataParallel(model, device_ids=TARGET_GPUS)

    reg_model = CAT_RegModel(args.feature_size, args.hdim, args.output_unit,
                             args.dropout, args.lamb)
    loaded_reg_model = torch.load(args.regmodel_checkpoint)
    reg_model.load_state_dict(loaded_reg_model)
    reg_model.cuda(args.gpu)
    reg_model = nn.parallel.DistributedDataParallel(reg_model,
                                                    device_ids=TARGET_GPUS)

    model.train()
    reg_model.eval()
    prev_epoch_time = timeit.default_timer()
    while True:
        # training stage
        epoch += 1
        gc.collect()

        if epoch > 2:
            cate_list = list(range(1, args.cate, 1))
            random.shuffle(cate_list)
        else:
            cate_list = range(1, args.cate, 1)

        for cate in cate_list:
            pkl_path = args.tr_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            batch_size = int(args.gpu_batch_size * 2 / cate)
            if batch_size < 2:
                batch_size = 2
            #print("rank {} pkl path {} batch size {}".format(
            #    args.rank, pkl_path, batch_size))
            tr_dataset = SpeechDatasetMemPickel(pkl_path)
            if tr_dataset.__len__() < args.world_size:
                continue
            jitter = random.randint(-args.jitter_range, args.jitter_range)
            chunk_size = args.default_chunk_size + jitter
            tr_sampler = DistributedSampler(tr_dataset)
            tr_dataloader = DataLoader(tr_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(chunk_size),
                                       drop_last=True,
                                       sampler=tr_sampler)
            tr_sampler.set_epoch(epoch)  # important for data shuffle
            print(
                "rank {} lengths_cate: {}, chunk_size: {}, training epoch: {}".
                format(args.rank, cate, chunk_size, epoch))
            train_chunk_model(model, reg_model, tr_dataloader, optimizer,
                              epoch, chunk_size, TARGET_GPUS, args, logger)

        # cv stage
        model.eval()
        cv_losses_sum = []
        cv_cls_losses_sum = []
        count = 0
        cate_list = range(1, args.cate, 1)
        for cate in cate_list:
            pkl_path = args.dev_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            batch_size = int(args.gpu_batch_size * 2 / cate)
            if batch_size < 2:
                batch_size = 2
            cv_dataset = SpeechDatasetMemPickel(pkl_path)
            cv_dataloader = DataLoader(cv_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(
                                           args.default_chunk_size),
                                       drop_last=True)
            validate_count = validate_chunk_model(model, reg_model,
                                                  cv_dataloader, epoch,
                                                  cv_losses_sum,
                                                  cv_cls_losses_sum, args,
                                                  logger)
            count += validate_count

        cv_loss = np.sum(np.asarray(cv_losses_sum)) / count
        cv_cls_loss = np.sum(np.asarray(cv_cls_losses_sum)) / count

        #print("mean_cv_loss:{} , mean_cv_cls_loss: {}".format(cv_loss, cv_cls_loss))
        if args.rank == 0:
            save_ckpt(
                {
                    'cv_loss': cv_loss,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr': lr,
                    'epoch': epoch
                }, epoch < args.min_epoch or cv_loss <= prev_cv_loss,
                ckpt_path, "model.epoch.{}".format(epoch))

            csv_row = [
                epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr,
                cv_loss
            ]
            prev_epoch_time = timeit.default_timer()
            csv_writer.writerow(csv_row)
            csv_file.flush()
            plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss
        else:
            args.annealing_epoch = 0

        lr = adjust_lr_distribute(optimizer, args.origin_lr, lr, cv_loss,
                                  prev_cv_loss, epoch, args.annealing_epoch,
                                  args.gpu_batch_size, args.world_size)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break

        model.train()

    ctc_crf_base.release_env(gpus)
Exemplo n.º 4
0
def main_worker(gpu, ngpus_per_node, args):
    csv_file = None
    csv_writer = None

    args.gpu = gpu
    args.rank = args.start_rank + gpu
    TARGET_GPUS = [args.gpu]
    logger = None
    ckpt_path = "models"
    os.system("mkdir -p {}".format(ckpt_path))

    if args.rank == 0:
        logger = init_logging(args.model, "{}/train.log".format(ckpt_path))
        args_msg = [
            '  %s: %s' % (name, value) for (name, value) in vars(args).items()
        ]
        logger.info('args:\n' + '\n'.join(args_msg))

        csv_file = open(args.csv_file, 'w', newline='')
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(header)

    gpus = torch.IntTensor(TARGET_GPUS)
    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)
    dist.init_process_group(backend='nccl',
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)

    torch.cuda.set_device(args.gpu)

    model = CAT_Model(args.arch, args.feature_size, args.hdim,
                      args.output_unit, args.layers, args.dropout, args.lamb,
                      args.ctc_crf)
    if args.rank == 0:
        params_msg = params_num(model)
        logger.info('\n'.join(params_msg))

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])
    model.cuda(args.gpu)
    model = nn.parallel.DistributedDataParallel(model, device_ids=TARGET_GPUS)

    tr_dataset = SpeechDatasetPickel(args.tr_data_path)
    tr_sampler = DistributedSampler(tr_dataset)
    tr_dataloader = DataLoader(tr_dataset,
                               batch_size=args.gpu_batch_size,
                               shuffle=False,
                               num_workers=args.data_loader_workers,
                               pin_memory=True,
                               collate_fn=PadCollate(),
                               sampler=tr_sampler)
    cv_dataset = SpeechDatasetPickel(args.dev_data_path)
    cv_dataloader = DataLoader(cv_dataset,
                               batch_size=args.gpu_batch_size,
                               shuffle=False,
                               num_workers=args.data_loader_workers,
                               pin_memory=True,
                               collate_fn=PadCollate())

    prev_epoch_time = timeit.default_timer()

    while True:
        # training stage
        epoch += 1
        tr_sampler.set_epoch(epoch)  # important for data shuffle
        gc.collect()
        train(model, tr_dataloader, optimizer, epoch, args, logger)
        cv_loss = validate(model, cv_dataloader, epoch, args, logger)
        # save model
        if args.rank == 0:
            save_ckpt(
                {
                    'cv_loss': cv_loss,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr': lr,
                    'epoch': epoch
                }, cv_loss <= prev_cv_loss, ckpt_path,
                "model.epoch.{}".format(epoch))

            csv_row = [
                epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr,
                cv_loss
            ]
            prev_epoch_time = timeit.default_timer()
            csv_writer.writerow(csv_row)
            csv_file.flush()
            plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss
        else:
            args.annealing_epoch = 0

        lr = adjust_lr_distribute(optimizer, args.origin_lr, lr, cv_loss,
                                  prev_cv_loss, epoch, args.annealing_epoch,
                                  args.gpu_batch_size, args.world_size)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break

    ctc_crf_base.release_env(gpus)
Exemplo n.º 5
0
        label_list = [
            labels_padded[i, :x] for i, x in enumerate(label_lengths)
        ]
        labels = torch.cat(label_list)
        # netout, _ = self.net(logits, input_lengths)
        # netout = self.linear(netout)
        netout = F.log_softmax(logits, dim=2)
        loss = self.loss_fn(netout, labels, input_lengths, label_lengths)
        return loss


if __name__ == '__main__':
    device = torch.device("cuda:0")

    ctc_crf_base.init_env(LM_PATH, gpus)

    # Softmax logits for the following inputs:
    logits = np.array([[0.1, 0.6, 0.6, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
                      dtype=np.float32)

    # dimensions should be t, n, p: (t timesteps, n minibatches,
    # p prob of each alphabet). This is one instance, so expand
    # dimensions in the middle
    logits = np.expand_dims(logits, 0)
    labels = np.asarray([[1, 2]], dtype=np.int32)
    input_lengths = np.asarray([2], dtype=np.int32)
    label_lengths = np.asarray([2], dtype=np.int32)

    # print(logits.shape)
Exemplo n.º 6
0
def train():
    args = parse_args()

    args_msg = [
        '  %s: %s' % (name, value) for (name, value) in vars(args).items()
    ]
    logger.info('args:\n' + '\n'.join(args_msg))

    ckpt_path = "models_chunk_twin_context"
    os.system("mkdir -p {}".format(ckpt_path))
    logger = init_logging("chunk_model", "{}/train.log".format(ckpt_path))

    csv_file = open(args.csv_file, 'w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(header)

    batch_size = args.batch_size
    device = torch.device("cuda:0")

    reg_weight = args.reg_weight

    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)

    model = CAT_Chunk_Model(args.feature_size, args.hdim, args.output_unit,
                            args.dropout, args.lamb, reg_weight)

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])

    model.cuda()
    model = nn.DataParallel(model)
    model.to(device)

    reg_model = CAT_RegModel(args.feature_size, args.hdim, args.output_unit,
                             args.dropout, args.lamb)

    loaded_reg_model = torch.load(args.regmodel_checkpoint)
    reg_model.load_state_dict(loaded_reg_model)

    reg_model.cuda()
    reg_model = nn.DataParallel(reg_model)
    reg_model.to(device)

    prev_epoch_time = timeit.default_timer()

    model.train()
    reg_model.eval()
    while True:
        # training stage
        epoch += 1
        gc.collect()

        if epoch > 2:
            cate_list = list(range(1, args.cate, 1))
            random.shuffle(cate_list)
        else:
            cate_list = range(1, args.cate, 1)

        for cate in cate_list:
            pkl_path = args.tr_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            tr_dataset = SpeechDatasetMemPickel(pkl_path)

            jitter = random.randint(-args.jitter_range, args.jitter_range)
            chunk_size = args.default_chunk_size + jitter

            tr_dataloader = DataLoader(tr_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(chunk_size))

            train_chunk_model(model, reg_model, tr_dataloader, optimizer,
                              epoch, chunk_size, TARGET_GPUS, args, logger)

        # cv stage
        model.eval()
        cv_losses_sum = []
        cv_cls_losses_sum = []
        count = 0
        cate_list = range(1, args.cate, 1)
        for cate in cate_list:
            pkl_path = args.dev_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            cv_dataset = SpeechDatasetMemPickel(pkl_path)
            cv_dataloader = DataLoader(cv_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(
                                           args.default_chunk_size))
            validate_count = validate_chunk_model(model, reg_model,
                                                  cv_dataloader, epoch,
                                                  cv_losses_sum,
                                                  cv_cls_losses_sum, args,
                                                  logger)
            count += validate_count
        cv_loss = np.sum(np.asarray(cv_losses_sum)) / count
        cv_cls_loss = np.sum(np.asarray(cv_cls_losses_sum)) / count
        # save model
        save_ckpt(
            {
                'cv_loss': cv_loss,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr': lr,
                'epoch': epoch
            }, epoch < args.min_epoch or cv_loss <= prev_cv_loss, ckpt_path,
            "model.epoch.{}".format(epoch))

        csv_row = [
            epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr, cv_loss
        ]
        prev_epoch_time = timeit.default_timer()
        csv_writer.writerow(csv_row)
        csv_file.flush()
        plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss

        lr = adjust_lr(optimizer, args.origin_lr, lr, cv_loss, prev_cv_loss,
                       epoch, args.min_epoch)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break
        model.train()

    ctc_crf_base.release_env(gpus)