コード例 #1
0
def validate(config_path, use_zoom_tta, f):
    config = load_config(config_path)
    model = get_model(config.model.name, config.model.pretrained_model_path)
    make_output_dir(config, f)

    if use_zoom_tta:
        tta_zoom_list = [1.0, 0.9, 0.8]
    else:
        tta_zoom_list = [config.data.tta_zoom]

    for tta_zoom in tta_zoom_list:
        valid_loader = ImetDataset(
            batch_size=config.eval.batch_size,
            mode="valid",
            img_size=config.data.img_size,
            tta_zoom=tta_zoom,
            valid_csv=config.data.valid_csv).get_loader()

        pickle_path = os.path.join(
            config.path.out, "{}_{}_{}.pickle".format(
                config.model.name,
                os.path.basename(config.model.pretrained_model_path),
                tta_zoom))
        valid_one_epoch(valid_loader,
                        model,
                        flip_tta=True,
                        pickle_name=pickle_path)

    search(config.path.out)
コード例 #2
0
ファイル: train.py プロジェクト: rskmoi/kaggle-imet
def train(config_path, f):
    config = load_config(config_path)
    make_output_dir(config, f)

    train_loader = ImetDataset(batch_size=config.train.batch_size,
                               mode="train",
                               img_size=config.data.img_size,
                               train_csv=config.data.train_csv).get_loader()
    valid_loader = ImetDataset(batch_size=config.eval.batch_size,
                               mode="valid",
                               img_size=config.data.img_size,
                               valid_csv=config.data.valid_csv).get_loader()
    model = get_model(config.model.name, config.model.pretrained_model_path,
                      config.model.multi)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.train.lr)
    train_total = len(train_loader.dataset)
    scheduler = CyclicLRWithRestarts(optimizer,
                                     config.train.batch_size,
                                     train_total,
                                     restart_period=2,
                                     t_mult=1.)
    criterion = FocalLoss()

    for epoch in range(config.train.num_epochs):
        scheduler.step()
        train_one_epoch(train_loader, model, optimizer, scheduler, criterion,
                        epoch, config.path.out)
        valid_one_epoch(valid_loader, model, epoch)
コード例 #3
0
ファイル: submit.py プロジェクト: rskmoi/kaggle-imet
def inference(config_path, use_zoom_tta, f):
    config = load_config(config_path)
    model = get_model(config.model.name, config.model.pretrained_model_path)
    make_output_dir(config, f)

    if use_zoom_tta:
        tta_zoom_list = [1.0, 0.9, 0.8]
    else:
        tta_zoom_list = [config.data.tta_zoom]

    for tta_zoom in tta_zoom_list:
        test_dataset = ImetDataset(
            batch_size=config.eval.batch_size,
            mode="test",
            img_size=config.data.img_size,
            tta_zoom=tta_zoom,
            valid_csv=config.data.valid_csv).get_loader()

        pickle_path = os.path.join(
            config.path.out, "{}_{}_{}.pickle".format(
                config.model.name,
                os.path.basename(config.model.pretrained_model_path),
                tta_zoom))

        inference_for_submit(test_dataset,
                             model,
                             config.data.img_size,
                             pickle_name=pickle_path)
コード例 #4
0
def do_calculation():
    model_used = 'rf'
    model, test_val = get_model([.2, 1, .1, 1], model_used)
    prediction_number = model.predict(test_val)
    probabilities = model.predict_proba(test_val)
    pred_label_dict = {0: 'Setosa', 1: 'Versicolor', 2: 'Virginica'}
    pred_label = pred_label_dict[prediction_number[0]]
    return pred_label, probabilities, model_used
