Пример #1
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> input size:{}".format(input_size))

    print(args)

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    # set the seed
    setup_seed(GLOBAL_SEED)
    print("=====> set Global Seed: ", GLOBAL_SEED)

    cudnn.enabled = True
    print("=====> building network")

    # build the model and initialization
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model,
                nn.init.kaiming_normal_,
                nn.BatchNorm2d,
                1e-3,
                0.1,
                mode='fan_in')

    print("=====> computing network parameters and FLOPs")
    total_parameters = netParams(model)
    print("the number of parameters: %d ==> %.2f M" %
          (total_parameters, (total_parameters / 1e6)))

    # load data and data augmentation
    datas, trainLoader, valLoader = build_dataset_train(
        args.dataset, args.classes, input_size, args.batch_size,
        args.train_type, False, False, args.num_workers)

    args.per_iter = len(trainLoader)
    args.max_iter = args.max_epochs * args.per_iter

    print('=====> Dataset statistics')
    print("data['classWeights']: ", datas['classWeights'])
    print('mean and std: ', datas['mean'], datas['std'])
    # datas['classWeights'] = np.array([4.044603, 2.0614128, 4.2246304, 6.0238333,
    #                                   10.107266, 8.601249, 8.808282], dtype=np.float32)
    # datas['mean'] = [0.5, 0.5, 0.5]
    # datas['std'] = [0.2, 0.2, 0.2]

    # define loss function, respectively
    weight = torch.from_numpy(datas['classWeights'])
    if args.dataset == 'pollen':
        weight = torch.tensor([1., 1.])

    if args.dataset == 'camvid':
        criteria = CrossEntropyLoss2d(weight=weight,
                                      ignore_label=args.ignore_label)
    elif args.dataset == 'camvid' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(
            weight=weight, ignore_label=args.ignore_label)

    elif args.dataset == 'cityscapes' and args.use_ohem:
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=args.ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    elif args.dataset == 'cityscapes' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(
            weight=weight, ignore_label=args.ignore_label)
    elif args.dataset == 'cityscape' and args.use_lovaszsoftmax:
        criteria = LovaszSoftmax(ignore_index=args.ignore_label)
    elif args.dataset == 'cityscapes' and args.use_focal:
        criteria = FocalLoss2d(weight=weight, ignore_index=args.ignore_label)
    elif args.dataset == 'seed':
        criteria = CrossEntropyLoss2d(weight=weight,
                                      ignore_label=args.ignore_label)

    elif args.dataset == 'remote':
        criteria = CrossEntropyLoss2d(weight=weight,
                                      ignore_label=args.ignore_label)
    elif args.dataset == 'remote' and args.use_ohem:
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=args.ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    elif args.dataset == 'remote' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(
            weight=weight, ignore_label=args.ignore_label)
    elif args.dataset == 'remote' and args.use_lovaszsoftmax:
        criteria = LovaszSoftmax(ignore_index=args.ignore_label)
    elif args.dataset == 'remote' and args.use_focal:
        criteria = FocalLoss2d(weight=weight, ignore_index=args.ignore_label)
    else:
        criteria = CrossEntropyLoss2d(weight=weight,
                                      ignore_label=args.ignore_label)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = nn.DataParallel(model).cuda()  # multi-card data parallel
        else:
            args.gpu_nums = 1
            print("single GPU for training")
            model = model.cuda()  # 1-card data parallel

    args.savedir = (args.savedir + args.dataset + '/' + args.model + 'bs' +
                    str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" +
                    str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    start_epoch = 0

    # continue training
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            # model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True
    # cudnn.deterministic = True ## my add

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s Seed: %s" %
                     (str(total_parameters), GLOBAL_SEED))
        logger.write("\n%s\t\t%s\t%s\t%s" %
                     ('Epoch', 'Loss(Tr)', 'mIOU (val)', 'lr'))
    logger.flush()

    # define optimization strategy
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-4)
    elif args.optim == 'radam':
        optimizer = RAdam(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.90, 0.999),
                          eps=1e-08,
                          weight_decay=1e-4)
    elif args.optim == 'ranger':
        optimizer = Ranger(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           betas=(0.95, 0.999),
                           eps=1e-08,
                           weight_decay=1e-4)
    elif args.optim == 'adamw':
        optimizer = AdamW(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.9, 0.999),
                          eps=1e-08,
                          weight_decay=1e-4)

    lossTr_list = []
    epoches = []
    mIOU_val_list = []

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        # training

        lossTr, lr = train(args, trainLoader, model, criteria, optimizer,
                           epoch)
        lossTr_list.append(lossTr)

        # validation
        if epoch % 2 == 0 or epoch == (args.max_epochs - 1):
            epoches.append(epoch)
            mIOU_val, per_class_iu = val(args, valLoader, model)
            mIOU_val_list.append(mIOU_val)
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_val, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "Epoch No.: %d\tTrain Loss = %.4f\t mIOU(val) = %.4f\t lr= %.6f\n"
                % (epoch, lossTr, mIOU_val, lr))
        else:
            # record train information
            logger.write("\n%d\t\t%.4f\t\t\t\t%.7f" % (epoch, lossTr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print("Epoch No.: %d\tTrain Loss = %.4f\t lr= %.6f\n" %
                  (epoch, lossTr, lr))

        # save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}

        # Individual Setting for save model !!!
        if args.dataset == 'camvid':
            torch.save(state, model_file_name)
        elif args.dataset == 'cityscapes':
            if epoch >= args.max_epochs - 10:
                torch.save(state, model_file_name)
            elif not epoch % 50:
                torch.save(state, model_file_name)
        elif args.dataset == 'seed':
            torch.save(state, model_file_name)
        else:
            torch.save(state, model_file_name)

        # draw plots for visualization
        if epoch % 5 == 0 or epoch == (args.max_epochs - 1):
            # Plot the figures per 50 epochs
            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(start_epoch, epoch + 1), lossTr_list)
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")

            plt.savefig(args.savedir + "loss_vs_epochs.png")

            plt.clf()

            fig2, ax2 = plt.subplots(figsize=(11, 8))

            ax2.plot(epoches, mIOU_val_list, label="Val IoU")
            ax2.set_title("Average IoU vs epochs")
            ax2.set_xlabel("Epochs")
            ax2.set_ylabel("Current IoU")
            plt.legend(loc='lower right')

            plt.savefig(args.savedir + "iou_vs_epochs.png")

            plt.close('all')

    logger.close()
Пример #2
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> input size:{}".format(input_size))

    print(args)

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    # set the seed
    setup_seed(GLOBAL_SEED)
    print("=====> set Global Seed: ", GLOBAL_SEED)

    cudnn.enabled = True
    print("=====> building network")

    # build the model and initialization
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model,
                nn.init.kaiming_normal_,
                nn.BatchNorm2d,
                1e-3,
                0.1,
                mode='fan_in')

    print("=====> computing network parameters and FLOPs")
    total_paramters = netParams(model)
    print("the number of parameters: %d ==> %.2f M" %
          (total_paramters, (total_paramters / 1e6)))

    # load data and data augmentation
    datas, trainLoader, valLoader = build_dataset_train(
        args.dataset, input_size, args.batch_size, args.train_type,
        args.random_scale, args.random_mirror, args.num_workers)

    print('=====> Dataset statistics')
    print("data['classWeights']: ", datas['classWeights'])
    print('mean and std: ', datas['mean'], datas['std'])

    # define loss function, respectively
    weight = torch.from_numpy(datas['classWeights'])

    if args.dataset == 'camvid':
        criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label)
    elif args.dataset == 'cityscapes':
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    else:
        raise NotImplementedError(
            "This repository now supports two datasets: cityscapes and camvid, %s is not included"
            % args.dataset)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = nn.DataParallel(model).cuda()  # multi-card data parallel
        else:
            args.gpu_nums = 1
            print("single GPU for training")
            model = model.cuda()  # 1-card data parallel

    args.savedir = (args.dataset + '/' + args.savedir + args.model + 'bs' +
                    str(args.batch_size) + "_" + str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    start_epoch = 0

    # continue training
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            # model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s Seed: %s" %
                     (str(total_paramters), GLOBAL_SEED))
        logger.write("\n%s\t\t%s\t%s\t%s" %
                     ('Epoch', 'Loss(Tr)', 'mIOU (val)', 'lr'))
    logger.flush()

    # define optimization criteria
    if args.dataset == 'camvid':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     args.lr, (0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=2e-4)

    elif args.dataset == 'cityscapes':
        #optimizer = torch.optim.SGD(
        #filter(lambda p: p.requires_grad, model.parameters()), args.lr, momentum=0.9, weight_decay=1e-4)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     args.lr, (0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-5)

    lossTr_list = []
    epoches = []
    mIOU_val_list = []

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        # training
        lossTr, lr = train(args, trainLoader, model, criteria, optimizer,
                           epoch)
        lossTr_list.append(lossTr)

        # validation
        if epoch % 30 == 0 or epoch == (args.max_epochs - 1):
            epoches.append(epoch)
            mIOU_val, per_class_iu = val(args, valLoader, model)
            mIOU_val_list.append(mIOU_val)
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_val, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "Epoch No.: %d\tTrain Loss = %.4f\t mIOU(val) = %.4f\t lr= %.6f\n"
                % (epoch, lossTr, mIOU_val, lr))
        else:
            # record train information
            logger.write("\n%d\t\t%.4f\t\t\t\t%.7f" % (epoch, lossTr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print("Epoch No.: %d\tTrain Loss = %.4f\t lr= %.6f\n" %
                  (epoch, lossTr, lr))

        # save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}

        if epoch >= args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif not epoch % 100:
            torch.save(state, model_file_name)

        # draw plots for visualization
        if epoch % 30 == 0 or epoch == (args.max_epochs - 1):
            # Plot the figures per 50 epochs
            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(start_epoch, epoch + 1), lossTr_list)
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")

            plt.savefig(args.savedir + "loss_vs_epochs.png")

            plt.clf()

            fig2, ax2 = plt.subplots(figsize=(11, 8))

            ax2.plot(epoches, mIOU_val_list, label="Val IoU")
            ax2.set_title("Average IoU vs epochs")
            ax2.set_xlabel("Epochs")
            ax2.set_ylabel("Current IoU")
            plt.legend(loc='lower right')

            plt.savefig(args.savedir + "iou_vs_epochs.png")

            plt.close('all')

    logger.close()
Пример #3
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("input size:{}".format(input_size))

    print(args)

    if args.cuda:
        print("use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    # set the seed
    setup_seed(GLOBAL_SEED)
    print("set Global Seed: ", GLOBAL_SEED)
    cudnn.enabled = True
    print("building network")

    # build the model and initialization
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model,
                nn.init.kaiming_normal_,
                nn.BatchNorm2d,
                1e-3,
                0.1,
                mode='fan_in')

    print("computing network parameters and FLOPs")
    total_paramters = netParams(model)
    print("the number of parameters: %d ==> %.2f M" %
          (total_paramters, (total_paramters / 1e6)))

    # load data and data augmentation
    datas, trainLoader, valLoader = build_dataset_train(
        args.dataset, input_size, args.batch_size, args.train_type,
        args.random_scale, args.random_mirror, args.num_workers)

    args.per_iter = len(trainLoader)
    args.max_iter = args.max_epochs * args.per_iter

    print('Dataset statistics')
    print("data['classWeights']: ", datas['classWeights'])
    print('mean and std: ', datas['mean'], datas['std'])

    # define loss function, respectively
    weight = torch.from_numpy(datas['classWeights'])

    if args.dataset == 'camvid':
        criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label)
    elif args.dataset == 'camvid' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(weight=weight,
                                                 ignore_label=ignore_label)

    elif args.dataset == 'cityscapes' and args.use_ohem:
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    elif args.dataset == 'cityscapes' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(weight=weight,
                                                 ignore_label=ignore_label)
    elif args.dataset == 'cityscapes' and args.use_lovaszsoftmax:
        criteria = LovaszSoftmax(ignore_index=ignore_label)
    elif args.dataset == 'cityscapes' and args.use_focal:
        criteria = FocalLoss2d(weight=weight, ignore_index=ignore_label)

    elif args.dataset == 'paris':
        criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label)

    else:
        raise NotImplementedError(
            "This repository now supports two datasets: cityscapes and camvid, %s is not included"
            % args.dataset)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = nn.DataParallel(model).cuda()  # multi-card data parallel
        else:
            args.gpu_nums = 1
            print("single GPU for training")
            model = model.cuda()  # 1-card data parallel

    args.savedir = (args.savedir + args.dataset + '/' + args.model + 'bs' +
                    str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" +
                    str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    with open(args.savedir + 'args.txt', 'w') as f:
        f.write('mean:{}\nstd:{}\n'.format(datas['mean'], datas['std']))
        f.write("Parameters: {} Seed: {}\n".format(str(total_paramters),
                                                   GLOBAL_SEED))
        f.write(str(args))

    start_epoch = 0
    # continue training
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            # model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True
    # cudnn.deterministic = True ## my add

    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=50)

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("%s\t%s\t\t%s\t%s\t%s" %
                     ('Epoch', '   lr', 'Loss(Tr)', 'Loss(Val)', 'mIOU(Val)'))
    logger.flush()

    # define optimization strategy
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-4)
    elif args.optim == 'radam':
        optimizer = RAdam(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.90, 0.999),
                          eps=1e-08,
                          weight_decay=1e-4)
    elif args.optim == 'ranger':
        optimizer = Ranger(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           betas=(0.95, 0.999),
                           eps=1e-08,
                           weight_decay=1e-4)
    elif args.optim == 'adamw':
        optimizer = AdamW(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.9, 0.999),
                          eps=1e-08,
                          weight_decay=1e-4)

    lossTr_list = []
    epoches = []
    mIOU_val_list = []
    lossVal_list = []
    print('>>>>>>>>>>>beginning training>>>>>>>>>>>')
    for epoch in range(start_epoch, args.max_epochs):
        # training
        lossTr, lr = train(args, trainLoader, model, criteria, optimizer,
                           epoch)
        lossTr_list.append(lossTr)

        # validation
        if epoch % args.val_miou_epochs == 0:
            epoches.append(epoch)
            val_loss, mIOU_val, per_class_iu = val(args, valLoader, criteria,
                                                   model, epoch)
            mIOU_val_list.append(mIOU_val)
            lossVal_list.append(val_loss.item())
            # record train information
            logger.write(
                "\n%d\t%.6f\t%.4f\t\t%.4f\t%0.4f\t %s" %
                (epoch, lr, lossTr, val_loss, mIOU_val, str(per_class_iu)))
            logger.flush()
            print(
                "Epoch  %d\tlr= %.6f\tTrain Loss = %.4f\tVal Loss = %.4f\tmIOU(val) = %.4f\tper_class_iu= %s\n"
                % (epoch, lr, lossTr, val_loss, mIOU_val, str(per_class_iu)))
        else:
            # record train information
            val_loss = val(args, valLoader, criteria, model, epoch)
            lossVal_list.append(val_loss.item())
            logger.write("\n%d\t%.6f\t%.4f\t\t%.4f" %
                         (epoch, lr, lossTr, val_loss))
            logger.flush()
            print("Epoch  %d\tlr= %.6f\tTrain Loss = %.4f\tVal Loss = %.4f\n" %
                  (epoch, lr, lossTr, val_loss))

        # save the model
        model_file_name = args.savedir + '/model_' + str(epoch) + '.pth'
        state = {"epoch": epoch, "model": model.state_dict()}

        # Individual Setting for save model
        if epoch >= args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif epoch % 10 == 0:
            torch.save(state, model_file_name)

        # draw plots for visualization
        if os.path.isfile(args.savedir + "loss.png"):
            f = open(args.savedir + 'log.txt', 'r')
            next(f)
            epoch_list = []
            lossTr_list = []
            lossVal_list = []
            for line in f.readlines():
                epoch_list.append(line.strip().split()[0])
                lossTr_list.append(line.strip().split()[2])
                lossVal_list.append(line.strip().split()[3])
            assert len(epoch_list) == len(lossTr_list) == len(lossVal_list)

            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(0, epoch + 1), lossTr_list, label='Train_loss')
            ax1.plot(range(0, epoch + 1), lossVal_list, label='Val_loss')
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")
            ax1.legend()

            plt.savefig(args.savedir + "loss.png")
            plt.clf()
        else:
            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(0, epoch + 1), lossTr_list, label='Train_loss')
            ax1.plot(range(0, epoch + 1), lossVal_list, label='Val_loss')
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")
            ax1.legend()

            plt.savefig(args.savedir + "loss.png")
            plt.clf()

            fig2, ax2 = plt.subplots(figsize=(11, 8))

            ax2.plot(epoches, mIOU_val_list, label="Val IoU")
            ax2.set_title("Average IoU vs epochs")
            ax2.set_xlabel("Epochs")
            ax2.set_ylabel("Current IoU")
            ax2.legend()

            plt.savefig(args.savedir + "mIou.png")
            plt.close('all')

        early_stopping.monitor(monitor=val_loss)
        if early_stopping.early_stop:
            print("Early stopping and Save checkpoint")
            if not os.path.exists(model_file_name):
                torch.save(state, model_file_name)
            break

    logger.close()
