Esempio n. 1
0
def train(criterion, net_name, use_dataset_index):
    assert net_name == "UNet" or net_name == "CENet" or net_name == "FCN"
    assert use_dataset_index < len(config.path_data)
    path_data = config.path_data[use_dataset_index]

    path_data_root = path_data["dataset"]
    path_checkpoints = path_data["checkpoints"]
    os.makedirs(path_checkpoints, exist_ok=True)

    for seed in config.random_seed:
        # 设置随机种子
        setup_seed(seed)
        net = get_net(net_name).to(config.device)

        dataset = MyDataset(path_data_root=path_data_root,
                            phase="train",
                            transform_list=transform_compose)
        data_loader = DataLoader(dataset,
                                 batch_size=config.batch_size,
                                 shuffle=True)

        optimizer = Adam(params=net.parameters(), lr=config.lr)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                           gamma=0.95)

        for epoch in range(config.epochs):
            for i, (img, label, mask, name) in enumerate(data_loader):
                img = img.to(config.device)
                label = label.to(config.device)

                pred = net(img)
                loss = criterion(pred, label)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if i % 50 == 0:
                    print("loss = {} ".format(loss))
            scheduler.step()

        torch.save(
            net.state_dict(),
            os.path.join(path_checkpoints, "{}_{}.pth").format(net_name, seed))
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> input size:{}".format(input_size))

    print(args)

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    # set the seed
    setup_seed(GLOBAL_SEED)
    print("=====> set Global Seed: ", GLOBAL_SEED)

    cudnn.enabled = True
    print("=====> building network")

    # build the model and initialization
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model,
                nn.init.kaiming_normal_,
                nn.BatchNorm2d,
                1e-3,
                0.1,
                mode='fan_in')

    print("=====> computing network parameters and FLOPs")
    total_parameters = netParams(model)
    print("the number of parameters: %d ==> %.2f M" %
          (total_parameters, (total_parameters / 1e6)))

    # load data and data augmentation
    datas, trainLoader, valLoader = build_dataset_train(
        args.dataset, args.classes, input_size, args.batch_size,
        args.train_type, False, False, args.num_workers)

    args.per_iter = len(trainLoader)
    args.max_iter = args.max_epochs * args.per_iter

    print('=====> Dataset statistics')
    print("data['classWeights']: ", datas['classWeights'])
    print('mean and std: ', datas['mean'], datas['std'])
    # datas['classWeights'] = np.array([4.044603, 2.0614128, 4.2246304, 6.0238333,
    #                                   10.107266, 8.601249, 8.808282], dtype=np.float32)
    # datas['mean'] = [0.5, 0.5, 0.5]
    # datas['std'] = [0.2, 0.2, 0.2]

    # define loss function, respectively
    weight = torch.from_numpy(datas['classWeights'])
    if args.dataset == 'pollen':
        weight = torch.tensor([1., 1.])

    if args.dataset == 'camvid':
        criteria = CrossEntropyLoss2d(weight=weight,
                                      ignore_label=args.ignore_label)
    elif args.dataset == 'camvid' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(
            weight=weight, ignore_label=args.ignore_label)

    elif args.dataset == 'cityscapes' and args.use_ohem:
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=args.ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    elif args.dataset == 'cityscapes' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(
            weight=weight, ignore_label=args.ignore_label)
    elif args.dataset == 'cityscape' and args.use_lovaszsoftmax:
        criteria = LovaszSoftmax(ignore_index=args.ignore_label)
    elif args.dataset == 'cityscapes' and args.use_focal:
        criteria = FocalLoss2d(weight=weight, ignore_index=args.ignore_label)
    elif args.dataset == 'seed':
        criteria = CrossEntropyLoss2d(weight=weight,
                                      ignore_label=args.ignore_label)

    elif args.dataset == 'remote':
        criteria = CrossEntropyLoss2d(weight=weight,
                                      ignore_label=args.ignore_label)
    elif args.dataset == 'remote' and args.use_ohem:
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=args.ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    elif args.dataset == 'remote' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(
            weight=weight, ignore_label=args.ignore_label)
    elif args.dataset == 'remote' and args.use_lovaszsoftmax:
        criteria = LovaszSoftmax(ignore_index=args.ignore_label)
    elif args.dataset == 'remote' and args.use_focal:
        criteria = FocalLoss2d(weight=weight, ignore_index=args.ignore_label)
    else:
        criteria = CrossEntropyLoss2d(weight=weight,
                                      ignore_label=args.ignore_label)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = nn.DataParallel(model).cuda()  # multi-card data parallel
        else:
            args.gpu_nums = 1
            print("single GPU for training")
            model = model.cuda()  # 1-card data parallel

    args.savedir = (args.savedir + args.dataset + '/' + args.model + 'bs' +
                    str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" +
                    str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    start_epoch = 0

    # continue training
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            # model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True
    # cudnn.deterministic = True ## my add

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s Seed: %s" %
                     (str(total_parameters), GLOBAL_SEED))
        logger.write("\n%s\t\t%s\t%s\t%s" %
                     ('Epoch', 'Loss(Tr)', 'mIOU (val)', 'lr'))
    logger.flush()

    # define optimization strategy
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-4)
    elif args.optim == 'radam':
        optimizer = RAdam(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.90, 0.999),
                          eps=1e-08,
                          weight_decay=1e-4)
    elif args.optim == 'ranger':
        optimizer = Ranger(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           betas=(0.95, 0.999),
                           eps=1e-08,
                           weight_decay=1e-4)
    elif args.optim == 'adamw':
        optimizer = AdamW(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.9, 0.999),
                          eps=1e-08,
                          weight_decay=1e-4)

    lossTr_list = []
    epoches = []
    mIOU_val_list = []

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        # training

        lossTr, lr = train(args, trainLoader, model, criteria, optimizer,
                           epoch)
        lossTr_list.append(lossTr)

        # validation
        if epoch % 2 == 0 or epoch == (args.max_epochs - 1):
            epoches.append(epoch)
            mIOU_val, per_class_iu = val(args, valLoader, model)
            mIOU_val_list.append(mIOU_val)
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_val, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "Epoch No.: %d\tTrain Loss = %.4f\t mIOU(val) = %.4f\t lr= %.6f\n"
                % (epoch, lossTr, mIOU_val, lr))
        else:
            # record train information
            logger.write("\n%d\t\t%.4f\t\t\t\t%.7f" % (epoch, lossTr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print("Epoch No.: %d\tTrain Loss = %.4f\t lr= %.6f\n" %
                  (epoch, lossTr, lr))

        # save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}

        # Individual Setting for save model !!!
        if args.dataset == 'camvid':
            torch.save(state, model_file_name)
        elif args.dataset == 'cityscapes':
            if epoch >= args.max_epochs - 10:
                torch.save(state, model_file_name)
            elif not epoch % 50:
                torch.save(state, model_file_name)
        elif args.dataset == 'seed':
            torch.save(state, model_file_name)
        else:
            torch.save(state, model_file_name)

        # draw plots for visualization
        if epoch % 5 == 0 or epoch == (args.max_epochs - 1):
            # Plot the figures per 50 epochs
            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(start_epoch, epoch + 1), lossTr_list)
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")

            plt.savefig(args.savedir + "loss_vs_epochs.png")

            plt.clf()

            fig2, ax2 = plt.subplots(figsize=(11, 8))

            ax2.plot(epoches, mIOU_val_list, label="Val IoU")
            ax2.set_title("Average IoU vs epochs")
            ax2.set_xlabel("Epochs")
            ax2.set_ylabel("Current IoU")
            plt.legend(loc='lower right')

            plt.savefig(args.savedir + "iou_vs_epochs.png")

            plt.close('all')

    logger.close()
Esempio n. 3
0
def main(args):
    """
    args:
       args: global arguments
    """
    # set the seed
    setup_seed(GLOBAL_SEED)
    # cudnn.enabled = True
    # cudnn.benchmark = True  # find the optimal configuration
    # cudnn.deterministic = True  # reduce volatility

    # learning scheduling, for 10 epoch lr*0.8
    # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.85)

    # build the model and initialization weights
    model = build_model(args.model, args.classes, args.backbone, args.pretrained, args.out_stride, args.mult_grid)

    # define loss function, respectively
    criterion = build_loss(args, None, ignore_label)

    # load train set and data augmentation
    datas, traindataset = build_dataset_train(args.root, args.dataset, args.base_size, args.crop_size)
    # load the test set, if want set cityscapes test dataset change none_gt=False
    testdataset, class_dict_df = build_dataset_test(args.root, args.dataset, args.crop_size,
                                                    mode=args.predict_mode, gt=True)

    # move model and criterion on cuda
    if args.cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus_id
        dist.init_process_group(backend="nccl", init_method='env://')
        args.local_rank = torch.distributed.get_rank()
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        gpus = len(list(os.environ["CUDA_VISIBLE_DEVICES"])) - (len(list(os.environ["CUDA_VISIBLE_DEVICES"])) // 2)

        trainLoader, model, criterion = Distribute(args, traindataset, model, criterion, device, gpus)
        # test with distributed
        # testLoader, _, _ = Distribute(args, testdataset, model, criterion, device, gpus)
        # test with single card
        testLoader = data.DataLoader(testdataset, batch_size=args.batch_size,
                                     shuffle=True, num_workers=args.batch_size, pin_memory=True, drop_last=False)

        if not torch.cuda.is_available():
            raise Exception("No GPU found or Wrong gpu id, please run without --cuda")

    # define optimization strategy
    # parameters = [{'params': model.get_1x_lr_params(), 'lr': args.lr},
    #             {'params': model.get_10x_lr_params(), 'lr': args.lr}]
    parameters = model.parameters()

    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=0.9, weight_decay=5e-4, nesterov=False)
    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(parameters, weight_decay=5e-4)
    elif args.optim == 'adamw':
        optimizer = torch.optim.AdamW(parameters, weight_decay=5e-4)

    # initial log file val output save
    args.savedir = (args.savedir + args.dataset + '/' + args.model + '/')
    if not os.path.exists(args.savedir) and args.local_rank == 0:
        os.makedirs(args.savedir)

    # save_seg_dir
    args.save_seg_dir = os.path.join(args.savedir, args.predict_mode)
    if not os.path.exists(args.save_seg_dir) and args.local_rank == 0:
        os.makedirs(args.save_seg_dir)

    recorder = record_log(args)
    if args.resume == None and args.local_rank == 0:
        recorder.record_args(datas, str(netParams(model) / 1e6) + ' M', GLOBAL_SEED)

    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=300)
    start_epoch = 1
    if args.local_rank == 0:
        print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>\n"
              ">>>>>>>>>>>  beginning training   >>>>>>>>>>>\n"
              ">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")

    epoch_list = []
    lossTr_list = []
    Miou_list = []
    lossVal_list = []
    Miou = 0
    Best_Miou = 0
    # continue training
    if args.resume:
        logger, lines = recorder.resume_logfile()
        for index, line in enumerate(lines):
            lossTr_list.append(float(line.strip().split()[2]))
            if len(line.strip().split()) != 3:
                epoch_list.append(int(line.strip().split()[0]))
                lossVal_list.append(float(line.strip().split()[3]))
                Miou_list.append(float(line.strip().split()[5]))

        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch'] + 1
            optimizer.load_state_dict(checkpoint['optimizer'])
            check_list = [i for i in checkpoint['model'].items()]
            # Read weights with multiple cards, and continue training with a single card this time
            if 'module.' in check_list[0][0]:
                new_stat_dict = {}
                for k, v in checkpoint['model'].items():
                    new_stat_dict[k[:]] = v
                model.load_state_dict(new_stat_dict, strict=True)
            # Read the training weight of a single card, and continue training with a single card this time
            else:
                model.load_state_dict(checkpoint['model'])
            if args.local_rank == 0:
                print("loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            if args.local_rank == 0:
                print("no checkpoint found at '{}'".format(args.resume))
    else:
        logger = recorder.initial_logfile()
        logger.flush()

    for epoch in range(start_epoch, args.max_epochs + 1):
        start_time = time.time()
        # training
        train_start = time.time()

        lossTr, lr = train(args, trainLoader, model, criterion, optimizer, epoch, device)
        if args.local_rank == 0:
            lossTr_list.append(lossTr)

        train_end = time.time()
        train_per_epoch_seconds = train_end - train_start
        validation_per_epoch_seconds = 60  # init validation time
        # validation if mode==validation, predict with label; elif mode==predict, predict without label.

        if epoch % args.val_epochs == 0 or epoch == 1 or args.max_epochs - 10 < epoch <= args.max_epochs:
            validation_start = time.time()

            loss, FWIoU, Miou, MIoU, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F1_avg = \
                predict_multiscale_sliding(args=args, model=model,
                                           testLoader=testLoader,
                                           class_dict_df=class_dict_df,
                                           # scales=[1.25, 1.5, 1.75, 2.0],
                                           scales=[1.0],
                                           overlap=0.3,
                                           criterion=criterion,
                                           mode=args.predict_type,
                                           save_result=True)
            torch.cuda.empty_cache()

            if args.local_rank == 0:
                epoch_list.append(epoch)
                Miou_list.append(Miou)
                lossVal_list.append(loss.item())
                # record trainVal information
                recorder.record_trainVal_log(logger, epoch, lr, lossTr, loss,
                                             FWIoU, Miou, MIoU, PerCiou_set, Pa, Mpa,
                                             PerCpa_set, MF, F_set, F1_avg,
                                             class_dict_df)

                torch.cuda.empty_cache()
                validation_end = time.time()
                validation_per_epoch_seconds = validation_end - validation_start
        else:
            if args.local_rank == 0:
                # record train information
                recorder.record_train_log(logger, epoch, lr, lossTr)

            # # Update lr_scheduler. In pytorch 1.1.0 and later, should call 'optimizer.step()' before 'lr_scheduler.step()'
            # lr_scheduler.step()
        if args.local_rank == 0:
            # draw log fig
            draw_log(args, epoch, epoch_list, lossTr_list, Miou_list, lossVal_list)

            # save the model
            model_file_name = args.savedir + '/best_model.pth'
            last_model_file_name = args.savedir + '/last_model.pth'
            state = {
                "epoch": epoch,
                "model": model.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            if Miou > Best_Miou:
                Best_Miou = Miou
                torch.save(state, model_file_name)
                recorder.record_best_epoch(epoch, Best_Miou, Pa)

            # early_stopping monitor
            early_stopping.monitor(monitor=Miou)
            if early_stopping.early_stop:
                print("Early stopping and Save checkpoint")
                if not os.path.exists(last_model_file_name):
                    torch.save(state, last_model_file_name)
                    torch.cuda.empty_cache()  # empty_cache

                    loss, FWIoU, Miou, Miou_Noback, PerCiou_set, Pa, PerCpa_set, Mpa, MF, F_set, F1_Noback = \
                        predict_multiscale_sliding(args=args, model=model,
                                                   testLoader=testLoader,
                                                   scales=[1.0],
                                                   overlap=0.3,
                                                   criterion=criterion,
                                                   mode=args.predict_type,
                                                   save_result=False)
                    print("Epoch {}  lr= {:.6f}  Train Loss={:.4f}  Val Loss={:.4f}  Miou={:.4f}  PerCiou_set={}\n"
                          .format(epoch, lr, lossTr, loss, Miou, str(PerCiou_set)))
                break

            total_second = start_time + (args.max_epochs - epoch) * train_per_epoch_seconds + \
                           ((args.max_epochs - epoch) / args.val_epochs + 10) * validation_per_epoch_seconds + 43200
            print('Best Validation MIoU:{}'.format(Best_Miou))
            print('Training deadline is: {}\n'.format(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(total_second))))
Esempio n. 4
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("=====> input size:{}".format(input_size))

    print(args)

    if args.cuda:
        print("=====> use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    # set the seed
    setup_seed(GLOBAL_SEED)
    print("=====> set Global Seed: ", GLOBAL_SEED)

    cudnn.enabled = True
    print("=====> building network")

    # build the model and initialization
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model,
                nn.init.kaiming_normal_,
                nn.BatchNorm2d,
                1e-3,
                0.1,
                mode='fan_in')

    print("=====> computing network parameters and FLOPs")
    total_paramters = netParams(model)
    print("the number of parameters: %d ==> %.2f M" %
          (total_paramters, (total_paramters / 1e6)))

    # load data and data augmentation
    datas, trainLoader, valLoader = build_dataset_train(
        args.dataset, input_size, args.batch_size, args.train_type,
        args.random_scale, args.random_mirror, args.num_workers)

    print('=====> Dataset statistics')
    print("data['classWeights']: ", datas['classWeights'])
    print('mean and std: ', datas['mean'], datas['std'])

    # define loss function, respectively
    weight = torch.from_numpy(datas['classWeights'])

    if args.dataset == 'camvid':
        criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label)
    elif args.dataset == 'cityscapes':
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    else:
        raise NotImplementedError(
            "This repository now supports two datasets: cityscapes and camvid, %s is not included"
            % args.dataset)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = nn.DataParallel(model).cuda()  # multi-card data parallel
        else:
            args.gpu_nums = 1
            print("single GPU for training")
            model = model.cuda()  # 1-card data parallel

    args.savedir = (args.dataset + '/' + args.savedir + args.model + 'bs' +
                    str(args.batch_size) + "_" + str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    start_epoch = 0

    # continue training
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            # model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("=====> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=====> no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s Seed: %s" %
                     (str(total_paramters), GLOBAL_SEED))
        logger.write("\n%s\t\t%s\t%s\t%s" %
                     ('Epoch', 'Loss(Tr)', 'mIOU (val)', 'lr'))
    logger.flush()

    # define optimization criteria
    if args.dataset == 'camvid':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     args.lr, (0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=2e-4)

    elif args.dataset == 'cityscapes':
        #optimizer = torch.optim.SGD(
        #filter(lambda p: p.requires_grad, model.parameters()), args.lr, momentum=0.9, weight_decay=1e-4)
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     args.lr, (0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-5)

    lossTr_list = []
    epoches = []
    mIOU_val_list = []

    print('=====> beginning training')
    for epoch in range(start_epoch, args.max_epochs):
        # training
        lossTr, lr = train(args, trainLoader, model, criteria, optimizer,
                           epoch)
        lossTr_list.append(lossTr)

        # validation
        if epoch % 30 == 0 or epoch == (args.max_epochs - 1):
            epoches.append(epoch)
            mIOU_val, per_class_iu = val(args, valLoader, model)
            mIOU_val_list.append(mIOU_val)
            # record train information
            logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.7f" %
                         (epoch, lossTr, mIOU_val, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print(
                "Epoch No.: %d\tTrain Loss = %.4f\t mIOU(val) = %.4f\t lr= %.6f\n"
                % (epoch, lossTr, mIOU_val, lr))
        else:
            # record train information
            logger.write("\n%d\t\t%.4f\t\t\t\t%.7f" % (epoch, lossTr, lr))
            logger.flush()
            print("Epoch : " + str(epoch) + ' Details')
            print("Epoch No.: %d\tTrain Loss = %.4f\t lr= %.6f\n" %
                  (epoch, lossTr, lr))

        # save the model
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        state = {"epoch": epoch + 1, "model": model.state_dict()}

        if epoch >= args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif not epoch % 100:
            torch.save(state, model_file_name)

        # draw plots for visualization
        if epoch % 30 == 0 or epoch == (args.max_epochs - 1):
            # Plot the figures per 50 epochs
            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(start_epoch, epoch + 1), lossTr_list)
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")

            plt.savefig(args.savedir + "loss_vs_epochs.png")

            plt.clf()

            fig2, ax2 = plt.subplots(figsize=(11, 8))

            ax2.plot(epoches, mIOU_val_list, label="Val IoU")
            ax2.set_title("Average IoU vs epochs")
            ax2.set_xlabel("Epochs")
            ax2.set_ylabel("Current IoU")
            plt.legend(loc='lower right')

            plt.savefig(args.savedir + "iou_vs_epochs.png")

            plt.close('all')

    logger.close()