コード例 #5
0
def main(cfg: DictConfig) -> None:
    print(cfg.pretty())
    neptune_logger = CustomNeptuneLogger(params=flatten_dict(
        OmegaConf.to_container(cfg, resolve=True)),
                                         **cfg.logging.neptune_logger)
    tb_logger = loggers.TensorBoardLogger(**cfg.logging.tb_logger)

    lr_logger = LearningRateLogger()

    # TODO change to cyclicLR per epochs
    my_callback = MyCallback(cfg)

    model = get_model(cfg)
    if cfg.model.ckpt_path is not None:
        ckpt_pth = glob.glob(utils.to_absolute_path(cfg.model.ckpt_path))
        model = load_pytorch_model(ckpt_pth[0], model)

    seed_everything(2020)

    # TODO change to enable logging losses
    lit_model = O2UNetSystem(hparams=cfg, model=model)

    checkpoint_callback_conf = OmegaConf.to_container(
        cfg.callbacks.model_checkpoint, resolve=True)
    checkpoint_callback = ModelCheckpoint(**checkpoint_callback_conf)

    early_stop_callback_conf = OmegaConf.to_container(cfg.callbacks.early_stop,
                                                      resolve=True)
    early_stop_callback = EarlyStopping(**early_stop_callback_conf)

    trainer = Trainer(
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stop_callback,
        logger=[tb_logger, neptune_logger],
        # logger=[tb_logger],
        callbacks=[lr_logger, my_callback],
        **cfg.trainer)

    # TODO change to train with all data

    datasets = get_datasets(OmegaConf.to_container(cfg, resolve=True))
    train_dataset = datasets["train"]
    valid_dataset = datasets["valid"]
    trainer.fit(
        lit_model,
        train_dataloader=DataLoader(train_dataset,
                                    **cfg["training"]["dataloader"]["train"]),
        val_dataloaders=DataLoader(valid_dataset,
                                   **cfg["training"]["dataloader"]["valid"]))
コード例 #6
0
ファイル: prediction.py プロジェクト: singlinhaha/cassava
    def __init__(self, cfg):
        self.dataset_root = cfg["test_dataset"]
        self.batch_size = cfg["batch_size"]
        self.img_width = cfg["img_width"]
        self.img_hight = cfg["img_hight"]
        self.model_weights = cfg["model_weights"]
        self.phase = cfg["phase"]
        self.save_dir = os.path.join(cfg["save_dir"], "test")
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)

        self.model = get_model(model_weight_path=None,
                               model_name=cfg["model_name"],
                               out_features=cfg["num_classes"],
                               img_width=cfg["img_width"],
                               img_hight=cfg["img_hight"],
                               verbose=True)

        self.judge_model_weight_path()
コード例 #7
0
def get_val():
    data_json = request.get_json()
    print(data_json)
    model_used = data_json['model_used']
    sepal_test = data_json['sepal']
    petal_test = data_json['petal']

    model, test_val = get_model([
        sepal_test['length'], sepal_test['width'], petal_test['length'],
        petal_test['width']
    ], model_used)

    # prediction = 'setosa'
    # vals = [.5, .2, .3]
    flowers = ["setosa", "versicolor", "virginica"]
    #
    probabilities = model.predict_proba(test_val)
    json_out = dict(zip(flowers, list(probabilities[0])))
    # return json.dumps(json_out)
    print(json_out)
    return jsonify(json_out)
コード例 #8
0
ファイル: train.py プロジェクト: hirune924/lightning-hydra
def main(cfg: DictConfig) -> None:
    print(cfg.pretty())
    neptune_logger = CustomNeptuneLogger(params=flatten_dict(
        OmegaConf.to_container(cfg, resolve=True)),
                                         **cfg.logging.neptune_logger)
    tb_logger = loggers.TensorBoardLogger(**cfg.logging.tb_logger)

    lr_logger = LearningRateLogger()

    my_callback = MyCallback(cfg)

    model = get_model(cfg)
    if cfg.model.ckpt_path is not None:
        ckpt_pth = glob.glob(utils.to_absolute_path(cfg.model.ckpt_path))
        model = load_pytorch_model(ckpt_pth[0], model)
    if cfg.trainer.distributed_backend == 'ddp':
        model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    seed_everything(2020)

    lit_model = PLRegressionImageClassificationSystem(hparams=cfg, model=model)

    checkpoint_callback_conf = OmegaConf.to_container(
        cfg.callbacks.model_checkpoint, resolve=True)
    checkpoint_callback = ModelCheckpoint(**checkpoint_callback_conf)

    early_stop_callback_conf = OmegaConf.to_container(cfg.callbacks.early_stop,
                                                      resolve=True)
    early_stop_callback = EarlyStopping(**early_stop_callback_conf)

    trainer = Trainer(
        checkpoint_callback=checkpoint_callback,
        early_stop_callback=early_stop_callback,
        logger=[tb_logger, neptune_logger],
        # logger=[tb_logger],
        callbacks=[lr_logger, my_callback],
        **cfg.trainer)

    trainer.fit(lit_model)