def test_model(args):
    """
     main function for testing
     param args: global arguments
     return: None
    """
    print(args)

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception("no GPU found or wrong gpu id, please run without --cuda")

    # build the model
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model, nn.init.kaiming_normal_,
                nn.BatchNorm2d, 1e-3, 0.1,
                mode='fan_in')

    if args.cuda:
        model = model.cuda()  # using GPU for inference
        cudnn.benchmark = True

    if args.save:
        if not os.path.exists(args.save_seg_dir):
            os.makedirs(args.save_seg_dir)

    # load the test set
    datas, testLoader = build_dataset_test(args.dataset, args.num_workers, args.batch_size)
    # datas, _, testLoader = build_dataset_train(args.dataset, (352,480), args.batch_size, 'train', False, False, args.num_workers)

    if not args.best:
        if args.checkpoint:
            if os.path.isfile(args.checkpoint):
                print("=====> loading checkpoint '{}'".format(args.checkpoint))
                checkpoint = torch.load(args.checkpoint)
                model.load_state_dict(checkpoint['model'])
                # model.load_state_dict(convert_state_dict(checkpoint['model']))
            else:
                print("=====> no checkpoint found at '{}'".format(args.checkpoint))
                raise FileNotFoundError("no checkpoint found at '{}'".format(args.checkpoint))

        print("=====> beginning validation")
        print("validation set length: ", len(testLoader))
        mIOU_val, per_class_iu = test(args, testLoader, model)
        print(mIOU_val)
        print(per_class_iu)

    # Get the best test result among the last 10 model records.
    else:
        if args.checkpoint:
            if os.path.isfile(args.checkpoint):
                dirname, basename = os.path.split(args.checkpoint)
                epoch = int(os.path.splitext(basename)[0].split('_')[1])
                mIOU_val = []
                per_class_iu = []
                for i in range(epoch - 9, epoch + 1):
                    basename = 'model_' + str(i) + '.pth'
                    resume = os.path.join(dirname, basename)
                    checkpoint = torch.load(resume)
                    model.load_state_dict(checkpoint['model'])
                    print("=====> beginning test the " + basename)
                    print("validation set length: ", len(testLoader))
                    mIOU_val_0, per_class_iu_0 = test(args, testLoader, model)
                    mIOU_val.append(mIOU_val_0)
                    per_class_iu.append(per_class_iu_0)

                index = list(range(epoch - 9, epoch + 1))[np.argmax(mIOU_val)]
                print("The best mIoU among the last 10 models is", index)
                print(mIOU_val)
                per_class_iu = per_class_iu[np.argmax(mIOU_val)]
                mIOU_val = np.max(mIOU_val)
                print(mIOU_val)
                print(per_class_iu)

            else:
                print("=====> no checkpoint found at '{}'".format(args.checkpoint))
                raise FileNotFoundError("no checkpoint found at '{}'".format(args.checkpoint))

    # Save the result
    if not args.best:
        model_path = os.path.splitext(os.path.basename(args.checkpoint))
        args.logFile = 'test_' + model_path[0] + '.txt'
        logFileLoc = os.path.join(os.path.dirname(args.checkpoint), args.logFile)
    else:
        args.logFile = 'test_' + 'best' + str(index) + '.txt'
        logFileLoc = os.path.join(os.path.dirname(args.checkpoint), args.logFile)

    # Save the result
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Mean IoU: %.4f" % mIOU_val)
        logger.write("\nPer class IoU: ")
        for i in range(len(per_class_iu)):
            logger.write("%.4f\t" % per_class_iu[i])
    logger.flush()
    logger.close()
