Exemple #1
0
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        # 让每次模型初始化一致, 不让只要中间有再次初始化的情况, 结果立马跑偏
        seed_reproducer(self.hparams.seed)

        self.model = se_resnext50_32x4d()
        self.criterion = CrossEntropyLossOneHot()
        self.logger_kun = init_logger("kun_in", hparams.log_dir)
Exemple #2
0
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

        # 让每次模型初始化一致, 不让只要中间有再次初始化的情况, 结果立马跑偏
        seed_reproducer(self.hparams.seed)

        # [efficientnet-b1, se_resnext50_32x4d, se_resnet50]
        self.model = CassavaModel(
            self.hparams.backbone
        ) if self.hparams.backbone == 'se_resnext50_32x4d' else CassavaModelTimm(
            self.hparams.backbone)
        self.criterion = torch.nn.CrossEntropyLoss(
        ) if not self.hparams.onehot else CrossEntropyLossOneHot()
        weight = torch.as_tensor([1., 1., 1., 1., 1.])
        # self.criterion = torch.nn.BCEWithLogitsLoss(weight=None, reduce=None, reduction='mean', pos_weight=weight)
        self.logger_kun = init_logger("kun_in",
                                      f'{hparams.log_dir}/{hparams.version}')
        self.fmix = FMix(decay_power=self.hparams.fmix_delta,
                         alpha=self.hparams.fmix_beta,
                         size=(int(self.hparams.image_size[1]),
                               self.hparams.image_size[1]),
                         max_soft=0.0,
                         reformulate=False)
Exemple #3
0
# Third party libraries
import torch
from dataset import generate_transforms
from sklearn.model_selection import KFold
from scipy.special import softmax
from torch.utils.data import DataLoader
from tqdm import tqdm

# User defined libraries
from train import CoolSystem
from utils import init_hparams, init_logger, seed_reproducer, load_data
from dataset import PlantDataset

if __name__ == "__main__":
    # Make experiment reproducible
    seed_reproducer(2020)

    # Init Hyperparameters
    hparams = init_hparams()

    # init logger
    logger = init_logger("kun_out", log_dir=hparams.log_dir)

    # Load data
    data, test_data = load_data(logger)

    # Generate transforms
    transforms = generate_transforms(hparams.image_size)

    early_stop_callback = EarlyStopping(monitor="val_roc_auc",
                                        patience=10,
from torchvision import datasets
from torch.autograd import Variable
import torch.nn.functional as F
from tqdm import tqdm
import pandas as pd
import numpy as np
import os
import sys
import csv

from utils import parse_args, seed_reproducer

# Training settings
args = parse_args(sys.argv[1:])
use_cuda = torch.cuda.is_available()
seed_reproducer()

# Create experiment folder
if not os.path.isdir(args.experiment):
    os.makedirs(args.experiment)
model_dir = os.path.join(args.experiment, 'checkpoints')
log_dir = os.path.join(args.experiment, 'logs')
if not os.path.isdir(model_dir):
    os.makedirs(model_dir)
if not os.path.isdir(log_dir):
    os.makedirs(log_dir)

from sklearn.model_selection import KFold
from data import BirdDataSetUnlabeled, BirdDataSetLabeled, return_data_transforms, mixup_data, return_data_test_transforms
from loss_function import alpha_weight, mixup_criterion, CenterLoss, LabelSmoothingCrossEntropy, linear_combination, reduce_loss
from model import Net, Net2