Ejemplo n.º 1
0
    def test_mnist(self):
        utils.set_global_seed(42)
        x_train = np.random.random((100, 1, 28, 28)).astype(np.float32)
        y_train = _to_categorical(np.random.randint(10, size=(100, 1)),
                                  num_classes=10).astype(np.float32)
        x_valid = np.random.random((20, 1, 28, 28)).astype(np.float32)
        y_valid = _to_categorical(np.random.randint(10, size=(20, 1)),
                                  num_classes=10).astype(np.float32)

        x_train, y_train, x_valid, y_valid = \
            list(map(torch.tensor, [x_train, y_train, x_valid, y_valid]))

        bs = 32
        num_workers = 4
        data_transform = transforms.ToTensor()

        loaders = collections.OrderedDict()

        trainset = torch.utils.data.TensorDataset(x_train, y_train)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=bs,
                                                  shuffle=True,
                                                  num_workers=num_workers)

        validset = torch.utils.data.TensorDataset(x_valid, y_valid)
        validloader = torch.utils.data.DataLoader(validset,
                                                  batch_size=bs,
                                                  shuffle=False,
                                                  num_workers=num_workers)

        loaders["train"] = trainloader
        loaders["valid"] = validloader

        # experiment setup
        num_epochs = 3
        logdir = "./logs"

        # model, criterion, optimizer
        model = Net()
        criterion = nn.BCEWithLogitsLoss()
        optimizer = torch.optim.Adam(model.parameters())

        # model runner
        runner = SupervisedRunner()

        # model training
        runner.train(model=model,
                     criterion=criterion,
                     optimizer=optimizer,
                     loaders=loaders,
                     logdir=logdir,
                     num_epochs=num_epochs,
                     verbose=False,
                     callbacks=[CheckpointCallback(save_n_best=3)])

        with open('./logs/checkpoints/_metrics.json') as f:
            metrics = json.load(f)
            self.assertTrue(
                metrics['train.3']['loss'] < metrics['train.1']['loss'])
            self.assertTrue(metrics['best']['loss'] < 0.35)
Ejemplo n.º 2
0
def get_callbacks(config: dict):
    required_callbacks = config["callbacks"]
    callbacks = []
    for callback_conf in required_callbacks:
        name = callback_conf["name"]
        params = callback_conf["params"]
        callback_cls = globals().get(name)

        if callback_cls is not None:
            callbacks.append(callback_cls(**params))

    callbacks.append(CheckpointCallback(save_n_best=0))
    return callbacks
Ejemplo n.º 3
0
    def predict_loader(
        self,
        model: Model,
        loader: DataLoader,
        resume: str = None,
        verbose: bool = False,
        state_kwargs: Dict = None,
        fp16: Union[Dict, bool] = None,
        check: bool = False,
    ) -> Any:
        """
        Makes a prediction on the whole loader with the specified model.

        Args:
            model (Model): model to infer
            loader (DataLoader): dictionary containing only one
                ``torch.utils.data.DataLoader`` for inference
            resume (str): path to checkpoint for model
            verbose (bool): ff true, it displays the status of the inference
                to the console.
            state_kwargs (dict): additional state params to ``State``
            fp16 (Union[Dict, bool]): If not None, then sets inference to FP16.
                See https://nvidia.github.io/apex/amp.html#properties
                if fp16=True, params by default will be ``{"opt_level": "O1"}``
            check (bool): if True, then only checks that pipeline is working
                (3 epochs only)
        """
        loaders = OrderedDict([("infer", loader)])

        callbacks = OrderedDict([("inference", InferCallback())])
        if resume is not None:
            callbacks["loader"] = CheckpointCallback(resume=resume)

        self.infer(
            model=model,
            loaders=loaders,
            callbacks=callbacks,
            verbose=verbose,
            state_kwargs=state_kwargs,
            fp16=fp16,
            check=check,
        )

        output = callbacks["inference"].predictions
        if isinstance(self.output_key, str):
            output = output[self.output_key]

        return output
