示例#1
0
def main():
    lane_config = Config()
    if os.path.exists(lane_config.SAVE_PATH):
        shutil.rmtree(lane_config.SAVE_PATH)
    os.makedirs(lane_config.SAVE_PATH, exist_ok=True)
    trainF = open(os.path.join(lane_config.SAVE_PATH, "train_log.csv"), 'w')
    testF = open(os.path.join(lane_config.SAVE_PATH, "val_log.csv"), 'w')
    kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    train_dataset = LaneDataset("data_list/train.csv", transform=transforms.Compose([ImageAug(), DeformAug(),
                                                                              ScaleAug(), CutOut(32, 0.5), ToTensor()]))

    train_data_batch = DataLoader(train_dataset, batch_size=8*len(device_list), shuffle=True, drop_last=True, **kwargs)
    val_dataset = LaneDataset("data_list/val.csv", transform=transforms.Compose([ToTensor()]))

    val_data_batch = DataLoader(val_dataset, batch_size=4*len(device_list), shuffle=False, drop_last=False, **kwargs)
    net = nets[train_net](lane_config)
    if torch.cuda.is_available():
        net = net.cuda(device=device_list[0])
        net = torch.nn.DataParallel(net, device_ids=device_list)
    # optimizer = torch.optim.SGD(net.parameters(), lr=lane_config.BASE_LR,
    #                             momentum=0.9, weight_decay=lane_config.WEIGHT_DECAY)
    optimizer = torch.optim.Adam(net.parameters(), lr=lane_config.BASE_LR, weight_decay=lane_config.WEIGHT_DECAY)
    for epoch in range(lane_config.EPOCHS):
        adjust_lr(optimizer, epoch)
        train_epoch(net, epoch, train_data_batch, optimizer, trainF, lane_config)
        test(net, epoch, val_data_batch, testF, lane_config)
        torch.save({'state_dict': net.state_dict()}, os.path.join(os.getcwd(), lane_config.SAVE_PATH, "laneNet{}.pth.tar".format(epoch)))
    trainF.close()
    testF.close()
    torch.save({'state_dict': net.state_dict()}, os.path.join(os.getcwd(), lane_config.SAVE_PATH, "finalNet.pth.tar"))
示例#2
0
def train(args):
    predict_net = args.net
    nets = {'deeplabv3p': DeepLab, 'unet': ResNetUNet}
    trainF = open(os.path.join(args.save_path, "train.csv"), 'w')
    valF = open(os.path.join(args.save_path, "test.csv"), 'w')
    kwargs = {
        'num_workers': args.num_works,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    train_dataset = LaneDataset("train.csv",
                                transform=transforms.Compose([
                                    ImageAug(),
                                    DeformAug(),
                                    ScaleAug(),
                                    CutOut(32, 0.5),
                                    ToTensor()
                                ]))
    train_data_batch = DataLoader(train_dataset,
                                  batch_size=2,
                                  shuffle=True,
                                  drop_last=True,
                                  **kwargs)
    val_dataset = LaneDataset("val.csv",
                              transform=transforms.Compose([ToTensor()]))
    val_data_batch = DataLoader(val_dataset,
                                batch_size=2,
                                shuffle=False,
                                drop_last=True,
                                **kwargs)
    net = nets[predict_net](args)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=args.base_lr,
                                 weight_decay=args.weight_decay)
    # Training and test
    for epoch in range(args.epochs):
        # 在train_epoch中
        train_epoch(net, epoch, train_data_batch, optimizer, trainF, args)
        val_epoch(net, epoch, val_data_batch, valF, args)
        if epoch % 2 == 0:
            torch.save({'state_dict': net.state_dict()},
                       os.path.join(os.getcwd(), args.save_path,
                                    "laneNet{}.pth.tar".format(epoch)))
    trainF.close()
    valF.close()
    torch.save({'state_dict': net.state_dict()},
               os.path.join(os.getcwd(), "result", "finalNet_unet.pth.tar"))
