def main(cfg):
    torch.multiprocessing.set_sharing_strategy('file_system')
    seed_torch(seed=cfg.seed)

    output_dir = os.path.join(cfg.output_dir, cfg.desc)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    train_dataset = build_dataset(cfg, phase='train')
    test_dataset = build_dataset(cfg, phase='test')

    if cfg.DATA.weighted_sample:
        train_dl = DataLoader(train_dataset,
                              batch_size=32,
                              sampler=WeightedRandomSampler(
                                  train_dataset.get_label_weight(),
                                  num_samples=5000),
                              num_workers=0,
                              drop_last=True)
    else:
        train_dl = DataLoader(train_dataset,
                              batch_size=32,
                              shuffle=True,
                              num_workers=16,
                              drop_last=True)
    test_dl = DataLoader(test_dataset,
                         batch_size=32,
                         num_workers=8,
                         drop_last=True)

    solver = Solver(cfg)

    solver.train(train_dl, test_dl)
Esempio n. 2
0
def load_emnist(val_size=10000, seed=None):
    """Return the train (55k), val (5k, randomly drawn from the original test set) and test (10k) dataset for MNIST."""
    image_transform = transforms.Compose([
        # EMNIST images are flipped and rotated by default, fix this here.
        transforms.RandomHorizontalFlip(1),
        transforms.RandomRotation((90, 90)),
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081, ))
    ])
    target_transform = lambda x: x - 1  # make labels start at 0 instead of 1

    raw_train_dataset = datasets.EMNIST('data/emnist',
                                        split='letters',
                                        train=True,
                                        download=True,
                                        transform=image_transform,
                                        target_transform=target_transform)
    test_dataset = datasets.EMNIST('data/emnist',
                                   split='letters',
                                   train=False,
                                   download=True,
                                   transform=image_transform,
                                   target_transform=target_transform)

    # Split 5k samples from the train dataset for validation (similar to Sacramento et al. 2018).
    utils.seed_torch(seed)
    train_dataset, val_dataset = torch.utils.data.dataset.random_split(
        raw_train_dataset, (len(raw_train_dataset) - val_size, val_size))

    return train_dataset, val_dataset, test_dataset
Esempio n. 3
0
def main():
    args = parse_args()

    # set random seed
    utils.seed_torch(42)

    model = PandaNet(arch=args.arch, pretrained=False)
    model_path = os.path.join(configure.MODEL_PATH,
                              f'{args.arch}_fold_{args.fold}_128_12.pth')

    model.load_state_dict(torch.load(model_path))
    model.cuda()

    df = pd.read_csv(configure.TRAIN_DF)

    dataset = PandaDataset(df=df, data_dir=configure.TRAIN_IMAGE_PATH)

    dataloader = DataLoader(dataset=dataset,
                            batch_size=args.batch_size,
                            num_workers=args.num_workers,
                            pin_memory=False,
                            shuffle=False)

    preds = predict(dataloader, model)
    score = utils.quadratic_weighted_kappa(preds, df['isup_grade'])
    print(score)
Esempio n. 4
0
def main():
    with timer('load data'):
        df = pd.read_csv(TRAIN_PATH)
        df = df[df.Image != "ID_6431af929"].reset_index(drop=True)
        df.loc[df.pre_SOPInstanceUID=="ID_6431af929", "pre1_SOPInstanceUID"] = df.loc[
            df.pre_SOPInstanceUID=="ID_6431af929", "Image"]
        df.loc[df.post_SOPInstanceUID == "ID_6431af929", "post1_SOPInstanceUID"] = df.loc[
            df.post_SOPInstanceUID == "ID_6431af929", "Image"]
        df.loc[df.prepre_SOPInstanceUID == "ID_6431af929", "pre2_SOPInstanceUID"] = df.loc[
            df.prepre_SOPInstanceUID == "ID_6431af929", "pre1_SOPInstanceUID"]
        df.loc[df.postpost_SOPInstanceUID == "ID_6431af929", "post2_SOPInstanceUID"] = df.loc[
            df.postpost_SOPInstanceUID == "ID_6431af929", "post1_SOPInstanceUID"]
        y = df[TARGET_COLUMNS].values
        df = df[["Image", "pre1_SOPInstanceUID", "post1_SOPInstanceUID", "pre2_SOPInstanceUID", "post2_SOPInstanceUID"]]
        gc.collect()

    with timer('preprocessing'):
        train_augmentation = Compose([
            CenterCrop(512 - 50, 512 - 50, p=1.0),
            HorizontalFlip(p=0.5),
            OneOf([
                ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
                GridDistortion(p=0.5),
                OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)
            ], p=0.5),
            Rotate(limit=30, border_mode=0, p=0.7),
            Resize(img_size, img_size, p=1)
        ])

        train_dataset = RSNADataset(df, y, img_size, IMAGE_PATH, id_colname=ID_COLUMNS,
                                    transforms=train_augmentation, black_crop=False, subdural_window=True, user_window=2)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
        del df, train_dataset
        gc.collect()

    with timer('create model'):
        model = CnnModel(num_classes=N_CLASSES, encoder="se_resnext50_32x4d", pretrained="imagenet", pool_type="avg")
        if model_path is not None:
            model.load_state_dict(torch.load(model_path))
        model.to(device)

        criterion = torch.nn.BCEWithLogitsLoss(weight=torch.FloatTensor([2, 1, 1, 1, 1, 1]).cuda())
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, eps=1e-4)
        model = torch.nn.DataParallel(model)

    with timer('train'):
        for epoch in range(1, epochs + 1):
            if epoch == 5:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = param_group['lr'] * 0.1
            seed_torch(SEED + epoch)

            LOGGER.info("Starting {} epoch...".format(epoch))
            tr_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
            LOGGER.info('Mean train loss: {}'.format(round(tr_loss, 5)))

            torch.save(model.module.state_dict(), 'models/{}_ep{}.pth'.format(EXP_ID, epoch))
