예제 #1
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint_path', type=str, default=None, required=True, help='The path to the pretrained weights of model')
    parser.add_argument('--crop_height', type=int, default=640, help='Height of cropped/resized input image to network')
    parser.add_argument('--crop_width', type=int, default=640, help='Width of cropped/resized input image to network')
    parser.add_argument('--data', type=str, default='/path/to/data', help='Path of training data')
    parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch')
    parser.add_argument('--context_path', type=str, default="resnet101", help='The context path model you are using.')
    parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training')
    parser.add_argument('--use_gpu', type=bool, default=True, help='Whether to user gpu for training')
    parser.add_argument('--num_classes', type=int, default=32, help='num of object classes (with void)')
    args = parser.parse_args(params)

    # create dataset and dataloader
    test_path = os.path.join(args.data, 'test')
    # test_path = os.path.join(args.data, 'train')
    test_label_path = os.path.join(args.data, 'test_labels')
    # test_label_path = os.path.join(args.data, 'train_labels')
    csv_path = os.path.join(args.data, 'class_dict.csv')
    dataset = CamVid(test_path, test_label_path, csv_path, scale=(args.crop_height, args.crop_width), mode='test')
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=4,
    )

    # build model
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        model = torch.nn.DataParallel(model).cuda()

    # load pretrained model if exists
    print('load model from %s ...' % args.checkpoint_path)
    model.module.load_state_dict(torch.load(args.checkpoint_path))
    print('Done!')

    # get label info
    label_info = get_label_info(csv_path)
    # test
    eval(model, dataloader, args, label_info)
예제 #2
0
파일: train.py 프로젝트: qiuhui1991/BiSeNet
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs',
                        type=int,
                        default=300,
                        help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i',
                        type=int,
                        default=0,
                        help='Start counting epochs from this number')
    parser.add_argument('--checkpoint_step',
                        type=int,
                        default=5,
                        help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step',
                        type=int,
                        default=1,
                        help='How often to perform validation (epochs)')
    parser.add_argument('--dataset',
                        type=str,
                        default="CamVid",
                        help='Dataset you are using.')
    parser.add_argument(
        '--crop_height',
        type=int,
        default=640,
        help='Height of cropped/resized input image to network')
    parser.add_argument('--crop_width',
                        type=int,
                        default=640,
                        help='Width of cropped/resized input image to network')
    parser.add_argument('--batch_size',
                        type=int,
                        default=1,
                        help='Number of images in each batch')
    parser.add_argument('--context_path',
                        type=str,
                        default="resnet101",
                        help='The context path model you are using.')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.01,
                        help='learning rate used for train')
    parser.add_argument('--data',
                        type=str,
                        default='/path/to/data',
                        help='path of training data')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='num of workers')
    parser.add_argument('--num_classes',
                        type=int,
                        default=32,
                        help='num of object classes (with void)')
    parser.add_argument('--cuda',
                        type=str,
                        default='0',
                        help='GPU ids used for training')
    parser.add_argument('--use_gpu',
                        type=bool,
                        default=True,
                        help='whether to user gpu for training')
    parser.add_argument('--pretrained_model_path',
                        type=str,
                        default=None,
                        help='path to pretrained model')
    parser.add_argument('--save_model_path',
                        type=str,
                        default=None,
                        help='path to save model')

    args = parser.parse_args(params)

    # create dataset and dataloader
    train_path = os.path.join(args.data, 'train')
    train_label_path = os.path.join(args.data, 'train_labels')
    val_path = os.path.join(args.data, 'val')
    val_label_path = os.path.join(args.data, 'val_labels')
    csv_path = os.path.join(args.data, 'class_dict.csv')
    dataset_train = CamVid(train_path,
                           train_label_path,
                           csv_path,
                           scale=(args.crop_height, args.crop_width),
                           mode='train')
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers)
    dataset_val = CamVid(val_path,
                         val_label_path,
                         csv_path,
                         scale=((args.crop_height, args.crop_width)),
                         mode='val')
    dataloader_val = DataLoader(
        dataset_val,
        # this has to be 1
        batch_size=1,
        shuffle=True,
        num_workers=args.num_workers)

    # build model
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        model = torch.nn.DataParallel(model).cuda()

    # build optimizer
    optimizer = torch.optim.RMSprop(model.parameters(), args.learning_rate)

    # load pretrained model if exists
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        model.module.load_state_dict(torch.load(args.pretrained_model_path))
        print('Done!')

    # train
    train(args, model, optimizer, dataloader_train, dataloader_val, csv_path)
