def train(model_name, image_size):

    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
    header = ['Epoch', 'Learning rate', 'Time', 'Train Loss', 'Val Loss']

    if not os.path.isfile(snapshot_path + '/log.csv'):
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
    df_all = pd.read_csv(csv_path)

    kfold_path_train = '../data/fold_5_by_study/'
    kfold_path_val = '../data/fold_5_by_study_image/'

    for num_fold in range(5):
        print('fold_num:', num_fold)

        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([num_fold])

        f_train = open(
            kfold_path_train + 'fold' + str(num_fold) + '/train.txt', 'r')
        f_val = open(kfold_path_val + 'fold' + str(num_fold) + '/val.txt', 'r')
        c_train = f_train.readlines()
        c_val = f_val.readlines()
        f_train.close()
        f_val.close()
        c_train = [s.replace('\n', '') for s in c_train]
        c_val = [s.replace('\n', '') for s in c_val]

        # for debug
        # c_train = c_train[0:1000]
        # c_val = c_val[0:4000]

        print('train dataset study num:', len(c_train),
              '  val dataset image num:', len(c_val))
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(
                ['train dataset:',
                 len(c_train), '  val dataset:',
                 len(c_val)])
            writer.writerow([
                'train_batch_size:', train_batch_size, 'val_batch_size:',
                val_batch_size
            ])

        train_transform, val_transform = generate_transforms(image_size)
        train_loader, val_loader = generate_dataset_loader(
            df_all, c_train, train_transform, train_batch_size, c_val,
            val_transform, val_batch_size, workers)

        model = eval(model_name + '()')
        model = model.cuda()

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=0.0005,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=0.00002)
        scheduler = WarmRestart(optimizer, T_max=5, T_mult=1, eta_min=1e-5)
        model = torch.nn.DataParallel(model)
        loss_cls = torch.nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor(
            [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).cuda())

        trMaxEpoch = 80
        for epochID in range(0, trMaxEpoch):
            epochID = epochID + 0

            start_time = time.time()
            model.train()
            trainLoss = 0
            lossTrainNorm = 10

            if epochID < 10:
                pass
            elif epochID < 80:
                if epochID != 10:
                    scheduler.step()
                    scheduler = warm_restart(scheduler, T_mult=2)
            else:
                optimizer.param_groups[0]['lr'] = 1e-5

            for batchID, (input, target) in enumerate(train_loader):
                if batchID == 0:
                    ss_time = time.time()

                print(str(batchID) + '/' +
                      str(int(len(c_train) / train_batch_size)) + '     ' +
                      str((time.time() - ss_time) / (batchID + 1)),
                      end='\r')
                varInput = torch.autograd.Variable(input)
                target = target.view(-1, 6).contiguous().cuda()
                varTarget = torch.autograd.Variable(target.contiguous().cuda())
                varOutput = model(varInput)
                lossvalue = loss_cls(varOutput, varTarget)
                trainLoss = trainLoss + lossvalue.item()
                lossTrainNorm = lossTrainNorm + 1

                lossvalue.backward()
                optimizer.step()
                optimizer.zero_grad()
                del lossvalue

            trainLoss = trainLoss / lossTrainNorm

            if (epochID + 1) % 5 == 0 or epochID > 79 or epochID == 0:
                valLoss, auc, loss_list, loss_sum = epochVal(
                    model, val_loader, loss_cls, c_val, val_batch_size)

            epoch_time = time.time() - start_time

            if (epochID + 1) % 5 == 0 or epochID > 79:
                torch.save(
                    {
                        'epoch': epochID + 1,
                        'state_dict': model.state_dict(),
                        'valLoss': valLoss
                    }, snapshot_path + '/model_epoch_' + str(epochID) + '_' +
                    str(num_fold) + '.pth')

            result = [
                epochID,
                round(optimizer.state_dict()['param_groups'][0]['lr'], 6),
                round(epoch_time, 0),
                round(trainLoss, 5),
                round(valLoss, 5), 'auc:', auc, 'loss:', loss_list, loss_sum
            ]

            print(result)

            with open(snapshot_path + '/log.csv', 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(result)

        del model
Beispiel #2
0
def train_one_model(model_name, img_size, use_chexpert, path_data):
    RESIZE_SIZE = img_size

    train_transform = albumentations.Compose([
        albumentations.Resize(RESIZE_SIZE, RESIZE_SIZE),
        albumentations.OneOf([
            albumentations.RandomGamma(gamma_limit=(60, 120), p=0.9),
            albumentations.RandomBrightnessContrast(brightness_limit=0.2,
                                                    contrast_limit=0.2,
                                                    p=0.9),
            albumentations.CLAHE(clip_limit=4.0, tile_grid_size=(4, 4), p=0.9),
        ]),
        albumentations.OneOf([
            albumentations.Blur(blur_limit=4, p=1),
            albumentations.MotionBlur(blur_limit=4, p=1),
            albumentations.MedianBlur(blur_limit=4, p=1)
        ],
                             p=0.5),
        albumentations.HorizontalFlip(p=0.5),
        albumentations.ShiftScaleRotate(shift_limit=0.2,
                                        scale_limit=0.2,
                                        rotate_limit=20,
                                        interpolation=cv2.INTER_LINEAR,
                                        border_mode=cv2.BORDER_CONSTANT,
                                        p=1),
        albumentations.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225),
                                 max_pixel_value=255.0,
                                 p=1.0)
    ])
    val_transform = albumentations.Compose([
        albumentations.Resize(RESIZE_SIZE, RESIZE_SIZE, p=1),
        albumentations.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225),
                                 max_pixel_value=255.0,
                                 p=1.0)
    ])
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
    header = [
        'Epoch', 'Learning rate', 'Time', 'Train Loss', 'Val Loss',
        'best_thr_with_no_mask', 'best_dice_with_no_mask',
        'best_thr_without_no_mask', 'best_dice_without_no_mask'
    ]
    if not os.path.isfile(snapshot_path + '/log.csv'):
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
    kfold_path = path_data['k_fold_path']
    extra_data = path_data['extra_img_csv']

    for f_fold in range(5):
        num_fold = f_fold
        print(num_fold)

        with open(snapshot_path + '/log.csv', 'a', newline='') as f:

            writer = csv.writer(f)
            writer.writerow([num_fold])

        if use_chexpert:  # only use csv1
            df1 = pd.read_csv(csv_path)
            df2 = pd.read_csv(extra_data +
                              'chexpert_mask_{}.csv'.format(num_fold + 1))
            df_all = df1.append(df2, ignore_index=True)

            f_train = open(kfold_path + 'fold' + str(num_fold) + '/train.txt',
                           'r')
            f_val = open(kfold_path + 'fold' + str(num_fold) + '/val.txt', 'r')

            f_fake = open(
                extra_data + '/chexpert_list_{}.txt'.format(num_fold + 1), 'r')

            c_train = f_train.readlines()
            c_val = f_val.readlines()
            c_fake = f_fake.readlines()
            c_train = c_fake + c_train

            f_train.close()
            f_val.close()
            f_fake.close()

        else:
            df_all = pd.read_csv(csv_path)
            f_train = open(kfold_path + 'fold' + str(num_fold) + '/train.txt',
                           'r')
            f_val = open(kfold_path + 'fold' + str(num_fold) + '/val.txt', 'r')
            c_train = f_train.readlines()
            c_val = f_val.readlines()
            f_train.close()
            f_val.close()

        c_train = [s.replace('\n', '') for s in c_train]
        c_val = [s.replace('\n', '') for s in c_val]

        print('train dataset:', len(c_train),
              '  val dataset c_val_without_no_mask:', 476,
              '  val dataset c_val_with_no_mask:', len(c_val))
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                'train dataset:',
                len(c_train), '  val dataset c_val_without_no_mask:', 476,
                '  val dataset c_val_with_no_mask:',
                len(c_val)
            ])
            writer.writerow([
                'train_batch_size:', train_batch_size, 'val_batch_size:',
                val_batch_size
            ])

        train_loader, val_loader = generate_dataset_loader_cls_seg(
            df_all, c_train, train_transform, train_batch_size, c_val,
            val_transform, val_batch_size, workers)

        if model_name == 'deep_se50':
            from semantic_segmentation.network.deepv3 import DeepSRNX50V3PlusD_m1  # r
            model = DeepSRNX50V3PlusD_m1(1, SoftDiceLoss_binary())
        elif model_name == 'unet_ef3':
            from ef_unet import EfficientNet_3_unet
            model = EfficientNet_3_unet()
        elif model_name == 'unet_ef5':
            from ef_unet import EfficientNet_5_unet
            model = EfficientNet_5_unet()
        else:
            print('No model name in it')
            model = None

        model = apex.parallel.convert_syncbn_model(model).cuda()

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=5e-4,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=0)
        scheduler = WarmRestart(optimizer, T_max=5, T_mult=1, eta_min=1e-6)

        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        model = torch.nn.DataParallel(model)

        loss_seg = SoftDiceLoss_binary()

        trMaxEpoch = 44
        lossMIN = 100000
        val_dice_max = 0

        for epochID in range(0, trMaxEpoch):

            start_time = time.time()
            model.train()
            trainLoss = 30
            lossTrainNorm = 0
            trainLoss_cls = 0
            trainLoss_seg = 0

            if epochID < 40:
                if epochID != 0:
                    scheduler.step()
                    scheduler = warm_restart(scheduler, T_mult=2)
            elif epochID > 39 and epochID < 42:
                optimizer.param_groups[0]['lr'] = 1e-5
            else:
                optimizer.param_groups[0]['lr'] = 5e-6

            for batchID, (input, target_seg,
                          target_cls) in enumerate(train_loader):

                if batchID == 0:
                    ss_time = time.time()
                print(str(batchID) + '/' +
                      str(int(len(c_train) / train_batch_size)) + '     ' +
                      str((time.time() - ss_time) / (batchID + 1)),
                      end='\r')
                varInput = torch.autograd.Variable(input)
                varTarget_seg = torch.autograd.Variable(
                    target_seg.contiguous().cuda(async=True))

                varOutput_seg = model(varInput)
                varTarget_seg = varTarget_seg.float()

                lossvalue_seg = loss_seg(varOutput_seg, varTarget_seg)
                trainLoss_seg = trainLoss_seg + lossvalue_seg.item()

                lossvalue = lossvalue_seg
                lossTrainNorm = lossTrainNorm + 1
                optimizer.zero_grad()
                with amp.scale_loss(lossvalue, optimizer) as scaled_loss:
                    scaled_loss.backward()
                optimizer.step()

            trainLoss_seg = trainLoss_seg / lossTrainNorm
            trainLoss = trainLoss_seg

            best_thr_with_no_mask = -1
            best_dice_with_no_mask = -1
            best_thr_without_no_mask = -1
            best_dice_without_no_mask = -1
            valLoss_seg = -1

            if epochID % 1 == 0:
                valLoss_seg, best_thr_with_no_mask, best_dice_with_no_mask, best_thr_without_no_mask, best_dice_without_no_mask = epochVal(
                    model, val_loader, loss_seg, c_val, val_batch_size
                )  # (model, dataLoader, loss_seg, c_val, val_batch_size):

            epoch_time = time.time() - start_time

            if epochID % 1 == 0:
                torch.save(
                    {
                        'epoch':
                        epochID + 1,
                        'state_dict':
                        model.state_dict(),
                        'valLoss':
                        0,
                        'best_thr_with_no_mask':
                        best_thr_with_no_mask,
                        'best_dice_with_no_mask':
                        float(best_dice_with_no_mask),
                        'best_thr_without_no_mask':
                        best_thr_without_no_mask,
                        'best_dice_without_no_mask':
                        float(best_dice_without_no_mask)
                    }, snapshot_path + '/model_epoch_' + str(epochID) + '_' +
                    str(num_fold) + '.pth.tar')

            result = [
                epochID,
                round(optimizer.state_dict()['param_groups'][0]['lr'], 6),
                round(epoch_time, 0),
                round(trainLoss, 4),
                round(valLoss_seg, 4),
                round(best_thr_with_no_mask, 3),
                round(float(best_dice_with_no_mask), 3),
                round(best_thr_without_no_mask, 3),
                round(float(best_dice_without_no_mask), 3)
            ]
            print(result)
            with open(snapshot_path + '/log.csv', 'a', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(result)
        del model
Beispiel #3
0
def train_one_model(model_name, image_size):

    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)
    header = ['Epoch', 'Learning rate', 'Time', 'Train Loss', 'Val Loss']

    if not os.path.isfile(snapshot_path + '/log.csv'):
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
    df_all = pd.read_csv(csv_path)
    kfold_path_train = '../data/fold_5_by_study/'
    kfold_path_val = '../data/fold_5_by_study_image/'

    for num_fold in range(5):
        print(num_fold)
        # if num_fold in [0,1,2]:
        #    continue

        with open(snapshot_path + '/log.csv', 'a', newline='') as f:

            writer = csv.writer(f)
            writer.writerow([num_fold])

        f_train = open(
            kfold_path_train + 'fold' + str(num_fold) + '/train.txt', 'r')
        f_val = open(kfold_path_val + 'fold' + str(num_fold) + '/val.txt', 'r')
        c_train = f_train.readlines()
        c_val = f_val.readlines()
        f_train.close()
        f_val.close()
        c_train = [s.replace('\n', '') for s in c_train]
        c_val = [s.replace('\n', '') for s in c_val]

        c_train = c_train[0:100]
        c_val = c_val[0:4000]
        print('train dataset:', len(c_train), '  val dataset:', len(c_val))
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(
                ['train dataset:',
                 len(c_train), '  val dataset:',
                 len(c_val)])
            writer.writerow([
                'train_batch_size:', train_batch_size, 'val_batch_size:',
                val_batch_size
            ])

        train_transform, val_transform = generate_transforms(image_size)
        train_loader, val_loader = generate_dataset_loader_cls_seg(
            df_all, c_train, train_transform, train_batch_size, c_val,
            val_transform, val_batch_size, workers)

        model = eval(model_name + '()')

        # state = torch.load('/data/lanjun/kaggle_rsna2019/models_snapshot/DenseNet169_change_avg_test_context_256/model_epoch_59_'+str(num_fold)+'.pth')['state_dict']
        # new_state_dict = OrderedDict()
        # for k, v in state.items():
        #     name = k[7:]
        #     new_state_dict[name] = v
        # model.load_state_dict(new_state_dict)
        model = apex.parallel.convert_syncbn_model(model).cuda()

        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=0.0005,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=0.00002)
        scheduler = WarmRestart(optimizer, T_max=5, T_mult=1, eta_min=1e-5)
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
        model = torch.nn.DataParallel(model)

        def loss_cls_com(input, target):
            loss_1 = FocalLoss()
            loss_2 = torch.nn.BCEWithLogitsLoss()

            loss = loss_1(input, target) + loss_2(input, target)

            return loss

        loss_cls = torch.nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor(
            [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]).cuda())

        trMaxEpoch = 1

        for epochID in range(0, trMaxEpoch):
            epochID = epochID + 0
            start_time = time.time()
            model.train()
            trainLoss = 0
            lossTrainNorm = 10

            if epochID < 10:
                pass
            elif epochID < 80:

                if epochID != 10:
                    scheduler.step()
                    scheduler = warm_restart(scheduler, T_mult=2)
            else:
                optimizer.param_groups[0]['lr'] = 1e-5

            for batchID, (input, target) in enumerate(train_loader):
                if batchID == 0:
                    ss_time = time.time()
                print(str(batchID) + '/' +
                      str(int(len(c_train) / train_batch_size)) + '     ' +
                      str((time.time() - ss_time) / (batchID + 1)),
                      end='\r')
                varInput = torch.autograd.Variable(input)
                target = target.view(-1, 6).contiguous().cuda(async=True)
                varTarget = torch.autograd.Variable(
                    target.contiguous().cuda(async=True))
                varOutput = model(varInput)
                lossvalue = loss_cls(varOutput, varTarget)
                trainLoss = trainLoss + lossvalue.item()
                lossTrainNorm = lossTrainNorm + 1
                optimizer.zero_grad()
                with amp.scale_loss(lossvalue, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                optimizer.step()

                del lossvalue

            trainLoss = trainLoss / lossTrainNorm

            if (epochID + 1) % 5 == 0 or epochID > 79 or epochID == 0:

                valLoss, auc, loss_list, loss_sum = epochVal(
                    num_fold, model, val_loader, loss_cls, c_val,
                    val_batch_size)

            epoch_time = time.time() - start_time

            if (epochID + 1) % 5 == 0 or epochID > 79:

                torch.save(
                    {
                        'epoch': epochID + 1,
                        'state_dict': model.state_dict(),
                        'valLoss': valLoss
                    }, snapshot_path + '/model_epoch_' + str(epochID) + '_' +
                    str(num_fold) + '.pth')

            result = [
                epochID,
                round(optimizer.state_dict()['param_groups'][0]['lr'], 6),
                round(epoch_time, 0),
                round(trainLoss, 5),
                round(valLoss, 5), 'auc:', auc, loss_list, loss_sum
            ]
            print(result)

            with open(snapshot_path + '/log.csv', 'a', newline='') as f:
                writer = csv.writer(f)

                writer.writerow(result)

        del model
Beispiel #4
0
def main():
    data_dir = './input/another_cv_3/'
    epoch = 100
    batch_size = 1
    lr = 0.0001

    # 分割した画像データのディレクトリを指定する.
    x_train_dir = os.path.join(data_dir, 'train_images_png')
    y_train_dir = os.path.join(data_dir, 'train_images_inpainted_labels')

    x_valid_dir = os.path.join(data_dir, 'val_images_png')
    y_valid_dir = os.path.join(data_dir, 'val_images_inpainted_labels')

    # 分割した画像データのファイルリストを作成する.
    x_train_files = glob.glob(x_train_dir + '/*')
    y_train_files = glob.glob(y_train_dir + '/*')

    # ENCODER = 'resnet18'
    ENCODER = 'inceptionv4'
    ENCODER_WEIGHTS = 'imagenet'
    CLASSES = ['coastline']
    ACTIVATION = 'sigmoid'  # could be None for logits or 'softmax2d' for multicalss segmentation
    DEVICE = 'cuda'

    # create segmentation model with pretrained encoder
    # model = smp.Unet(
    #     encoder_name=ENCODER,
    #     encoder_weights=ENCODER_WEIGHTS,
    #     classes=len(CLASSES),
    #     activation=ACTIVATION,
    # )
    model = smp.FPN(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )

    # tensorboardの設定
    writer = SummaryWriter(comment=f'_ENCODER_{ENCODER}_LR_{lr}')
    log_dir = writer.log_dir

    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)

    train_dataset = Dataset(
        x_train_dir,
        y_train_dir,
        augmentation=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        classes=CLASSES,
    )

    valid_dataset = Dataset(
        x_valid_dir,
        y_valid_dir,
        augmentation=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        classes=CLASSES,
    )

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=0)

    # lossの設定
    loss = smp.utils.losses.DiceLoss()
    # loss = smp.utils.losses.BCELoss()
    metrics = [
        smp.utils.metrics.IoU(threshold=0.5),
    ]

    # optimizerの設定
    # optimizer = torch.optim.Adam([
    #     dict(params=model.parameters(), lr=lr),
    # ])

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=0.0005,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=0)
    scheduler = WarmRestart(optimizer, T_max=5, T_mult=1, eta_min=1e-5)

    # 学習ループの設定
    train_epoch = smp.utils.train.TrainEpoch(
        model,
        loss=loss,
        metrics=metrics,
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )

    valid_epoch = smp.utils.train.ValidEpoch(
        model,
        loss=loss,
        metrics=metrics,
        device=DEVICE,
        verbose=True,
    )

    max_score = 0

    # train accurascy, train loss, val_accuracy, val_loss をグラフ化できるように設定.
    x_epoch_data = []
    train_dice_loss = []
    train_iou_score = []
    valid_dice_loss = []
    valid_iou_score = []

    for i in range(epoch):

        if i < 30:

            if i != 0:
                scheduler.step()
                scheduler = warm_restart(scheduler, T_mult=2)
        elif i > 29 and i < 32:
            optimizer.param_groups[0]['lr'] = 1e-5
        else:
            optimizer.param_groups[0]['lr'] = 5e-6

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)
        valid_logs = valid_epoch.run(valid_loader)

        x_epoch_data.append(i)
        train_dice_loss.append(train_logs['dice_loss'])
        train_iou_score.append(train_logs['iou_score'])
        valid_dice_loss.append(valid_logs['dice_loss'])
        valid_iou_score.append(valid_logs['iou_score'])

        writer.add_scalar('Loss/train', train_logs['dice_loss'], i)
        writer.add_scalar('iou/train', train_logs['iou_score'], i)
        writer.add_scalar('Loss/valid', valid_logs['dice_loss'], i)
        writer.add_scalar('iou/valid', valid_logs['iou_score'], i)

        # do something (save model, change lr, etc.)
        if max_score < valid_logs['iou_score']:
            max_score = valid_logs['iou_score']
            torch.save(model, log_dir + '/best_model.pth')
            print('Model saved!')