示例#3
0
def main():
    # 设置model parameters
    lane_config = Config()
    if os.path.exists(lane_config.SAVE_PATH):
        shutil.rmtree(lane_config.SAVE_PATH)
    os.makedirs(lane_config.SAVE_PATH, exist_ok=True)
    trainF = open(os.path.join(lane_config.SAVE_PATH, "train.csv"), 'w')
    testF = open(os.path.join(lane_config.SAVE_PATH, "test.csv"), 'w')

    # set up dataset
    # 'pin_memory'意味着生成的Tensor数据最开始是属于内存中的索页,这样的话转到GPU的显存就会很快
    # numworkers 代表子进程数目,用来为主进程加载一个batch的数据,太大会是内存溢出
    kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
    # 对训练集进行数据增强,对验证集不需要数据增强
    train_dataset = LaneDataset("train.csv", transform=transforms.Compose([ImageAug(), DeformAug(),
                                                                              ScaleAug(), CutOut(32, 0.5), ToTensor()]))

    train_data_batch = DataLoader(train_dataset, batch_size=len(device_list), shuffle=True, drop_last=True, **kwargs)
    val_dataset = LaneDataset("val.csv", transform=transforms.Compose([ToTensor()]))

    val_data_batch = DataLoader(val_dataset, batch_size=len(device_list), shuffle=False, drop_last=False, **kwargs)

    # build model
    net = nets[train_net](lane_config)
    if torch.cuda.is_available():
        net = net.cuda(device=device_list[0])
        net = torch.nn.DataParallel(net, device_ids=device_list)
    # optimizer = torch.optim.SGD(net.parameters(), lr=lane_config.BASE_LR,
    #                             momentum=0.9, weight_decay=lane_config.WEIGHT_DECAY)
    optimizer = torch.optim.Adam(net.parameters(), lr=lane_config.BASE_LR, weight_decay=lane_config.WEIGHT_DECAY)

    # Training and test
    for epoch in range(lane_config.EPOCHS):
        # adjust_lr(optimizer, epoch)
        train_epoch(net, epoch, train_data_batch, optimizer, trainF, lane_config)
        test(net, epoch, val_data_batch, testF, lane_config)
        # net.module.state_dict()
        if epoch % 2 == 0:
            torch.save({'state_dict': net.state_dict()}, os.path.join(os.getcwd(), lane_config.SAVE_PATH, "laneNet{}.pth.tar".format(epoch)))
    trainF.close()
    testF.close()
    torch.save({'state_dict': net.state_dict()}, os.path.join(os.getcwd(), lane_config.SAVE_PATH, "finalNet.pth.tar"))
示例#4
0
文件: train_1.py 项目: ShawnXiee/myDL
def main():
    # network = 'deeplabv3p'
    # save_model_path = "./model_weights/" + network + "_"
    # model_path = "./model_weights/" + network + "_0_6000"
    data_dir = ''
    val_percent = .1

    epochs = 9

    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    training_dataset = LaneDataset(
        "~/workspace/myDL/CV/week8/Lane_Segmentation_pytorch/data_list/train.csv",
        transform=transforms.Compose(
            [ImageAug(),
             DeformAug(),
             ScaleAug(),
             CutOut(32, 0.5),
             ToTensor()]))

    training_data_batch = DataLoader(training_dataset,
                                     batch_size=2,
                                     shuffle=True,
                                     drop_last=True,
                                     **kwargs)

    dataset = BasicDataset(data_dir,
                           img_size=cfg.IMG_SIZE,
                           crop_offset=cfg.crop_offset)

    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train,
                              batch_size=cfg.batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=cfg.batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)

    model = unet_base(cfg.num_classes, cfg.IMG_SIZE)
    model.cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.base_lr,
                                 betas=(0.9, 0.99))

    bce_criterion = nn.BCEWithLogitsLoss()
    dice_criterion = MulticlassDiceLoss()

    model.train()
    epoch_loss = 0

    dataprocess = tqdm(training_data_batch)
    for batch_item in dataprocess:
        image, mask = batch_item['image'], batch_item['mask']
        if torch.cuda.is_available():
            image, mask = image.cuda(), mask.cuda()
            image = image.to(torch.float32).requires_grad_()
            mask = mask.to(torch.float32).requires_grad_()

            masks_pred = model(image)
            masks_pred = torch.argmax(masks_pred, dim=1)
            masks_pred = masks_pred.to(torch.float32)
            mask = mask.to(torch.float32)

            # print('mask_pred:', masks_pred)
            # print('mask:', mask)
            loss = bce_criterion(masks_pred, mask) + dice_criterion(
                masks_pred, mask)
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
示例#5
0
from tqdm import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from utils.image_process import LaneDataset, ImageAug, DeformAug
from utils.image_process import ScaleAug, CutOut, ToTensor