예제 #3
0
def main():
    # Call Python's garbage collector, and empty torch's CUDA cache. Just in case
    gc.collect()
    torch.cuda.empty_cache()

    # Enable cuDNN in benchmark mode. For more info see:
    # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    # Load Bisenet generator
    generator = BiSeNet(NUM_CLASSES, CONTEXT_PATH).cuda()
    # generator.load_state_dict(torch.load('./checkpoint_101_adversarial_both_augmentation_epoch_len_IDDA/37_Generator.pth'))
    generator.train()
    # Build discriminator
    discriminator = Discriminator(NUM_CLASSES).cuda()
    # discriminator.load_state_dict(torch.load('./checkpoint_101_adversarial_both_augmentation_epoch_len_IDDA/37_Discriminator.pth'))
    discriminator.train()

    # Load source dataset
    source_dataset = IDDA(image_path=IDDA_PATH,
                          label_path=IDDA_LABEL_PATH,
                          classes_info_path=JSON_IDDA_PATH,
                          scale=(CROP_HEIGHT, CROP_WIDTH),
                          loss=LOSS,
                          mode='train')
    source_dataloader = DataLoader(source_dataset,
                                   batch_size=BATCH_SIZE_IDDA,
                                   shuffle=True,
                                   num_workers=NUM_WORKERS,
                                   drop_last=True,
                                   pin_memory=True)

    # Load target dataset
    target_dataset = CamVid(image_path=CAMVID_PATH,
                            label_path=CAMVID_LABEL_PATH,
                            csv_path=CSV_CAMVID_PATH,
                            scale=(CROP_HEIGHT, CROP_WIDTH),
                            loss=LOSS,
                            mode='adversarial_train')
    target_dataloader = DataLoader(target_dataset,
                                   batch_size=BATCH_SIZE_CAMVID,
                                   shuffle=True,
                                   num_workers=NUM_WORKERS,
                                   drop_last=True,
                                   pin_memory=True)

    optimizer_BiSeNet = torch.optim.SGD(generator.parameters(),
                                        lr=LEARNING_RATE_SEGMENTATION,
                                        momentum=MOMENTUM,
                                        weight_decay=WEIGHT_DECAY)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=LEARNING_RATE_DISCRIMINATOR,
                                               betas=(0.9, 0.99))

    # Loss for discriminator training
    # Sigmoid layer + BCELoss
    bce_loss = nn.BCEWithLogitsLoss()

    # Loss for segmentation loss
    # Log-softmax layer + 2D Cross Entropy
    cross_entropy_loss = CrossEntropy2d()

    # for epoch in range(NUM_STEPS):
    for epoch in range(1, 51):
        source_dataloader_iter = iter(source_dataloader)
        target_dataloader_iter = iter(target_dataloader)

        print(f'begin epoch {epoch}')

        # Initialize gradients=0 for Generator and Discriminator
        optimizer_BiSeNet.zero_grad()
        optimizer_discriminator.zero_grad()

        # Setting losses equal to 0
        l_seg_to_print_acc, l_adv_to_print_acc, l_d_to_print_acc = 0, 0, 0

        # Compute learning rate for this epoch
        adjust_learning_rate(optimizer_BiSeNet, LEARNING_RATE_SEGMENTATION,
                             epoch, NUM_STEPS, POWER)
        adjust_learning_rate(optimizer_discriminator,
                             LEARNING_RATE_DISCRIMINATOR, epoch, NUM_STEPS,
                             POWER)

        for i in tqdm(range(len(target_dataloader))):
            optimizer_BiSeNet.zero_grad()
            optimizer_discriminator.zero_grad()
            l_seg_to_print, l_adv_to_print, l_d_to_print = minibatch(
                source_dataloader_iter, target_dataloader_iter, generator,
                discriminator, cross_entropy_loss, bce_loss, source_dataloader,
                target_dataloader)
            l_seg_to_print_acc += l_seg_to_print
            l_adv_to_print_acc += l_adv_to_print
            l_d_to_print_acc += l_d_to_print
            # Run optimizers using the gradient obtained via backpropagations
            optimizer_BiSeNet.step()
            optimizer_discriminator.step()

        # Output at each epoch
        print(
            f'epoch = {epoch}/{NUM_STEPS}, loss_seg = {l_seg_to_print_acc:.3f}, loss_adv = {l_adv_to_print_acc:.3f}, loss_D = {l_d_to_print_acc:.3f}'
        )

        # Save intermediate generator (checkpoint)
        if epoch % CHECKPOINT_STEP == 0 and epoch != 0:
            # If the directory does not exists create it
            if not os.path.isdir(CHECKPOINT_PATH):
                os.mkdir(CHECKPOINT_PATH)
            # Save the parameters of the generator (segmentation network) and discriminator
            generator_checkpoint_path = os.path.join(
                CHECKPOINT_PATH, f"{BETA}_{epoch}_Generator.pth")
            torch.save(generator.state_dict(), generator_checkpoint_path)
            discriminator_checkpoint_path = os.path.join(
                CHECKPOINT_PATH, f"{BETA}_{epoch}_Discriminator.pth")
            torch.save(discriminator.state_dict(),
                       discriminator_checkpoint_path)
            print(
                f"saved:\n{generator_checkpoint_path}\n{discriminator_checkpoint_path}"
            )