Ejemplo n.º 4
0
    def train(
        self,
        model: Model,
        criterion: Criterion,
        optimizer: Optimizer,
        loaders: "OrderedDict[str, DataLoader]",
        logdir: str,
        callbacks: "Union[List[Callback], OrderedDict[str, Callback]]" = None,
        scheduler: Scheduler = None,
        resume: str = None,
        num_epochs: int = 1,
        valid_loader: str = "valid",
        main_metric: str = "loss",
        minimize_metric: bool = True,
        verbose: bool = False,
        state_kwargs: Dict = None,
        checkpoint_data: Dict = None,
        fp16: Union[Dict, bool] = None,
        monitoring_params: Dict = None,
        check: bool = False,
    ) -> None:
        """
        Starts the training process of the model.

        Args:
            model (Model): model to train
            criterion (Criterion): criterion function for training
            optimizer (Optimizer): optimizer for training
            loaders (dict): dictionary containing one or several
                ``torch.utils.data.DataLoader`` for training and validation
            logdir (str): path to output directory
            callbacks (List[catalyst.dl.Callback]): list of callbacks
            scheduler (Scheduler): scheduler for training
            resume (str): path to checkpoint for model
            num_epochs (int): number of training epochs
            valid_loader (str): loader name used to calculate
                the metrics and save the checkpoints. For example,
                you can pass `train` and then
                the metrics will be taken from `train` loader.
            main_metric (str): the key to the name of the metric
                by which the checkpoints will be selected.
            minimize_metric (bool): flag to indicate whether
                the ``main_metric`` should be minimized.
            verbose (bool): ff true, it displays the status of the training
                to the console.
            state_kwargs (dict): additional state params to ``State``
            checkpoint_data (dict): additional data to save in checkpoint,
                for example: ``class_names``, ``date_of_training``, etc
            fp16 (Union[Dict, bool]): If not None, then sets training to FP16.
                See https://nvidia.github.io/apex/amp.html#properties
                if fp16=True, params by default will be ``{"opt_level": "O1"}``
            monitoring_params (dict): If not None, then create monitoring
                through Alchemy or Weights&Biases.
                For example,
                ``{"token": "api_token", "experiment": "experiment_name"}``
            check (bool): if True, then only checks that pipeline is working
                (3 epochs only)
        """
        if len(loaders) == 1:
            valid_loader = list(loaders.keys())[0]
            logger.warning(
                "Attention, there is only one data loader - "
                + str(valid_loader)
            )
        if isinstance(fp16, bool) and fp16:
            fp16 = {"opt_level": "O1"}

        if model is not None:
            self.model = model

        if resume is not None:
            callbacks = utils.process_callbacks(callbacks)
            checkpoint_callback_flag = any(
                isinstance(x, CheckpointCallback) for x in callbacks.values()
            )
            if not checkpoint_callback_flag:
                callbacks["loader"] = CheckpointCallback(resume=resume)
            else:
                raise NotImplementedError("CheckpointCallback already exist")

        experiment = self._experiment_fn(
            stage="train",
            model=model,
            loaders=loaders,
            callbacks=callbacks,
            logdir=logdir,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=num_epochs,
            valid_loader=valid_loader,
            main_metric=main_metric,
            minimize_metric=minimize_metric,
            verbose=verbose,
            check_run=check,
            state_kwargs=state_kwargs,
            checkpoint_data=checkpoint_data,
            distributed_params=fp16,
            monitoring_params=monitoring_params,
        )
        self.run_experiment(experiment)
dataset_length = len(testset)

loaders = collections.OrderedDict()
testloader = torch.utils.data.DataLoader(testset, shuffle=False)

model = SimpleNetRGB(11, channels_in=3)
runner = SupervisedRunner(device="cuda")

