Ejemplo n.º 1
0
def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir,
                     checkpoint, num_threads, batch_size, max_epochs,
                     learning_rate, lr_drop_epochs, lr_drop_ratio, momentum,
                     boost_none, none_count_scale, max_load_candidates,
                     coverage_thr, save_freq, use_cuda):
    logger = logging.getLogger(train_neusomatic.__name__)

    logger.info("----------------Train NeuSomatic Network-------------------")

    if not use_cuda:
        torch.set_num_threads(num_threads)

    data_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    num_channels = 119 if args.ensemble else 26
    net = NeuSomaticNet(num_channels)
    if use_cuda:
        net.cuda()

    if torch.cuda.device_count() > 1:
        logger.info("We use {} GPUs!".format(torch.cuda.device_count()))
        net = nn.DataParallel(net)

    if not os.path.exists("{}/models/".format(out_dir)):
        os.mkdir("{}/models/".format(out_dir))

    if checkpoint:
        logger.info(
            "Load pretrained model from checkpoint {}".format(checkpoint))
        pretrained_dict = torch.load(checkpoint,
                                     map_location=lambda storage, loc: storage)
        pretrained_state_dict = pretrained_dict["state_dict"]
        tag = pretrained_dict["tag"]
        sofar_epochs = pretrained_dict["epoch"]
        logger.info(
            "sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs))
        coverage_thr = pretrained_dict["coverage_thr"]
        logger.info(
            "Override coverage_thr from pretrained checkpoint: {}".format(
                coverage_thr))
        prev_epochs = sofar_epochs + 1
        model_dict = net.state_dict()
        # 1. filter out unnecessary keys
        # pretrained_state_dict = {
        # k: v for k, v in pretrained_state_dict.items() if k in model_dict}
        if "module." in pretrained_state_dict.keys(
        )[0] and "module." not in model_dict.keys()[0]:
            pretrained_state_dict = {
                k.split("module.")[1]: v
                for k, v in pretrained_state_dict.items()
                if k.split("module.")[1] in model_dict
            }
        elif "module." not in pretrained_state_dict.keys(
        )[0] and "module." in model_dict.keys()[0]:
            pretrained_state_dict = {("module." + k): v
                                     for k, v in pretrained_state_dict.items()
                                     if ("module." + k) in model_dict}
        else:
            pretrained_state_dict = {
                k: v
                for k, v in pretrained_state_dict.items() if k in model_dict
            }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_state_dict)
        # 3. load the new state dict
        net.load_state_dict(pretrained_state_dict)
    else:
        prev_epochs = 0
        time_now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
        tag = "neusomatic_{}".format(time_now)
    logger.info("tag: {}".format(tag))

    train_set = NeuSomaticDataset(roots=candidates_tsv,
                                  max_load_candidates=max_load_candidates,
                                  transform=data_transform,
                                  is_test=False,
                                  num_threads=num_threads,
                                  coverage_thr=coverage_thr)
    none_indices = train_set.get_none_indices()
    var_indices = train_set.get_var_indices()
    if none_indices:
        none_indices = map(lambda i: none_indices[i],
                           torch.randperm(len(none_indices)).tolist())
    logger.info("Non-somatic candidates: {}".format(len(none_indices)))
    if var_indices:
        var_indices = map(lambda i: var_indices[i],
                          torch.randperm(len(var_indices)).tolist())
    logger.info("Somatic candidates: {}".format(len(var_indices)))
    none_count = min(len(none_indices), len(var_indices) * none_count_scale)
    logger.info("Non-somatic considered in each epoch: {}".format(none_count))
    sampler = SubsetNoneSampler(none_indices, var_indices, none_count)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               num_workers=num_threads,
                                               pin_memory=True,
                                               sampler=sampler)
    logger.info("#Train cadidates: {}".format(len(train_set)))

    if validation_candidates_tsv:
        validation_set = NeuSomaticDataset(
            roots=validation_candidates_tsv,
            max_load_candidates=max_load_candidates,
            transform=data_transform,
            is_test=True,
            num_threads=num_threads,
            coverage_thr=coverage_thr)
        validation_loader = torch.utils.data.DataLoader(
            validation_set,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_threads,
            pin_memory=True)
        logger.info("#Validation candidates: {}".format(len(validation_set)))

    weights_type, weights_length = make_weights_for_balanced_classes(
        train_set.count_class_t, train_set.count_class_l, 4, 4, none_count)

    weights_type[2] *= boost_none
    weights_length[0] *= boost_none

    logger.info("weights_type:{}, weights_length:{}".format(
        weights_type, weights_length))

    loss_s = []
    gradients = torch.FloatTensor(weights_type)
    gradients2 = torch.FloatTensor(weights_length)
    if use_cuda:
        gradients = gradients.cuda()
        gradients2 = gradients2.cuda()
    criterion_crossentropy = nn.CrossEntropyLoss(gradients)
    criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2)
    criterion_smoothl1 = nn.SmoothL1Loss()
    optimizer = optim.SGD(net.parameters(),
                          lr=learning_rate,
                          momentum=momentum)

    net.train()
    len_train_set = none_count + len(var_indices)
    logger.info("Number of candidater per epoch: {}".format(len_train_set))
    print_freq = max(1, int(len_train_set / float(batch_size) / 4.0))
    curr_epoch = int(round(
        len(loss_s) / float(len_train_set) * batch_size)) + prev_epochs
    torch.save(
        {
            "state_dict": net.state_dict(),
            "tag": tag,
            "epoch": curr_epoch,
            "coverage_thr": coverage_thr
        },
        '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch))
    # loop over the dataset multiple times
    for epoch in range(max_epochs - prev_epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            # get the inputs
            (inputs, labels, var_pos_s, var_len_s, _), _ = data
            # wrap them in Variable
            inputs, labels, var_pos_s, var_len_s = Variable(inputs), Variable(
                labels), Variable(var_pos_s), Variable(var_len_s)
            if use_cuda:
                inputs, labels, var_pos_s, var_len_s = inputs.cuda(
                ), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            outputs, _ = net(inputs)
            [outputs_classification, outputs_pos, outputs_len] = outputs
            var_len_labels = Variable(
                torch.LongTensor(var_len_s.cpu().data.numpy()))
            if use_cuda:
                var_len_labels = var_len_labels.cuda()
            loss = criterion_crossentropy(
                outputs_classification, labels) + 1 * criterion_smoothl1(
                    outputs_pos, var_pos_s[:, 1]
                ) + 1 * criterion_crossentropy2(outputs_len, var_len_labels)

            loss.backward()
            optimizer.step()
            loss_s.append(loss.data[0])

            running_loss += loss.data[0]
            if i % print_freq == print_freq - 1:
                logger.info('epoch: {}, i: {:>5}, lr: {}, loss: {:.5f}'.format(
                    epoch + 1 + prev_epochs, i + 1, learning_rate,
                    running_loss / print_freq))
                running_loss = 0.0
        curr_epoch = int(round(
            len(loss_s) / float(len_train_set) * batch_size)) + prev_epochs
        if curr_epoch % save_freq == 0:
            torch.save(
                {
                    "state_dict": net.state_dict(),
                    "tag": tag,
                    "epoch": curr_epoch,
                    "coverage_thr": coverage_thr,
                }, '{}/models/checkpoint_{}_epoch{}.pth'.format(
                    out_dir, tag, curr_epoch))
            if validation_candidates_tsv:
                test(net, curr_epoch, validation_loader, use_cuda)
        if curr_epoch % lr_drop_epochs == 0:
            learning_rate *= lr_drop_ratio
            optimizer = optim.SGD(net.parameters(),
                                  lr=learning_rate,
                                  momentum=momentum)
    logger.info('Finished Training')

    curr_epoch = int(round(
        len(loss_s) / float(len_train_set) * batch_size)) + prev_epochs
    torch.save(
        {
            "state_dict": net.state_dict(),
            "tag": tag,
            "epoch": curr_epoch,
            "coverage_thr": coverage_thr
        },
        '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch))
    if validation_candidates_tsv:
        test(net, curr_epoch, validation_loader, use_cuda)
    logger.info("Total Epochs: {}".format(curr_epoch))
    return '{}/models/checkpoint_{}_epoch{}.pth'.format(
        out_dir, tag, curr_epoch)
Ejemplo n.º 2
0
def call_neusomatic(candidates_tsv, ref_file, out_dir, checkpoint, num_threads,
                    batch_size, max_load_candidates, pass_threshold,
                    lowqual_threshold, ensemble, use_cuda):
    logger = logging.getLogger(call_neusomatic.__name__)

    logger.info("-----------------Call Somatic Mutations--------------------")

    logger.info("PyTorch Version: {}".format(torch.__version__))
    logger.info("Torchvision Version: {}".format(torchvision.__version__))
    if not use_cuda:
        torch.set_num_threads(num_threads)

    chroms_order = get_chromosomes_order(reference=ref_file)
    with pysam.FastaFile(ref_file) as rf:
        chroms = rf.references

    vartype_classes = ['DEL', 'INS', 'NONE', 'SNP']
    data_transform = transforms.Compose(
        [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    num_channels = 119 if ensemble else 26
    net = NeuSomaticNet(num_channels)
    if use_cuda:
        logger.info("GPU calling!")
        net.cuda()
    else:
        logger.info("CPU calling!")

    if torch.cuda.device_count() > 1:
        logger.info("We use {} GPUs!".format(torch.cuda.device_count()))
        net = nn.DataParallel(net)

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    logger.info("Load pretrained model from checkpoint {}".format(checkpoint))
    pretrained_dict = torch.load(checkpoint,
                                 map_location=lambda storage, loc: storage)
    pretrained_state_dict = pretrained_dict["state_dict"]
    model_tag = pretrained_dict["tag"]
    logger.info("tag: {}".format(model_tag))

    matrices_dir = "{}/matrices_{}".format(out_dir, model_tag)
    if os.path.exists(matrices_dir):
        logger.warning("Remove matrices directory: {}".format(matrices_dir))
        shutil.rmtree(matrices_dir)
    os.mkdir(matrices_dir)
    coverage_thr = pretrained_dict["coverage_thr"]

    model_dict = net.state_dict()

    # 1. filter out unnecessary keys
    # pretrained_state_dict = {
    #     k: v for k, v in pretrained_state_dict.items() if k in model_dict}
    if "module." in list(
            pretrained_state_dict.keys())[0] and "module." not in list(
                model_dict.keys())[0]:
        pretrained_state_dict = {
            k.split("module.")[1]: v
            for k, v in pretrained_state_dict.items()
            if k.split("module.")[1] in model_dict
        }
    elif "module." not in list(
            pretrained_state_dict.keys())[0] and "module." in list(
                model_dict.keys())[0]:
        pretrained_state_dict = {("module." + k): v
                                 for k, v in pretrained_state_dict.items()
                                 if ("module." + k) in model_dict}
    else:
        pretrained_state_dict = {
            k: v
            for k, v in pretrained_state_dict.items() if k in model_dict
        }

    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_state_dict)
    # 3. load the new state dict
    net.load_state_dict(pretrained_state_dict)

    new_split_tsvs_dir = os.path.join(out_dir, "split_tsvs")
    if os.path.exists(new_split_tsvs_dir):
        logger.warning(
            "Remove split candidates directory: {}".format(new_split_tsvs_dir))
        shutil.rmtree(new_split_tsvs_dir)
    os.mkdir(new_split_tsvs_dir)
    Ls = []
    candidates_tsv_ = []
    split_i = 0
    for candidate_file in candidates_tsv:
        idx = pickle.load(open(candidate_file + ".idx", "rb"))
        if len(idx) > max_load_candidates / 2:
            logger.info("Splitting {} of lenght {}".format(
                candidate_file, len(idx)))
            new_split_tsvs_dir_i = os.path.join(new_split_tsvs_dir,
                                                "split_{}".format(split_i))
            if os.path.exists(new_split_tsvs_dir_i):
                logger.warning("Remove split candidates directory: {}".format(
                    new_split_tsvs_dir_i))
                shutil.rmtree(new_split_tsvs_dir_i)
            os.mkdir(new_split_tsvs_dir_i)
            candidate_file_splits = merge_tsvs(input_tsvs=[candidate_file],
                                               out=new_split_tsvs_dir_i,
                                               candidates_per_tsv=max(
                                                   1, max_load_candidates / 2),
                                               max_num_tsvs=100000,
                                               overwrite_merged_tsvs=True,
                                               keep_none_types=True)
            for candidate_file_split in candidate_file_splits:
                idx_split = pickle.load(
                    open(candidate_file_split + ".idx", "rb"))
                candidates_tsv_.append(candidate_file_split)
                Ls.append(len(idx_split) - 1)
            split_i += 1
        else:
            candidates_tsv_.append(candidate_file)
            Ls.append(len(idx) - 1)

    current_L = 0
    candidate_files = []
    all_vcf_records = []
    all_vcf_records_none = []
    for i, (candidate_file, L) in enumerate(
            sorted(zip(candidates_tsv_, Ls), key=lambda x: x[1])):
        current_L += L
        candidate_files.append(candidate_file)
        if current_L > max_load_candidates / 10 or i == len(
                candidates_tsv_) - 1:
            logger.info("Run for candidate files: {}".format(candidate_files))
            call_set = NeuSomaticDataset(
                roots=candidate_files,
                max_load_candidates=max_load_candidates,
                transform=data_transform,
                is_test=True,
                num_threads=num_threads,
                coverage_thr=coverage_thr)
            call_loader = torch.utils.data.DataLoader(call_set,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      pin_memory=True,
                                                      num_workers=num_threads)

            current_L = 0
            candidate_files = []

            logger.info("N_dataset: {}".format(len(call_set)))
            if len(call_set) == 0:
                logger.warning(
                    "Skip {} with 0 candidates".format(candidate_file))
                continue

            final_preds_, none_preds_, true_path_ = call_variants(
                net, vartype_classes, call_loader, out_dir, model_tag,
                use_cuda)
            all_vcf_records.extend(
                pred_vcf_records(ref_file, final_preds_, true_path_, chroms,
                                 vartype_classes, num_threads))
            all_vcf_records_none.extend(
                pred_vcf_records_none(none_preds_, chroms))

    all_vcf_records = dict(all_vcf_records)
    all_vcf_records_none = dict(all_vcf_records_none)

    if os.path.exists(new_split_tsvs_dir):
        logger.warning(
            "Remove split candidates directory: {}".format(new_split_tsvs_dir))
        shutil.rmtree(new_split_tsvs_dir)

    logger.info("Prepare Output VCF")
    output_vcf = "{}/pred.vcf".format(out_dir)
    var_vcf_records = get_vcf_records(all_vcf_records)
    write_vcf(var_vcf_records, output_vcf, chroms_order, pass_threshold,
              lowqual_threshold)

    logger.info("Prepare Non-Somatics VCF")
    output_vcf_none = "{}/none.vcf".format(out_dir)
    vcf_records_none = get_vcf_records(all_vcf_records_none)
    write_vcf(vcf_records_none, output_vcf_none, chroms_order, pass_threshold,
              lowqual_threshold)

    if os.path.exists(matrices_dir):
        logger.warning("Remove matrices directory: {}".format(matrices_dir))
        shutil.rmtree(matrices_dir)
    return output_vcf
Ejemplo n.º 3
0
def train_neusomatic(candidates_tsv, validation_candidates_tsv, out_dir,
                     checkpoint, num_threads, batch_size, max_epochs,
                     learning_rate, lr_drop_epochs, lr_drop_ratio, momentum,
                     boost_none, none_count_scale, max_load_candidates,
                     coverage_thr, save_freq, ensemble,
                     merged_candidates_per_tsv, merged_max_num_tsvs,
                     overwrite_merged_tsvs, trian_split_len, use_cuda):
    logger = logging.getLogger(train_neusomatic.__name__)

    logger.info("----------------Train NeuSomatic Network-------------------")
    logger.info("PyTorch Version: {}".format(torch.__version__))
    logger.info("Torchvision Version: {}".format(torchvision.__version__))

    if not os.path.exists(out_dir):
        os.mkdir(out_dir)

    if not use_cuda:
        torch.set_num_threads(num_threads)

    data_transform = transforms.Compose(
        [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    num_channels = 119 if ensemble else 26
    net = NeuSomaticNet(num_channels)
    if use_cuda:
        logger.info("GPU training!")
        net.cuda()
    else:
        logger.info("CPU training!")

    if torch.cuda.device_count() > 1:
        logger.info("We use {} GPUs!".format(torch.cuda.device_count()))
        net = nn.DataParallel(net)

    if not os.path.exists("{}/models/".format(out_dir)):
        os.mkdir("{}/models/".format(out_dir))

    if checkpoint:
        logger.info(
            "Load pretrained model from checkpoint {}".format(checkpoint))
        pretrained_dict = torch.load(checkpoint,
                                     map_location=lambda storage, loc: storage)
        pretrained_state_dict = pretrained_dict["state_dict"]
        tag = pretrained_dict["tag"]
        sofar_epochs = pretrained_dict["epoch"]
        logger.info(
            "sofar_epochs from pretrained checkpoint: {}".format(sofar_epochs))
        coverage_thr = pretrained_dict["coverage_thr"]
        logger.info(
            "Override coverage_thr from pretrained checkpoint: {}".format(
                coverage_thr))
        prev_epochs = sofar_epochs + 1
        model_dict = net.state_dict()
        # 1. filter out unnecessary keys
        # pretrained_state_dict = {
        # k: v for k, v in pretrained_state_dict.items() if k in model_dict}
        if "module." in list(
                pretrained_state_dict.keys())[0] and "module." not in list(
                    model_dict.keys())[0]:
            pretrained_state_dict = {
                k.split("module.")[1]: v
                for k, v in pretrained_state_dict.items()
                if k.split("module.")[1] in model_dict
            }
        elif "module." not in list(
                pretrained_state_dict.keys())[0] and "module." in list(
                    model_dict.keys())[0]:
            pretrained_state_dict = {("module." + k): v
                                     for k, v in pretrained_state_dict.items()
                                     if ("module." + k) in model_dict}
        else:
            pretrained_state_dict = {
                k: v
                for k, v in pretrained_state_dict.items() if k in model_dict
            }
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_state_dict)
        # 3. load the new state dict
        net.load_state_dict(pretrained_state_dict)
    else:
        prev_epochs = 0
        time_now = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
        tag = "neusomatic_{}".format(time_now)
    logger.info("tag: {}".format(tag))

    shuffle(candidates_tsv)

    if len(candidates_tsv) > merged_max_num_tsvs:
        candidates_tsv = merge_tsvs(
            input_tsvs=candidates_tsv,
            out=out_dir,
            candidates_per_tsv=merged_candidates_per_tsv,
            max_num_tsvs=merged_max_num_tsvs,
            overwrite_merged_tsvs=overwrite_merged_tsvs,
            keep_none_types=True)

    Ls = []
    for tsv in candidates_tsv:
        idx = pickle.load(open(tsv + ".idx", "rb"))
        Ls.append(len(idx) - 1)

    Ls, candidates_tsv = list(
        zip(*sorted(zip(Ls, candidates_tsv), key=lambda x: x[0],
                    reverse=True)))

    train_split_tsvs = []
    current_L = 0
    current_split_tsvs = []
    for i, (L, tsv) in enumerate(zip(Ls, candidates_tsv)):
        current_L += L
        current_split_tsvs.append(tsv)
        if current_L >= trian_split_len or (i == len(candidates_tsv) - 1
                                            and current_L > 0):
            logger.info("tsvs in split {}: {}".format(len(train_split_tsvs),
                                                      current_split_tsvs))
            train_split_tsvs.append(current_split_tsvs)
            current_L = 0
            current_split_tsvs = []

    assert sum(map(lambda x: len(x), train_split_tsvs)) == len(candidates_tsv)
    train_sets = []
    none_counts = []
    var_counts = []
    none_indices_ = []
    var_indices_ = []
    samplers = []
    for split_i, tsvs in enumerate(train_split_tsvs):
        train_set = NeuSomaticDataset(
            roots=tsvs,
            max_load_candidates=int(max_load_candidates * len(tsvs) /
                                    float(len(candidates_tsv))),
            transform=data_transform,
            is_test=False,
            num_threads=num_threads,
            coverage_thr=coverage_thr)
        train_sets.append(train_set)
        none_indices = train_set.get_none_indices()
        var_indices = train_set.get_var_indices()
        if none_indices:
            none_indices = list(
                map(lambda i: none_indices[i],
                    torch.randperm(len(none_indices)).tolist()))
        logger.info("Non-somatic candidates is split {}: {}".format(
            split_i, len(none_indices)))
        if var_indices:
            var_indices = list(
                map(lambda i: var_indices[i],
                    torch.randperm(len(var_indices)).tolist()))
        logger.info("Somatic candidates in split {}: {}".format(
            split_i, len(var_indices)))
        none_count = max(
            min(len(none_indices),
                len(var_indices) * none_count_scale), 1)
        logger.info(
            "Non-somatic considered in each epoch of split {}: {}".format(
                split_i, none_count))

        sampler = SubsetNoneSampler(none_indices, var_indices, none_count)
        samplers.append(sampler)
        none_counts.append(none_count)
        var_counts.append(len(var_indices))
        var_indices_.append(var_indices)
        none_indices_.append(none_indices)
    logger.info("# Total Train cadidates: {}".format(
        sum(map(lambda x: len(x), train_sets))))

    if validation_candidates_tsv:
        validation_set = NeuSomaticDataset(
            roots=validation_candidates_tsv,
            max_load_candidates=max_load_candidates,
            transform=data_transform,
            is_test=True,
            num_threads=num_threads,
            coverage_thr=coverage_thr)
        validation_loader = torch.utils.data.DataLoader(
            validation_set,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_threads,
            pin_memory=True)
        logger.info("#Validation candidates: {}".format(len(validation_set)))

    count_class_t = [0] * 4
    count_class_l = [0] * 4
    for train_set in train_sets:
        for i in range(4):
            count_class_t[i] += train_set.count_class_t[i]
            count_class_l[i] += train_set.count_class_l[i]

    weights_type, weights_length = make_weights_for_balanced_classes(
        count_class_t, count_class_l, 4, 4, sum(none_counts))

    weights_type[2] *= boost_none
    weights_length[0] *= boost_none

    logger.info("weights_type:{}, weights_length:{}".format(
        weights_type, weights_length))

    loss_s = []
    gradients = torch.FloatTensor(weights_type)
    gradients2 = torch.FloatTensor(weights_length)
    if use_cuda:
        gradients = gradients.cuda()
        gradients2 = gradients2.cuda()
    criterion_crossentropy = nn.CrossEntropyLoss(gradients)
    criterion_crossentropy2 = nn.CrossEntropyLoss(gradients2)
    criterion_smoothl1 = nn.SmoothL1Loss()
    optimizer = optim.SGD(net.parameters(),
                          lr=learning_rate,
                          momentum=momentum)

    net.train()
    len_train_set = sum(none_counts) + sum(var_counts)
    logger.info("Number of candidater per epoch: {}".format(len_train_set))
    print_freq = max(1, int(len_train_set / float(batch_size) / 4.0))
    curr_epoch = prev_epochs
    torch.save(
        {
            "state_dict": net.state_dict(),
            "tag": tag,
            "epoch": curr_epoch,
            "coverage_thr": coverage_thr
        },
        '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch))

    if len(train_sets) == 1:
        train_sets[0].open_candidate_tsvs()
        train_loader = torch.utils.data.DataLoader(train_sets[0],
                                                   batch_size=batch_size,
                                                   num_workers=num_threads,
                                                   pin_memory=True,
                                                   sampler=samplers[0])
    # loop over the dataset multiple times
    n_epoch = 0
    for epoch in range(max_epochs - prev_epochs):
        n_epoch += 1
        running_loss = 0.0
        i_ = 0
        for split_i, train_set in enumerate(train_sets):
            if len(train_sets) > 1:
                train_set.open_candidate_tsvs()
                train_loader = torch.utils.data.DataLoader(
                    train_set,
                    batch_size=batch_size,
                    num_workers=num_threads,
                    pin_memory=True,
                    sampler=samplers[split_i])
            for i, data in enumerate(train_loader, 0):
                i_ += 1
                # get the inputs
                (inputs, labels, var_pos_s, var_len_s, _), _ = data
                # wrap them in Variable
                inputs, labels, var_pos_s, var_len_s = Variable(
                    inputs), Variable(labels), Variable(var_pos_s), Variable(
                        var_len_s)
                if use_cuda:
                    inputs, labels, var_pos_s, var_len_s = inputs.cuda(
                    ), labels.cuda(), var_pos_s.cuda(), var_len_s.cuda()

                # zero the parameter gradients
                optimizer.zero_grad()

                outputs, _ = net(inputs)
                [outputs_classification, outputs_pos, outputs_len] = outputs
                var_len_labels = Variable(
                    torch.LongTensor(var_len_s.cpu().data.numpy()))
                if use_cuda:
                    var_len_labels = var_len_labels.cuda()
                loss = criterion_crossentropy(
                    outputs_classification, labels) + 1 * criterion_smoothl1(
                        outputs_pos.squeeze(1),
                        var_pos_s[:, 1]) + 1 * criterion_crossentropy2(
                            outputs_len, var_len_labels)

                loss.backward()
                optimizer.step()
                loss_s.append(loss.data)

                running_loss += loss.data
                if i_ % print_freq == print_freq - 1:
                    logger.info(
                        'epoch: {}, iter: {:>7}, lr: {}, loss: {:.5f}'.format(
                            n_epoch + prev_epochs, len(loss_s), learning_rate,
                            running_loss / print_freq))
                    running_loss = 0.0
            if len(train_sets) > 1:
                train_set.close_candidate_tsvs()

        curr_epoch = n_epoch + prev_epochs
        if curr_epoch % save_freq == 0:
            torch.save(
                {
                    "state_dict": net.state_dict(),
                    "tag": tag,
                    "epoch": curr_epoch,
                    "coverage_thr": coverage_thr,
                }, '{}/models/checkpoint_{}_epoch{}.pth'.format(
                    out_dir, tag, curr_epoch))
            if validation_candidates_tsv:
                test(net, curr_epoch, validation_loader, use_cuda)
        if curr_epoch % lr_drop_epochs == 0:
            learning_rate *= lr_drop_ratio
            optimizer = optim.SGD(net.parameters(),
                                  lr=learning_rate,
                                  momentum=momentum)
    logger.info('Finished Training')

    if len(train_sets) == 1:
        train_sets[0].close_candidate_tsvs()

    curr_epoch = n_epoch + prev_epochs
    torch.save(
        {
            "state_dict": net.state_dict(),
            "tag": tag,
            "epoch": curr_epoch,
            "coverage_thr": coverage_thr
        },
        '{}/models/checkpoint_{}_epoch{}.pth'.format(out_dir, tag, curr_epoch))
    if validation_candidates_tsv:
        test(net, curr_epoch, validation_loader, use_cuda)
    logger.info("Total Epochs: {}".format(curr_epoch))
    return '{}/models/checkpoint_{}_epoch{}.pth'.format(
        out_dir, tag, curr_epoch)