コード例 #9
0
prior_args.batch_size = eval_args.samples

##################
## Specify data ##
##################

eval_loader, data_shape, cond_shape = get_data(args, eval_only=True)

####################
## Specify models ##
####################

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# conditional model
model = get_model(args, data_shape=data_shape, cond_shape=cond_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)
checkpoint = torch.load(path_check, map_location=torch.device(device))
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model = model.eval()
print('Loaded weights for conditional model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

# prior model
prior_model = get_model(prior_args,
                        data_shape=(data_shape[0],
                                    data_shape[1] // args.sr_scale_factor,
                                    data_shape[2] // args.sr_scale_factor))
if prior_args.parallel == 'dp':
コード例 #10
0
add_optim_args(parser)
args = parser.parse_args()
set_seeds(args.seed)

##################
## Specify data ##
##################

train_loader, eval_loader = get_data(args)
data_id = get_data_id(args)

###################
## Specify model ##
###################

model = get_model(args)
model_id = get_model_id(args)

#######################
## Specify optimizer ##
#######################

optimizer, scheduler_iter, scheduler_epoch = get_optim(args,
                                                       model.parameters())
optim_id = get_optim_id(args)

##############
## Training ##
##############

exp = TeacherExperiment(args=args,
コード例 #11
0
ファイル: train.py プロジェクト: menegop/HPA_kaggle
    image_tensors = []
    weak_labels_list = []
    info_dictionaries = []
    for image, weak_labels, info in batch:
        image_tensors += [image]
        weak_labels_list += [weak_labels]
        info_dictionaries += [info]
    return image_tensors, weak_labels_list, info_dictionaries


device = torch.device('cuda') if torch.cuda.is_available() else torch.device(
    'cpu')
dataset = CellDataset("dataset/short_train.csv", "dataset/train",
                      "dataset/short_mask/hpa_cell_mask")

model = get_model()
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=7,
                                          collate_fn=collate_images)


def clip_gradient(model, clip_norm):
    """Computes a gradient clipping coefficient based on gradient norm."""
    totalnorm = 0
    for p in model.parameters():
        if p.requires_grad and p.grad is not None:
            modulenorm = p.grad.norm()
コード例 #12
0
                    help='Path of pretrained model weights.')
parser.add_argument('--file',
                    action='store',
                    type=str,
                    help='Path of raw point cloud file.')
parser.add_argument('--output',
                    action='store',
                    type=str,
                    help='Path of prediction destination folder.')

config = parser.parse_args()

if not os.path.exists(config.output):
    os.mkdir(config.output)

model = get_model(num_points=config.num_points, num_classes=config.num_classes)

model.load_weights(config.weights)

f = h5py.File(config.file, 'r')

points = f['points']
npoints = f['normalized_points']
labels = f['labels']

data = npoints[:, :, 0:3]

predictions = model.predict(data)

print(f'file: {config.file} - prediction shape: {predictions.shape}')
コード例 #13
0
args.batch_size = eval_args.samples

##################
## Specify data ##
##################

eval_loader, data_shape, cond_shape = get_data(args, eval_only=True)

###################
## Specify model ##
###################

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = get_model(args, data_shape=data_shape, cond_shape=cond_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)
checkpoint = torch.load(path_check, map_location=torch.device(device))
model.load_state_dict(checkpoint['model'])
model = model.to(device)
model = model.eval()
print('Loaded weights for model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

############
## Sample ##
############


def save_images(imgs, file_path, num_bits=args.num_bits, nrow=eval_args.nrow):
コード例 #14
0
ファイル: train.py プロジェクト: linxin98/diff
    config.read(args.config)
    # Get logger.
    logger = logger.get_logger(config)
    logger.info('Finishing program initialization.')
    # Set device.
    if config['normal']['device'] == 'CUDA' and torch.cuda.is_available():
        device = 'cuda:' + config['normal']['gpu_id']
        torch.backends.cudnn.benchmark = True
    else:
        device = 'cpu'
    logger.info('Set device:' + device)
    device = torch.device(device)
    setup_seed(0)

    # Get model.
    base_model, diff_attention_model = model.get_model(config)
    base_model = base_model.to(device)
    base_model.eval()
    if diff_attention_model is not None:
        diff_attention_model = diff_attention_model.to(device)
        diff_attention_model.eval()
    logger.info('Get model.')

    # Get data loaders.
    train_loader, query_loader, gallery_loader = loader.get_data_loaders(config, base_model=base_model, device=device)
    logger.info('Get data loaders.')

    # Get loss.
    loss, center_loss = loss.get_loss(config, device)
    logger.info('Get loss.')
コード例 #15
0
torch.manual_seed(eval_args.seed)

###############
## Load args ##
###############

with open(path_args, 'rb') as f:
    args = pickle.load(f)

################
## Experiment ##
################

if eval_args.model_type == "flow":
    student, teacher, data_id = get_model(args)
    model_id = get_model_id(args)
    args.dataset = data_id

    optimizer, scheduler_iter, scheduler_epoch = get_optim(
        args, student.parameters())
    optim_id = get_optim_id(args)

    exp = StudentExperiment(args=args,
                            data_id=data_id,
                            model_id=model_id,
                            optim_id=optim_id,
                            model=student,
                            teacher=teacher,
                            optimizer=optimizer,
                            scheduler_iter=scheduler_iter,
コード例 #16
0
ファイル: test.py プロジェクト: yanzhicong/VAE-GAN
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    tf.reset_default_graph()

    # load config file
    config = get_config(args.config_file, args.disp_config)

    # make the assets directory and copy the config file to it
    # so if you want to reproduce the result in assets dir
    # just copy the config_file.json to ./cfgs folder and run python3 train.py --config=(config_file)
    if not os.path.exists(config['assets dir']):
        os.makedirs(config['assets dir'])
    copyfile(os.path.join('./cfgs', args.config_file + '.json'),
             os.path.join(config['assets dir'], 'config_file.json'))

    # prepare dataset
    dataset = get_dataset(config['dataset'], config['dataset params'])

    tfconfig = tf.ConfigProto()
    tfconfig.gpu_options.allow_growth = True

    with tf.Session(config=tfconfig) as sess:

        # build model
        config['model params']['assets dir'] = config['assets dir']
        model = get_model(config['model'], config['model params'])

        # start testing
        config['tester params']['assets dir'] = config['assets dir']
        trainer = get_trainer(config['tester'], config['tester params'], model)
        trainer.train(sess, dataset, model)
コード例 #17
0
    return tuple(zip(*batch))


train_loader = DataLoader(train_dataset,
                          batch_size=1,
                          shuffle=True,
                          num_workers=10,
                          collate_fn=collate_fn)

test_loader = DataLoader(validation_dataset,
                         batch_size=1,
                         shuffle=True,
                         num_workers=10,
                         collate_fn=collate_fn)

model = get_model(VERSION_FAST)

if resume_net_path is not None:
    print("load weights from {}".format(resume_net_path))
    model.load_state_dict(torch.load(resume_net_path))

model.eval()
print('Finished loading model!')
print(model)

#device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cpu')
print("device = ", device)
model.to(device)

params = [param for param in model.parameters() if param.requires_grad]
コード例 #18
0
                    heatmap)
                cv2.imwrite(
                    os.path.join(
                        save_dir, "{}_{}_to_{}_superimposed.jpg".format(
                            name, label, idx)), superimposed_img)


if __name__ == "__main__":
    from model.model import get_model
    from config import cfg

    root = "/media/biototem/Elements/lisen/haosen/competition/dataset/PET_data/val/AD"
    img_to_idx = dict(
        zip([i.split(".")[0] for i in os.listdir(root)],
            [0] * len(os.listdir(root))))
    save_dir = "/media/biototem/Elements/lisen/haosen/competition/PET/output/densenet121_fold=3_epoch=80/AD_true"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    model = get_model(model_weight_path=None,
                      model_name=cfg["model_name"],
                      out_features=cfg["num_classes"],
                      img_width=cfg["img_width"],
                      img_hight=cfg["img_hight"],
                      verbose=True)
    model.load_state_dict(torch.load(cfg["model_weights"][0]))
    cam = GradCAM(model=model,
                  feature_layer=model.features,
                  img_width=cfg["img_width"],
                  img_hight=cfg["img_hight"],
                  idx_to_class=cfg["idx_to_class"])
    cam.createCAM(root=root, img_to_idx=img_to_idx, save_dir=save_dir)
コード例 #19
0
                                                stratify=labels,
                                                random_state=42)

# construct the image generator for data augmentation
aug = ImageDataGenerator(rotation_range=10,
                         zoom_range=0.05,
                         width_shift_range=0.1,
                         height_shift_range=0.1,
                         shear_range=0.15,
                         horizontal_flip=False,
                         fill_mode="nearest")

# initialize and compile our deep neural network
print("[INFO] Compiling model...")
opt = SGD(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model = model.get_model()
model.compile(loss="categorical_crossentropy",
              optimizer="adam",
              metrics=["accuracy"])

# train the network
print("[INFO] Training network...")
H = model.fit(aug.flow(trainX, trainY, batch_size=BS),
              validation_data=(testX, testY),
              steps_per_epoch=len(trainX) // BS,
              epochs=EPOCHS,
              verbose=1)

# define the list of label names
labelNames = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
labelNames = [l for l in labelNames]
コード例 #20
0
def main_model():

    return model.get_model()
コード例 #21
0
ファイル: train.py プロジェクト: singlinhaha/cassava
def main(cfg, step):
    model_save_dir = os.path.join(cfg["output_dir"], "weights" + step)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    history_save_dir = os.path.join(cfg["output_dir"], "visual" + step)
    if not os.path.exists(history_save_dir):
        os.makedirs(history_save_dir)

    # time_mark = '2020_03_06_'
    # 以当前时间作为保存的文件名标识
    time_mark = time.strftime('%Y_%m_%d_', time.localtime(time.time()))
    file_path = os.path.join(model_save_dir,
                             time_mark + "epoch_{epoch}-model_weights.pth")
    history_path = os.path.join(history_save_dir, time_mark + "result.csv")
    callbacks_s = call_backs(file_path, history_path)

    train_dataset = ImageSelectFolder(
        root=cfg["train_dataset"],
        label=cfg["label"],
        select_condition=cfg["train_select"],
        data_expansion=True,
        transform=transforms.Compose([
            transforms.RandomApply(
                [
                    transforms.RandomCrop(size=(448, 448)),
                    # transforms.RandomResizedCrop(size=cfg["img_width"]),
                ],
                p=0.3),
            transforms.ColorJitter(brightness=0.2,
                                   contrast=0.2,
                                   saturation=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(360),
            transforms.Resize((cfg["img_width"], cfg["img_hight"])),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ]))
    val_dataset = ImageSelectFolder(
        root=cfg["val_dataset"],
        label=cfg["label"],
        select_condition=cfg["val_select"],
        transform=transforms.Compose([
            transforms.Resize((cfg["img_width"], cfg["img_hight"])),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ]))

    train_dataload = DataLoader(dataset=train_dataset,
                                batch_size=cfg["batch_size"],
                                shuffle=True)
    val_dataload = DataLoader(dataset=val_dataset,
                              batch_size=cfg["batch_size"],
                              shuffle=False)

    model = get_model(model_weight_path=cfg["model_weight_path"],
                      model_name=cfg["model_name"],
                      out_features=cfg["num_classes"],
                      img_width=cfg["img_width"],
                      img_hight=cfg["img_hight"],
                      verbose=True)
    model.cuda()
    loss_function = nn.CrossEntropyLoss().cuda()

    # 定义额外的评价指标
    recall = GetRecallScore(average="micro")
    precision = GetPrecisionScore(average="micro")
    f1 = GetF1Score(average="micro")
    metrics = {"recall": recall, "precision": precision, "f1 score": f1}

    train_transfer_learning(model,
                            loss_function,
                            train_dataload,
                            val_dataload,
                            cfg["tl_lr"],
                            epochs=3,
                            metrics=metrics)
    trrain_fine_tuning(model,
                       loss_function,
                       train_dataload,
                       val_dataload,
                       history_save_dir,
                       lr=cfg["ft_lr"],
                       epochs=cfg["nepochs"],
                       callbacks=callbacks_s,
                       metrics=metrics)
    del model
コード例 #22
0
from torch.utils.data import DataLoader

from torchvision.ops import boxes as box_ops
from PIL import Image
import torch.nn.functional as F
import time
import numpy as np
import cv2

from model.model import get_model

import torch
import torchvision

NN_WEIGHT_FILE_PATH = 'D:\model\efficient_rcnn_9.pth'
torch_model = get_model(49)
torch_model.load_state_dict(torch.load(NN_WEIGHT_FILE_PATH))
#torch_model = torch.load(NN_WEIGHT_FILE_PATH) # pytorch模型加载
batch_size = 1  #批处理大小
input_shape = (3, 244, 244)  #输入数据

# set the model to inference mode
torch_model.eval()

x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = torch_model(x)
# optionally, if you want to export the model to ONNX:
torch.onnx.export(torch_model, x, "faster_rcnn.onnx", opset_version=11)
'''
x = torch.randn(batch_size,*input_shape)		# 生成张量
export_onnx_file = "test.onnx"					# 目的ONNX文件名
コード例 #23
0
###############

with open(path_args, 'rb') as f:
    args = pickle.load(f)

##################
## Specify data ##
##################

_, _, data_shape = get_data(args)

###################
## Specify model ##
###################

model = get_model(args, data_shape=data_shape)
if args.parallel == 'dp':
    model = DataParallelDistribution(model)
checkpoint = torch.load(path_check)
model.load_state_dict(checkpoint['model'])
print('Loaded weights for model at {}/{} epochs'.format(
    checkpoint['current_epoch'], args.epochs))

############
## Sample ##
############

path_samples = '{}/samples/sample_ep{}_s{}.png'.format(
    eval_args.model, checkpoint['current_epoch'], eval_args.seed)
if not os.path.exists(os.path.dirname(path_samples)):
    os.mkdir(os.path.dirname(path_samples))
コード例 #24
0
ファイル: train.py プロジェクト: singlinhaha/cassava
def train(cfg, step):
    # 设置保存目录
    model_save_dir = os.path.join(cfg["output_dir"], "weights" + step)
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    history_save_dir = os.path.join(cfg["output_dir"], "visual" + step)
    if not os.path.exists(history_save_dir):
        os.makedirs(history_save_dir)

    # time_mark = '2020_03_06_'
    # 以当前时间作为保存的文件名标识
    time_mark = time.strftime('%Y_%m_%d_', time.localtime(time.time()))
    file_path = os.path.join(model_save_dir,
                             time_mark + "epoch_{epoch}-model_weights.pth")
    history_path = os.path.join(history_save_dir, time_mark + "result.csv")
    callbacks_s = call_backs(file_path, history_path)

    # 加载数据集
    train_dataset = ImageSelectFolder(
        root=cfg["train_dataset"],
        label=cfg["label"],
        select_condition=cfg["train_select"],
        data_expansion=True,
        transform=transforms.Compose([
            transforms.RandomApply(
                [
                    transforms.RandomCrop(size=(448, 448)),
                    # transforms.RandomResizedCrop(size=cfg["img_width"]),
                ],
                p=0.3),
            transforms.ColorJitter(brightness=0.2,
                                   contrast=0.2,
                                   saturation=0.2),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(360),
            transforms.Resize((cfg["img_width"], cfg["img_hight"])),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ]))
    val_dataset = ImageSelectFolder(
        root=cfg["val_dataset"],
        label=cfg["label"],
        select_condition=cfg["val_select"],
        transform=transforms.Compose([
            transforms.Resize((cfg["img_hight"], cfg["img_width"])),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406),
                                 std=(0.229, 0.224, 0.225))
        ]))

    train_dataload = DataLoader(dataset=train_dataset,
                                batch_size=cfg["batch_size"],
                                shuffle=True)
    val_dataload = DataLoader(dataset=val_dataset,
                              batch_size=cfg["batch_size"],
                              shuffle=False)

    model = get_model(model_weight_path=cfg["model_weight_path"],
                      model_name=cfg["model_name"],
                      out_features=cfg["num_classes"],
                      img_width=cfg["img_width"],
                      img_hight=cfg["img_hight"],
                      verbose=False)
    model.cuda()
    loss_function = nn.CrossEntropyLoss().cuda()
    optimizer = optim.Adam(model.parameters(), lr=cfg["lr"], weight_decay=1e-4)
    fit_generator = TrainFitGenerator(net=model,
                                      optimizer=optimizer,
                                      loss_function=loss_function,
                                      generator=train_dataload,
                                      epochs=cfg["nepochs"],
                                      validation_data=val_dataload,
                                      callbacks=callbacks_s)
    fit_generator.run()
    plot_training_metrics(fit_generator.history,
                          history_save_dir,
                          "loss",
                          title=f"train and validation loss",
                          is_show=False)
    plot_training_metrics(fit_generator.history,
                          history_save_dir,
                          "acc",
                          title=f"train and validation accuracy",
                          is_show=False)