loaders["valid"] = testloader
loaders = collections.OrderedDict([("infer", loaders["valid"])])
runner.infer(
    model=model,
    loaders=loaders,
    callbacks=[
        InferCallback(),
        CheckpointCallback(resume=f"{logdir}/checkpoints/best.pth"),
    ],
)

predictions = runner.callbacks[0].predictions["logits"].reshape(
    dataset_length, 9)
predictions = sigmoid(predictions)
# predictions = softmax(predictions, axis=1)
predictions = np.concatenate([np.expand_dims(ids, axis=1), predictions],
                             axis=1)

pred_frame = pd.DataFrame(predictions,
                          columns=[
                              "field_id",
                              "crop_id_1",
                              "crop_id_2",
Ejemplo n.º 6
0
    )

    val_dataset = OcrDataset(DATASET_PATH + 'val/',
                             DATASET_PATH + 'val.csv',
                             transforms=ResizeToTensor(
                                 CV_CONFIG.data['ocr_image_size']))

    val_loader = DataLoader(val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            num_workers=4)

    model = CRNN(**MODEL_PARAMS)
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    callbacks = [CheckpointCallback(save_n_best=10)]
    runner = SupervisedRunner(input_key="image", input_target_key="targets")

    runner.train(model=model,
                 criterion=WrapCTCLoss(alphabet),
                 optimizer=optimizer,
                 scheduler=scheduler,
                 loaders={
                     'train': train_loader,
                     "valid": val_loader
                 },
                 logdir="./logs/ocr",
                 num_epochs=NUM_EPOCHS,
                 verbose=True,
                 callbacks=callbacks)
Ejemplo n.º 7
0
def test_multiple_best_checkpoints():
    old_stdout = sys.stdout
    sys.stdout = str_stdout = StringIO()

    # experiment_setup
    logdir = "./logs/periodic_loader"
    checkpoint = logdir  # + "/checkpoints"
    logfile = checkpoint + "/_metrics.json"

    # data
    num_samples, num_features = int(1e4), int(1e1)
    X = torch.rand(num_samples, num_features)
    y = torch.randint(0, 5, size=[num_samples])
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=32, num_workers=1)
    loaders = {
        "train": loader,
        "valid": loader,
    }

    # model, criterion, optimizer, scheduler
    model = torch.nn.Linear(num_features, 5)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    runner = SupervisedRunner()

    n_epochs = 12
    period = 2
    # first stage
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        loaders=loaders,
        logdir=logdir,
        num_epochs=n_epochs,
        verbose=False,
        valid_loader="valid",
        valid_metric="loss",
        minimize_valid_metric=True,
        callbacks=[
            PeriodicLoaderCallback(valid_loader_key="valid",
                                   valid_metric_key="loss",
                                   minimize=True,
                                   valid=period),
            CheckRunCallback(num_epoch_steps=n_epochs),
            CheckpointCallback(logdir=logdir,
                               loader_key="valid",
                               metric_key="loss",
                               minimize=True,
                               save_n_best=3),
        ],
    )

    sys.stdout = old_stdout
    exp_output = str_stdout.getvalue()

    # assert len(re.findall(r"\(train\)", exp_output)) == n_epochs
    # assert len(re.findall(r"\(valid\)", exp_output)) == (n_epochs // period)
    # assert len(re.findall(r".*/train\.\d{1,2}\.pth", exp_output)) == 3

    assert os.path.isfile(logfile)
    assert os.path.isfile(checkpoint + "/train.8.pth")
    assert os.path.isfile(checkpoint + "/train.8_full.pth")
    assert os.path.isfile(checkpoint + "/train.10.pth")
    assert os.path.isfile(checkpoint + "/train.10_full.pth")
    assert os.path.isfile(checkpoint + "/train.12.pth")
    assert os.path.isfile(checkpoint + "/train.12_full.pth")
    assert os.path.isfile(checkpoint + "/best.pth")
    assert os.path.isfile(checkpoint + "/best_full.pth")
    assert os.path.isfile(checkpoint + "/last.pth")
    assert os.path.isfile(checkpoint + "/last_full.pth")

    shutil.rmtree(logdir, ignore_errors=True)