kwargs = {
    'num_workers': 1,
    'pin_memory': True
} if torch.cuda.is_available() else {}
training_dataset = LaneDataset("train.csv",
                               transform=transforms.Compose([
                                   ImageAug(),
                                   DeformAug(),
                                   ScaleAug(),
                                   CutOut(32, 0.5),
                                   ToTensor()
                               ]))

training_data_batch = DataLoader(training_dataset,
                                 batch_size=16,
                                 shuffle=True,
                                 drop_last=True,
                                 **kwargs)

dataprocess = tqdm(training_data_batch)
for batch_item in dataprocess:
示例#6
0
def main():
    lane_config = Config()
    if not os.path.exists(lane_config.SAVE_PATH):
        #shutil.rmtree(lane_config.SAVE_PATH)
        os.makedirs(lane_config.SAVE_PATH, exist_ok=True)
    trainF = open(os.path.join(lane_config.SAVE_PATH, "train.csv"), 'w')
    testF = open(os.path.join(lane_config.SAVE_PATH, "test.csv"), 'w')
    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    train_dataset = LaneDataset("train.csv",
                                transform=transforms.Compose([
                                    ImageAug(),
                                    DeformAug(),
                                    ScaleAug(),
                                    CutOut(32, 0.5),
                                    ToTensor()
                                ]))

    train_data_batch = DataLoader(train_dataset,
                                  batch_size=2 * len(device_list),
                                  shuffle=True,
                                  drop_last=True,
                                  **kwargs)
    val_dataset = LaneDataset("val.csv",
                              transform=transforms.Compose([ToTensor()]))

    val_data_batch = DataLoader(val_dataset,
                                batch_size=2 * len(device_list),
                                shuffle=False,
                                drop_last=False,
                                **kwargs)

    net = nets[train_net](lane_config)

    #先将net转入cuda中
    if torch.cuda.is_available():
        print("cuda is available")
        net = net.cuda(device=device_list[0])
        #在这里加了一个数据并行,相当于甲类一个moduel
        #net = torch.nn.DataParallel(net, device_ids=device_list)

    # optimizer = torch.optim.SGD(net.parameters(), lr=lane_config.BASE_LR,
    #                             momentum=0.9, weight_decay=lane_config.WEIGHT_DECAY)

    #得到一个optimizer,若是要恢复训练,则在Resume块中重新加载参数
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=lane_config.BASE_LR,
                                 weight_decay=lane_config.WEIGHT_DECAY)

    # 是否Resume 恢复训练
    Resume = True
    epoch_to_continue = 65  #
    if Resume is True:
        checkpoint_path = os.path.join(
            os.getcwd(), lane_config.SAVE_PATH,
            "epoch{}Net.pth.tar".format(epoch_to_continue))
        if not os.path.exists(checkpoint_path):
            print("checkpoint_path not exists!")
            exit()

        checkpoint = torch.load(checkpoint_path,
                                map_location='cuda:{}'.format(device_list[0]))
        #model_param = torch.load(checkpoint_path)['state_dict']
        #model_param = {k.replace('module.', ''):v for k, v in model_param.items()}
        net.load_state_dict(checkpoint['state_dict'])  #加载net参数
        optimizer.load_state_dict(
            checkpoint['optimizer_state_dict'])  #加载optimizer参数
        epoch_to_continue = checkpoint['epoch']

    #加入数据并行
    if torch.cuda.is_available():
        #在这里加了一个数据并行,相当于甲类一个moduel
        net = torch.nn.DataParallel(net, device_ids=device_list)

    for epoch in range(epoch_to_continue + 1,
                       epoch_to_continue + lane_config.EPOCHS):
        adjust_lr(optimizer, epoch)
        train_epoch(net, epoch, train_data_batch, optimizer, trainF,
                    lane_config)
        if epoch % 5 == 0:
            #存储的参数是net的模型参数,没有网络结构
            #torch.save({'state_dict': net.module.state_dict()}, os.path.join(os.getcwd(), lane_config.SAVE_PATH, "laneNet{}.pth.tar".format(epoch)))
            #torch.save({'state_dict': net.state_dict()}, os.path.join(os.getcwd(), lane_config.SAVE_PATH, "laneNet{}.pth.tar".format(epoch)))
            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': net.module.state_dict(),  #加了module
                    'optimizer_state_dict': optimizer.state_dict(),
                },
                os.path.join(os.getcwd(), lane_config.SAVE_PATH,
                             "epoch{}Net.pth.tar".format(epoch)))

        test(net, epoch, val_data_batch, testF, lane_config)

    trainF.close()
    testF.close()
