コード例 #1
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    args.rank = args.rank * ngpus_per_node + gpu
    print(f"Use GPU: local[{args.gpu}] | global[{args.rank}]")

    dist.init_process_group(
        backend=args.dist_backend, init_method=args.dist_url,
        world_size=args.world_size, rank=args.rank)

    args.batch_size = args.batch_size // ngpus_per_node

    print("> Data prepare")
    if args.h5py:
        data_format = "hdf5"
        utils.highlight_msg("H5py reading might cause error with Multi-GPUs.")
        Dataset = DataSet.SpeechDataset
    else:
        data_format = "pickle"
        Dataset = DataSet.SpeechDatasetPickle

    tr_set = Dataset(
        f"{args.data}/{data_format}/tr.{data_format}")
    test_set = Dataset(
        f"{args.data}/{data_format}/cv.{data_format}")
    print("Data prepared.")

    train_sampler = DistributedSampler(tr_set)
    test_sampler = DistributedSampler(test_set)
    test_sampler.set_epoch(1)

    trainloader = DataLoader(
        tr_set, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True,
        sampler=train_sampler, collate_fn=DataSet.sortedPadCollate())

    testloader = DataLoader(
        test_set, batch_size=args.batch_size, shuffle=(test_sampler is None),
        num_workers=args.workers, pin_memory=True,
        sampler=test_sampler, collate_fn=DataSet.sortedPadCollate())

    logger = OrderedDict({
        'log_train': ['epoch,loss,loss_real,net_lr,time'],
        'log_eval': ['loss_real,time']
    })
    manager = utils.Manager(logger, build_model, args)

    # get GPU info
    gpu_info = utils.gather_all_gpu_info(args.gpu)

    if args.rank == 0:
        print("> Model built.")
        print("Model size:{:.2f}M".format(
            utils.count_parameters(manager.model)/1e6))

        utils.gen_readme(args.dir+'/readme.md',
                         model=manager.model, gpu_info=gpu_info)

    # init ctc-crf, args.iscrf is set in build_model
    if args.iscrf:
        gpus = torch.IntTensor([args.gpu])
        ctc_crf_base.init_env(f"{args.data}/den_meta/den_lm.fst", gpus)

    # training
    manager.run(train_sampler, trainloader, testloader, args)

    if args.iscrf:
        ctc_crf_base.release_env(gpus)
