예제 #1
0
def val(args):
    list_threhold = [0.5]
    if config.top4_DeepNN:
        model = getattr(models,
                        config.model_name)(num_classes=config.num_classes,
                                           channel_size=config.channel_size)
        model.load_state_dict(
            torch.load(args.ckpt, map_location='cpu')['state_dict'])
        model = model.to(device)
        print(config.model_name, args.ckpt)
    else:
        model = None
        print('no', config.model_name)
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=4)
    if config.kind == 2 and config.top4_catboost:
        val_f1 = top4_val_epoch(model, val_dataloader)
        print('catboost val_f1:%.3f\n' % val_f1)
    else:
        criterion = nn.BCEWithLogitsLoss()
        for threshold in list_threhold:
            val_loss, val_f1 = val_epoch(model, criterion, val_dataloader,
                                         threshold)
            print('threshold %.2f val_loss:%0.3e val_f1:%.3f\n' %
                  (threshold, val_loss, val_f1))
예제 #2
0
def check(args):
    config.model_name = args.model_name
    print(config.model_name)
    model_save_dir = '%s/%s' % (config.ckpt,
                                config.model_name + '_' + str(args.fold))
    args.ckpt = model_save_dir
    config.train_data = config.train_data + str(args.fold) + '.pth'
    list_threhold = [0.5]
    model = getattr(models, config.model_name)()
    if args.ckpt:
        model.load_state_dict(
            torch.load(os.path.join(model_save_dir, config.best_w),
                       map_location='cpu')['state_dict'])
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    dd = torch.load(config.train_data)
    filename = dd['train']
    val_dataset = ECGDataset(data_path=config.train_data, train=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                shuffle=False,
                                num_workers=6)

    for threshold in list_threhold:
        val_loss, val_f1, target, output = val_epoch(model, criterion,
                                                     val_dataloader, threshold)
        print('threshold %.2f val_loss:%0.3e val_f1:%.3f\n' %
              (threshold, val_loss, val_f1))

    return target, output, filename
예제 #3
0
def val(mode, ckpt):
    threshold = 0.5
    model = getattr(resnet, config.model_name)(input_dim=config.input_dim)
    model.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    val_dataset = ECGDataset(data_path=config.train_data, mode=mode)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=6)
    val_loss, val_p, val_r, val_f1, pr_df = val_epoch(model, criterion,
                                                      val_dataloader,
                                                      threshold, False)
    print(
        'threshold %.2f val_loss:%0.3e val_precision:%.4f val_recall:%.4f val_f1:%.4f\n'
        % (
            threshold,
            val_loss,
            val_p,
            val_r,
            val_f1,
        ))
    pr_df['arry'] = pr_df['arry'].map(val_dataset.idx2name)
    pr_df.to_csv('../user_data/%s_f1.csv' % mode, encoding='gbk')
    print(pr_df)
예제 #4
0
파일: main.py 프로젝트: white0531/SE-ECGNet
def val(args):
    list_threhold = [0.5]
    model = models.resnet34()
    if args.ckpt: model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4)
    for threshold in list_threhold:
        val_loss, val_f1,val_acc,val_recall,val_precision = val_epoch(model, criterion, val_dataloader, threshold)
        print('threshold %.2f val_loss:%0.6e val_f1:%.8f\n , val_acc: %.8f\n val_recall: %.8f\n val_precision : %.8f\n' % (threshold, val_loss, val_f1,val_acc,val_recall,val_precision))
예제 #5
0
def val(args):
    list_threhold = [0.5]
    model = getattr(models, config.model_name)()
    if args.ckpt: model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset, batch_size=config.batch_size, num_workers=4)
    for threshold in list_threhold:
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader, threshold)
        print('threshold %.2f val_loss:%0.3e val_f1:%.3f\n' % (threshold, val_loss, val_f1))
예제 #6
0
def test():
    test_dataset = ECGDataset(data_path=config.test_data, mode='test')
    test_dataloader = DataLoader(test_dataset,
                                 batch_size=config.batch_size,
                                 num_workers=6)
    # model
    model = getattr(resnet, config.model_name)(num_classes=config.num_classes,
                                               input_dim=config.input_dim)
    model = model.to(device)
    ckpt_dir = os.path.join(config.ckpt, os.listdir(config.ckpt)[0])
    # ckpts = [os.path.join(ckpt_dir, 'e%i'%i) for i in range(25, 30)]
    ckpts = [
        os.path.join('../user_data/ckpt/resnet34_201910101016', 'e%i' % i)
        for i in range(31, 35)
    ]

    # make prediction
    preds_ckpts = []
    for ckpt in ckpts:
        model.load_state_dict(torch.load(ckpt)['state_dict'])
        model.eval()
        outputs = []
        with torch.no_grad():
            for inputs, extra_info in test_dataloader:
                inputs = inputs.to(device)
                extra_info = extra_info.to(device)
                output = model(inputs, extra_info)
                outputs.append(output.detach().cpu().numpy())
        outputs = np.concatenate(outputs)
        preds = output2pred(outputs)
        preds_ckpts.append(preds)
    preds = np.median(np.array(preds_ckpts), axis=0)
    preds_dict = {id_[1]: preds[i] for i, id_ in enumerate(test_dataset.data)}

    # make submission
    sub_file = os.path.join(config.sub_dir, 'result.txt')
    fout = open(sub_file, 'w', encoding='utf-8')
    for line in open(config.test_label, encoding='utf-8'):
        fout.write(line.strip('\n'))
        id_ = line.split('\t')[0]
        pred = preds_dict[id_]
        ixs = [i for i, out in enumerate(pred) if out == 1]
        for i in ixs:
            fout.write("\t" + test_dataset.idx2name[i])
        fout.write('\n')
    fout.close()