Ejemplo n.º 8
0
def train_model():

    tmp_list = []
    all_birds_dirs = Path(global_config.RESAMPLED_TRAIN_AUDIO_PATH)

    for ebird_d in all_birds_dirs.iterdir():
        if ebird_d.is_file():
            continue
        for wav_f in ebird_d.iterdir():
            tmp_list.append([ebird_d.name, wav_f.name, wav_f.as_posix()])

    print(f">>> Total training examples: {len(tmp_list)}\n\n", tmp_list[:3])

    train_wav_path_exist = pd.DataFrame(
        tmp_list, columns=["ebird_code", "resampled_filename", "file_path"])

    train_all = pd.merge(global_config.train_csv,
                         train_wav_path_exist,
                         on=["ebird_code", "resampled_filename"],
                         how="inner")

    ##############################################################
    #######   K-Fold split on each bird kind (ebird_code)  #######
    ##############################################################

    skf = StratifiedKFold(n_splits=9, shuffle=True, random_state=42)

    train_all["fold"] = -1
    for fold_id, (train_index, val_index) in enumerate(
            skf.split(train_all, train_all["ebird_code"])):
        # df["fold"] == fold_id
        train_all.iloc[val_index, -1] = fold_id

    # check the propotion
    fold_proportion = pd.pivot_table(train_all,
                                     index="ebird_code",
                                     columns="fold",
                                     values="xc_id",
                                     aggfunc=len)
    print(
        f">>> Number of bird kinds: {fold_proportion.shape[0]} \n>>> Number of folds: {fold_proportion.shape[1]}"
    )

    val_fold_num = 4

    train_file_list = train_all.query("fold != @val_fold_num")[[
        "file_path", "ebird_code"
    ]].values.tolist()
    val_file_list = train_all.query("fold == @val_fold_num")[[
        "file_path", "ebird_code"
    ]].values.tolist()

    print(">>> Valid_Fold: [fold {}] train: {}, val: {}".format(
        val_fold_num, len(train_file_list), len(val_file_list)))

    ##############################################
    ###########       Model Setup      ###########
    ##############################################

    device = torch.device("cuda:0")

    # loaders
    loaders = {
        "train":
        data.DataLoader(PANNsDataset(train_file_list, None),
                        batch_size=64,
                        shuffle=True,
                        num_workers=2,
                        pin_memory=True,
                        drop_last=True),
        "valid":
        data.DataLoader(PANNsDataset(val_file_list, None),
                        batch_size=64,
                        shuffle=False,
                        num_workers=2,
                        pin_memory=True,
                        drop_last=False)
    }

    # model
    global_config.model_config["classes_num"] = 527
    model = PANNsCNN14Att(**global_config.model_config)
    weights = torch.load(PRETRAIN_PANNS)

    # Load Pretrained Weight
    model.load_state_dict(weights["model"])
    model.att_block = AttBlock(2048, 264, activation='sigmoid')
    model.att_block.init_weights()

    ###################################################################
    # model.load_state_dict(torch.load("./fold0/checkpoints/train.14.pth")["model_state_dict"])
    ###################################################################

    model.to(device)
    print(f">>> Pretrained Model is loaded to {device}!")

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=global_config.INIT_LR)

    # Scheduler
    NUM_EPOCHS = global_config.NUM_EPOCHS
    NUM_CYCLES = int(NUM_EPOCHS / (2 * global_config.NUM_CYCLES))
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                     T_max=NUM_CYCLES)

    # Loss
    criterion = PANNsLoss().to(device)

    # Resume
    if global_config.RESUME_WEIGHT:
        # callbacks
        callbacks = [
            F1Callback(input_key="targets", output_key="logits", prefix="f1"),
            mAPCallback(input_key="targets", output_key="logits",
                        prefix="mAP"),
            CheckpointCallback(
                save_n_best=5,
                resume=global_config.RESUME_WEIGHT)  # save 5 best models
        ]
    else:
        # callbacks
        callbacks = [
            F1Callback(input_key="targets", output_key="logits", prefix="f1"),
            mAPCallback(input_key="targets", output_key="logits",
                        prefix="mAP"),
            CheckpointCallback(save_n_best=5)  # save 5 best models
        ]

    # Model Training
    runner = SupervisedRunner(device=device,
                              input_key="waveform",
                              input_target_key="targets")

    runner.train(
        model=model,
        criterion=criterion,
        loaders=loaders,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=NUM_EPOCHS,
        verbose=True,
        logdir=f"fold0",
        callbacks=callbacks,
        main_metric="epoch_f1",  # metric to select the best ckpt
        minimize_metric=False)