コード例 #2
0
def train():
    parser = argparse.ArgumentParser(description="recognition argument")
    parser.add_argument("--min_epoch", type=int, default=15)
    parser.add_argument("--output_unit", type=int)
    parser.add_argument("--lamb", type=float, default=0.1)
    parser.add_argument("--hdim", type=int, default=512)
    parser.add_argument("--layers", type=int, default=6)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--feature_size", type=int, default=120)
    parser.add_argument("--data_path")
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--stop_lr", type=float, default=0.00001)
    args = parser.parse_args()

    batch_size = args.batch_size

    model = Model(args.feature_size, args.hdim, args.output_unit, args.layers,
                  args.dropout, args.lamb)
    device = torch.device("cuda:0")
    model.cuda()
    model = nn.DataParallel(model)
    model.to(device)

    lr = args.lr
    optimizer = optim.Adam(model.parameters(), lr=lr)

    tr_dataset = SpeechDatasetMem(args.data_path + "/data/hdf5/tr.hdf5")
    tr_dataloader = DataLoader(tr_dataset,
                               batch_size=batch_size,
                               shuffle=True,
                               num_workers=16,
                               collate_fn=PadCollate())

    cv_dataset = SpeechDatasetMem(args.data_path + "/data/hdf5/cv.hdf5")
    cv_dataloader = DataLoader(cv_dataset,
                               batch_size=batch_size,
                               shuffle=False,
                               num_workers=16,
                               collate_fn=PadCollate())

    prev_t = 0
    epoch = 0
    prev_cv_loss = np.inf
    model.train()
    while True:
        # training stage
        torch.save(model.module.state_dict(),
                   args.data_path + "/models/best_model")
        epoch += 1

        for i, minibatch in enumerate(tr_dataloader):
            print("training epoch: {}, step: {}".format(epoch, i))
            logits, input_lengths, labels_padded, label_lengths, path_weights = minibatch

            sys.stdout.flush()
            model.zero_grad()
            optimizer.zero_grad()

            loss = model(logits, labels_padded, input_lengths, label_lengths)
            partial_loss = torch.mean(loss.cpu())
            weight = torch.mean(path_weights)
            real_loss = partial_loss - weight

            loss.backward(loss.new_ones(len(TARGET_GPUS)))

            optimizer.step()
            t2 = timeit.default_timer()
            print("time: {}, tr_real_loss: {}, lr: {}".format(
                t2 - prev_t, real_loss.item(),
                optimizer.param_groups[0]['lr']))
            prev_t = t2

        # save model
        torch.save(model.module.state_dict(),
                   args.data_path + "/models/model.epoch.{}".format(epoch))

        # cv stage
        model.eval()
        cv_losses = []
        cv_losses_sum = []
        count = 0

        for i, minibatch in enumerate(cv_dataloader):
            print("cv epoch: {}, step: {}".format(epoch, i))
            logits, input_lengths, labels_padded, label_lengths, path_weights = minibatch

            loss = model(logits, labels_padded, input_lengths, label_lengths)
            loss_size = loss.size(0)
            count = count + loss_size
            partial_loss = torch.mean(loss.cpu())
            weight = torch.mean(path_weights)
            real_loss = partial_loss - weight
            real_loss_sum = real_loss * loss_size
            cv_losses_sum.append(real_loss_sum.item())
            print("cv_real_loss: {}".format(real_loss.item()))

        cv_loss = np.sum(np.asarray(cv_losses_sum)) / count
        print("mean_cv_loss: {}".format(cv_loss))
        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            torch.save(model.module.state_dict(),
                       args.data_path + "/models/best_model")
            prev_cv_loss = cv_loss
        else:
            print(
                "cv loss does not improve, decay the learning rate from {} to {}"
                .format(lr, lr / 10.0))
            adjust_lr(optimizer, lr / 10.0)
            lr = lr / 10.0
            if (lr < args.stop_lr):
                print("learning rate is too small, finish training")
                break

        model.train()

    ctc_crf_base.release_env(gpus)