# @Author: chargerKong
# @Time: 20-6-22 下午3:18
# @File: data_generator.py

import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
from utils.image_process import LaneDataset, ToTensor, CutOut, ImageAug, DeformAug, ScaleAug
from utils.config import train_csv_path
from torchvision import transforms

kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {
}
training_dataset = LaneDataset(train_csv_path, transform=transforms.Compose(
    [ImageAug(), DeformAug(), ScaleAug(), CutOut(32, 0.5), ToTensor()]))

data_gen = DataLoader(training_dataset, batch_size=2, drop_last=True, **kwargs)


for batch in tqdm(data_gen):

    image, mask = batch['image'], batch['mask']
    if torch.cuda.is_available():
        image, mask = image.cuda(), mask.cuda()
示例#8
0
def main():
    lane_config = Config()
    if os.path.exists(lane_config.SAVE_PATH):
        shutil.rmtree(lane_config.SAVE_PATH)
    os.makedirs(lane_config.SAVE_PATH, exist_ok=True)
    trainF = open(os.path.join(lane_config.SAVE_PATH, "train.csv"), 'w')
    testF = open(os.path.join(lane_config.SAVE_PATH, "test.csv"), 'w')
    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    train_dataset = LaneDataset("train.csv",
                                transform=transforms.Compose([
                                    ImageAug(),
                                    DeformAug(),
                                    ScaleAug(),
                                    ToTensor()
                                ]))

    train_data_batch = DataLoader(train_dataset,
                                  batch_size=4 * len(device_list),
                                  shuffle=True,
                                  drop_last=True,
                                  **kwargs)
    val_dataset = LaneDataset("val.csv",
                              transform=transforms.Compose([ToTensor()]))

    val_data_batch = DataLoader(val_dataset,
                                batch_size=2 * len(device_list),
                                shuffle=False,
                                drop_last=False,
                                **kwargs)
    net = DeeplabV3Plus(lane_config)
    # net = UNet(n_classes=8)
    if torch.cuda.is_available():
        net = net.cuda(device=device_list[0])
        net = torch.nn.DataParallel(net, device_ids=device_list)
        # optimizer = torch.optim.SGD(net.parameters(), lr=lane_config.BASE_LR,
        # momentum=0.9, weight_decay=lane_config.WEIGHT_DECAY)
    # summary(net, (3, 384, 1024))
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=lane_config.BASE_LR,
                                 weight_decay=lane_config.WEIGHT_DECAY)
    path = "/home/ubuntu/baidu/Lane-Segmentation/logs/finalNet.pth"
    # if os.path.exists(path):
    #     checkpoint = torch.load(path)
    #     net.load_state_dict(checkpoint['model'])
    #     optimizer.load_state_dict(checkpoint['optimizer'])
    #     start_epoch = checkpoint['epoch']
    #     print('加载 epoch {} 成功!'.format(start_epoch))
    # else:
    #     start_epoch = 0
    #     print('无保存模型,将从头开始训练!')

    for epoch in range(lane_config.EPOCHS):
        # adjust_lr(optimizer,epoch)
        train_epoch(net, epoch, train_data_batch, optimizer, trainF,
                    lane_config)
        test(net, epoch, val_data_batch, testF, lane_config)
        if epoch % 5 == 0:
            path1 = "/home/ubuntu/baidu/Lane-Segmentation/logs/laneNet{}.pth".format(
                epoch)
            state = {
                'model': net.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch
            }
            torch.save(state, path1)
    trainF.close()
    testF.close()
    state = {
        'model': net.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': lane_config.EPOCHS
    }
    torch.save(state, path)