Ejemplo n.º 9
0
def main(cfg):
    logger.info(cfg.pretty())
    logger.info(os.getcwd())
    logger.info(hydra.utils.get_original_cwd())
    utils.seed_everything()

    if cfg.debug:
        logger.info("running debug mode")
        EPOCH = 1
    else:
        EPOCH = cfg.epoch

    df = pd.read_csv(utils.DATA_DIR / cfg.train_csv)
    # remove row becase XC195038.mp3 cannot load
    df = df.drop(df[df.filename == "XC195038.mp3"].index)
    df = df.drop(
        df[(df.filename == "XC575512.mp3") & (df.ebird_code == "swathr")].index
    )
    df = df.drop(
        df[(df.filename == "XC433319.mp3") & (df.ebird_code == "aldfly")].index
    )
    df = df.drop(
        df[(df.filename == "XC471618.mp3") & (df.ebird_code == "redcro")].index
    )
    train_audio_dir = utils.DATA_DIR / cfg.train_audio_dir
    print(df.shape)

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

    df["fold"] = -1
    for fold_id, (train_index, val_index) in enumerate(skf.split(df, df["ebird_code"])):
        df.iloc[val_index, -1] = fold_id

    # # check the propotion
    fold_proportion = pd.pivot_table(
        df, index="ebird_code", columns="fold", values="xc_id", aggfunc=len
    )
    print(fold_proportion.shape)

    use_fold = 0
    if cfg.gpu:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    warnings.simplefilter("ignore")

    # loaders
    logging.info(f"fold: {use_fold}")
    loaders = {
        "train": data.DataLoader(
            # PANNsDataset(train_file_list, None),
            pann_utils.PANNsDataset(
                df=df.query("fold != @use_fold").reset_index(), datadir=train_audio_dir,
            ),
            shuffle=True,
            drop_last=True,
            **cfg.dataloader,
        ),
        "valid": data.DataLoader(
            # PANNsDataset(val_file_list, None),
            pann_utils.PANNsDataset(
                df=df.query("fold == @use_fold").reset_index(), datadir=train_audio_dir,
            ),
            shuffle=False,
            drop_last=False,
            **cfg.dataloader,
        ),
    }

    # model
    model_config = cfg.model
    model_config["classes_num"] = 527
    model = pann_utils.get_model(model_config)

    if cfg.multi and cfg.gpu:
        logger.info("Using pararell gpu")
        model = nn.DataParallel(model)

    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    criterion = pann_utils.PANNsLoss().to(device)
    callbacks = [
        pann_utils.F1Callback(input_key="targets", output_key="logits", prefix="f1"),
        pann_utils.mAPCallback(input_key="targets", output_key="logits", prefix="mAP"),
        CheckpointCallback(save_n_best=0),
    ]

    runner = SupervisedRunner(
        device=device, input_key="waveform", input_target_key="targets"
    )
    runner.train(
        model=model,
        criterion=criterion,
        loaders=loaders,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=EPOCH,
        verbose=True,
        logdir=f"fold0",
        callbacks=callbacks,
        main_metric="epoch_f1",
        minimize_metric=False,
    )

    logging.info("train all...")
    loaders = {
        "train": data.DataLoader(
            # PANNsDataset(train_file_list, None),
            pann_utils.PANNsDataset(df=df.reset_index(), datadir=train_audio_dir,),
            shuffle=True,
            drop_last=True,
            **cfg.dataloader,
        ),
    }

    # model
    model_config = cfg.model
    model_config["classes_num"] = 527
    model = pann_utils.get_model(model_config)

    if cfg.multi and cfg.gpu:
        logger.info("Using pararell gpu")
        model = nn.DataParallel(model)

    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    criterion = pann_utils.PANNsLoss().to(device)
    callbacks = [
        pann_utils.F1Callback(input_key="targets", output_key="logits", prefix="f1"),
        pann_utils.mAPCallback(input_key="targets", output_key="logits", prefix="mAP"),
        CheckpointCallback(save_n_best=0),
    ]

    runner = SupervisedRunner(
        device=device, input_key="waveform", input_target_key="targets"
    )
    runner.train(
        model=model,
        criterion=criterion,
        loaders=loaders,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=EPOCH,
        verbose=True,
        logdir=f"all",
        callbacks=callbacks,
        main_metric="epoch_f1",
        minimize_metric=False,
    )

    logger.info(os.getcwd())