Esempio n. 5
0
def main():
    args = parse_args()

    # set random seed
    utils.seed_torch(args.seed)

    # Setup CUDA, GPU
    if not torch.cuda.is_available():
        print("cuda is not available")
        exit(0)

    train_loader, valid_loader = datasets.get_dataloader(
        fold=args.fold,
        batch_size=args.batch_size,
        num_workers=args.num_workers)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.CrossEntropyLoss()

    model = PandaNet(arch=args.arch)
    model.to("cuda")

    metric = ArcMarginProduct(in_features=512, out_features=6).to("cuda")

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 30, 60, 90], gamma=0.5)

    """ Train the model """
    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_prefix = f'{current_time}_{args.arch}_fold_{args.fold}_{args.tile_size}_{args.num_tiles}'
    log_dir = os.path.join(configure.TRAINING_LOG_PATH,
                           log_prefix)

    tb_writer = None
    if args.log:
        tb_writer = SummaryWriter(log_dir=log_dir)

    best_score = 0.0
    model_path = os.path.join(configure.MODEL_PATH,
                              f'{args.arch}_fold_{args.fold}_{args.tile_size}_{args.num_tiles}.pth')

    print(f'training started: {current_time}')
    for epoch in range(args.epochs):
        train_loss = train(
            dataloader=train_loader,
            model=model,
            criterion=criterion,
            metric=metric,
            optimizer=optimizer)

    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    print(f'training finished: {current_time}')
Esempio n. 6
0
def process_image(image):
    """Process image."""
    seed_torch(seed=42)
    proc_image = transforms.Compose([
        Rescale(256),
        RandomCrop(224),
        ToTensor(),
    ])(image)
    proc_image = proc_image.unsqueeze(0).to(DEVICE)
    return proc_image
Esempio n. 7
0
def train():
    """Train"""
    client = storage.Client(PROJECT)
    raw_bucket = client.get_bucket(RAW_BUCKET)
    bucket = client.get_bucket(BUCKET)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"  Device found = {device}")

    metadata_df = (
        pd.read_csv(f"gs://{RAW_BUCKET}/{RAW_DATA_DIR}/metadata.csv").query(
            "view == 'PA'")  # taking only PA view
    )

    print("Split train and validation data")
    proc_data = ImageDataset(
        root_dir=BASE_DIR,
        image_dir=PREPROCESSED_DIR,
        df=metadata_df,
        bucket=bucket,
        transform=ToTensor(),
    )
    seed_torch(seed=42)
    valid_size = int(len(proc_data) * 0.2)
    train_data, valid_data = torch.utils.data.random_split(
        proc_data, [len(proc_data) - valid_size, valid_size])
    train_loader = DataLoader(train_data,
                              batch_size=CFG.batch_size,
                              shuffle=True,
                              drop_last=True)
    valid_loader = DataLoader(valid_data,
                              batch_size=CFG.batch_size,
                              shuffle=False)

    print("Train model")
    se_model_blob = raw_bucket.blob(CFG.pretrained_weights)
    model = CustomSEResNeXt(
        BytesIO(se_model_blob.download_as_string()),
        device,
        CFG.n_classes,
        save=CFG.pretrained_model_path,
    )
    train_fn(model, train_loader, valid_loader, device)

    print("Evaluate")
    y_probs, y_val = predict(model, valid_loader, device)
    y_preds = y_probs.argmax(axis=1)

    compute_log_metrics(y_val, y_probs[:, 1], y_preds)
Esempio n. 8
0
def main():
    args = parse_args()
    seed_torch(seed=42)

    model_save_path = os.path.join(SAVE_MODEL_PATH, args.model)
    training_history_path = os.path.join(TRAINING_HISTORY_PATH, args.model)

    df_train_path = os.path.join(SPLIT_FOLDER,
                                 "fold_{}_train.csv".format(args.fold))
    df_train = pd.read_csv(df_train_path)

    df_valid_path = os.path.join(SPLIT_FOLDER,
                                 "fold_{}_valid.csv".format(args.fold))
    df_valid = pd.read_csv(df_valid_path)

    print("Training on {} images, Fish: {}, Flower: {}, Gravel: {}, Sugar: {}".
          format(len(df_train), df_train['isFish'].sum(),
                 df_train['isFlower'].sum(), df_train['isGravel'].sum(),
                 df_train['isSugar'].sum()))
    print("Validate on {} images, Fish: {}, Flower: {}, Gravel: {}, Sugar: {}".
          format(len(df_valid), df_valid['isFish'].sum(),
                 df_valid['isFlower'].sum(), df_valid['isGravel'].sum(),
                 df_valid['isSugar'].sum()))
    model_trainer, best = None, None
    if args.model == "UResNet34":
        model_trainer = TrainerSegmentation(
            model=UResNet34(),
            num_workers=args.num_workers,
            batch_size=args.batch_size,
            num_epochs=args.num_epochs,
            model_save_path=model_save_path,
            training_history_path=training_history_path,
            model_save_name=args.model,
            fold=args.fold)
    elif args.model == "ResNet34":
        model_trainer = TrainerClassification(
            model=ResNet34(),
            num_workers=args.num_workers,
            batch_size=args.batch_size,
            num_epochs=args.num_epochs,
            model_save_path=model_save_path,
            training_history_path=training_history_path,
            model_save_name=args.model,
            fold=args.fold,
            mixup=args.mixup)

    best = model_trainer.start()

    print("Training is done! best is {}".format(best))
def main(cfg):
    torch.multiprocessing.set_sharing_strategy('file_system')
    seed_torch(seed=cfg.seed)

    output_dir = os.path.join(cfg.output_dir, cfg.desc)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    test_dataset = build_dataset(cfg, phase='test')

    test_dl = DataLoader(test_dataset,
                         batch_size=32,
                         num_workers=8,
                         drop_last=True)

    solver = Solver(cfg, use_tensorboardx=False)
    with torch.no_grad():
        solver.val(test_dl, epoch=args.epoch)