예제 #7
0
def deep_predict():
    from torch import nn
    from torch.utils.data import DataLoader
    from dataset import ECGDataset
    list_threhold = [0.5]
    model = getattr(models, config.model_name)()
    checkpoint = torch.load(os.path.join(r'ckpt\resnet34', config.best_w),
                            map_location='cpu')
    model.load_state_dict(checkpoint['state_dict'])
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=6)
    for threshold in list_threhold:
        output_all = val_epoch(model, criterion, val_dataloader, threshold)
        # print('threshold %.2f val_loss:%0.3e val_f1:%.3f\n' % (threshold, val_loss, val_f1))
    return output_all.cpu().numpy()
예제 #8
0
def val(mode, ckpt):
    model = getattr(resnet, config.model_name)(num_classes=config.num_classes,
                                               input_dim=config.input_dim)
    model.load_state_dict(torch.load(ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    val_dataset = ECGDataset(data_path=config.train_data, mode=mode)
    groups = config.groups
    count = val_dataset.count
    criterion = utils.WeightedMultilabel(groups, count, device)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=6)
    val_loss, val_p, val_r, val_f1, pr_df = val_epoch(model, criterion,
                                                      val_dataloader, False)
    print('val_loss:%0.3e val_precision:%.4f val_recall:%.4f val_f1:%.4f\n' % (
        val_loss,
        val_p,
        val_r,
        val_f1,
    ))
    pr_df['arry'] = pr_df['arry'].map(val_dataset.idx2name)
    pr_df.to_csv('../user_data/%s_f1.csv' % mode, encoding='gbk')
    display.display(pr_df)
예제 #9
0
파일: main.py 프로젝트: hitachinsk/kkp
def val(args):
    list_threhold = [0.5]  # We can add more data to test
    if config.fuse == 'False':
        model = getattr(models, config.model_name)()
    elif config.fuse == 'True':
        model = ResMlp(ResMlpParams)
    else:
        raise ValueError(
            'Not supported type of fuse item in train initialization phase!')
    if args.ckpt:
        model.load_state_dict(
            torch.load(args.ckpt, map_location='cpu')['state_dict'])
    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss()
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=4)
    for threshold in list_threhold:
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader,
                                     threshold)
        print('threshold %.2f val_loss:%0.3e val_f1:%.3f\n' %
              (threshold, val_loss, val_f1))  # 他应该是选择了不同的门槛值来做这件事
예제 #10
0
파일: main.py 프로젝트: hitachinsk/kkp
def train(args):
    # model
    if config.fuse == 'False':
        model = getattr(models, config.model_name)()
    elif config.fuse == 'True':
        model = ResMlp(ResMlpParams)
    else:
        raise ValueError(
            'Not supported type of fuse item in train initialization phase!')
    if args.ckpt and not args.resume:
        state = torch.load(args.ckpt, map_location='cpu')
        model.load_state_dict(state['state_dict'])
        print('train with pretrained weight val_f1', state['f1'])
    model = model.to(device)
    # data
    train_dataset = ECGDataset(data_path=config.train_data, train=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=4)
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=2)
    print("train_datasize", len(train_dataset), "val_datasize",
          len(val_dataset))
    # optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
    criterion = utils.WeightedMultilabel(w)
    # 模型保存文件夹
    model_save_dir = '%s/%s_%s' % (args.output, config.model_name,
                                   time.strftime("%Y%m%d%H%M"))
    utils.mkdirs(model_save_dir)
    if args.ex: model_save_dir += args.ex
    best_f1 = -1
    lr = config.lr
    start_epoch = 1
    stage = 1
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    logdir = os.path.join(args.output, 'logs',
                          current_time + '_' + config.fuse)
    writer = SummaryWriter(logdir)
    # 从上一个断点,继续训练
    if args.resume:
        if os.path.exists(args.ckpt):  # 这里是存放权重的目录
            current_w = torch.load(os.path.join(args.ckpt, config.current_w))
            best_w = torch.load(os.path.join(args.ckpt, config.best_w))
            best_f1 = best_w['loss']
            start_epoch = current_w['epoch'] + 1
            lr = current_w['lr']
            stage = current_w['stage']
            model.load_state_dict(current_w['state_dict'])
            # 如果中断点恰好为转换stage的点
            if start_epoch - 1 in config.stage_epoch:
                stage += 1
                lr /= config.lr_decay
                utils.adjust_learning_rate(optimizer, lr)
                model.load_state_dict(best_w['state_dict'])  # 一旦断点就从最好的模型开始训练
            print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
    # =========>开始训练<=========
    for epoch in range(start_epoch, config.max_epoch + 1):
        since = time.time()
        train_loss, train_f1 = train_epoch(model,
                                           optimizer,
                                           criterion,
                                           train_dataloader,
                                           show_interval=100)
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader)
        print(
            '#epoch:%02d stage:%d train_loss:%.3e train_f1:%.3f  val_loss:%0.3e val_f1:%.3f time:%s\n'
            % (epoch, stage, train_loss, train_f1, val_loss, val_f1,
               utils.print_time_cost(since)))
        writer.add_scalar('scalar/train_loss', train_loss,
                          epoch)  # 粗略的查看,可以改造成每一个iteration的更加细致的查看
        writer.add_scalar('scalar/train_f1', train_f1, epoch)
        writer.add_scalar('scalar/val_loss', val_loss, epoch)
        writer.add_scalar('scalar/val_f1', val_f1, epoch)
        state = {
            "state_dict": model.state_dict(),
            "epoch": epoch,
            "loss": val_loss,
            'f1': val_f1,
            'lr': lr,
            'stage': stage
        }
        save_ckpt(state, best_f1 < val_f1, model_save_dir)
        best_f1 = max(best_f1, val_f1)
        if epoch in config.stage_epoch:
            stage += 1
            lr /= config.lr_decay
            best_w = os.path.join(
                model_save_dir,
                config.best_w)  # 在进入到每一个阶段之前选取前一阶段表现最好的模型进行训练,贪心方法,但是这样真的对吗?
            model.load_state_dict(torch.load(best_w)['state_dict'])
            print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
            utils.adjust_learning_rate(optimizer, lr)
    writer.close()