示例#9
0
def main():
    #设置model parameters
    lane_config = Config()

    #查看路径是否存在
    if os.path.exists(lane_config.SAVE_PATH):
        #如果存在的话,全部删掉
        shutil.rmtree(lane_config.SAVE_PATH)
    #建立一个新的文件件
    os.makedirs(lane_config.SAVE_PATH, exist_ok=True)

    #打开文件夹,在这两个文件内记录
    trainF = open(os.path.join(lane_config.SAVE_PATH, "train.csv"), 'w')
    testF = open(os.path.join(lane_config.SAVE_PATH, "test.csv"), 'w')

    #set up dataset
    # 'pin_memory'意味着生成的Tensor数据最开始是属于内存中的索页,这样的话转到GPU的显存就会很快
    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}

    #set up training dataset
    train_dataset = LaneDataset("train.csv",
                                transform=transforms.Compose([
                                    ImageAug(),
                                    DeformAug(),
                                    ScaleAug(),
                                    CutOut(32, 0.5),
                                    ToTensor()
                                ]))

    #set up training dataset 的dataloader
    train_data_batch = DataLoader(train_dataset,
                                  batch_size=8 * len(device_list),
                                  shuffle=True,
                                  drop_last=True,
                                  **kwargs)

    #set ip validation dataset
    val_dataset = LaneDataset("val.csv",
                              transform=transforms.Compose([ToTensor()]))

    #set up validation dataset's dataloader
    val_data_batch = DataLoader(val_dataset,
                                batch_size=4 * len(device_list),
                                shuffle=False,
                                drop_last=False,
                                **kwargs)

    #build model

    net = DeeplabV3Plus(lane_config)

    #检测一下环境中是否存在GPU,存在的话就转化成cuda的格式
    if torch.cuda.is_available():
        net = net.cuda(device=device_list[0])
        net = torch.nn.DataParallel(net, device_ids=device_list)

    #config the optimizer
    # optimizer = torch.optim.SGD(net.parameters(), lr=lane_config.BASE_LR,
    #                             momentum=0.9, weight_decay=lane_config.WEIGHT_DECAY)

    #查一下weight_decay的作用
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=lane_config.BASE_LR,
                                 weight_decay=lane_config.WEIGHT_DECAY)

    #Training and test
    for epoch in range(lane_config.EPOCHS):
        # adjust_lr(optimizer, epoch)
        #在train_epoch中
        train_epoch(net, epoch, train_data_batch, optimizer, trainF,
                    lane_config)

        test(net, epoch, val_data_batch, testF, lane_config)

        if epoch % 2 == 0:
            torch.save({'state_dict': net.state_dict()},
                       os.path.join(os.getcwd(), lane_config.SAVE_PATH,
                                    "laneNet{}.pth.tar".format(epoch)))
    trainF.close()
    testF.close()

    torch.save({'state_dict': net.state_dict()},
               os.path.join(os.getcwd(), lane_config.SAVE_PATH,
                            "finalNet.pth.tar"))
示例#10
0
from tqdm import tqdm
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from utils.image_process import LaneDataset, ImageAug, DeformAug
from utils.image_process import ScaleAug, CutOut, ToTensor


kwargs = {'num_workers': 4, 'pin_memory': True} if torch.cuda.is_available() else {}
training_dataset = LaneDataset("train.csv", transform=transforms.Compose([ImageAug(), DeformAug(),
                                                                          ScaleAug(), CutOut(32,0.5), ToTensor()]))


#真正开始处理数据
training_data_batch = DataLoader(training_dataset, batch_size=16,
                                 shuffle=True, drop_last=True, **kwargs)
"""
102
20
2"""
for batch_item in training_data_batch:
    image, mask = batch_item['image'], batch_item['mask'] #得到的就是经过数据处理的
    if torch.cuda.is_available():
        image, mask = image.cuda(), mask.cuda()

    #如果有模型的话,就是讲数据加载进模型开始训练了
    #  prediction = model(image)
    # loss = f (prediction,mask)