Esempio n. 10
0
def load_mnist(val_size=5000, seed=None):
    """Return the train (55k), val (5k, randomly drawn from the original test set) and test (10k) dataset for MNIST."""
    image_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.1307, ), (0.3081, ))])

    raw_train_dataset = datasets.MNIST('data/mnist',
                                       train=True,
                                       download=True,
                                       transform=image_transform)
    test_dataset = datasets.MNIST('data/mnist',
                                  train=False,
                                  download=True,
                                  transform=image_transform)

    # Split 5k samples from the train dataset for validation (similar to Sacramento et al. 2018).
    utils.seed_torch(seed)
    train_dataset, val_dataset = torch.utils.data.dataset.random_split(
        raw_train_dataset, (len(raw_train_dataset) - val_size, val_size))

    return train_dataset, val_dataset, test_dataset
def train_a_kfold(cfg: Dict, cfg_name: str, output_path: Path) -> None:
    # Checkpoint callback
    kfold = cfg.Data.dataset.kfold
    checkpoint_callback = MyModelCheckpoint(
        model_name=cfg.Model.base,
        kfold=kfold,
        cfg_name=cfg_name,
        filepath=str(output_path),
        verbose=True,  # print when save result, not must
    )

    # Logger
    logger_name = f"kfold_{str(kfold).zfill(2)}.csv"
    mylogger = MyLogger(logger_df_path=output_path / logger_name)

    # Trainer
    seed_torch(cfg.General.seed)
    seed_everything(cfg.General.seed)
    debug = cfg.General.debug
    trainer = Trainer(
        logger=mylogger,
        max_epochs=5 if debug else cfg.General.epoch,
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=False,
        train_percent_check=0.02 if debug else 1.0,
        val_percent_check=0.06 if debug else 1.0,
        gpus=cfg.General.gpus,
        use_amp=cfg.General.fp16,
        amp_level=cfg.General.amp_level,
        distributed_backend=cfg.General.multi_gpu_mode,
        log_save_interval=5 if debug else 200,
        accumulate_grad_batches=cfg.General.grad_acc,
        deterministic=True,
    )

    # # Lightning module and start training
    model = LightningModuleReg(cfg)
    trainer.fit(model)
Esempio n. 12
0
def main():
    args = parse_args()
    seed_torch(args.seed)
    config = get_config(args.config)
    folds = list(set(pd.read_csv(config['dataset']['df_path'])['fold']))
    if args.ignore_fold is not None:
        folds.remove(args.ignore_fold)
    if args.val_fold:
        avg_metrics = run_fold(config, args, args.val_fold)
    elif args.cv:
        avg_metrics = {}
        best_metrics = []
        best_epochs = {}
        for val_fold in folds:
            train_folds = [i for i in folds if i != val_fold]
            metrics = run_fold(config, args, train_folds, val_fold)
            metrics = {int(k.split("_")[1]) + 1: v for k, v in metrics.items()}
            if len(avg_metrics) == 0:
                avg_metrics = {k: [v] for k, v in metrics.items()}
            else:
                for k, v in metrics.items():
                    avg_metrics[k].append(v)
            best_metrics.append(np.max(list(metrics.values())))
            best_epochs[val_fold] = max(metrics, key=metrics.get)

        avg_metrics = {k: np.average(v) for k, v in avg_metrics.items()}
        avg_metrics['best_avg_epoch'] = max(avg_metrics, key=avg_metrics.get)
        avg_metrics['best'] = np.average(best_metrics)
        avg_metrics['best_epochs'] = best_epochs

        clear_checkpoints(os.path.join(args.logdir, args.model_name),
                          avg_metrics['best_avg_epoch'])
        path = os.path.join(args.logdir, args.model_name, "avg_metrics.json")
        with open(path, 'w') as f:
            json.dump(avg_metrics, f)
    print(avg_metrics)
Esempio n. 13
0
from utils import seed_torch

import torch
seed_torch()

from .base_loss import BaseLoss


class CrossEntropyLoss(BaseLoss):
    def __init__(self, **kwargs):
        assert kwargs[
            'train_type'] == 'classification', "Cross Entropy Loss can only be used for classification."
        super().__init__()
        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, outputs, labels):
        return self.criterion(outputs, labels)
Esempio n. 14
0
from spo_model import SPOModel, EntityLink_entity_vector
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from tokenize_pkg.tokenize import Tokenizer
from tqdm import tqdm as tqdm
import torch.nn as nn
from utils import seed_torch, read_data, load_glove, calc_f1, get_threshold
from pytorch_pretrained_bert import BertTokenizer,BertAdam
import logging
import time
from torch.nn import functional as F
from sklearn.model_selection import KFold, train_test_split
from sklearn.externals import joblib

file_namne = 'data/raw_data/train.json'
train_part, valid_part = data_manager.read_entity_embedding(file_name=file_namne,train_num=10000000)
seed_torch(2019)
print('train size %d, valid size %d', (len(train_part), len(valid_part)))