コード例 #3
0
def train():
    parser = argparse.ArgumentParser(description="recognition argument")
    parser.add_argument("dir", default="models")
    parser.add_argument("--arch",
                        choices=[
                            'BLSTM', 'LSTM', 'VGGBLSTM', 'VGGLSTM',
                            'LSTMrowCONV', 'TDNN_LSTM', 'BLSTMN'
                        ],
                        default='BLSTM')
    parser.add_argument("--min_epoch", type=int, default=15)
    parser.add_argument("--output_unit", type=int)
    parser.add_argument("--lamb", type=float, default=0.1)
    parser.add_argument("--hdim", type=int, default=512)
    parser.add_argument("--layers", type=int, default=6)
    parser.add_argument("--dropout", type=float, default=0.5)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--feature_size", type=int, default=120)
    parser.add_argument("--data_path")
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--stop_lr", type=float, default=0.00001)
    parser.add_argument("--resume", type=bool, default=False)
    parser.add_argument("--pretrained_model_path")
    args = parser.parse_args()

    os.makedirs(args.dir + '/board', exist_ok=True)
    writer = SummaryWriter(args.dir + '/board')
    # save configuration
    with open(args.dir + '/config.json', "w") as fout:
        config = {
            "arch": args.arch,
            "output_unit": args.output_unit,
            "hdim": args.hdim,
            "layers": args.layers,
            "dropout": args.dropout,
            "feature_size": args.feature_size,
        }
        json.dump(config, fout)

    model = Model(args.arch, args.feature_size, args.hdim, args.output_unit,
                  args.layers, args.dropout, args.lamb)

    if args.resume:
        print("resume from {}".format(args.pretrained_model_path))
        pretrained_dict = torch.load(args.pretrained_model_path)
        model.load_state_dict(pretrained_dict)

    device = torch.device("cuda:0")
    model.cuda()
    model = nn.DataParallel(model)
    model.to(device)

    lr = args.lr
    optimizer = optim.Adam(model.parameters(), lr=lr)

    tr_dataset = SpeechDatasetMem(args.data_path + "/tr.hdf5")
    tr_dataloader = DataLoader(tr_dataset,
                               batch_size=args.batch_size,
                               shuffle=True,
                               pin_memory=True,
                               num_workers=0,
                               collate_fn=PadCollate())

    cv_dataset = SpeechDatasetMem(args.data_path + "/cv.hdf5")
    cv_dataloader = DataLoader(cv_dataset,
                               batch_size=args.batch_size,
                               shuffle=False,
                               pin_memory=True,
                               num_workers=0,
                               collate_fn=PadCollate())

    prev_t = 0
    epoch = 0
    prev_cv_loss = np.inf
    model.train()
    while True:
        # training stage
        torch.save(model.module.state_dict(), args.dir + "/best_model")
        epoch += 1

        for i, minibatch in enumerate(tr_dataloader):
            print("training epoch: {}, step: {}".format(epoch, i))
            logits, input_lengths, labels_padded, label_lengths, path_weights = minibatch

            sys.stdout.flush()
            model.zero_grad()
            optimizer.zero_grad()

            loss = model(logits, labels_padded, input_lengths, label_lengths)
            partial_loss = torch.mean(loss.cpu())
            weight = torch.mean(path_weights)
            real_loss = partial_loss - weight

            loss.backward(loss.new_ones(len(TARGET_GPUS)))

            optimizer.step()
            t2 = timeit.default_timer()
            writer.add_scalar('training loss', real_loss.item(),
                              (epoch - 1) * len(tr_dataloader) + i)
            prev_t = t2

        # save model
        torch.save(model.module.state_dict(),
                   args.dir + "/model.epoch.{}".format(epoch))

        # cv stage
        model.eval()
        cv_losses_sum = []
        count = 0

        for i, minibatch in enumerate(cv_dataloader):
            print("cv epoch: {}, step: {}".format(epoch, i))
            logits, input_lengths, labels_padded, label_lengths, path_weights = minibatch

            loss = model(logits, labels_padded, input_lengths, label_lengths)
            loss_size = loss.size(0)
            count = count + loss_size
            partial_loss = torch.mean(loss.cpu())
            weight = torch.mean(path_weights)
            real_loss = partial_loss - weight
            real_loss_sum = real_loss * loss_size
            cv_losses_sum.append(real_loss_sum.item())
            print("cv_real_loss: {}".format(real_loss.item()))

        cv_loss = np.sum(np.asarray(cv_losses_sum)) / count
        writer.add_scalar('mean_cv_loss', cv_loss, epoch)
        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            torch.save(model.module.state_dict(), args.dir + "/best_model")
            prev_cv_loss = cv_loss
        else:
            print(
                "cv loss does not improve, decay the learning rate from {} to {}"
                .format(lr, lr / 10.0))
            adjust_lr(optimizer, lr / 10.0)
            lr = lr / 10.0
            if (lr < args.stop_lr):
                print("learning rate is too small, finish training")
                break

        model.train()

    ctc_crf_base.release_env(gpus)