Beispiel #5
0
def train_one_model(model_name):

    snapshot_path = path_data['snapshot_path'] + model_name + '_' + str(
        Image_size) + '_25_local_val'
    if not os.path.exists(snapshot_path):
        os.makedirs(snapshot_path)

    df_all = pd.read_csv(csv_path)
    kfold_path = path_data['k_fold_path']

    for num_fold in range(5):
        print(num_fold)
        # if num_fold in [0,1,2]:
        #     continue

        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([num_fold])
            writer.writerow([
                'train_batch_size:',
                str(train_batch_size), 'val_batch_size:',
                str(val_batch_size), 'backbone', model_name, 'Image_size',
                Image_size
            ])

        f_train = open(kfold_path + 'fold' + str(num_fold) + '/train.txt', 'r')
        f_val = open(kfold_path + 'fold' + str(num_fold) + '/val.txt', 'r')
        c_train = f_train.readlines()
        c_val = f_val.readlines()
        f_train.close()
        f_val.close()
        c_train = [s.replace('\n', '') for s in c_train]
        c_val = [s.replace('\n', '') for s in c_val]

        print('train dataset:', len(c_train), '  val dataset:', len(c_val))
        with open(snapshot_path + '/log.csv', 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(
                ['train dataset:',
                 len(c_train), '  val dataset:',
                 len(c_val)])
        # c_train = c_train[0:500]
        # c_val = c_val[0:2000]

        train_loader, val_loader = generate_dataset_loader_25(
            df_all, c_train, train_transform, train_batch_size, c_val,
            val_transform, val_batch_size, workers)

        model = DenseNet121_change_avg(25, True)
        model = torch.nn.DataParallel(model).cuda()

        optimizer = torch.optim.Adamax(model.parameters(),
                                       lr=0.001,
                                       betas=(0.9, 0.999),
                                       eps=1e-08,
                                       weight_decay=0)
        scheduler = WarmRestart(optimizer, T_max=10, T_mult=1, eta_min=1e-5)

        loss = torch.nn.BCELoss(size_average=True)

        trMaxEpoch = 42
        lossMIN = 100000
        val_f1_mean = 0
        val_auc_mean = 0

        for epochID in range(0, trMaxEpoch):

            start_time = time.time()
            model.train()
            trainLoss = 0
            lossTrainNorm = 0

            for batchID, (input, target) in enumerate(train_loader):

                target = target.view(-1, 25).contiguous().cuda(async=True)
                varInput = torch.autograd.Variable(input)
                varTarget = torch.autograd.Variable(target)
                varOutput = model(varInput)
                # print(varOutput.shape, varTarget.shape)
                lossvalue = loss(varOutput, varTarget)
                trainLoss = trainLoss + lossvalue.item()
                lossTrainNorm = lossTrainNorm + 1
                optimizer.zero_grad()
                lossvalue.backward()
                optimizer.step()
            if epochID < 39:
                scheduler.step()
                scheduler = warm_restart(scheduler, T_mult=2)

            else:
                optimizer = torch.optim.Adam(model.parameters(),
                                             lr=1e-5,
                                             betas=(0.9, 0.999),
                                             weight_decay=1e-5,
                                             eps=1e-08,
                                             amsgrad=True)

            trainLoss = trainLoss / lossTrainNorm
            if (epochID + 1) % 10 == 0 or epochID > 39 or epochID == 0:

                valLoss, val_auc, val_threshold, val_f1, precision_list, recall_list = epochVal(
                    model, val_loader, optimizer, scheduler, loss)

            epoch_time = time.time() - start_time
            if valLoss < lossMIN:
                lossMIN = valLoss
                torch.save(
                    {
                        'epoch': epochID + 1,
                        'state_dict': model.state_dict(),
                        'best_loss': lossMIN,
                        'optimizer': optimizer.state_dict(),
                        'val_threshold': val_threshold,
                        'val_f1': val_f1,
                        'val_f1_mean': np.mean(val_f1),
                        'val_auc': val_auc,
                        'val_auc_mean': np.mean(val_auc)
                    }, snapshot_path + '/model_min_loss_' + str(num_fold) +
                    '.pth.tar')
            if val_f1_mean < np.mean(val_f1):
                val_f1_mean = np.mean(val_f1)
                torch.save(
                    {
                        'epoch': epochID + 1,
                        'state_dict': model.state_dict(),
                        'best_loss': lossMIN,
                        'optimizer': optimizer.state_dict(),
                        'val_threshold': val_threshold,
                        'val_f1': val_f1,
                        'val_f1_mean': np.mean(val_f1),
                        'val_auc': val_auc,
                        'val_auc_mean': np.mean(val_auc)
                    }, snapshot_path + '/model_max_f1_' + str(num_fold) +
                    '.pth.tar')
            if val_auc_mean < np.mean(val_auc):
                val_auc_mean = np.mean(val_auc)
                torch.save(
                    {
                        'epoch': epochID + 1,
                        'state_dict': model.state_dict(),
                        'best_loss': lossMIN,
                        'optimizer': optimizer.state_dict(),
                        'val_threshold': val_threshold,
                        'val_f1': val_f1,
                        'val_f1_mean': np.mean(val_f1),
                        'val_auc': val_auc,
                        'val_auc_mean': np.mean(val_auc)
                    }, snapshot_path + '/model_max_auc_' + str(num_fold) +
                    '.pth.tar')

            if (epochID + 1) % 10 == 0:
                torch.save(
                    {
                        'epoch': epochID + 1,
                        'state_dict': model.state_dict(),
                        'best_loss': lossMIN,
                        'optimizer': optimizer.state_dict(),
                        'val_threshold': val_threshold,
                        'val_f1': val_f1,
                        'val_f1_mean': np.mean(val_f1),
                        'val_auc': val_auc,
                        'val_auc_mean': np.mean(val_auc)
                    }, snapshot_path + '/model_epoch_' + str(epochID) + '_' +
                    str(num_fold) + '.pth.tar')

            result = [
                epochID,
                round(optimizer.state_dict()['param_groups'][0]['lr'], 5),
                round(trainLoss, 4),
                round(valLoss, 4),
                round(epoch_time, 0),
                round(np.mean(val_f1), 3),
                round(np.mean(val_auc), 4)
            ]
            print(result)
            # print(val_f1)
            with open(snapshot_path + '/log.csv', 'a', newline='') as f:
                writer = csv.writer(f)

                writer.writerow(result + val_threshold + val_f1 + val_auc +
                                precision_list + recall_list)

        del model