def test_curve(args):
    """
     main function for testing
     param args: global arguments
     return: None
    """
    args.save = False
    args.best = False

    print(args)

    # if args.checkpoint:
    #     if os.path.isdir(args.checkpoint):
    #         model_dir = get_dir_list(args.checkpoint, args.model)
    #         if len(model_dir) == 0:
    #             print("=====> no checkpoint found at '{}'".format(args.checkpoint))
    #             return
    #     else:
    #         print("=====> no a dir '{}'".format(args.checkpoint))
    #         # raise FileNotFoundError("no checkpoint found at '{}'".format(args.checkpoint))
    #         return
    model_dir = [args.model + 'bs5gpu1_train_True']

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "no GPU found or wrong gpu id, please run without --cuda")

    # build the model
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model,
                nn.init.kaiming_normal_,
                nn.BatchNorm2d,
                1e-3,
                0.1,
                mode='fan_in')

    if args.cuda:
        model = model.cuda()  # using GPU for inference
        cudnn.benchmark = True

    # if args.save:
    #     if not os.path.exists(args.save_seg_dir):
    #         os.makedirs(args.save_seg_dir)

    # load the test set
    datas, testLoader = build_dataset_test(args.dataset, args.num_workers,
                                           args.batch_size)
    # datas, _, testLoader = build_dataset_train(args.dataset, (352,480), args.batch_size, 'train', False, False, args.num_workers)

    for d in model_dir:
        csv_path = path.join(args.checkpoint, d, d + '_test.csv')
        log_path = path.join(args.checkpoint, d, 'log.txt')
        if os.path.exists(log_path) and os.path.exists(csv_path) and (
                os.path.getmtime(log_path) < os.path.getmtime(csv_path)):
            pass
            # print(csv_path + ' is the newest------------------------------!')
            # continue
        pth_list = get_file_list(path.join(args.checkpoint, d))
        if len(pth_list) == 0:
            continue

        # print(pth_list)
        if args.dataset == 'cityscapes':
            results = pd.DataFrame(columns=[
                'epoch', 'mIoU', 'road', 'sidewalk', 'building', 'wall',
                'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation',
                'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus',
                'train', 'motorcycle', 'bicycle'
            ])
        elif (args.dataset == 'camvid') or (args.dataset == 'camvid352'):
            results = pd.DataFrame(columns=[
                'epoch', 'mIoU', 'Sky', 'Building', 'Pole', 'Road', 'Sidewalk',
                'Tree', 'Sign', 'Fence', 'Car', 'Pedestrian', 'Bicyclist'
            ])
        else:
            raise NotImplementedError(
                "This repository now supports two datasets: cityscapes and camvid, %s is not included"
                % args.dataset)

        for pth in pth_list:
            checkpoint = torch.load(path.join(args.checkpoint, d, pth),
                                    map_location=torch.device('cpu'))
            print("=====> beginning load model {}/{}".format(d, pth))
            model.load_state_dict(checkpoint['model'])
            print("=====> beginning validation {}/{}".format(d, pth))
            print("validation set length: ", len(testLoader))
            mIOU_val, per_class_iu = test(args, testLoader, model)
            # mIOU_val, per_class_iu = 1,np.array([1,2,3,4,5,6,7,8,9,0,1])
            epoch = int(pth.strip('model_').strip('.pth'))
            results.loc[results.shape[0]] = [epoch, mIOU_val
                                             ] + per_class_iu.tolist()
        results.sort_values(by=['epoch'], axis=0, inplace=True)
        results.reset_index(drop=True, inplace=True)
        results.to_csv(csv_path)
        print('save {}!!!!!'.format(csv_path))