Esempio n. 5
0
def train_model(args):
    """
    args:
       args: global arguments
    """
    h, w = map(int, args.input_size.split(','))
    input_size = (h, w)
    print("input size:{}".format(input_size))

    print(args)

    if args.cuda:
        print("use gpu id: '{}'".format(args.gpus))
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
        if not torch.cuda.is_available():
            raise Exception(
                "No GPU found or Wrong gpu id, please run without --cuda")

    # set the seed
    setup_seed(GLOBAL_SEED)
    print("set Global Seed: ", GLOBAL_SEED)
    cudnn.enabled = True
    print("building network")

    # build the model and initialization
    model = build_model(args.model, num_classes=args.classes)
    init_weight(model,
                nn.init.kaiming_normal_,
                nn.BatchNorm2d,
                1e-3,
                0.1,
                mode='fan_in')

    print("computing network parameters and FLOPs")
    total_paramters = netParams(model)
    print("the number of parameters: %d ==> %.2f M" %
          (total_paramters, (total_paramters / 1e6)))

    # load data and data augmentation
    datas, trainLoader, valLoader = build_dataset_train(
        args.dataset, input_size, args.batch_size, args.train_type,
        args.random_scale, args.random_mirror, args.num_workers)

    args.per_iter = len(trainLoader)
    args.max_iter = args.max_epochs * args.per_iter

    print('Dataset statistics')
    print("data['classWeights']: ", datas['classWeights'])
    print('mean and std: ', datas['mean'], datas['std'])

    # define loss function, respectively
    weight = torch.from_numpy(datas['classWeights'])

    if args.dataset == 'camvid':
        criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label)
    elif args.dataset == 'camvid' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(weight=weight,
                                                 ignore_label=ignore_label)

    elif args.dataset == 'cityscapes' and args.use_ohem:
        min_kept = int(args.batch_size // len(args.gpus) * h * w // 16)
        criteria = ProbOhemCrossEntropy2d(use_weight=True,
                                          ignore_label=ignore_label,
                                          thresh=0.7,
                                          min_kept=min_kept)
    elif args.dataset == 'cityscapes' and args.use_label_smoothing:
        criteria = CrossEntropyLoss2dLabelSmooth(weight=weight,
                                                 ignore_label=ignore_label)
    elif args.dataset == 'cityscapes' and args.use_lovaszsoftmax:
        criteria = LovaszSoftmax(ignore_index=ignore_label)
    elif args.dataset == 'cityscapes' and args.use_focal:
        criteria = FocalLoss2d(weight=weight, ignore_index=ignore_label)

    elif args.dataset == 'paris':
        criteria = CrossEntropyLoss2d(weight=weight, ignore_label=ignore_label)

    else:
        raise NotImplementedError(
            "This repository now supports two datasets: cityscapes and camvid, %s is not included"
            % args.dataset)

    if args.cuda:
        criteria = criteria.cuda()
        if torch.cuda.device_count() > 1:
            print("torch.cuda.device_count()=", torch.cuda.device_count())
            args.gpu_nums = torch.cuda.device_count()
            model = nn.DataParallel(model).cuda()  # multi-card data parallel
        else:
            args.gpu_nums = 1
            print("single GPU for training")
            model = model.cuda()  # 1-card data parallel

    args.savedir = (args.savedir + args.dataset + '/' + args.model + 'bs' +
                    str(args.batch_size) + 'gpu' + str(args.gpu_nums) + "_" +
                    str(args.train_type) + '/')

    if not os.path.exists(args.savedir):
        os.makedirs(args.savedir)

    with open(args.savedir + 'args.txt', 'w') as f:
        f.write('mean:{}\nstd:{}\n'.format(datas['mean'], datas['std']))
        f.write("Parameters: {} Seed: {}\n".format(str(total_paramters),
                                                   GLOBAL_SEED))
        f.write(str(args))

    start_epoch = 0
    # continue training
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            # model.load_state_dict(convert_state_dict(checkpoint['model']))
            print("loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("no checkpoint found at '{}'".format(args.resume))

    model.train()
    cudnn.benchmark = True
    # cudnn.deterministic = True ## my add

    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=50)

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("%s\t%s\t\t%s\t%s\t%s" %
                     ('Epoch', '   lr', 'Loss(Tr)', 'Loss(Val)', 'mIOU(Val)'))
    logger.flush()

    # define optimization strategy
    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad,
                                           model.parameters()),
                                    lr=args.lr,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=args.lr,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-4)
    elif args.optim == 'radam':
        optimizer = RAdam(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.90, 0.999),
                          eps=1e-08,
                          weight_decay=1e-4)
    elif args.optim == 'ranger':
        optimizer = Ranger(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=args.lr,
                           betas=(0.95, 0.999),
                           eps=1e-08,
                           weight_decay=1e-4)
    elif args.optim == 'adamw':
        optimizer = AdamW(filter(lambda p: p.requires_grad,
                                 model.parameters()),
                          lr=args.lr,
                          betas=(0.9, 0.999),
                          eps=1e-08,
                          weight_decay=1e-4)

    lossTr_list = []
    epoches = []
    mIOU_val_list = []
    lossVal_list = []
    print('>>>>>>>>>>>beginning training>>>>>>>>>>>')
    for epoch in range(start_epoch, args.max_epochs):
        # training
        lossTr, lr = train(args, trainLoader, model, criteria, optimizer,
                           epoch)
        lossTr_list.append(lossTr)

        # validation
        if epoch % args.val_miou_epochs == 0:
            epoches.append(epoch)
            val_loss, mIOU_val, per_class_iu = val(args, valLoader, criteria,
                                                   model, epoch)
            mIOU_val_list.append(mIOU_val)
            lossVal_list.append(val_loss.item())
            # record train information
            logger.write(
                "\n%d\t%.6f\t%.4f\t\t%.4f\t%0.4f\t %s" %
                (epoch, lr, lossTr, val_loss, mIOU_val, str(per_class_iu)))
            logger.flush()
            print(
                "Epoch  %d\tlr= %.6f\tTrain Loss = %.4f\tVal Loss = %.4f\tmIOU(val) = %.4f\tper_class_iu= %s\n"
                % (epoch, lr, lossTr, val_loss, mIOU_val, str(per_class_iu)))
        else:
            # record train information
            val_loss = val(args, valLoader, criteria, model, epoch)
            lossVal_list.append(val_loss.item())
            logger.write("\n%d\t%.6f\t%.4f\t\t%.4f" %
                         (epoch, lr, lossTr, val_loss))
            logger.flush()
            print("Epoch  %d\tlr= %.6f\tTrain Loss = %.4f\tVal Loss = %.4f\n" %
                  (epoch, lr, lossTr, val_loss))

        # save the model
        model_file_name = args.savedir + '/model_' + str(epoch) + '.pth'
        state = {"epoch": epoch, "model": model.state_dict()}

        # Individual Setting for save model
        if epoch >= args.max_epochs - 10:
            torch.save(state, model_file_name)
        elif epoch % 10 == 0:
            torch.save(state, model_file_name)

        # draw plots for visualization
        if os.path.isfile(args.savedir + "loss.png"):
            f = open(args.savedir + 'log.txt', 'r')
            next(f)
            epoch_list = []
            lossTr_list = []
            lossVal_list = []
            for line in f.readlines():
                epoch_list.append(line.strip().split()[0])
                lossTr_list.append(line.strip().split()[2])
                lossVal_list.append(line.strip().split()[3])
            assert len(epoch_list) == len(lossTr_list) == len(lossVal_list)

            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(0, epoch + 1), lossTr_list, label='Train_loss')
            ax1.plot(range(0, epoch + 1), lossVal_list, label='Val_loss')
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")
            ax1.legend()

            plt.savefig(args.savedir + "loss.png")
            plt.clf()
        else:
            fig1, ax1 = plt.subplots(figsize=(11, 8))

            ax1.plot(range(0, epoch + 1), lossTr_list, label='Train_loss')
            ax1.plot(range(0, epoch + 1), lossVal_list, label='Val_loss')
            ax1.set_title("Average training loss vs epochs")
            ax1.set_xlabel("Epochs")
            ax1.set_ylabel("Current loss")
            ax1.legend()

            plt.savefig(args.savedir + "loss.png")
            plt.clf()

            fig2, ax2 = plt.subplots(figsize=(11, 8))

            ax2.plot(epoches, mIOU_val_list, label="Val IoU")
            ax2.set_title("Average IoU vs epochs")
            ax2.set_xlabel("Epochs")
            ax2.set_ylabel("Current IoU")
            ax2.legend()

            plt.savefig(args.savedir + "mIou.png")
            plt.close('all')

        early_stopping.monitor(monitor=val_loss)
        if early_stopping.early_stop:
            print("Early stopping and Save checkpoint")
            if not os.path.exists(model_file_name):
                torch.save(state, model_file_name)
            break

    logger.close()