data_all = np.array(train_part)
kfold = KFold(n_splits=5, shuffle=False, random_state=2019)
pred_vector = []
round = 0
for train_index, test_index in kfold.split(np.zeros(len(train_part))):
    train_part = data_all[train_index]
    valid_part = data_all[test_index]

    BERT_MODEL = 'bert-base-chinese'
    CASED = False
    t = BertTokenizer.from_pretrained(
        BERT_MODEL,
Esempio n. 15
0
    k_fold(features, labels, net, config)


def train(config):
    net = LinearNet(config)
    features, labels = load_data_for_linear(config)
    dataset = Data.TensorDataset(features, labels)
    data_iter = Data.DataLoader(dataset, shuffle=True)
    loss = torch.nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    for epoch in range(config.num_epochs):
        for X, y in data_iter:
            l = loss(net(X), y.view(-1, 1))
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
        print('{}/{} epoch, loss {}'.format(epoch + 1, config.num_epochs, l.item()))


if __name__ == '__main__':
    from config import config as conf
    seed_torch(conf.random_seed)
    train(conf)
    # train_linear(conf)
    # net = LinearNet(config=conf)
    # features, labels = load_data_for_linear(conf)
    # k_fold(features, labels, net, conf)
    # for param in net.parameters():
    #     torch.nn.init.normal_(param, mean=0, std=0.01)
    # train(net, myconfig)
Esempio n. 16
0
def main(args):

    #denoiser = VQ_CVAE(128, k=512, num_channels=3)
    #denoiser.load_state_dict(torch.load("/mnt/home2/dlongo/eegML/VQ-VAE-master/vq_vae/saved_models/train.pt"))
    #denoiser = torch.no_grad(denoiser)
    #denoiser.cuda()
    #denoiser = nn.DataParallel(denoiser)
    # Get device
    device, args.gpu_ids = utils.get_available_devices()
    args.train_batch_size *= max(1, len(args.gpu_ids))
    args.test_batch_size *= max(1, len(args.gpu_ids))

    # Set random seed
    utils.seed_torch(seed=SEED)

    # Get save directories
    train_save_dir = utils.get_save_dir(args.save_dir, training=True)
    args.train_save_dir = train_save_dir

    # Save args
    args_file = os.path.join(train_save_dir, ARGS_FILE_NAME)
    with open(args_file, 'w') as f:
        json.dump(vars(args), f, indent=4, sort_keys=True)

    # Set up logging and devices
    log = utils.get_logger(train_save_dir, 'train_denoised')
    tbx = SummaryWriter(train_save_dir)
    log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True)))

    if args.cross_val:
        # Loop over folds
        for fold_idx in range(args.num_folds):
            log.info('Starting fold {}...'.format(fold_idx))

            # Train
            fold_save_dir = os.path.join(train_save_dir,
                                         'fold_' + str(fold_idx))
            if not os.path.exists(fold_save_dir):
                os.makedirs(fold_save_dir)

            # Training on current fold...
            train_fold(args,
                       device,
                       fold_save_dir,
                       log,
                       tbx,
                       cross_val=True,
                       fold_idx=fold_idx)
            best_path = os.path.join(fold_save_dir, 'best.pth.tar')

            # Predict on current fold with best model..
            if args.model_name == 'SeizureNet':
                model = SeizureNet(args)

            model = nn.DataParallel(model, args.gpu_ids)
            model, _ = utils.load_model(model, best_path, args.gpu_ids)

            model.to(device)
            results = evaluate_fold(model,
                                    args,
                                    fold_save_dir,
                                    device,
                                    cross_val=True,
                                    fold_idx=fold_idx,
                                    is_test=True,
                                    write_outputs=True)

            # Log to console
            results_str = ', '.join('{}: {:05.2f}'.format(k, v)
                                    for k, v in results.items())
            print('Fold {} test results: {}'.format(fold_idx, results_str))
            log.info('Finished fold {}...'.format(fold_idx))
    else:
        # no cross-validation
        # Train
        train_fold(args, device, train_save_dir, log, tbx, cross_val=False)
        best_path = os.path.join(train_save_dir, 'best.pth.tar')

        if args.model_name == 'SeizureNet':
            model = SeizureNet(args)

        model = nn.DataParallel(model, args.gpu_ids)
        model, _ = utils.load_model(model, best_path, args.gpu_ids)

        model.to(device)
        results = evaluate_fold(model,
                                args,
                                train_save_dir,
                                device,
                                cross_val=False,
                                fold_idx=None,
                                is_test=True,
                                write_outputs=True)

        # Log to console
        results_str = ', '.join('{}: {:05.2f}'.format(k, v)
                                for k, v in results.items())
        print('Test set prediction results: {}'.format(results_str))
Esempio n. 17
0
    print('Loading pre-trained models ...')
    model1, model2, model3, model4, model5, config = load_checkpoint(self_pretrained_model_path)
    config.data_root = '/home/cfang/works/COVID-19/PMP/data/'
    
    # finetune config
    config.finetune_repeat = 2
    config.finetune_epoch = 100
    config.batch_size = 20
    config.few_shot_num = 10
    config.init_tau = 5
    config.tau_lr = 0.1
    config.lr = 0.01
    config.cos_loss_margin = 0.2
    config.cos_loss_gamma = 0.5
    config.seed = 8888
    seed_torch(seed=config.seed)

    # load clinical data
    config.data = config.data_root + 'data3'
    raw_data3 = load_data(config, 'train')

    # for test_data_flag in range(1,-1,-2): # training on dataset2
    for test_data_flag in range(1): # training on dataset3
        for finetune_repeat in range(config.finetune_repeat):
            check_dir = time.strftime('%Y-%m-%d %H:%M:%S')
            check_dir = 'DA_on_data3'+'_'+check_dir+'_repeat_'+str(finetune_repeat)+'_few_shot_'+str(config.few_shot_num)
            os.mkdir(os.path.join('checkpoints', check_dir))
            os.mkdir(os.path.join('res', check_dir))

            # get training and test data index
            if test_data_flag==0:
Esempio n. 18
0
                'lr_scheduler': scheduler.state_dict(),
                'results': results,
                'train_params': train_params,
                'best_acc': best_acc
            }, save_dir, is_best, save_epoch)


if __name__ == "__main__":
    import imp, argparse, shutil, copy
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default="configs/config_bnn.py")
    parser.add_argument('--seed', type=int, default=0)
    args = parser.parse_args()

    #### Load config file and set random seed
    seed_torch(seed=int(args.seed))
    cfg_fname = args.config.split('/')[-1]
    cfg = imp.load_source("configs", args.config)
    cfg = cfg.load_config()

    #### Use GPU or CPU
    if torch.cuda.is_available():
        device_ids = cfg["device_ids"]
        if device_ids is not None:
            cuda_idx = np.min(device_ids)
        else:
            cuda_idx = int(cfg["cuda"])
        device = torch.device("cuda:{}".format(cuda_idx))
        print "Device IDs in data parallel:", device_ids
    else:
        device = torch.device("cpu")