예제 #11
0
파일: main.py 프로젝트: ycd2016/HFECG
def train(args):
    model = models.myecgnet()
    if args.ckpt and not args.resume:
        state = torch.load(args.ckpt, map_location='cpu')
        model.load_state_dict(state['state_dict'])
        print('train with pretrained weight val_f1', state['f1'])
    model = model.to(device)
    train_dataset = ECGDataset(data_path=config.train_data, train=True)
    train_dataloader = DataLoader(train_dataset,
                                  collate_fn=my_collate_fn,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=8)
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=8)
    print("train_datasize", len(train_dataset), "val_datasize",
          len(val_dataset))
    optimizer = AdamW(model.parameters(), lr=config.lr)
    w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
    criterion = utils.WeightedMultilabel(w)
    model_save_dir = '%s/%s_%s' % (config.ckpt, config.model_name,
                                   time.strftime("%Y%m%d%H%M"))
    os.mkdir(model_save_dir)
    if args.ex: model_save_dir += args.ex
    best_f1 = -1
    lr = config.lr
    start_epoch = 1
    stage = 1
    if args.resume:
        if os.path.exists(args.ckpt):
            model_save_dir = args.ckpt
            current_w = torch.load(os.path.join(args.ckpt, config.current_w))
            best_w = torch.load(os.path.join(model_save_dir, config.best_w))
            best_f1 = best_w['loss']
            start_epoch = current_w['epoch'] + 1
            lr = current_w['lr']
            stage = current_w['stage']
            model.load_state_dict(current_w['state_dict'])
            if start_epoch - 1 in config.stage_epoch:
                stage += 1
                lr /= config.lr_decay
                utils.adjust_learning_rate(optimizer, lr)
                model.load_state_dict(best_w['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
    for epoch in range(start_epoch, config.max_epoch + 1):
        since = time.time()
        train_loss, train_f1 = train_epoch(model,
                                           optimizer,
                                           criterion,
                                           train_dataloader,
                                           show_interval=10)
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader)
        print(
            '#epoch:%03d\tstage:%d\ttrain_loss:%.4f\ttrain_f1:%.3f\tval_loss:%0.4f\tval_f1:%.3f\ttime:%s\n'
            % (epoch, stage, train_loss, train_f1, val_loss, val_f1,
               utils.print_time_cost(since)))
        state = {
            "state_dict": model.state_dict(),
            "epoch": epoch,
            "loss": val_loss,
            'f1': val_f1,
            'lr': lr,
            'stage': stage
        }
        save_ckpt(state, best_f1 < val_f1, model_save_dir)
        best_f1 = max(best_f1, val_f1)
        if epoch in config.stage_epoch:
            stage += 1
            lr /= config.lr_decay
            best_w = os.path.join(model_save_dir, config.best_w)
            model.load_state_dict(torch.load(best_w)['state_dict'])
            print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
            utils.adjust_learning_rate(optimizer, lr)
예제 #12
0
파일: main.py 프로젝트: yyyu200/ecg_pytorch
def train(args):
    # model
    model = getattr(models, config.model_name)()
    if args.ckpt and not args.resume:
        state = torch.load(args.ckpt, map_location='cpu')
        model.load_state_dict(state['state_dict'])
        print('train with pretrained weight val_f1', state['f1'])
    model = model.to(device)
    # data
    train_dataset = ECGDataset(data_path=config.train_data, train=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=6)
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=4)
    print("train_datasize", len(train_dataset), "val_datasize",
          len(val_dataset))
    # optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
    criterion = utils.WeightedMultilabel(w)
    # 模型保存文件夹
    model_save_dir = '%s/%s_%s' % (config.ckpt, config.model_name,
                                   time.strftime("%Y%m%d%H%M"))
    if args.ex: model_save_dir += args.ex
    best_f1 = -1
    lr = config.lr
    start_epoch = 1
    stage = 1
    # 从上一个断点,继续训练
    if args.resume:
        if os.path.exists(args.ckpt):  # 这里是存放权重的目录
            model_save_dir = args.ckpt
            current_w = torch.load(os.path.join(args.ckpt, config.current_w))
            best_w = torch.load(os.path.join(model_save_dir, config.best_w))
            best_f1 = best_w['loss']
            start_epoch = current_w['epoch'] + 1
            lr = current_w['lr']
            stage = current_w['stage']
            model.load_state_dict(current_w['state_dict'])
            # 如果中断点恰好为转换stage的点
            if start_epoch - 1 in config.stage_epoch:
                stage += 1
                lr /= config.lr_decay
                utils.adjust_learning_rate(optimizer, lr)
                model.load_state_dict(best_w['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
    logger = Logger(logdir=model_save_dir, flush_secs=2)
    # =========>开始训练<=========
    for epoch in range(start_epoch, config.max_epoch + 1):
        since = time.time()
        train_loss, train_f1 = train_epoch(model,
                                           optimizer,
                                           criterion,
                                           train_dataloader,
                                           show_interval=100)
        val_loss, val_f1 = val_epoch(model, criterion, val_dataloader)
        print(
            '#epoch:%02d stage:%d train_loss:%.3e train_f1:%.3f  val_loss:%0.3e val_f1:%.3f time:%s\n'
            % (epoch, stage, train_loss, train_f1, val_loss, val_f1,
               utils.print_time_cost(since)))
        logger.log_value('train_loss', train_loss, step=epoch)
        logger.log_value('train_f1', train_f1, step=epoch)
        logger.log_value('val_loss', val_loss, step=epoch)
        logger.log_value('val_f1', val_f1, step=epoch)
        state = {
            "state_dict": model.state_dict(),
            "epoch": epoch,
            "loss": val_loss,
            'f1': val_f1,
            'lr': lr,
            'stage': stage
        }
        save_ckpt(state, best_f1 < val_f1, model_save_dir)
        best_f1 = max(best_f1, val_f1)
        if epoch in config.stage_epoch:
            stage += 1
            lr /= config.lr_decay
            best_w = os.path.join(model_save_dir, config.best_w)
            model.load_state_dict(torch.load(best_w)['state_dict'])
            print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
            utils.adjust_learning_rate(optimizer, lr)
예제 #13
0
    else:
        leads = args.leads.split(',')
        nleads = len(leads)
    data_dir = args.data_dir
    label_csv = os.path.join(data_dir, 'labels.csv')

    if (args.biGRU == 1):
        net = resnet34_GRU(input_channels=nleads).to(device)
    else:
        net = resnet34(input_channels=nleads).to(device)

    net.load_state_dict(torch.load(args.model_path, map_location=device))
    net.eval()

    train_folds, val_folds, test_folds = split_data(seed=args.seed)
    train_dataset = ECGDataset('train', data_dir, label_csv, train_folds,
                               leads, args.downsamp_rate)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    val_dataset = ECGDataset('val', data_dir, label_csv, val_folds, leads,
                             args.downsamp_rate)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            pin_memory=True)
    test_dataset = ECGDataset('test', data_dir, label_csv, test_folds, leads,
                              args.downsamp_rate)
    test_loader = DataLoader(test_dataset,
예제 #14
0
def train(mode='train', ckpt=None, resume=False):
    # model
    model = getattr(resnet, config.model_name)(input_dim=config.input_dim)
    if ckpt is not None and not resume:
        state = torch.load(ckpt, map_location='cpu')
        model.load_state_dict(state['state_dict'])
        print('train with pretrained weight val_f1', state['f1'])
    model = model.to(device)
    # data
    train_dataset = ECGDataset(data_path=config.train_data, mode=mode)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=6)
    val_dataset = ECGDataset(data_path=config.train_data, mode='val')
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=6)
    print("train_datasize", len(train_dataset), "val_datasize",
          len(val_dataset))
    # optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    weights = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
    criterion = torch.nn.BCEWithLogitsLoss(weights)
    # criterion = torch.nn.BCEWithLogitsLoss()
    # 模型保存文件夹
    model_save_dir = '%s/%s_%s' % (config.ckpt, config.model_name,
                                   time.strftime("%Y%m%d%H%M"))
    best_f1 = -1
    lr = config.lr
    start_epoch = 1
    stage = 1
    # 从上一个断点,继续训练
    if resume:
        if os.path.exists(ckpt):  # 这里是存放权重的目录
            model_save_dir = ckpt
            current_w = torch.load(os.path.join(ckpt, config.current_w))
            best_w = torch.load(os.path.join(model_save_dir, config.best_w))
            best_f1 = best_w['loss']
            start_epoch = current_w['epoch'] + 1
            lr = current_w['lr']
            stage = current_w['stage']
            model.load_state_dict(current_w['state_dict'])
            # 如果中断点恰好为转换stage的点
            if start_epoch - 1 in config.stage_epoch:
                stage += 1
                lr /= config.lr_decay
                utils.adjust_learning_rate(optimizer, lr)
                model.load_state_dict(best_w['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
    if not os.path.exists(config.ckpt):
        os.mkdir(config.ckpt)
    os.mkdir(model_save_dir)
    # =========>开始训练<=========
    for epoch in range(start_epoch, config.max_epoch + 1):
        since = time.time()
        train_loss, train_p, train_r, train_f1 = train_epoch(model,
                                                             optimizer,
                                                             criterion,
                                                             train_dataloader,
                                                             show_interval=50)
        val_loss, val_p, val_r, val_f1 = val_epoch(model, criterion,
                                                   val_dataloader)
        print('#epoch:%02d stage:%d time:%s' %
              (epoch, stage, utils.print_time_cost(since)))
        print(
            'train_loss:%.3e train_precision:%.4f train_recall:%.4f train_f1:%.4f'
            % (train_loss, train_p, train_r, train_f1))
        print(
            'val_loss:%.3e val_precision:%.4f val_recall:%.4f val_f1:%.4f \n' %
            (val_loss, val_p, val_r, val_f1))
        state = {
            "state_dict": model.state_dict(),
            "epoch": epoch,
            "loss": val_loss,
            'f1': val_f1,
            'lr': lr,
            'stage': stage
        }
        torch.save(state, os.path.join(model_save_dir, 'e%i' % (epoch)))
        best_f1 = max(best_f1, val_f1)

        if epoch in config.stage_epoch:
            stage += 1
            lr /= config.lr_decay
            print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
            utils.adjust_learning_rate(optimizer, lr)
예제 #15
0
def train(mode='train', ckpt=None, resume=False):
    # model
    model = getattr(resnet, config.model_name)(num_classes=config.num_classes,
                                               input_dim=config.input_dim)
    if ckpt is not None and not resume:
        state = torch.load(ckpt, map_location='cpu')
        model.load_state_dict(state['state_dict'])
        print('train with pretrained weight val_f1', state['f1'])
    model = model.to(device)
    # data
    train_dataset = ECGDataset(data_path=config.train_data, mode=mode)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=6)
    val_dataset = ECGDataset(data_path=config.train_data, mode='val')
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=6)
    print("train_datasize", len(train_dataset), "val_datasize",
          len(val_dataset))
    # optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    groups = config.groups
    count = train_dataset.count
    criterion = utils.WeightedMultilabel(groups, count, device)
    # 模型保存文件夹
    model_save_dir = '%s/%s_%s' % (config.ckpt, config.model_name,
                                   time.strftime("%Y%m%d%H%M"))
    if not os.path.exists(config.ckpt):
        os.mkdir(config.ckpt)
    os.mkdir(model_save_dir)
    best_f1 = -1
    lr = config.lr
    start_epoch = 1
    stage = 1
    # 从上一个断点,继续训练
    if resume:
        if os.path.exists(ckpt):  # 这里是存放权重的目录
            current_w = torch.load(os.path.join(ckpt))
            start_epoch = current_w['epoch'] + 1
            lr = current_w['lr']
            stage = current_w['stage']
            model.load_state_dict(current_w['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
    # logger = Logger(logdir=model_save_dir, flush_secs=2)
    # =========>开始训练<=========
    val_loss, val_p, val_r, val_f1 = val_epoch(model, criterion,
                                               val_dataloader)
    print('start training')
    print('val_loss:%.3e val_precision:%.4f val_recall:%.4f val_f1:%.4f \n' %
          (val_loss, val_p, val_r, val_f1))
    for epoch in range(start_epoch, config.max_epoch + 1):
        since = time.time()
        train_loss, train_p, train_r, train_f1 = train_epoch(
            model,
            optimizer,
            criterion,
            train_dataloader,
            show_interval=config.show_interval)
        val_loss, val_p, val_r, val_f1, pr_df = val_epoch(model,
                                                          criterion,
                                                          val_dataloader,
                                                          simple_mode=False)
        pr_df['arry'] = pr_df['arry'].map(val_dataset.idx2name)
        print('#epoch:%02d stage:%d time:%s' %
              (epoch, stage, utils.print_time_cost(since)))
        print(
            'train_loss:%.3e train_precision:%.4f train_recall:%.4f train_f1:%.4f'
            % (train_loss, train_p, train_r, train_f1))
        print(
            'val_loss:%.3e val_precision:%.4f val_recall:%.4f val_f1:%.4f \n' %
            (val_loss, val_p, val_r, val_f1))
        display.display(pr_df)
        # logger.log_value('train_loss', train_loss, step=epoch)
        # logger.log_value('train_f1', train_f1, step=epoch)
        # logger.log_value('val_loss', val_loss, step=epoch)
        # logger.log_value('val_f1', val_f1, step=epoch)
        state = {
            "state_dict": model.state_dict(),
            "epoch": epoch,
            "loss": val_loss,
            'f1': val_f1,
            'lr': lr,
            'stage': stage
        }
        torch.save(state, os.path.join(model_save_dir, 'e%i' % (epoch)))
        best_f1 = max(best_f1, val_f1)

        if epoch in config.stage_epoch:
            stage += 1
            lr /= config.lr_decay
            print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
            utils.adjust_learning_rate(optimizer, lr)
예제 #16
0
    if args.use_gpu and torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = 'cpu'

    if args.leads == 'all':
        leads = 'all'
        nleads = 12
    else:
        leads = args.leads.split(',')
        nleads = len(leads)

    label_csv = os.path.join(data_dir, 'labels.csv')

    train_folds, val_folds, test_folds = split_data(seed=args.seed)
    train_dataset = ECGDataset('train', data_dir, label_csv, train_folds,
                               leads)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers,
                              pin_memory=True)
    val_dataset = ECGDataset('val', data_dir, label_csv, val_folds, leads)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            pin_memory=True)
    test_dataset = ECGDataset('test', data_dir, label_csv, test_folds, leads)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
