Ejemplo n.º 1
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.start_rank + gpu
    TARGET_GPUS = [args.gpu]
    gpus = torch.IntTensor(TARGET_GPUS)
    logger = None
    ckpt_path = "models_chunk_twin_context"
    os.system("mkdir -p {}".format(ckpt_path))
    if args.rank == 0:
        logger = init_logging(
            "chunk_model", "{}/train.log".format("models_chunk_twin_context"))
        args_msg = [
            '  %s: %s' % (name, value) for (name, value) in vars(args).items()
        ]
        logger.info('args:\n' + '\n'.join(args_msg))

        csv_file = open(args.csv_file, 'w', newline='')
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(header)

    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)
    #print("rank {} init process grop".format(args.rank),
    #      datetime.datetime.now(), flush=True)
    dist.init_process_group(backend='nccl',
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)

    model = CAT_Chunk_Model(args.feature_size, args.hdim, args.output_unit,
                            args.dropout, args.lamb, args.reg_weight,
                            args.ctc_crf)
    if args.rank == 0:
        params_msg = params_num(model)
        logger.info('\n'.join(params_msg))

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])

    model.cuda(args.gpu)
    model = nn.parallel.DistributedDataParallel(model, device_ids=TARGET_GPUS)

    reg_model = CAT_RegModel(args.feature_size, args.hdim, args.output_unit,
                             args.dropout, args.lamb)
    loaded_reg_model = torch.load(args.regmodel_checkpoint)
    reg_model.load_state_dict(loaded_reg_model)
    reg_model.cuda(args.gpu)
    reg_model = nn.parallel.DistributedDataParallel(reg_model,
                                                    device_ids=TARGET_GPUS)

    model.train()
    reg_model.eval()
    prev_epoch_time = timeit.default_timer()
    while True:
        # training stage
        epoch += 1
        gc.collect()

        if epoch > 2:
            cate_list = list(range(1, args.cate, 1))
            random.shuffle(cate_list)
        else:
            cate_list = range(1, args.cate, 1)

        for cate in cate_list:
            pkl_path = args.tr_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            batch_size = int(args.gpu_batch_size * 2 / cate)
            if batch_size < 2:
                batch_size = 2
            #print("rank {} pkl path {} batch size {}".format(
            #    args.rank, pkl_path, batch_size))
            tr_dataset = SpeechDatasetMemPickel(pkl_path)
            if tr_dataset.__len__() < args.world_size:
                continue
            jitter = random.randint(-args.jitter_range, args.jitter_range)
            chunk_size = args.default_chunk_size + jitter
            tr_sampler = DistributedSampler(tr_dataset)
            tr_dataloader = DataLoader(tr_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(chunk_size),
                                       drop_last=True,
                                       sampler=tr_sampler)
            tr_sampler.set_epoch(epoch)  # important for data shuffle
            print(
                "rank {} lengths_cate: {}, chunk_size: {}, training epoch: {}".
                format(args.rank, cate, chunk_size, epoch))
            train_chunk_model(model, reg_model, tr_dataloader, optimizer,
                              epoch, chunk_size, TARGET_GPUS, args, logger)

        # cv stage
        model.eval()
        cv_losses_sum = []
        cv_cls_losses_sum = []
        count = 0
        cate_list = range(1, args.cate, 1)
        for cate in cate_list:
            pkl_path = args.dev_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            batch_size = int(args.gpu_batch_size * 2 / cate)
            if batch_size < 2:
                batch_size = 2
            cv_dataset = SpeechDatasetMemPickel(pkl_path)
            cv_dataloader = DataLoader(cv_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(
                                           args.default_chunk_size),
                                       drop_last=True)
            validate_count = validate_chunk_model(model, reg_model,
                                                  cv_dataloader, epoch,
                                                  cv_losses_sum,
                                                  cv_cls_losses_sum, args,
                                                  logger)
            count += validate_count

        cv_loss = np.sum(np.asarray(cv_losses_sum)) / count
        cv_cls_loss = np.sum(np.asarray(cv_cls_losses_sum)) / count

        #print("mean_cv_loss:{} , mean_cv_cls_loss: {}".format(cv_loss, cv_cls_loss))
        if args.rank == 0:
            save_ckpt(
                {
                    'cv_loss': cv_loss,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr': lr,
                    'epoch': epoch
                }, epoch < args.min_epoch or cv_loss <= prev_cv_loss,
                ckpt_path, "model.epoch.{}".format(epoch))

            csv_row = [
                epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr,
                cv_loss
            ]
            prev_epoch_time = timeit.default_timer()
            csv_writer.writerow(csv_row)
            csv_file.flush()
            plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss
        else:
            args.annealing_epoch = 0

        lr = adjust_lr_distribute(optimizer, args.origin_lr, lr, cv_loss,
                                  prev_cv_loss, epoch, args.annealing_epoch,
                                  args.gpu_batch_size, args.world_size)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break

        model.train()

    ctc_crf_base.release_env(gpus)