コード例 #4
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    args.rank = args.start_rank + gpu
    TARGET_GPUS = [args.gpu]
    gpus = torch.IntTensor(TARGET_GPUS)
    logger = None
    ckpt_path = "models_chunk_twin_context"
    os.system("mkdir -p {}".format(ckpt_path))
    if args.rank == 0:
        logger = init_logging(
            "chunk_model", "{}/train.log".format("models_chunk_twin_context"))
        args_msg = [
            '  %s: %s' % (name, value) for (name, value) in vars(args).items()
        ]
        logger.info('args:\n' + '\n'.join(args_msg))

        csv_file = open(args.csv_file, 'w', newline='')
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(header)

    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)
    #print("rank {} init process grop".format(args.rank),
    #      datetime.datetime.now(), flush=True)
    dist.init_process_group(backend='nccl',
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)
    torch.cuda.set_device(args.gpu)

    model = CAT_Chunk_Model(args.feature_size, args.hdim, args.output_unit,
                            args.dropout, args.lamb, args.reg_weight,
                            args.ctc_crf)
    if args.rank == 0:
        params_msg = params_num(model)
        logger.info('\n'.join(params_msg))

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])

    model.cuda(args.gpu)
    model = nn.parallel.DistributedDataParallel(model, device_ids=TARGET_GPUS)

    reg_model = CAT_RegModel(args.feature_size, args.hdim, args.output_unit,
                             args.dropout, args.lamb)
    loaded_reg_model = torch.load(args.regmodel_checkpoint)
    reg_model.load_state_dict(loaded_reg_model)
    reg_model.cuda(args.gpu)
    reg_model = nn.parallel.DistributedDataParallel(reg_model,
                                                    device_ids=TARGET_GPUS)

    model.train()
    reg_model.eval()
    prev_epoch_time = timeit.default_timer()
    while True:
        # training stage
        epoch += 1
        gc.collect()

        if epoch > 2:
            cate_list = list(range(1, args.cate, 1))
            random.shuffle(cate_list)
        else:
            cate_list = range(1, args.cate, 1)

        for cate in cate_list:
            pkl_path = args.tr_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            batch_size = int(args.gpu_batch_size * 2 / cate)
            if batch_size < 2:
                batch_size = 2
            #print("rank {} pkl path {} batch size {}".format(
            #    args.rank, pkl_path, batch_size))
            tr_dataset = SpeechDatasetMemPickel(pkl_path)
            if tr_dataset.__len__() < args.world_size:
                continue
            jitter = random.randint(-args.jitter_range, args.jitter_range)
            chunk_size = args.default_chunk_size + jitter
            tr_sampler = DistributedSampler(tr_dataset)
            tr_dataloader = DataLoader(tr_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(chunk_size),
                                       drop_last=True,
                                       sampler=tr_sampler)
            tr_sampler.set_epoch(epoch)  # important for data shuffle
            print(
                "rank {} lengths_cate: {}, chunk_size: {}, training epoch: {}".
                format(args.rank, cate, chunk_size, epoch))
            train_chunk_model(model, reg_model, tr_dataloader, optimizer,
                              epoch, chunk_size, TARGET_GPUS, args, logger)

        # cv stage
        model.eval()
        cv_losses_sum = []
        cv_cls_losses_sum = []
        count = 0
        cate_list = range(1, args.cate, 1)
        for cate in cate_list:
            pkl_path = args.dev_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            batch_size = int(args.gpu_batch_size * 2 / cate)
            if batch_size < 2:
                batch_size = 2
            cv_dataset = SpeechDatasetMemPickel(pkl_path)
            cv_dataloader = DataLoader(cv_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(
                                           args.default_chunk_size),
                                       drop_last=True)
            validate_count = validate_chunk_model(model, reg_model,
                                                  cv_dataloader, epoch,
                                                  cv_losses_sum,
                                                  cv_cls_losses_sum, args,
                                                  logger)
            count += validate_count

        cv_loss = np.sum(np.asarray(cv_losses_sum)) / count
        cv_cls_loss = np.sum(np.asarray(cv_cls_losses_sum)) / count

        #print("mean_cv_loss:{} , mean_cv_cls_loss: {}".format(cv_loss, cv_cls_loss))
        if args.rank == 0:
            save_ckpt(
                {
                    'cv_loss': cv_loss,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr': lr,
                    'epoch': epoch
                }, epoch < args.min_epoch or cv_loss <= prev_cv_loss,
                ckpt_path, "model.epoch.{}".format(epoch))

            csv_row = [
                epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr,
                cv_loss
            ]
            prev_epoch_time = timeit.default_timer()
            csv_writer.writerow(csv_row)
            csv_file.flush()
            plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss
        else:
            args.annealing_epoch = 0

        lr = adjust_lr_distribute(optimizer, args.origin_lr, lr, cv_loss,
                                  prev_cv_loss, epoch, args.annealing_epoch,
                                  args.gpu_batch_size, args.world_size)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break

        model.train()

    ctc_crf_base.release_env(gpus)