예제 #17
0
def top4_catboost_train(args):
    print('top4_catboost_train begin')
    if config.top4_DeepNN:
        model = getattr(models,
                        config.model_name)(num_classes=config.num_classes,
                                           channel_size=config.channel_size)
        model.load_state_dict(
            torch.load(args.ckpt, map_location='cpu')['state_dict'])
        model = model.to(device)
        print(config.model_name, args.ckpt)
    else:
        model = None
        print('no', config.model_name)
    val_dataset = ECGDataset(data_path=config.train_data, train=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=4)
    train_dataset = ECGDataset(data_path=config.train_data, train=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.batch_size,
                                  num_workers=4)
    train_df = top4_make_dateset(model,
                                 train_dataloader,
                                 config.top4_data_train,
                                 save=False)
    X_train, y_train = top4_getXy(train_df)
    val_df = top4_make_dateset(model,
                               val_dataloader,
                               config.top4_data_val,
                               save=False)
    X_validation, y_validation = top4_getXy(val_df)
    print('train_catboost finish dataset')

    model_list = [None] * y_train.shape[1]
    for tagi, model in enumerate(model_list):
        if model_list[tagi] == None:
            model = CatBoostClassifier(
                iterations=1000,
                random_seed=42,
                eval_metric='F1',
                learning_rate=0.03,  # todo 超参数选择
                task_type={
                    'cuda': 'GPU',
                    'cpu': 'CPU'
                }[device.type],
                od_type='Iter',  # 早停
                od_wait=40)
            model_list[tagi] = model

        model.fit(
            X_train,
            y_train.iloc[:, tagi],
            cat_features=config.top4_cat_features,
            eval_set=(X_validation, y_validation.iloc[:, tagi]),
            verbose=False,  # 打印
            plot=False  # 作图
        )
    y_pred_train = model_list_predict(model_list, X_train)
    # train_f1 = utils.calc_f1(y_train.values, y_pred_train)
    train_f1 = f1_score(y_train.values, y_pred_train, average='micro')
    y_pred = model_list_predict(model_list, X_validation)
    # val_f1 = utils.calc_f1(y_validation.values, y_pred)
    val_f1 = f1_score(y_validation, y_pred, average='micro')
    save_model_list(model_list,
                    os.path.join(config.ckpt, config.top4_catboost_model))
    print('catboost train_f1:%.3f\tval_f1:%.3f\n' % (train_f1, val_f1))