Esempio n. 19
0
def main():
    args = parse_args()

    # set random seed
    utils.seed_torch(args.seed)

    # Setup CUDA, GPU
    if not torch.cuda.is_available():
        print("cuda is not available")
        exit(0)
    else:
        args.device = torch.device("cuda")
        args.n_gpus = torch.cuda.device_count()
        print(f"available cuda: {args.n_gpus}")

    # Setup model
    model = MelanomaNet(arch=args.arch)
    if args.n_gpus > 1:
        model = torch.nn.DataParallel(module=model)
    model.to(args.device)
    model_path = f'{configure.MODEL_PATH}/{args.arch}_fold_{args.fold}.pth'

    # Setup data
    total_batch_size = args.per_gpu_batch_size * args.n_gpus
    train_loader, valid_loader = datasets.get_dataloader(
        image_dir=configure.TRAIN_IMAGE_PATH,
        fold=args.fold,
        batch_size=total_batch_size,
        num_workers=args.num_workers)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.BCEWithLogitsLoss()
    # criterion = MarginFocalBCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=args.learning_rate,
                                  weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=5,
                                                gamma=0.5)
    """ Train the model """
    current_time = datetime.now().strftime('%b%d_%H_%M_%S')
    log_dir = f'{configure.TRAINING_LOG_PATH}/{args.arch}_fold_{args.fold}_{current_time}'

    tb_writer = None
    if args.log:
        tb_writer = SummaryWriter(log_dir=log_dir)

    print(f'training started: {current_time}')
    best_score = 0.0
    for epoch in range(args.epochs):
        train_loss = train(dataloader=train_loader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           args=args)

        valid_loss, y_true, y_score = valid(dataloader=valid_loader,
                                            model=model,
                                            criterion=criterion,
                                            args=args)

        valid_score = roc_auc_score(y_true=y_true, y_score=y_score)

        learning_rate = scheduler.get_lr()[0]
        if args.log:
            tb_writer.add_scalar("learning_rate", learning_rate, epoch)
            tb_writer.add_scalar("Loss/train", train_loss, epoch)
            tb_writer.add_scalar("Loss/valid", valid_loss, epoch)
            tb_writer.add_scalar("Score/valid", valid_score, epoch)

            # Log the roc curve as an image summary.
            figure = utils.plot_roc_curve(y_true=y_true, y_score=y_score)
            figure = utils.plot_to_image(figure)
            tb_writer.add_image("ROC curve", figure, epoch)

        if valid_score > best_score:
            best_score = valid_score
            state = {
                'state_dict': model.module.state_dict(),
                'train_loss': train_loss,
                'valid_loss': valid_loss,
                'valid_score': valid_score
            }
            torch.save(state, model_path)

        current_time = datetime.now().strftime('%b%d_%H_%M_%S')
        print(
            f"epoch:{epoch:02d}, "
            f"train:{train_loss:0.3f}, valid:{valid_loss:0.3f}, "
            f"score:{valid_score:0.3f}, best:{best_score:0.3f}, date:{current_time}"
        )

        scheduler.step()

    current_time = datetime.now().strftime('%b%d_%H_%M_%S')
    print(f'training finished: {current_time}')

    if args.log:
        tb_writer.close()
Esempio n. 20
0
# Sampling mode, e.g random sampling
SAMPLING_MODE = args.sampling_mode
# Pre-computed weights to restore
CHECKPOINT = args.restore
# Learning rate for the SGD
LEARNING_RATE = args.lr
# Automated class balancing
CLASS_BALANCING = args.class_balancing
# Training ground truth file
TRAIN_GT = args.train_set
# Testing ground truth file
TEST_GT = args.test_set
TEST_STRIDE = args.test_stride

# set random seed
seed_torch(seed=args.seed)

if args.download is not None and len(args.download) > 0:
    for dataset in args.download:
        get_dataset(dataset, target_folder=FOLDER)
    quit()

viz = visdom.Visdom(env=DATASET + " " + MODEL)
if not viz.check_connection:
    print("Visdom is not connected. Did you run 'python -m visdom.server' ?")

hyperparams = vars(args)
# Load the dataset
img1, img2, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(
    DATASET, FOLDER)
Esempio n. 21
0
def main():
    args = parse_args()
    seed_torch(seed=42)

    model_save_path = os.path.join(SAVE_MODEL_PATH, args.model)
    training_history_path = os.path.join(TRAINING_HISTORY_PATH, args.model)

    df_train_path = os.path.join(SPLIT_FOLDER,
                                 "fold_{}_train.csv".format(args.fold))
    df_train = pd.read_csv(df_train_path)

    df_valid_path = os.path.join(SPLIT_FOLDER,
                                 "fold_{}_valid.csv".format(args.fold))
    df_valid = pd.read_csv(df_valid_path)

    if args.model in ["UResNet34", "USeResNext50"]:
        df_train = df_train.loc[(df_train["defect1"] != 0) |
                                (df_train["defect2"] != 0) |
                                (df_train["defect3"] != 0) |
                                (df_train["defect4"] != 0)]
        df_valid = df_valid.loc[(df_valid["defect1"] != 0) |
                                (df_valid["defect2"] != 0) |
                                (df_valid["defect3"] != 0) |
                                (df_valid["defect4"] != 0)]

    print(
        "Training on {} images, class 1: {}, class 2: {}, class 3: {}, class 4: {}"
        .format(len(df_train), df_train['defect1'].sum(),
                df_train['defect2'].sum(), df_train['defect3'].sum(),
                df_train['defect4'].sum()))
    print(
        "Validate on {} images, class 1: {}, class 2: {}, class 3: {}, class 4: {}"
        .format(len(df_valid), df_valid['defect1'].sum(),
                df_valid['defect2'].sum(), df_valid['defect3'].sum(),
                df_valid['defect4'].sum()))
    model_trainer, best = None, None
    if args.model == "UResNet34":
        model_trainer = TrainerSegmentation(
            model=UResNet34(),
            num_workers=args.num_workers,
            batch_size=args.batch_size,
            num_epochs=200,
            model_save_path=model_save_path,
            training_history_path=training_history_path,
            model_save_name=args.model,
            fold=args.fold)

    elif args.model == "ResNet34":
        model_trainer = TrainerClassification(
            model=ResNet34(),
            num_workers=args.num_workers,
            batch_size=args.batch_size,
            num_epochs=100,
            model_save_path=model_save_path,
            training_history_path=training_history_path,
            model_save_name=args.model,
            fold=args.fold)

    best = model_trainer.start()

    print("Training is done, best: {}".format(best))