コード例 #5
0
def main_worker(gpu, ngpus_per_node, args):
    csv_file = None
    csv_writer = None

    args.gpu = gpu
    args.rank = args.start_rank + gpu
    TARGET_GPUS = [args.gpu]
    logger = None
    ckpt_path = "models"
    os.system("mkdir -p {}".format(ckpt_path))

    if args.rank == 0:
        logger = init_logging(args.model, "{}/train.log".format(ckpt_path))
        args_msg = [
            '  %s: %s' % (name, value) for (name, value) in vars(args).items()
        ]
        logger.info('args:\n' + '\n'.join(args_msg))

        csv_file = open(args.csv_file, 'w', newline='')
        csv_writer = csv.writer(csv_file)
        csv_writer.writerow(header)

    gpus = torch.IntTensor(TARGET_GPUS)
    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)
    dist.init_process_group(backend='nccl',
                            init_method=args.dist_url,
                            world_size=args.world_size,
                            rank=args.rank)

    torch.cuda.set_device(args.gpu)

    model = CAT_Model(args.arch, args.feature_size, args.hdim,
                      args.output_unit, args.layers, args.dropout, args.lamb,
                      args.ctc_crf)
    if args.rank == 0:
        params_msg = params_num(model)
        logger.info('\n'.join(params_msg))

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])
    model.cuda(args.gpu)
    model = nn.parallel.DistributedDataParallel(model, device_ids=TARGET_GPUS)

    tr_dataset = SpeechDatasetPickel(args.tr_data_path)
    tr_sampler = DistributedSampler(tr_dataset)
    tr_dataloader = DataLoader(tr_dataset,
                               batch_size=args.gpu_batch_size,
                               shuffle=False,
                               num_workers=args.data_loader_workers,
                               pin_memory=True,
                               collate_fn=PadCollate(),
                               sampler=tr_sampler)
    cv_dataset = SpeechDatasetPickel(args.dev_data_path)
    cv_dataloader = DataLoader(cv_dataset,
                               batch_size=args.gpu_batch_size,
                               shuffle=False,
                               num_workers=args.data_loader_workers,
                               pin_memory=True,
                               collate_fn=PadCollate())

    prev_epoch_time = timeit.default_timer()

    while True:
        # training stage
        epoch += 1
        tr_sampler.set_epoch(epoch)  # important for data shuffle
        gc.collect()
        train(model, tr_dataloader, optimizer, epoch, args, logger)
        cv_loss = validate(model, cv_dataloader, epoch, args, logger)
        # save model
        if args.rank == 0:
            save_ckpt(
                {
                    'cv_loss': cv_loss,
                    'model': model.module.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr': lr,
                    'epoch': epoch
                }, cv_loss <= prev_cv_loss, ckpt_path,
                "model.epoch.{}".format(epoch))

            csv_row = [
                epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr,
                cv_loss
            ]
            prev_epoch_time = timeit.default_timer()
            csv_writer.writerow(csv_row)
            csv_file.flush()
            plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss
        else:
            args.annealing_epoch = 0

        lr = adjust_lr_distribute(optimizer, args.origin_lr, lr, cv_loss,
                                  prev_cv_loss, epoch, args.annealing_epoch,
                                  args.gpu_batch_size, args.world_size)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break

    ctc_crf_base.release_env(gpus)
コード例 #6
0
    ctc_crf_base.init_env(LM_PATH, gpus)

    # Softmax logits for the following inputs:
    logits = np.array([[0.1, 0.6, 0.6, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]],
                      dtype=np.float32)

    # dimensions should be t, n, p: (t timesteps, n minibatches,
    # p prob of each alphabet). This is one instance, so expand
    # dimensions in the middle
    logits = np.expand_dims(logits, 0)
    labels = np.asarray([[1, 2]], dtype=np.int32)
    input_lengths = np.asarray([2], dtype=np.int32)
    label_lengths = np.asarray([2], dtype=np.int32)

    # print(logits.shape)

    model = Model(0.1)
    model.cuda()
    model = nn.DataParallel(model)
    model.to(device)

    # self.data_batch.append([torch.FloatTensor(mat), torch.IntTensor(label), torch.FloatTensor(weight)])
    loss = model(torch.FloatTensor(logits), torch.IntTensor(labels),
                 torch.IntTensor(input_lengths),
                 torch.IntTensor(label_lengths))
    print(loss)
    # loss.backward(loss.new_ones(len(TARGET_GPUS)))
    # print(x)
    ctc_crf_base.release_env(gpus)