예제 #18
0
def transfer_train(args):
    print(args.model_name)
    config.train_data = config.train_data + 'trainsfer.pth'
    config.model_name = args.model_name
    model = getattr(models, config.model_name)()
    model = model.to(device)
    import dataset2
    train_dataset = dataset2.ECGDataset(data_path=config.train_data,
                                        train=True,
                                        transfer=True,
                                        transform=True)
    train_dataloader = DataLoader(train_dataset,
                                  collate_fn=my_collate_fn,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=6)
    val_dataset = ECGDataset(data_path=config.train_data,
                             train=False,
                             transfer=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=6)
    print("train_datasize", len(train_dataset), "val_datasize",
          len(val_dataset))
    # optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    # optimizer = optim.RMSprop(model.parameters(), lr=config.lr)
    w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
    criterion = utils.WeightedMultilabel2(w)
    #    criterion = utils.My_loss(w)
    # 模型保存文件夹
    model_save_dir = '%s/%s' % (config.ckpt, config.model_name + '_transfer')
    args.ckpt = model_save_dir
    # if args.ex: model_save_dir += args.ex
    best_f1 = -1
    lr = 3e-4
    start_epoch = 1
    stage = 1
    # 从上一个断点,继续训练
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)

    if args.resume:
        if os.path.exists(args.ckpt):  # 这里是存放权重的目录
            # model_save_dir = args.ckpt
            current_w = torch.load(os.path.join(args.ckpt, config.best_w))
            best_w = torch.load(os.path.join(model_save_dir, config.best_w))
            best_f1 = best_w['best_f']
            start_epoch = current_w['epoch'] + 1
            lr = current_w['lr']
            stage = current_w['stage']
            model.load_state_dict(current_w['state_dict'])
            # 如果中断点恰好为转换stage的点
            if start_epoch - 1 in config.stage_epoch:
                stage += 1
                lr /= config.lr_decay
                utils.adjust_learning_rate(optimizer, lr)
                model.load_state_dict(best_w['state_dict'])
            print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
    # =========>开始训练<=========
    val_loss = 10
    val_f1 = -1
    state = {}
    for epoch in range(start_epoch, 25 + 1):
        since = time.time()
        train_loss, train_f1, best_f1 = train_epoch(
            model, optimizer, criterion, train_dataloader, epoch, lr, best_f1,
            val_dataloader, model_save_dir, state, 50)
        # if epoch % 2 == 1:
        val_loss, val_f1, _, _ = val_epoch(model, criterion, val_dataloader)
        print(
            '#epoch:%02d stage:%d train_loss:%.3e train_f1:%.3f  val_loss:%0.3e val_f1:%.3f time:%s'
            % (epoch, stage, train_loss, train_f1, val_loss, val_f1,
               utils.print_time_cost(since)))
        state = {
            "state_dict": model.state_dict(),
            "epoch": epoch,
            "loss": val_loss,
            'f1': val_f1,
            'lr': lr,
            'stage': stage,
            "best_f": val_f1
        }
        if best_f1 < val_f1:
            save_ckpt(state, best_f1 < val_f1, model_save_dir)
            print('save best')
        else:
            save_ckpt(state, False, model_save_dir)
        best_f1 = max(best_f1, val_f1)

        if epoch in config.stage_epoch:
            stage += 1
            lr /= config.lr_decay
            print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
            utils.adjust_learning_rate(optimizer, lr)