Ejemplo n.º 2
0
def main_worker(gpu, ngpus_per_node, args):
    csv_file = None
    csv_writer = None

    args.gpu = gpu
    args.rank = args.start_rank + gpu
    TARGET_GPUS = [args.gpu]
    logger = None
    ckpt_path = "models"
    os.system("mkdir -p {}".format(ckpt_path))

    if args.rank == 0:
        logger = init_logging(args.model, "{}/train.log".format(ckpt_path))
        args_msg = [
            '  %s: %s' % (name, value) for (name, value) in vars(args).items()
        ]
        logger.info('args:\n' + '\n'.join(args_msg))

        csv_file = open(args.csv_file, 'w', newline='')
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(header)

    gpus = torch.IntTensor(TARGET_GPUS)
    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)
    dist.init_process_group(backend='nccl',
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)

    torch.cuda.set_device(args.gpu)

    model = CAT_Model(args.arch, args.feature_size, args.hdim,
                      args.output_unit, args.layers, args.dropout, args.lamb,
                      args.ctc_crf)
    if args.rank == 0:
        params_msg = params_num(model)
        logger.info('\n'.join(params_msg))

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])
    model.cuda(args.gpu)
    model = nn.parallel.DistributedDataParallel(model, device_ids=TARGET_GPUS)

    tr_dataset = SpeechDatasetPickel(args.tr_data_path)
    tr_sampler = DistributedSampler(tr_dataset)
    tr_dataloader = DataLoader(tr_dataset,
                               batch_size=args.gpu_batch_size,
                               shuffle=False,
                               num_workers=args.data_loader_workers,
                               pin_memory=True,
                               collate_fn=PadCollate(),
                               sampler=tr_sampler)
    cv_dataset = SpeechDatasetPickel(args.dev_data_path)
    cv_dataloader = DataLoader(cv_dataset,
                               batch_size=args.gpu_batch_size,
                               shuffle=False,
                               num_workers=args.data_loader_workers,
                               pin_memory=True,
                               collate_fn=PadCollate())

    prev_epoch_time = timeit.default_timer()

    while True:
        # training stage
        epoch += 1
        tr_sampler.set_epoch(epoch)  # important for data shuffle
        gc.collect()
        train(model, tr_dataloader, optimizer, epoch, args, logger)
        cv_loss = validate(model, cv_dataloader, epoch, args, logger)
        # save model
        if args.rank == 0:
            save_ckpt(
                {
                    'cv_loss': cv_loss,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr': lr,
                    'epoch': epoch
                }, cv_loss <= prev_cv_loss, ckpt_path,
                "model.epoch.{}".format(epoch))

            csv_row = [
                epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr,
                cv_loss
            ]
            prev_epoch_time = timeit.default_timer()
            csv_writer.writerow(csv_row)
            csv_file.flush()
            plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss
        else:
            args.annealing_epoch = 0

        lr = adjust_lr_distribute(optimizer, args.origin_lr, lr, cv_loss,
                                  prev_cv_loss, epoch, args.annealing_epoch,
                                  args.gpu_batch_size, args.world_size)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break

    ctc_crf_base.release_env(gpus)