Esempio n. 6
0
def main():
    args = parse_args()
    update_config(cfg, args)

    setup_seed(cfg.SEED)

    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(x) for x in cfg.GPUS])

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, args.mention, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)
    # print(model)

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1]))

    try:
        writer_dict['writer'].add_graph(model, (dump_input, ))
    except Exception as e:
        logger.info(e)

    try:
        logger.info(get_model_summary(model, dump_input))
    except:
        pass

    model = torch.nn.DataParallel(model, device_ids=list(range(len(
        cfg.GPUS)))).cuda()

    # define loss function (criterion) and optimizer
    criterion = eval(cfg.LOSS.NAME)(cfg).cuda()

    if cfg.LOSS.NAME == 'ModMSE_KL_CC_NSS_Loss':
        criterion_val = ModMSE_KL_CC_Loss(cfg).cuda()
    else:
        criterion_val = criterion

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    logger.info(os.linesep +
                'train_set : {:d} entries'.format(len(train_dataset)))
    logger.info('val_set   : {:d} entries'.format(len(valid_dataset)) +
                os.linesep)

    if cfg.DATASET.SAMPLER == "":
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
            shuffle=cfg.TRAIN.SHUFFLE,
            num_workers=cfg.WORKERS,
            pin_memory=cfg.PIN_MEMORY)
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
            shuffle=False,
            num_workers=cfg.WORKERS,
            pin_memory=cfg.PIN_MEMORY)
    elif cfg.DATASET.SAMPLER == "RandomIdentitySampler":
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            sampler=dataset.RandomIdentitySampler(
                train_dataset.images,
                cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
                cfg.DATASET.NUM_INSTANCES),
            batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS) //
            cfg.DATASET.NUM_INSTANCES,
            shuffle=False,
            num_workers=cfg.WORKERS,
            pin_memory=cfg.PIN_MEMORY)
        valid_loader = torch.utils.data.DataLoader(
            valid_dataset,
            sampler=dataset.RandomIdentitySampler(
                valid_dataset.images,
                cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
                cfg.DATASET.NUM_INSTANCES),
            batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS) //
            cfg.DATASET.NUM_INSTANCES,
            shuffle=False,
            num_workers=cfg.WORKERS,
            pin_memory=cfg.PIN_MEMORY)
    else:
        assert False

    best_perf = None
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    if cfg.TRAIN.WARMUP_EPOCHS > 0:
        lr_scheduler = WarmupMultiStepLR(optimizer,
                                         cfg.TRAIN.LR_STEP,
                                         cfg.TRAIN.LR_FACTOR,
                                         warmup_iters=cfg.TRAIN.WARMUP_EPOCHS,
                                         last_epoch=last_epoch)
    else:
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            cfg.TRAIN.LR_STEP,
            cfg.TRAIN.LR_FACTOR,
            last_epoch=last_epoch)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # torch.cuda.empty_cache()

        # train for one epoch
        train(cfg, train_loader, model, criterion, optimizer, epoch,
              final_output_dir, tb_log_dir, writer_dict)

        # torch.cuda.empty_cache()

        # evaluate on validation set
        perf_indicator, is_larger_better = validate(cfg, valid_loader,
                                                    valid_dataset, model,
                                                    criterion_val,
                                                    final_output_dir,
                                                    tb_log_dir, writer_dict)

        if is_larger_better:
            if best_perf is None or perf_indicator >= best_perf:
                best_perf = perf_indicator
                best_model = True
            else:
                best_model = False
        else:
            if best_perf is None or perf_indicator <= best_perf:
                best_perf = perf_indicator
                best_model = True
            else:
                best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
    def parse_arg(self):
        opt = self.parser.parse_args()
        setup_seed(opt.seed)
        opt.amp_available = True if LooseVersion(
            torch.__version__) >= LooseVersion('1.6.0') and opt.amp else False
        ####################################################################################################
        """ Directory """
        dir_root = os.getcwd()
        opt.dir_data = os.path.join(dir_root, 'data', opt.dataset)
        opt.dir_img = os.path.join(opt.dir_data, 'image')
        opt.dir_label = os.path.join(opt.dir_data, 'label')
        opt.dir_log = os.path.join(dir_root, 'logs', opt.dataset,
                                   f"EXP_{opt.exp_id}_NET_{opt.arch}")

        opt.dir_vis = os.path.join(dir_root, 'vis', opt.dataset, opt.exp_id)
        opt.dir_result = os.path.join(dir_root, 'results', opt.dataset,
                                      opt.exp_id)

        ####################################################################################################
        """ Model Architecture """
        os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.gpus)
        opt.net = get_model(3, 1, opt.arch)
        opt.param = "%.2fM" % (sum(x.numel()
                                   for x in opt.net.parameters()) / 1e+6)
        opt.device = torch.device(
            f'cuda:{0}' if torch.cuda.is_available() else 'cpu')
        if opt.gpus is not None:
            warnings.warn(
                'You have chosen a specific GPU. This will completely '
                'disable data parallelism.')
        opt.net.to(device=opt.device)
        ####################################################################################################
        """ Optimizer """
        if opt.optim == "Adam":
            opt.optimizer = optim.Adam(opt.net.parameters(),
                                       lr=opt.lr,
                                       weight_decay=opt.l2)
        elif opt.optim == "SGD":
            opt.optimizer = optim.SGD(opt.net.parameters(),
                                      lr=opt.lr,
                                      momentum=0.9,
                                      weight_decay=opt.l2)
        ####################################################################################################
        """ Scheduler """
        if opt.sche == "ExpLR":
            gamma = 0.95
            opt.scheduler = torch.optim.lr_scheduler.ExponentialLR(
                opt.optimizer, gamma=gamma, last_epoch=-1)
        elif opt.sche == "MulStepLR":
            milestones = [90, 120]
            opt.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                opt.optimizer, milestones=milestones, gamma=0.1)
        elif opt.sche == "CosAnnLR":
            t_max = 5
            opt.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                opt.optimizer, T_max=t_max, eta_min=0.)
        elif opt.sche == "ReduceLR":
            mode = "max"
            factor = 0.9
            patience = 10
            opt.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                opt.optimizer, mode=mode, factor=factor, patience=patience)
        ####################################################################################################
        """ Loss Function """
        if opt.loss == "dice_loss":
            opt.loss_function = DiceLoss()
        elif opt.loss == "dice_bce_loss":
            opt.loss_function = DiceBCELoss()
        ####################################################################################################

        return opt