예제 #19
0
def train(input_directory, output_directory):
    # model
    model = getattr(models, config.model_name)()

    # if args.ckpt and not args.resume:
    #     state = torch.load(args.ckpt, map_location='cpu')
    #     model.load_state_dict(state['state_dict'])
    #     print('train with pretrained weight val_f1', state['f1'])

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, config.num_classes)

    model = model.to(device)
    # data
    train_dataset = ECGDataset(data_path=config.train_data,
                               data_dir=input_directory,
                               train=True)
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.batch_size,
                                  shuffle=True,
                                  num_workers=6)
    val_dataset = ECGDataset(data_path=config.train_data,
                             data_dir=input_directory,
                             train=False)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=config.batch_size,
                                num_workers=4)

    print("train_datasize", len(train_dataset), "val_datasize",
          len(val_dataset))
    # optimizer and loss
    #optimizer = optim.Adam(model.parameters(), lr=config.lr)
    optimizer = radam.RAdam(model.parameters(),
                            lr=config.lr,
                            weight_decay=1e-4)  #config.lr
    #optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, dampening=0, weight_decay=1e-4, nesterov=False)
    w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
    criterion = utils.WeightedMultilabel(w)  ##   # utils.FocalLoss() #

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        'max',
        verbose=True,
        factor=0.1,
        patience=5,
        min_lr=1e-06,
        eps=1e-08)  #CosineAnnealingLR  CosineAnnealingWithRestartsLR
    #scheduler = pytorchtools.CosineAnnealingWithRestartsLR(optimizer,T_max=30, T_mult = 1.2, eta_min=1e-6)

    # optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, nesterov=True)
    # scheduler = pytorchtools.CosineAnnealingLR_with_Restart(optimizer, T_max=12, T_mult=1, model=model, out_dir='./snapshot',take_snapshot=True, eta_min=1e-9)

    # 模型保存文件夹
    model_save_dir = '%s/%s_%s' % (config.ckpt, config.model_name,
                                   time.strftime("%Y%m%d%H%M"))

    # if args.ex: model_save_dir += args.ex

    best_f1 = -1
    best_cm = -1
    lr = config.lr
    start_epoch = 1
    stage = 1

    # 从上一个断点,继续训练
    # if args.resume:
    #     if os.path.exists(args.ckpt):  # 这里是存放权重的目录
    #         model_save_dir = args.ckpt
    #         current_w = torch.load(os.path.join(args.ckpt, config.current_w))
    #         best_w = torch.load(os.path.join(model_save_dir, config.best_w))
    #         best_f1 = best_w['loss']
    #         start_epoch = current_w['epoch'] + 1
    #         lr = current_w['lr']
    #         stage = current_w['stage']
    #         model.load_state_dict(current_w['state_dict'])
    #         # 如果中断点恰好为转换stage的点
    #         if start_epoch - 1 in config.stage_epoch:
    #             stage += 1
    #             lr /= config.lr_decay
    #             utils.adjust_learning_rate(optimizer, lr)
    #             model.load_state_dict(best_w['state_dict'])
    #         print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))

    logger = Logger(logdir=model_save_dir, flush_secs=2)
    # =========>开始训练<=========
    for epoch in range(start_epoch, config.max_epoch + 1):
        since = time.time()
        train_loss, train_acc, train_f1, train_f2, train_g2, train_cm = train_epoch(
            model, optimizer, criterion, train_dataloader, show_interval=100)
        val_loss, val_acc, val_f1, val_f2, val_g2, val_cm = val_epoch(
            model, criterion, val_dataloader)

        # train_loss, train_f1 = train_beat_epoch(model, optimizer, criterion, train_dataloader, show_interval=100)
        # val_loss, val_f1 = val_beat_epoch(model, criterion, val_dataloader)

        print('#epoch:%02d, stage:%d, train_loss:%.3e, train_acc:%.3f, train_f1:%.3f, train_f2:%.3f, train_g2:%.3f,train_cm:%.3f,\n \
                val_loss:%0.3e, val_acc:%.3f, val_f1:%.3f, val_f2:%.3f, val_g2:%.3f, val_cm:%.3f,time:%s\n'
              % (epoch, stage, train_loss, train_acc,train_f1,train_f2,train_g2,train_cm, \
                val_loss, val_acc, val_f1, val_f2, val_g2, val_cm,utils.print_time_cost(since)))

        logger.log_value('train_loss', train_loss, step=epoch)
        logger.log_value('train_f1', train_f1, step=epoch)
        logger.log_value('val_loss', val_loss, step=epoch)
        logger.log_value('val_f1', val_f1, step=epoch)
        state = {
            "state_dict": model.state_dict(),
            "epoch": epoch,
            "loss": val_loss,
            'f1': val_f1,
            'lr': lr,
            'stage': stage
        }

        save_ckpt(state, best_cm < val_cm, model_save_dir, output_directory)
        best_cm = max(best_cm, val_cm)

        scheduler.step(val_cm)
        # scheduler.step()

        if val_cm < best_cm:
            epoch_cum += 1
        else:
            epoch_cum = 0