Ejemplo n.º 3
0
def main():
    args = parse_args()
    ngpus_per_node = torch.cuda.device_count()
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
    plot_train_figure(args.csv_file, args.figure_file)
Ejemplo n.º 4
0
def train():
    args = parse_args()

    args_msg = [
        '  %s: %s' % (name, value) for (name, value) in vars(args).items()
    ]
    logger.info('args:\n' + '\n'.join(args_msg))

    ckpt_path = "models_chunk_twin_context"
    os.system("mkdir -p {}".format(ckpt_path))
    logger = init_logging("chunk_model", "{}/train.log".format(ckpt_path))

    csv_file = open(args.csv_file, 'w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(header)

    batch_size = args.batch_size
    device = torch.device("cuda:0")

    reg_weight = args.reg_weight

    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)

    model = CAT_Chunk_Model(args.feature_size, args.hdim, args.output_unit,
                            args.dropout, args.lamb, reg_weight)

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])

    model.cuda()
    model = nn.DataParallel(model)
    model.to(device)

    reg_model = CAT_RegModel(args.feature_size, args.hdim, args.output_unit,
                             args.dropout, args.lamb)

    loaded_reg_model = torch.load(args.regmodel_checkpoint)
    reg_model.load_state_dict(loaded_reg_model)

    reg_model.cuda()
    reg_model = nn.DataParallel(reg_model)
    reg_model.to(device)

    prev_epoch_time = timeit.default_timer()

    model.train()
    reg_model.eval()
    while True:
        # training stage
        epoch += 1
        gc.collect()

        if epoch > 2:
            cate_list = list(range(1, args.cate, 1))
            random.shuffle(cate_list)
        else:
            cate_list = range(1, args.cate, 1)

        for cate in cate_list:
            pkl_path = args.tr_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            tr_dataset = SpeechDatasetMemPickel(pkl_path)

            jitter = random.randint(-args.jitter_range, args.jitter_range)
            chunk_size = args.default_chunk_size + jitter

            tr_dataloader = DataLoader(tr_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(chunk_size))

            train_chunk_model(model, reg_model, tr_dataloader, optimizer,
                              epoch, chunk_size, TARGET_GPUS, args, logger)

        # cv stage
        model.eval()
        cv_losses_sum = []
        cv_cls_losses_sum = []
        count = 0
        cate_list = range(1, args.cate, 1)
        for cate in cate_list:
            pkl_path = args.dev_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            cv_dataset = SpeechDatasetMemPickel(pkl_path)
            cv_dataloader = DataLoader(cv_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(
                                           args.default_chunk_size))
            validate_count = validate_chunk_model(model, reg_model,
                                                  cv_dataloader, epoch,
                                                  cv_losses_sum,
                                                  cv_cls_losses_sum, args,
                                                  logger)
            count += validate_count
        cv_loss = np.sum(np.asarray(cv_losses_sum)) / count
        cv_cls_loss = np.sum(np.asarray(cv_cls_losses_sum)) / count
        # save model
        save_ckpt(
            {
                'cv_loss': cv_loss,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr': lr,
                'epoch': epoch
            }, epoch < args.min_epoch or cv_loss <= prev_cv_loss, ckpt_path,
            "model.epoch.{}".format(epoch))

        csv_row = [
            epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr, cv_loss
        ]
        prev_epoch_time = timeit.default_timer()
        csv_writer.writerow(csv_row)
        csv_file.flush()
        plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss

        lr = adjust_lr(optimizer, args.origin_lr, lr, cv_loss, prev_cv_loss,
                       epoch, args.min_epoch)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break
        model.train()

    ctc_crf_base.release_env(gpus)