import numpy as np
import torch, os
from data_prepare import data_manager, read_kb
from spo_dataset import QAPair, get_mask, qapair_collate_fn
from torch.utils.data import DataLoader, RandomSampler, TensorDataset
from tokenize_pkg.tokenize import Tokenizer
from tqdm import tqdm as tqdm
import torch.nn as nn
from utils import seed_torch, read_data, load_glove, calc_f1, get_threshold
from pytorch_pretrained_bert import BertTokenizer, BertAdam
from sklearn.model_selection import KFold
from sklearn.externals import joblib

file_namne = 'data/raw_data/train.json'
data_all = data_manager.read_deep_match('data/final.pkl')
seed_torch(2020)

t = Tokenizer(max_feature=10000, segment=False, lowercase=True)
corpus = list(data_all['question']) + list(
    data_all['answer']) + data_manager.read_eval()
t.fit(corpus)
joblib.dump(t, 'data/deep_match_tokenizer.pkl')

# 准备embedding数据
embedding_file = 'embedding/miniembedding_baike_deepmatch.npy'
# embedding_file = 'embedding/miniembedding_engineer_qq_att.npy'

if os.path.exists(embedding_file):
    embedding_matrix = np.load(embedding_file)
else:
    embedding = '/home/zhukaihua/Desktop/nlp/embedding/baike'
Esempio n. 23
0
import torch
from torch.utils import data
from dataset_brain import Dataset_gan
from model import netD, Unet, define_G
from utils import label2onehot, classification_loss, gradient_penalty, seed_torch
from loss import dice_loss, dice_score
import numpy as np
import time
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt

seed_torch(10)

print('*******************test_gan*******************')
file_path = './brats18_dataset/npy_test/test_t1.npy'
model_path = './weight/generator_t2_tumor_bw.pth'
train_data = Dataset_gan(file_path)
length = len(train_data)
print('len', length)
batch_size = 64
train_loader = data.DataLoader(dataset=train_data,
                               batch_size=batch_size,
                               num_workers=4)
gen = define_G(
    4,
    1,
    64,
    'unet_128',
    norm='instance',
)
Esempio n. 24
0
            optimizer.step()

        print('Val on source data ...')
        best_model_state_dict, best_acc, best_sen, best_val_epoch, ret_dict = val(
            model, val_dataloader, config, epoch, best_acc, best_sen,
            best_val_epoch, best_model_state_dict, ret_dict, 'val', num_fold)
        print('Accuracy %.3f sensitivity %.3f @ epoch %d' %
              (best_acc, best_sen, best_val_epoch + 1))
        scheduler_StepLR.step()
        print('Current lr: ', optimizer.param_groups[0]['lr'])

    return model, ret_dict


if __name__ == "__main__":
    seed_torch(seed=7777)
    remove_items = []
    config = edict()
    # path setting
    config.data_root = ''
    config.pretrained_model_path = ''

    # training setting
    config.lr = 0.05  # 1e-3
    config.fc_bias = False
    config.clinica_feat_dim = 61
    config.CT_feat_dim = 128
    config.lstm_indim = config.clinica_feat_dim + config.CT_feat_dim
    config.hidden_dim = config.lstm_indim * 2
    config.num_classes = 2
    config.seq_len = 7
Esempio n. 25
0
]
N_CLASSES = 18

# ===============
# Settings
# ===============
SEED = np.random.randint(100000)
device = "cuda"
img_size = 512
batch_size = 32
epochs = 5
EXP_ID = "exp34_seres"
model_path = None

setup_logger(out_file=LOGGER_PATH)
seed_torch(SEED)
LOGGER.info("seed={}".format(SEED))


@contextmanager
def timer(name):
    t0 = time.time()
    yield
    LOGGER.info('[{}] done in {} s'.format(name, round(time.time() - t0, 2)))


def main():
    with timer('load data'):
        df = pd.read_csv(TRAIN_PATH)
        df = df[df.Image != "ID_6431af929"].reset_index(drop=True)
        df.loc[df.pre_SOPInstanceUID == "ID_6431af929",
Esempio n. 26
0
def main():
    args = parse_args()

    # set random seed
    utils.seed_torch(args.seed)

    # Setup CUDA, GPU
    if not torch.cuda.is_available():
        print("cuda is not available")
        exit(0)
    else:
        args.device = torch.device("cuda")
        args.n_gpus = torch.cuda.device_count()
        print(f"available cuda: {args.n_gpus}")

    # Setup model
    model = PandaNet(arch=args.arch, num_classes=1)
    model_path = os.path.join(
        configure.MODEL_PATH,
        f'{args.arch}_fold_{args.fold}_{args.tile_size}_{args.num_tiles}.pth')
    if args.resume:
        assert os.path.exists(model_path), "checkpoint does not exist"
        state_dict = torch.load(model_path)
        valid_score = state_dict['valid_score']
        threshold = state_dict['threshold']
        print(
            f"load model from checkpoint, threshold: {threshold}, valid score: {state_dict['valid_score']:0.3f}"
        )
        model.load_state_dict(state_dict['state_dict'])
        best_score = valid_score
        args.learning_rate = 3e-05
    else:
        best_score = 0.0

    if args.n_gpus > 1:
        model = torch.nn.DataParallel(module=model)
    model.to(args.device)

    # Setup data
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    print(f"loading data: {current_time}")
    filename = f"train_images_level_{args.level}_{args.tile_size}_{args.num_tiles}.npy"
    data = np.load(os.path.join(configure.DATA_PATH, filename),
                   allow_pickle=True)
    print(f"data loaded: {datetime.now().strftime('%b%d_%H-%M-%S')}")

    total_batch_size = args.per_gpu_batch_size * args.n_gpus
    train_loader, valid_loader = datasets.get_dataloader(
        data=data,
        fold=args.fold,
        batch_size=total_batch_size,
        num_workers=args.num_workers)

    # define loss function (criterion) and optimizer
    if args.loss == "l1":
        criterion = torch.nn.L1Loss()
    elif args.loss == "mse":
        criterion = torch.nn.MSELoss()
    elif args.loss == "smooth_l1":
        criterion = torch.nn.SmoothL1Loss()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=15,
                                                gamma=0.5)
    """ Train the model """
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_prefix = f'{current_time}_{args.arch}_fold_{args.fold}_{args.tile_size}_{args.num_tiles}'
    log_dir = os.path.join(configure.TRAINING_LOG_PATH, log_prefix)

    tb_writer = None
    if args.log:
        tb_writer = SummaryWriter(log_dir=log_dir)

    print(f'training started: {current_time}')
    for epoch in range(args.epochs):
        train_loss = train(dataloader=train_loader,
                           model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           args=args)

        valid_loss, valid_score, valid_cm, threshold = valid(
            dataloader=valid_loader,
            model=model,
            criterion=criterion,
            args=args)

        learning_rate = scheduler.get_lr()[0]
        if args.log:
            tb_writer.add_scalar("learning_rate", learning_rate, epoch)
            tb_writer.add_scalar("Loss/train", train_loss, epoch)
            tb_writer.add_scalar("Loss/valid", valid_loss, epoch)
            tb_writer.add_scalar("Score/valid", valid_score, epoch)

            # Log the confusion matrix as an image summary.
            figure = utils.plot_confusion_matrix(
                valid_cm, class_names=[0, 1, 2, 3, 4, 5], score=valid_score)
            cm_image = utils.plot_to_image(figure)
            tb_writer.add_image("Confusion Matrix valid", cm_image, epoch)

        if valid_score > best_score:
            best_score = valid_score
            state = {
                'state_dict': model.module.state_dict(),
                'train_loss': train_loss,
                'valid_loss': valid_loss,
                'valid_score': valid_score,
                'threshold': np.sort(threshold),
                'mean': data.item().get('mean'),
                'std': data.item().get('std')
            }
            torch.save(state, model_path)

        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        print(
            f"epoch:{epoch:02d}, "
            f"train:{train_loss:0.3f}, valid:{valid_loss:0.3f}, "
            f"threshold: {np.sort(threshold)}, "
            f"score:{valid_score:0.3f}, best:{best_score:0.3f}, date:{current_time}"
        )

        scheduler.step()

    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    print(f'training finished: {current_time}')

    if args.log:
        tb_writer.close()