Ejemplo n.º 10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')

    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default='data',
                        help='Data directory')

    parser.add_argument('-l',
                        '--loss',
                        type=str,
                        default='label_smooth_cross_entropy')
    parser.add_argument('-t1', '--temper1', type=float, default=0.2)
    parser.add_argument('-t2', '--temper2', type=float, default=4.0)
    parser.add_argument('-optim', '--optimizer', type=str, default='adam')

    parser.add_argument('-prep', '--prep_function', type=str, default='none')

    parser.add_argument('--train_on_different_datasets', action='store_true')
    parser.add_argument('--use-current', action='store_true')
    parser.add_argument('--use-extra', action='store_true')
    parser.add_argument('--use-unlabeled', action='store_true')

    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--mixup', action='store_true')
    parser.add_argument('--balance', action='store_true')
    parser.add_argument('--balance-datasets', action='store_true')

    parser.add_argument('--show', action='store_true')
    parser.add_argument('-v', '--verbose', action='store_true')

    parser.add_argument('-m',
                        '--model',
                        type=str,
                        default='efficientnet-b4',
                        help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=100,
                        help='Epoch to run')
    parser.add_argument('-s',
                        '--sizes',
                        default=380,
                        type=int,
                        help='Image size for training & inference')
    parser.add_argument('-f', '--fold', type=int, default=None)
    parser.add_argument('-t', '--transfer', default=None, type=str, help='')
    parser.add_argument('-lr',
                        '--learning_rate',
                        type=float,
                        default=1e-4,
                        help='Initial learning rate')
    parser.add_argument('-a',
                        '--augmentations',
                        default='medium',
                        type=str,
                        help='')
    parser.add_argument('-accum', '--accum-step', type=int, default=1)
    parser.add_argument('-metric', '--metric', type=str, default='accuracy01')

    args = parser.parse_args()

    diff_dataset_train = args.train_on_different_datasets

    data_dir = args.data_dir
    epochs = args.epochs
    batch_size = args.batch_size
    seed = args.seed

    loss_name = args.loss
    optim_name = args.optimizer

    prep_function = args.prep_function

    model_name = args.model
    size = args.sizes,
    print(size)
    print(size[0])
    image_size = (size[0], size[0])
    print(image_size)
    fast = args.fast
    fold = args.fold
    mixup = args.mixup
    balance = args.balance
    balance_datasets = args.balance_datasets
    show_batches = args.show
    verbose = args.verbose
    use_current = args.use_current
    use_extra = args.use_extra
    use_unlabeled = args.use_unlabeled

    learning_rate = args.learning_rate
    augmentations = args.augmentations
    transfer = args.transfer
    accum_step = args.accum_step

    #cosine_loss    accuracy01
    main_metric = args.metric

    print(data_dir)

    num_classes = 5

    assert use_current or use_extra

    print(fold)

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    current_time = datetime.now().strftime('%b%d_%H_%M')
    random_name = get_random_name()

    # if folds is None or len(folds) == 0:
    #     folds = [None]

    torch.cuda.empty_cache()
    checkpoint_prefix = f'{model_name}_{size}_{augmentations}'

    if transfer is not None:
        checkpoint_prefix += '_pretrain_from_' + str(transfer)
    else:
        if use_current:
            checkpoint_prefix += '_current'
        if use_extra:
            checkpoint_prefix += '_extra'
        if use_unlabeled:
            checkpoint_prefix += '_unlabeled'
        if fold is not None:
            checkpoint_prefix += f'_fold{fold}'

    directory_prefix = f'{current_time}_{checkpoint_prefix}'
    log_dir = os.path.join('runs', directory_prefix)
    os.makedirs(log_dir, exist_ok=False)

    set_manual_seed(seed)
    model = get_model(model_name)

    if transfer is not None:
        print("Transfering weights from model checkpoint")
        model.load_state_dict(torch.load(transfer)['model_state_dict'])

    model = model.cuda()

    if diff_dataset_train:
        train_on = ['current_train', 'extra_train']
        valid_on = ['unlabeled']
        train_ds, valid_ds, train_sizes = get_datasets_universal(
            train_on=train_on,
            valid_on=valid_on,
            image_size=image_size,
            augmentation=augmentations,
            target_dtype=int,
            prep_function=prep_function)
    else:
        train_ds, valid_ds, train_sizes = get_datasets(
            data_dir=data_dir,
            use_current=use_current,
            use_extra=use_extra,
            image_size=image_size,
            prep_function=prep_function,
            augmentation=augmentations,
            target_dtype=int,
            fold=fold,
            folds=5)

    train_loader, valid_loader = get_dataloaders(train_ds,
                                                 valid_ds,
                                                 batch_size=batch_size,
                                                 train_sizes=train_sizes,
                                                 num_workers=6,
                                                 balance=True,
                                                 balance_datasets=True,
                                                 balance_unlabeled=False)

    loaders = collections.OrderedDict()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader

    runner = SupervisedRunner(input_key='image')

    criterions = get_loss(loss_name)
    # criterions_tempered = TemperedLogLoss()
    # optimizer = catalyst.contrib.nn.optimizers.radam.RAdam(model.parameters(), lr = learning_rate)
    optimizer = get_optim(optim_name, model, learning_rate)
    # optimizer = catalyst.contrib.nn.optimizers.Adam(model.parameters(), lr = learning_rate)
    # criterions = nn.CrossEntropyLoss()
    # optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25], gamma=0.8)
    # cappa = CappaScoreCallback()

    Q = math.floor(len(train_ds) / batch_size)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Q)
    if main_metric != 'accuracy01':
        callbacks = [
            AccuracyCallback(num_classes=num_classes),
            CosineLossCallback(),
            OptimizerCallback(accumulation_steps=accum_step),
            CheckpointCallback(save_n_best=epochs)
        ]
    else:
        callbacks = [
            AccuracyCallback(num_classes=num_classes),
            OptimizerCallback(accumulation_steps=accum_step),
            CheckpointCallback(save_n_best=epochs)
        ]

    # main_metric = 'accuracy01'

    runner.train(
        fp16=True,
        model=model,
        criterion=criterions,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=callbacks,
        loaders=loaders,
        logdir=log_dir,
        num_epochs=epochs,
        verbose=verbose,
        main_metric=main_metric,
        minimize_metric=False,
    )