#         # if epoch in config.stage_epoch:
#         if epoch_cum == 5:
#             stage += 1
#             lr /= config.lr_decay
#             if lr < 1e-6:
#                 lr = 1e-6
#                 print("*" * 20, "step into stage%02d lr %.3ef" % (stage, lr))
#             best_w = os.path.join(model_save_dir, config.best_w)
#             model.load_state_dict(torch.load(best_w)['state_dict'])
#             print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
#             utils.adjust_learning_rate(optimizer, lr)

#         elif epoch_cum >= 12:
#             print("*" * 20, "step into stage%02d lr %.3ef" % (stage, lr))
#             break

        if epoch_cum >= 12:
            print("*" * 20, "step into stage%02d lr %.3ef" % (stage, lr))
            break
예제 #20
0
def train_cv(input_directory, output_directory):
    # model
    # 模型保存文件夹
    model_save_dir = '%s/%s_%s' % (
        config.ckpt, config.model_name + "_cv", time.strftime("%Y%m%d%H%M")
    )  #'%s/%s_%s' % (config.ckpt, args.model_name+"_cv", time.strftime("%Y%m%d%H%M"))
    for fold in range(config.kfold):
        print("***************************fold : {}***********************".
              format(fold))
        model = getattr(models, config.model_name)(fold=fold)
        # if args.ckpt and not args.resume:
        #     state = torch.load(args.ckpt, map_location='cpu')
        #     model.load_state_dict(state['state_dict'])
        #     print('train with pretrained weight val_f1', state['f1'])

        num_ftrs = model.fc.in_features
        model.fc = nn.Linear(num_ftrs, config.num_classes)

        #2019/11/11
        #save dense/fc weight for pretrain 55 classes
        # model = MyModel()
        # num_ftrs = model.classifier.out_features
        # model.fc = nn.Linear(55, config.num_classes)

        model = model.to(device)
        # data
        train_dataset = ECGDataset(data_path=config.train_data_cv.format(fold),
                                   data_dir=input_directory,
                                   train=True)

        train_dataloader = DataLoader(train_dataset,
                                      batch_size=config.batch_size,
                                      shuffle=True,
                                      drop_last=True,
                                      num_workers=6)

        val_dataset = ECGDataset(data_path=config.train_data_cv.format(fold),
                                 data_dir=input_directory,
                                 train=False)

        val_dataloader = DataLoader(val_dataset,
                                    batch_size=config.batch_size,
                                    drop_last=True,
                                    num_workers=4)

        print("fold_{}_train_datasize".format(fold), len(train_dataset),
              "fold_{}_val_datasize".format(fold), len(val_dataset))
        # optimizer and loss
        optimizer = radam.RAdam(
            model.parameters(),
            lr=config.lr)  #optim.Adam(model.parameters(), lr=config.lr)
        w = torch.tensor(train_dataset.wc, dtype=torch.float).to(device)
        criterion = utils.WeightedMultilabel(w)  ## utils.FocalLoss() #
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         'max',
                                                         verbose=True,
                                                         factor=0.1,
                                                         patience=5,
                                                         min_lr=1e-06,
                                                         eps=1e-08)

        # if args.ex: model_save_dir += args.ex
        # best_f1 = -1
        # lr = config.lr
        # start_epoch = 1
        # stage = 1

        best_f1 = -1
        best_cm = -1
        lr = config.lr
        start_epoch = 1
        stage = 1
        # 从上一个断点,继续训练
        #         if args.resume:
        #             if os.path.exists(args.ckpt):  # 这里是存放权重的目录
        #                 model_save_dir = args.ckpt
        #                 current_w = torch.load(os.path.join(args.ckpt, config.current_w))
        #                 best_w = torch.load(os.path.join(model_save_dir, config.best_w))
        #                 best_f1 = best_w['loss']
        #                 start_epoch = current_w['epoch'] + 1
        #                 lr = current_w['lr']
        #                 stage = current_w['stage']
        #                 model.load_state_dict(current_w['state_dict'])
        #                 # 如果中断点恰好为转换stage的点
        #                 if start_epoch - 1 in config.stage_epoch:
        #                     stage += 1
        #                     lr /= config.lr_decay
        #                     utils.adjust_learning_rate(optimizer, lr)
        #                     model.load_state_dict(best_w['state_dict'])
        #                 print("=> loaded checkpoint (epoch {})".format(start_epoch - 1))
        logger = Logger(logdir=model_save_dir, flush_secs=2)
        # =========>开始训练<=========
        for epoch in range(start_epoch, config.max_epoch + 1):
            since = time.time()
            train_loss, train_acc, train_f1, train_f2, train_g2, train_cm = train_epoch(
                model,
                optimizer,
                criterion,
                train_dataloader,
                show_interval=100)
            val_loss, val_acc, val_f1, val_f2, val_g2, val_cm = val_epoch(
                model, criterion, val_dataloader)

            # train_loss, train_f1 = train_beat_epoch(model, optimizer, criterion, train_dataloader, show_interval=100)
            # val_loss, val_f1 = val_beat_epoch(model, criterion, val_dataloader)

            print('#epoch:%02d, stage:%d, train_loss:%.3e, train_acc:%.3f, train_f1:%.3f, train_f2:%.3f, train_g2:%.3f,train_cm:%.3f,\n \
                    val_loss:%0.3e, val_acc:%.3f, val_f1:%.3f, val_f2:%.3f, val_g2:%.3f, val_cm:%.3f,time:%s\n'
                  % (epoch, stage, train_loss, train_acc,train_f1,train_f2,train_g2,train_cm, \
                    val_loss, val_acc, val_f1, val_f2, val_g2, val_cm,utils.print_time_cost(since)))

            logger.log_value('fold{}_train_loss'.format(fold),
                             train_loss,
                             step=epoch)
            logger.log_value('fold{}_train_f1'.format(fold),
                             train_f1,
                             step=epoch)
            logger.log_value('fold{}_val_loss'.format(fold),
                             val_loss,
                             step=epoch)
            logger.log_value('fold{}_val_f1'.format(fold), val_f1, step=epoch)
            state = {
                "state_dict": model.state_dict(),
                "epoch": epoch,
                "loss": val_loss,
                'f1': val_f1,
                'lr': lr,
                'stage': stage
            }

            save_ckpt_cv(state, best_cm < val_cm, model_save_dir, fold,
                         output_directory)
            best_cm = max(best_cm, val_cm)

            scheduler.step(val_cm)
            # scheduler.step()

            if val_cm < best_cm:
                epoch_cum += 1
            else:
                epoch_cum = 0

            # save_ckpt_cv(state, best_f1 < val_f1, model_save_dir,fold)
            # best_f1 = max(best_f1, val_f1)

            # if val_f1 < best_f1:
            #     epoch_cum += 1
            # else:
            #     epoch_cum = 0

            # if epoch in config.stage_epoch:
            # if epoch_cum == 5:
            #     stage += 1
            #     lr /= config.lr_decay
            #     if lr < 1e-6:
            #         lr = 1e-6
            #         print("*" * 20, "step into stage%02d lr %.3ef" % (stage, lr))
            #     best_w = os.path.join(model_save_dir, config.best_w_cv.format(fold))
            #     model.load_state_dict(torch.load(best_w)['state_dict'])
            #     print("*" * 10, "step into stage%02d lr %.3ef" % (stage, lr))
            #     utils.adjust_learning_rate(optimizer, lr)

            # elif epoch_cum >= 12:
            #     print("*" * 20, "step into stage%02d lr %.3ef" % (stage, lr))
            #     break

            if epoch_cum >= 12:
                print("*" * 20, "step into stage%02d lr %.3ef" % (stage, lr))
                break