Esempio n. 27
0
def main():
    seed_torch(args.randseed)

    logger.info('Starting training with arguments')
    logger.info(vars(args))

    save_path = args.save
    save_pseudo_label_path = osp.join(
        save_path,
        'pseudo_label')  # in 'save_path'. Save labelIDs, not trainIDs.
    save_stats_path = osp.join(save_path, 'stats')  # in 'save_path'
    save_lst_path = osp.join(save_path, 'list')
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    if not os.path.exists(save_pseudo_label_path):
        os.makedirs(save_pseudo_label_path)
    if not os.path.exists(save_stats_path):
        os.makedirs(save_stats_path)
    if not os.path.exists(save_lst_path):
        os.makedirs(save_lst_path)

    tgt_portion = args.init_tgt_port
    image_tgt_list, image_name_tgt_list, _, _ = parse_split_list(
        args.data_tgt_train_list.format(args.city))

    model = make_network(args).to(device)
    test(model, -1)
    for round_idx in range(args.num_rounds):
        save_round_eval_path = osp.join(args.save, str(round_idx))
        save_pseudo_label_color_path = osp.join(
            save_round_eval_path,
            'pseudo_label_color')  # in every 'save_round_eval_path'
        if not os.path.exists(save_round_eval_path):
            os.makedirs(save_round_eval_path)
        if not os.path.exists(save_pseudo_label_color_path):
            os.makedirs(save_pseudo_label_color_path)
        src_portion = args.init_src_port
        ########## pseudo-label generation
        conf_dict, pred_cls_num, save_prob_path, save_pred_path = validate_model(
            model, save_round_eval_path, round_idx, args)
        cls_thresh = label_selection.kc_parameters(conf_dict, pred_cls_num,
                                                   tgt_portion, round_idx,
                                                   save_stats_path, args)

        label_selection.label_selection(cls_thresh, round_idx, save_prob_path,
                                        save_pred_path, save_pseudo_label_path,
                                        save_pseudo_label_color_path,
                                        save_round_eval_path, args)

        tgt_portion = min(tgt_portion + args.tgt_port_step, args.max_tgt_port)
        tgt_train_lst = savelst_tgt(image_tgt_list, image_name_tgt_list,
                                    save_lst_path, save_pseudo_label_path)

        rare_id = np.load(save_stats_path + '/rare_id_round' + str(round_idx) +
                          '.npy')
        mine_id = np.load(save_stats_path + '/mine_id_round' + str(round_idx) +
                          '.npy')
        # mine_chance = args.mine_chance

        src_transforms, tgt_transforms = get_train_transforms(args, mine_id)
        srcds = CityscapesDataset(transforms=src_transforms)

        tgtds = CrossCityDataset(args.data_tgt_dir.format(args.city),
                                 tgt_train_lst,
                                 pseudo_root=save_pseudo_label_path,
                                 transforms=tgt_transforms)

        if args.no_src_data:
            mixtrainset = tgtds
        else:
            mixtrainset = torch.utils.data.ConcatDataset([srcds, tgtds])

        mix_loader = DataLoader(mixtrainset,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.batch_size,
                                pin_memory=torch.cuda.is_available())
        src_portion = min(src_portion + args.src_port_step, args.max_src_port)
        optimizer = optim.SGD(model.optim_parameters(args),
                              lr=args.learning_rate,
                              momentum=args.momentum,
                              weight_decay=args.weight_decay)
        interp = nn.Upsample(size=args.input_size[::-1],
                             mode='bilinear',
                             align_corners=True)
        torch.backends.cudnn.enabled = True  # enable cudnn
        torch.backends.cudnn.benchmark = True
        start = time.time()
        for epoch in range(args.epr):
            train(mix_loader, model, interp, optimizer, args)
            print('taking snapshot ...')
            torch.save(
                model.state_dict(),
                osp.join(
                    args.save, '2nthy_round' + str(round_idx) + '_epoch' +
                    str(epoch) + '.pth'))
        end = time.time()

        logger.info(
            '###### Finish model retraining dataset in round {}! Time cost: {:.2f} seconds. ######'
            .format(round_idx, end - start))
        test(model, round_idx)
        cleanup(args.save)
    cleanup(args.save)
    shutil.rmtree(save_pseudo_label_path)
    test(model, args.num_rounds - 1)