コード例 #7
0
ファイル: train_chunk_context.py プロジェクト: thuspmi/CAT
def train():
    args = parse_args()

    args_msg = [
        '  %s: %s' % (name, value) for (name, value) in vars(args).items()
    ]
    logger.info('args:\n' + '\n'.join(args_msg))

    ckpt_path = "models_chunk_twin_context"
    os.system("mkdir -p {}".format(ckpt_path))
    logger = init_logging("chunk_model", "{}/train.log".format(ckpt_path))

    csv_file = open(args.csv_file, 'w', newline='')
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(header)

    batch_size = args.batch_size
    device = torch.device("cuda:0")

    reg_weight = args.reg_weight

    ctc_crf_base.init_env(args.den_lm_fst_path, gpus)

    model = CAT_Chunk_Model(args.feature_size, args.hdim, args.output_unit,
                            args.dropout, args.lamb, reg_weight)

    lr = args.origin_lr
    optimizer = optim.Adam(model.parameters(), lr=lr)
    epoch = 0
    prev_cv_loss = np.inf
    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        epoch = checkpoint['epoch']
        lr = checkpoint['lr']
        prev_cv_loss = checkpoint['cv_loss']
        model.load_state_dict(checkpoint['model'])

    model.cuda()
    model = nn.DataParallel(model)
    model.to(device)

    reg_model = CAT_RegModel(args.feature_size, args.hdim, args.output_unit,
                             args.dropout, args.lamb)

    loaded_reg_model = torch.load(args.regmodel_checkpoint)
    reg_model.load_state_dict(loaded_reg_model)

    reg_model.cuda()
    reg_model = nn.DataParallel(reg_model)
    reg_model.to(device)

    prev_epoch_time = timeit.default_timer()

    model.train()
    reg_model.eval()
    while True:
        # training stage
        epoch += 1
        gc.collect()

        if epoch > 2:
            cate_list = list(range(1, args.cate, 1))
            random.shuffle(cate_list)
        else:
            cate_list = range(1, args.cate, 1)

        for cate in cate_list:
            pkl_path = args.tr_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            tr_dataset = SpeechDatasetMemPickel(pkl_path)

            jitter = random.randint(-args.jitter_range, args.jitter_range)
            chunk_size = args.default_chunk_size + jitter

            tr_dataloader = DataLoader(tr_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(chunk_size))

            train_chunk_model(model, reg_model, tr_dataloader, optimizer,
                              epoch, chunk_size, TARGET_GPUS, args, logger)

        # cv stage
        model.eval()
        cv_losses_sum = []
        cv_cls_losses_sum = []
        count = 0
        cate_list = range(1, args.cate, 1)
        for cate in cate_list:
            pkl_path = args.dev_data_path + "/" + str(cate) + ".pkl"
            if not os.path.exists(pkl_path):
                continue
            cv_dataset = SpeechDatasetMemPickel(pkl_path)
            cv_dataloader = DataLoader(cv_dataset,
                                       batch_size=batch_size,
                                       shuffle=False,
                                       num_workers=0,
                                       collate_fn=PadCollateChunk(
                                           args.default_chunk_size))
            validate_count = validate_chunk_model(model, reg_model,
                                                  cv_dataloader, epoch,
                                                  cv_losses_sum,
                                                  cv_cls_losses_sum, args,
                                                  logger)
            count += validate_count
        cv_loss = np.sum(np.asarray(cv_losses_sum)) / count
        cv_cls_loss = np.sum(np.asarray(cv_cls_losses_sum)) / count
        # save model
        save_ckpt(
            {
                'cv_loss': cv_loss,
                'model': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr': lr,
                'epoch': epoch
            }, epoch < args.min_epoch or cv_loss <= prev_cv_loss, ckpt_path,
            "model.epoch.{}".format(epoch))

        csv_row = [
            epoch, (timeit.default_timer() - prev_epoch_time) / 60, lr, cv_loss
        ]
        prev_epoch_time = timeit.default_timer()
        csv_writer.writerow(csv_row)
        csv_file.flush()
        plot_train_figure(args.csv_file, args.figure_file)

        if epoch < args.min_epoch or cv_loss <= prev_cv_loss:
            prev_cv_loss = cv_loss

        lr = adjust_lr(optimizer, args.origin_lr, lr, cv_loss, prev_cv_loss,
                       epoch, args.min_epoch)
        if (lr < args.stop_lr):
            print("rank {} lr is too slow, finish training".format(args.rank),
                  datetime.datetime.now(),
                  flush=True)
            break
        model.train()

    ctc_crf_base.release_env(gpus)