示例#11
0
def main():

    # using multi process to load data when cuda is available
    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    # set image augment
    augments = [ImageAug(), ToTensor()]
    # get dataset and iterable dataloader
    train_dataset = LaneSegTrainDataset("train.csv",
                                        transform=transforms.Compose(augments))
    eval_dataset = LaneSegTrainDataset("eval.csv",
                                       transform=transforms.Compose(
                                           [ToTensor()]))
    if cfg.MULTI_GPU:
        train_batch_size = cfg.TRAIN_BATCH_SIZE * len(cfg.DEVICE_LIST)
        eval_batch_size = cfg.EVAL_BATCH_SIZE * len(cfg.DEVICE_LIST)

    else:
        train_batch_size = cfg.TRAIN_BATCH_SIZE
        eval_batch_size = cfg.EVAL_BATCH_SIZE
    train_data_batch = DataLoader(train_dataset,
                                  batch_size=train_batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  **kwargs)
    eval_data_batch = DataLoader(eval_dataset,
                                 batch_size=eval_batch_size,
                                 shuffle=False,
                                 drop_last=False,
                                 **kwargs)

    # define model
    if cfg.MODEL == 'deeplabv3+':
        net = Deeplabv3plus(class_num=cfg.CLASS_NUM, normal=cfg.NORMAL)
    elif cfg.MODEL == 'unet':
        net = UNetv1(class_num=cfg.CLASS_NUM, normal=cfg.NORMAL)
    else:
        net = UNetv1(class_num=cfg.CLASS_NUM, normal=cfg.NORMAL)

    # use cuda if available
    if torch.cuda.is_available():
        if cfg.MULTI_GPU:
            net = torch.nn.DataParallel(net, device_ids=cfg.DEVICE_LIST)
            net = net.cuda(device=cfg.DEVICE_LIST[0])
        else:
            net = net.cuda()
        # load pretrained weights
        if cfg.PRE_TRAINED:
            checkpoint = torch.load(
                os.path.join(cfg.LOG_DIR, cfg.PRE_TRAIN_WEIGHTS))
            net.load_state_dict(checkpoint['state_dict'])

    # define optimizer
    # optimizer = torch.optim.SGD(net.parameters(),
    #                             lr=cfg.BASE_LR,
    #                             momentum=0.9,
    #                             weight_decay=cfg.WEIGHT_DECAY)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=cfg.BASE_LR,
                                 weight_decay=cfg.WEIGHT_DECAY)
    # criterion = CrossEntropyLoss(cfg.CLASS_NUM)
    criterion = FocalLoss(cfg.CLASS_NUM)
    # define log file
    train_log = open(
        os.path.join(cfg.LOG_DIR, "train_log_{}.csv".format(cfg.TRAIN_NUMBER)),
        'w')
    train_log_title = "epoch,average loss\n"
    train_log.write(train_log_title)
    train_log.flush()
    eval_log = open(
        os.path.join(cfg.LOG_DIR, "eval_log_{}.csv".format(cfg.TRAIN_NUMBER)),
        'w')
    eval_log_title = "epoch,average_loss, mean_iou, iou_0,iou_1,iou_2,iou_3,iou_4,iou_5,iou_6,iou_7, " \
                     "mean_precision, precision_0,precision_1,precision_2,precision_3, precision_4," \
                     "precision_5,precision_6,precision_7, mean_recall, recall_0,recall_1,recall_2," \
                     "recall_3,recall_4,recall_5,recall_6,recall_7\n"
    eval_log.write(eval_log_title)
    eval_log.flush()

    # train and test epoch by epoch
    for epoch in range(cfg.EPOCHS):
        print('current epoch learning rate: {}'.format(cfg.BASE_LR))
        train_epoch(net, epoch, train_data_batch, optimizer, criterion,
                    train_log)
        # save model
        if epoch != cfg.EPOCHS - 1:
            torch.save({'state_dict': net.state_dict()},
                       os.path.join(
                           cfg.LOG_DIR,
                           "laneNet_{0}_{1}th_epoch_{2}.pt".format(
                               cfg.MODEL, cfg.TRAIN_NUMBER, epoch)))
        else:
            torch.save({'state_dict': net.state_dict()},
                       os.path.join(
                           cfg.LOG_DIR, "laneNet_{0}_{1}th.pt".format(
                               cfg.MODEL, cfg.TRAIN_NUMBER)))
        eval_epoch(net, epoch, eval_data_batch, eval_log)

    train_log.close()
    eval_log.close()