Esempio n. 28
0
model_name = options.model
run_name = f'{model_name}_{now:%Y%m%d%H%M%S}'

compe_params = config.compe
data_params = config.data
train_params = config.train_params
setting_params = config.settings
notify_params = dp.load(options.notify)

logger_path = Path(f'../logs/{run_name}')

# ===============
# Main
# ===============
t = Timer()
seed_torch(compe_params.seed)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

logger_path.mkdir()
logging.basicConfig(filename=logger_path / 'train.log', level=logging.DEBUG)

dp.save(logger_path / 'config.yml', config)

with t.timer('load data'):
    root = Path(data_params.input_root)
    train_df = dp.load(root / data_params.train_file)

    train_img_df = pd.DataFrame()
    for file in data_params.img_file:
        df = dp.load(root / file)
        train_img_df = pd.concat([train_img_df, df],
Esempio n. 29
0
def run():
    seed_torch(seed=config.seed)
    os.makedirs(config.MODEL_PATH, exist_ok=True)
    setup_logger(config.MODEL_PATH + 'log.txt')
    writer = SummaryWriter(config.MODEL_PATH)

    folds = pd.read_csv(config.fold_csv)
    folds.head()
    if config.tile_stats_csv:
        attention_df = pd.read_csv(config.tile_stats_csv)
        attention_df.head()

    #train val split
    if config.DEBUG:
        folds = folds.sample(
            n=50, random_state=config.seed).reset_index(drop=True).copy()

    logging.info(f"fold: {config.fold}")
    fold = config.fold
    #trn_idx = folds[folds['fold'] != fold].index
    #val_idx = folds[folds['fold'] == fold].index
    trn_idx = folds[folds[f'fold_{fold}'] == 0].index
    val_idx = folds[folds[f'fold_{fold}'] == 1].index

    df_train = folds.loc[trn_idx]
    df_val = folds.loc[val_idx]
    # #------single image------
    if config.strategy == 'stitched':
        train_dataset = PANDADataset(image_folder=config.DATA_PATH,
                                     df=df_train,
                                     image_size=config.IMG_SIZE,
                                     num_tiles=config.num_tiles,
                                     rand=False,
                                     transform=get_transforms(phase='train'),
                                     attention_df=attention_df)
        valid_dataset = PANDADataset(image_folder=config.DATA_PATH,
                                     df=df_val,
                                     image_size=config.IMG_SIZE,
                                     num_tiles=config.num_tiles,
                                     rand=False,
                                     transform=get_transforms(phase='valid'),
                                     attention_df=attention_df)

    #------image tiles------
    else:
        train_dataset = PANDADatasetTiles(
            image_folder=config.DATA_PATH,
            df=df_train,
            image_size=config.IMG_SIZE,
            num_tiles=config.num_tiles,
            transform=get_transforms(phase='train'),
            attention_df=attention_df)
        valid_dataset = PANDADatasetTiles(
            image_folder=config.DATA_PATH,
            df=df_val,
            image_size=config.IMG_SIZE,
            num_tiles=config.num_tiles,
            transform=get_transforms(phase='valid'),
            attention_df=attention_df)

    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              sampler=RandomSampler(train_dataset),
                              num_workers=multiprocessing.cpu_count(),
                              pin_memory=True)
    val_loader = DataLoader(valid_dataset,
                            batch_size=config.batch_size,
                            sampler=SequentialSampler(valid_dataset),
                            num_workers=multiprocessing.cpu_count(),
                            pin_memory=True)

    device = torch.device("cuda")
    #model=EnetNetVLAD(num_clusters=config.num_cluster,num_tiles=config.num_tiles,num_classes=config.num_class,arch=config.backbone)
    #model = EnetV1(backbone=config.backbone, num_classes=config.num_class)
    #------Model use for Generate Tile Weights--------
    #model = EfficientModel(c_out=config.num_class,n_tiles=config.num_tiles,
    #                       tile_size=config.IMG_SIZE,
    #                       name=config.backbone,
    #                       strategy='bag',
    #                       head='attention')
    #--------------------------------------------------
    #model = Regnet(num_classes=config.num_class,ckpt=config.pretrain_model)
    model = RegnetNetVLAD(num_clusters=config.num_cluster,
                          num_tiles=config.num_tiles,
                          num_classes=config.num_class,
                          ckpt=config.pretrain_model)
    model = model.to(device)
    if config.multi_gpu:
        model = torch.nn.DataParallel(model)
    if config.ckpt_path:
        model.load_state_dict(torch.load(config.ckpt_path))
    warmup_factor = 10
    warmup_epo = 1
    optimizer = Adam(model.parameters(), lr=config.lr / warmup_factor)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, config.num_epoch - warmup_epo)
    scheduler = GradualWarmupScheduler(optimizer,
                                       multiplier=warmup_factor,
                                       total_epoch=warmup_epo,
                                       after_scheduler=scheduler_cosine)

    if config.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

    best_score = 0.
    best_loss = 100.
    if config.model_type == 'reg':
        optimized_rounder = OptimizedRounder()
    optimizer.zero_grad()
    optimizer.step()
    for epoch in range(1, config.num_epoch + 1):
        if scheduler:
            scheduler.step(epoch - 1)
        if config.model_type != 'reg':
            train_fn(train_loader, model, optimizer, device, epoch, writer,
                     df_train)
            metric = eval_fn(val_loader, model, device, epoch, writer, df_val)
        else:
            coefficients = train_fn(train_loader, model, optimizer, device,
                                    epoch, writer, df_train, optimized_rounder)
            metric = eval_fn(val_loader, model, device, epoch, writer, df_val,
                             coefficients)
        score = metric['score']
        val_loss = metric['loss']
        if score > best_score:
            best_score = score
            logging.info(f"Epoch {epoch} - found best score {best_score}")
            save_model(model, config.MODEL_PATH + f"best_kappa_f{fold}.pth")
        if val_loss < best_loss:
            best_loss = val_loss
            logging.info(f"Epoch {epoch} - found best loss {best_loss}")
            save_model(model, config.MODEL_PATH + f"best_loss_f{fold}.pth")