예제 #4
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number')
    parser.add_argument('--checkpoint_step', type=int, default=10, help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step', type=int, default=2, help='How often to perform validation (epochs)')
    parser.add_argument('--dataset', type=str, default="CamVid", help='Dataset you are using.')
    parser.add_argument('--crop_height', type=int, default=720, help='Height of cropped/resized input image to network')
    parser.add_argument('--crop_width', type=int, default=960, help='Width of cropped/resized input image to network')
    parser.add_argument('--batch_size', type=int, default=32, help='Number of images in each batch')
    parser.add_argument('--context_path', type=str, default="resnet101",
                        help='The context path model you are using, resnet18, resnet101.')
    parser.add_argument('--learning_rate_G', type=float, default=0.01, help='learning rate for G')
    parser.add_argument('--learning_rate_D', type=float, default=0.01, help='learning rate for D')#add lr_D 1e-4
    parser.add_argument('--data_CamVid', type=str, default='', help='path of training data_CamVid')
    parser.add_argument('--data_IDDA', type=str, default='', help='path of training data_IDDA')
    parser.add_argument('--num_workers', type=int, default=4, help='num of workers')
    parser.add_argument('--num_classes', type=int, default=32, help='num of object classes (with void)')
    parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training')
    parser.add_argument('--use_gpu', type=bool, default=True, help='whether to user gpu for training')
    parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model')
    parser.add_argument('--save_model_path', type=str, default=None, help='path to save model')
    parser.add_argument('--optimizer_G', type=str, default='rmsprop', help='optimizer_G, support rmsprop, sgd, adam')  
    parser.add_argument('--optimizer_D', type=str, default='rmsprop', help='optimizer_D, support rmsprop, sgd, adam')
    parser.add_argument('--loss', type=str, default='dice', help='loss function, dice or crossentropy')
    parser.add_argument('--loss_G', type=str, default='dice', help='loss function, dice or crossentropy')
    parser.add_argument('--lambda_adv', type=float, default=0.01, help='lambda coefficient for adversarial loss')

    args = parser.parse_args(params)

    # create dataset and dataloader for CamVid
    CamVid_train_path = [os.path.join(args.data_CamVid, 'train'), os.path.join(args.data_CamVid, 'val')]
    CamVid_train_label_path = [os.path.join(args.data_CamVid, 'train_labels'),
                               os.path.join(args.data_CamVid, 'val_labels')]
    CamVid_test_path = os.path.join(args.data_CamVid, 'test')
    CamVid_test_label_path = os.path.join(args.data_CamVid, 'test_labels')
    CamVid_csv_path = os.path.join(args.data_CamVid, 'class_dict.csv')
    CamVid_dataset_train = CamVid(CamVid_train_path, CamVid_train_label_path, CamVid_csv_path,
                                  scale=(args.crop_height, args.crop_width),
                                  loss=args.loss, mode='train')
    CamVid_dataloader_train = DataLoader(
        CamVid_dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True
    )
    CamVid_dataset_val = CamVid(CamVid_test_path, CamVid_test_label_path, CamVid_csv_path,
                                scale=(args.crop_height, args.crop_width),
                                loss=args.loss, mode='test')
    CamVid_dataloader_val = DataLoader(
        CamVid_dataset_val,
        # this has to be 1
        batch_size=1,
        shuffle=True,
        num_workers=args.num_workers
    )

    # create dataset and dataloader for IDDA
    IDDA_path = os.path.join(args.data_IDDA, 'rgb')
    IDDA_label_path = os.path.join(args.data_IDDA, 'labels')
    IDDA_info_path = os.path.join(args.data_IDDA, 'classes_info.json')
    IDDA_dataset = IDDA(IDDA_path, IDDA_label_path, IDDA_info_path, CamVid_csv_path, scale=(args.crop_height, args.crop_width), loss=args.loss)
    IDDA_dataloader = DataLoader(
        IDDA_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True
    )

    # build model_G
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model_G = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        model_G = torch.nn.DataParallel(model_G).cuda()
        
    #build model_D
    model_D = DW_Discriminator(args.num_classes)
    if torch.cuda.is_available() and args.use_gpu:
        model_D = torch.nn.DataParallel(model_D).cuda()

    # build optimizer G
    if args.optimizer_G == 'rmsprop':
        optimizer_G = torch.optim.RMSprop(model_G.parameters(), args.learning_rate_G)
    elif args.optimizer_G == 'sgd':
        optimizer_G = torch.optim.SGD(model_G.parameters(), args.learning_rate_G, momentum=0.9, weight_decay=1e-4)
    elif args.optimizer_G == 'adam':
        optimizer_G = torch.optim.Adam(model_G.parameters(), args.learning_rate_G)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None

    # build optimizer D
    if args.optimizer_D == 'rmsprop':
        optimizer_D = torch.optim.RMSprop(model_D.parameters(), args.learning_rate_D)
    elif args.optimizer_D == 'sgd':
        optimizer_D = torch.optim.SGD(model_D.parameters(), args.learning_rate_D, momentum=0.9, weight_decay=1e-4)
    elif args.optimizer_D == 'adam':
        optimizer_D = torch.optim.Adam(model_D.parameters(), args.learning_rate_D)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None

    curr_epoch = 0
    max_miou = 0
         
    # load pretrained model if exists
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)   
        state = torch.load(os.path.realpath(args.pretrained_model_path))  # upload the pretrained  MODEL_G 
        model_G.module.load_state_dict(state['model_G_state'])
        optimizer_G.load_state_dict(state['optimizer_G'])
        model_D.module.load_state_dict(state['model_D_state'])            # upload the pretrained  MODEL_D 
        optimizer_D.load_state_dict(state['optimizer_D'])
        curr_epoch = state["epoch"]
        max_miou = state["max_miou"]
        print(str(curr_epoch - 1) + " already trained")
        print("start training from epoch " + str(curr_epoch))
        print('Done!')

    # train
    train (args, model_G, model_D, optimizer_G, optimizer_D, CamVid_dataloader_train, CamVid_dataloader_val, IDDA_dataloader, curr_epoch, max_miou)