Пример #1
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='/home/disk2/xs/sun', help='path of training data')
    parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number')
    parser.add_argument('--checkpoint_step', type=int, default=10, help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step', type=int, default=1, help='How often to perform validation (epochs)')
    parser.add_argument('--dataset', type=str, default='SUN', help='Dataset you are using.')
    parser.add_argument('--crop_height', type=int, default=480, help='Height of cropped/resized input image to network')
    parser.add_argument('--crop_width', type=int, default=640, help='Width of cropped/resized input image to network')
    parser.add_argument('--batch_size', type=int, default=5, help='Number of images in each batch')
    parser.add_argument('--context_path', type=str, default="resnet101", help='The context path model you are using.')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate used for train')
    parser.add_argument('--num_workers', type=int, default=4, help='num of workers')
    parser.add_argument('--num_classes', type=int, default=38, help='num of object classes (with void)')
    parser.add_argument('--cuda', type=str, default='2', help='GPU ids used for training')
    parser.add_argument('--use_gpu', type=bool, default=True, help='whether to user gpu for training')
    parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model')
    parser.add_argument('--save_model_path', type=str, default='./checkpoints', help='path to save model')
    parser.add_argument('--csv_path', type=str, default='/home/disk2/xs/sun/seg37_class_dict.csv', help='Path to label info csv file')

    args = parser.parse_args(params)

    # create dataset and dataloader
    train_img_path = os.path.join(args.data, 'train/image')
    train_depth_path = os.path.join(args.data, 'train/depth')
    train_label_path = os.path.join(args.data, 'train/label')

    csv_path = os.path.join(args.data, 'seg37_class_dict.csv')

    dataset_train = SUN(train_img_path, train_depth_path, train_label_path, csv_path, scale=(args.crop_height, args.crop_width), mode='train')
    dataloader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers
    )

    # build model
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        model = torch.nn.DataParallel(model).cuda()

    # build optimizer
    optimizer = torch.optim.RMSprop(model.parameters(), args.learning_rate)

    # load pretrained model if exists
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        model.module.load_state_dict(torch.load(args.pretrained_model_path))
        print('Done!')

    # train
    train(args, model, optimizer, dataloader_train, csv_path)
Пример #2
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs',
                        type=int,
                        default=300,
                        help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i',
                        type=int,
                        default=0,
                        help='Start counting epochs from this number')
    parser.add_argument('--checkpoint_step',
                        type=int,
                        default=10,
                        help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step',
                        type=int,
                        default=10,
                        help='How often to perform validation (epochs)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=1,
                        help='Number of images in each batch')
    parser.add_argument(
        '--context_path',
        type=str,
        default="resnet101",
        help='The context path model you are using, resnet18, resnet101.')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.01,
                        help='learning rate used for train')
    parser.add_argument('--data',
                        type=str,
                        default='data',
                        help='path of training data')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='num of workers')
    parser.add_argument('--num_classes',
                        type=int,
                        default=32,
                        help='num of object classes (with void)')
    parser.add_argument('--cuda',
                        type=str,
                        default='0',
                        help='GPU ids used for training')
    parser.add_argument('--use_gpu',
                        type=bool,
                        default=True,
                        help='whether to user gpu for training')
    parser.add_argument('--pretrained_model_path',
                        type=str,
                        default=None,
                        help='path to pretrained model')
    parser.add_argument('--save_model_path',
                        type=str,
                        default="checkpoints",
                        help='path to save model')
    parser.add_argument('--optimizer',
                        type=str,
                        default='rmsprop',
                        help='optimizer, support rmsprop, sgd, adam')
    parser.add_argument('--loss',
                        type=str,
                        default='crossentropy',
                        help='loss function, dice or crossentropy')

    # settiamo i nostri parametri
    args = parser.parse_args(params)

    # create dataset and dataloader
    train_path = args.data
    train_transform, val_transform = get_transform()

    # creiamo un oggetto di tipo VOC per il training
    dataset_train = VOC(train_path,
                        image_set="train",
                        transform=train_transform)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  num_workers=args.num_workers,
                                  drop_last=True)

    # creiamo un oggetto di tipo VOC per la validation
    dataset_val = VOC(train_path, image_set="val", transform=val_transform)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=args.batch_size,
                                shuffle=True,
                                num_workers=args.num_workers)

    # build model
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        model = model.cuda()

    # build optimizer
    if args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), args.learning_rate)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    args.learning_rate,
                                    momentum=0.9,
                                    weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None

    # load pretrained model if exists
    # Non ce l'abbiamo
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        model.load_state_dict(torch.load(args.pretrained_model_path))
        print('Done!')

    # train
    # funzioni presenti in questo file
    train(args, model, optimizer, dataloader_train, dataloader_val)

    val(args, model, dataloader_val)
Пример #3
0
def main():
    # Call Python's garbage collector, and empty torch's CUDA cache. Just in case
    gc.collect()
    torch.cuda.empty_cache()

    # Enable cuDNN in benchmark mode. For more info see:
    # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    # Load Bisenet generator
    generator = BiSeNet(NUM_CLASSES, CONTEXT_PATH).cuda()
    # generator.load_state_dict(torch.load('./checkpoint_101_adversarial_both_augmentation_epoch_len_IDDA/37_Generator.pth'))
    generator.train()
    # Build discriminator
    discriminator = Discriminator(NUM_CLASSES).cuda()
    # discriminator.load_state_dict(torch.load('./checkpoint_101_adversarial_both_augmentation_epoch_len_IDDA/37_Discriminator.pth'))
    discriminator.train()

    # Load source dataset
    source_dataset = IDDA(image_path=IDDA_PATH,
                          label_path=IDDA_LABEL_PATH,
                          classes_info_path=JSON_IDDA_PATH,
                          scale=(CROP_HEIGHT, CROP_WIDTH),
                          loss=LOSS,
                          mode='train')
    source_dataloader = DataLoader(source_dataset,
                                   batch_size=BATCH_SIZE_IDDA,
                                   shuffle=True,
                                   num_workers=NUM_WORKERS,
                                   drop_last=True,
                                   pin_memory=True)

    # Load target dataset
    target_dataset = CamVid(image_path=CAMVID_PATH,
                            label_path=CAMVID_LABEL_PATH,
                            csv_path=CSV_CAMVID_PATH,
                            scale=(CROP_HEIGHT, CROP_WIDTH),
                            loss=LOSS,
                            mode='adversarial_train')
    target_dataloader = DataLoader(target_dataset,
                                   batch_size=BATCH_SIZE_CAMVID,
                                   shuffle=True,
                                   num_workers=NUM_WORKERS,
                                   drop_last=True,
                                   pin_memory=True)

    optimizer_BiSeNet = torch.optim.SGD(generator.parameters(),
                                        lr=LEARNING_RATE_SEGMENTATION,
                                        momentum=MOMENTUM,
                                        weight_decay=WEIGHT_DECAY)
    optimizer_discriminator = torch.optim.Adam(discriminator.parameters(),
                                               lr=LEARNING_RATE_DISCRIMINATOR,
                                               betas=(0.9, 0.99))

    # Loss for discriminator training
    # Sigmoid layer + BCELoss
    bce_loss = nn.BCEWithLogitsLoss()

    # Loss for segmentation loss
    # Log-softmax layer + 2D Cross Entropy
    cross_entropy_loss = CrossEntropy2d()

    # for epoch in range(NUM_STEPS):
    for epoch in range(1, 51):
        source_dataloader_iter = iter(source_dataloader)
        target_dataloader_iter = iter(target_dataloader)

        print(f'begin epoch {epoch}')

        # Initialize gradients=0 for Generator and Discriminator
        optimizer_BiSeNet.zero_grad()
        optimizer_discriminator.zero_grad()

        # Setting losses equal to 0
        l_seg_to_print_acc, l_adv_to_print_acc, l_d_to_print_acc = 0, 0, 0

        # Compute learning rate for this epoch
        adjust_learning_rate(optimizer_BiSeNet, LEARNING_RATE_SEGMENTATION,
                             epoch, NUM_STEPS, POWER)
        adjust_learning_rate(optimizer_discriminator,
                             LEARNING_RATE_DISCRIMINATOR, epoch, NUM_STEPS,
                             POWER)

        for i in tqdm(range(len(target_dataloader))):
            optimizer_BiSeNet.zero_grad()
            optimizer_discriminator.zero_grad()
            l_seg_to_print, l_adv_to_print, l_d_to_print = minibatch(
                source_dataloader_iter, target_dataloader_iter, generator,
                discriminator, cross_entropy_loss, bce_loss, source_dataloader,
                target_dataloader)
            l_seg_to_print_acc += l_seg_to_print
            l_adv_to_print_acc += l_adv_to_print
            l_d_to_print_acc += l_d_to_print
            # Run optimizers using the gradient obtained via backpropagations
            optimizer_BiSeNet.step()
            optimizer_discriminator.step()

        # Output at each epoch
        print(
            f'epoch = {epoch}/{NUM_STEPS}, loss_seg = {l_seg_to_print_acc:.3f}, loss_adv = {l_adv_to_print_acc:.3f}, loss_D = {l_d_to_print_acc:.3f}'
        )

        # Save intermediate generator (checkpoint)
        if epoch % CHECKPOINT_STEP == 0 and epoch != 0:
            # If the directory does not exists create it
            if not os.path.isdir(CHECKPOINT_PATH):
                os.mkdir(CHECKPOINT_PATH)
            # Save the parameters of the generator (segmentation network) and discriminator
            generator_checkpoint_path = os.path.join(
                CHECKPOINT_PATH, f"{BETA}_{epoch}_Generator.pth")
            torch.save(generator.state_dict(), generator_checkpoint_path)
            discriminator_checkpoint_path = os.path.join(
                CHECKPOINT_PATH, f"{BETA}_{epoch}_Discriminator.pth")
            torch.save(discriminator.state_dict(),
                       discriminator_checkpoint_path)
            print(
                f"saved:\n{generator_checkpoint_path}\n{discriminator_checkpoint_path}"
            )
Пример #4
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number')
    parser.add_argument('--checkpoint_step', type=int, default=1000, help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step', type=int, default=100, help='How often to perform validation (epochs)')
    parser.add_argument('--dataset', type=str, default="CamVid", help='Dataset you are using.')
    parser.add_argument('--crop_height', type=int, default=720, help='Height of cropped/resized input image to network')
    parser.add_argument('--crop_width', type=int, default=960, help='Width of cropped/resized input image to network')
    parser.add_argument('--batch_size', type=int, default=1, help='Number of images in each batch')
    parser.add_argument('--context_path', type=str, default="resnet101",
                        help='The context path model you are using, resnet18, resnet101.')
    parser.add_argument('--learning_rate', type=float, default=0.01, help='learning rate used for train')
    parser.add_argument('--data', type=str, default='', help='path of training data')
    parser.add_argument('--num_workers', type=int, default=4, help='num of workers')
    parser.add_argument('--num_classes', type=int, default=32, help='num of object classes (with void)')
    parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training')
    parser.add_argument('--use_gpu', type=bool, default=True, help='whether to user gpu for training')
    parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model')
    parser.add_argument('--save_model_path', type=str, default=None, help='path to save model')
    parser.add_argument('--optimizer', type=str, default='rmsprop', help='optimizer, support rmsprop, sgd, adam')
    parser.add_argument('--loss', type=str, default='dice', help='loss function, dice or crossentropy')

    args = parser.parse_args(params)

    # create dataset and dataloader
    train_path = [os.path.join(args.data, 'train'), os.path.join(args.data, 'val')]
    train_label_path = [os.path.join(args.data, 'train_labels'), os.path.join(args.data, 'val_labels')]
    test_path = os.path.join(args.data, 'test')
    test_label_path = os.path.join(args.data, 'test_labels')
    csv_path = os.path.join(args.data, 'class_dict.csv')
    dataset_train = CamVid(train_path, train_label_path, csv_path, scale=(args.crop_height, args.crop_width),
                           loss=args.loss, mode='train')
    dataloader_train = DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True
    )
    dataset_val = CamVid(test_path, test_label_path, csv_path, scale=(args.crop_height, args.crop_width),
                         loss=args.loss, mode='test')
    dataloader_val = DataLoader(
        dataset_val,
        # this has to be 1
        batch_size=1,
        shuffle=True,
        num_workers=args.num_workers
    )

    # build model
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        print('Training using a GPU')
        model = torch.nn.DataParallel(model).cuda()

    # build optimizer
    if args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), args.learning_rate)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=0.9, weight_decay=1e-4)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None

    # load pretrained model if exists
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        model.module.load_state_dict(torch.load(args.pretrained_model_path))
        print('Done!')

    # train
    train(args, model, optimizer, dataloader_train, dataloader_val)
Пример #5
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs',
                        type=int,
                        default=300,
                        help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i',
                        type=int,
                        default=0,
                        help='Start counting epochs from this number')
    parser.add_argument('--checkpoint_step',
                        type=int,
                        default=5,
                        help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step',
                        type=int,
                        default=1,
                        help='How often to perform validation (epochs)')
    parser.add_argument('--dataset',
                        type=str,
                        default="CamVid",
                        help='Dataset you are using.')
    parser.add_argument(
        '--crop_height',
        type=int,
        default=640,
        help='Height of cropped/resized input image to network')
    parser.add_argument('--crop_width',
                        type=int,
                        default=640,
                        help='Width of cropped/resized input image to network')
    parser.add_argument('--train_batch_size',
                        type=int,
                        default=1,
                        help='Number of images in each batch')
    parser.add_argument('--val_batch_size',
                        type=int,
                        default=1,
                        help='Number of images in each batch')
    parser.add_argument('--context_path',
                        type=str,
                        default="resnet101",
                        help='The context path model you are using.')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.01,
                        help='learning rate used for train')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='num of workers')
    parser.add_argument('--num_classes',
                        type=int,
                        default=32,
                        help='num of object classes (with void)')
    parser.add_argument('--num_char',
                        type=int,
                        default=8,
                        help='num of lincense chars (include background)')
    parser.add_argument('--cuda',
                        type=str,
                        default='0',
                        help='GPU ids used for training')
    parser.add_argument('--use_gpu',
                        type=bool,
                        default=True,
                        help='whether to user gpu for training')
    parser.add_argument('--log_path',
                        type=str,
                        default=None,
                        help='tensorboard path')
    parser.add_argument('--pretrained_model_path',
                        type=str,
                        default=None,
                        help='path to pretrained model')
    parser.add_argument('--save_model_path',
                        type=str,
                        default=None,
                        help='path to save model')

    args = parser.parse_args(params)

    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path)
    if not os.path.exists(args.save_model_path):
        os.makedirs(args.save_model_path)

    dataset_train = License_Real_seg_pos_train(split='train_without_night',
                                               num_epochs=args.num_epochs)
    dataloader_train = DataLoader(dataset_train,
                                  batch_size=args.train_batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=args.num_workers)

    dataset_val = License_Real_seg_pos_val(split='val_without_night',
                                           num_epochs=1)
    dataloader_val = DataLoader(dataset_val,
                                batch_size=args.val_batch_size,
                                shuffle=False,
                                pin_memory=True,
                                num_workers=args.num_workers)

    # build model
    model = BiSeNet(args.num_classes, args.num_char, args.context_path)

    if torch.cuda.is_available() and args.use_gpu:
        model = torch.nn.DataParallel(model).cuda()

    # build optimizer

    # optimizer = torch.optim.RMSprop(model.parameters(), args.learning_rate)
    optimizer = torch.optim.Adam(model.parameters(), args.learning_rate)
    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
    # load pretrained model if exists
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        model.module.load_state_dict(torch.load(args.pretrained_model_path),
                                     False)
        print('Done!')

    # train
    train(args, model, optimizer, criterion, dataloader_train, dataloader_val)
Пример #6
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs', type=int, default=300, help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i', type=int, default=0, help='Start counting epochs from this number')
    parser.add_argument('--checkpoint_step', type=int, default=10, help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step', type=int, default=2, help='How often to perform validation (epochs)')
    parser.add_argument('--dataset', type=str, default="CamVid", help='Dataset you are using.')
    parser.add_argument('--crop_height', type=int, default=720, help='Height of cropped/resized input image to network')
    parser.add_argument('--crop_width', type=int, default=960, help='Width of cropped/resized input image to network')
    parser.add_argument('--batch_size', type=int, default=32, help='Number of images in each batch')
    parser.add_argument('--context_path', type=str, default="resnet101",
                        help='The context path model you are using, resnet18, resnet101.')
    parser.add_argument('--learning_rate_G', type=float, default=0.01, help='learning rate for G')
    parser.add_argument('--learning_rate_D', type=float, default=0.01, help='learning rate for D')#add lr_D 1e-4
    parser.add_argument('--data_CamVid', type=str, default='', help='path of training data_CamVid')
    parser.add_argument('--data_IDDA', type=str, default='', help='path of training data_IDDA')
    parser.add_argument('--num_workers', type=int, default=4, help='num of workers')
    parser.add_argument('--num_classes', type=int, default=32, help='num of object classes (with void)')
    parser.add_argument('--cuda', type=str, default='0', help='GPU ids used for training')
    parser.add_argument('--use_gpu', type=bool, default=True, help='whether to user gpu for training')
    parser.add_argument('--pretrained_model_path', type=str, default=None, help='path to pretrained model')
    parser.add_argument('--save_model_path', type=str, default=None, help='path to save model')
    parser.add_argument('--optimizer_G', type=str, default='rmsprop', help='optimizer_G, support rmsprop, sgd, adam')  
    parser.add_argument('--optimizer_D', type=str, default='rmsprop', help='optimizer_D, support rmsprop, sgd, adam')
    parser.add_argument('--loss', type=str, default='dice', help='loss function, dice or crossentropy')
    parser.add_argument('--loss_G', type=str, default='dice', help='loss function, dice or crossentropy')
    parser.add_argument('--lambda_adv', type=float, default=0.01, help='lambda coefficient for adversarial loss')

    args = parser.parse_args(params)

    # create dataset and dataloader for CamVid
    CamVid_train_path = [os.path.join(args.data_CamVid, 'train'), os.path.join(args.data_CamVid, 'val')]
    CamVid_train_label_path = [os.path.join(args.data_CamVid, 'train_labels'),
                               os.path.join(args.data_CamVid, 'val_labels')]
    CamVid_test_path = os.path.join(args.data_CamVid, 'test')
    CamVid_test_label_path = os.path.join(args.data_CamVid, 'test_labels')
    CamVid_csv_path = os.path.join(args.data_CamVid, 'class_dict.csv')
    CamVid_dataset_train = CamVid(CamVid_train_path, CamVid_train_label_path, CamVid_csv_path,
                                  scale=(args.crop_height, args.crop_width),
                                  loss=args.loss, mode='train')
    CamVid_dataloader_train = DataLoader(
        CamVid_dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True
    )
    CamVid_dataset_val = CamVid(CamVid_test_path, CamVid_test_label_path, CamVid_csv_path,
                                scale=(args.crop_height, args.crop_width),
                                loss=args.loss, mode='test')
    CamVid_dataloader_val = DataLoader(
        CamVid_dataset_val,
        # this has to be 1
        batch_size=1,
        shuffle=True,
        num_workers=args.num_workers
    )

    # create dataset and dataloader for IDDA
    IDDA_path = os.path.join(args.data_IDDA, 'rgb')
    IDDA_label_path = os.path.join(args.data_IDDA, 'labels')
    IDDA_info_path = os.path.join(args.data_IDDA, 'classes_info.json')
    IDDA_dataset = IDDA(IDDA_path, IDDA_label_path, IDDA_info_path, CamVid_csv_path, scale=(args.crop_height, args.crop_width), loss=args.loss)
    IDDA_dataloader = DataLoader(
        IDDA_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        drop_last=True
    )

    # build model_G
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model_G = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        model_G = torch.nn.DataParallel(model_G).cuda()
        
    #build model_D
    model_D = DW_Discriminator(args.num_classes)
    if torch.cuda.is_available() and args.use_gpu:
        model_D = torch.nn.DataParallel(model_D).cuda()

    # build optimizer G
    if args.optimizer_G == 'rmsprop':
        optimizer_G = torch.optim.RMSprop(model_G.parameters(), args.learning_rate_G)
    elif args.optimizer_G == 'sgd':
        optimizer_G = torch.optim.SGD(model_G.parameters(), args.learning_rate_G, momentum=0.9, weight_decay=1e-4)
    elif args.optimizer_G == 'adam':
        optimizer_G = torch.optim.Adam(model_G.parameters(), args.learning_rate_G)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None

    # build optimizer D
    if args.optimizer_D == 'rmsprop':
        optimizer_D = torch.optim.RMSprop(model_D.parameters(), args.learning_rate_D)
    elif args.optimizer_D == 'sgd':
        optimizer_D = torch.optim.SGD(model_D.parameters(), args.learning_rate_D, momentum=0.9, weight_decay=1e-4)
    elif args.optimizer_D == 'adam':
        optimizer_D = torch.optim.Adam(model_D.parameters(), args.learning_rate_D)
    else:  # rmsprop
        print('not supported optimizer \n')
        return None

    curr_epoch = 0
    max_miou = 0
         
    # load pretrained model if exists
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)   
        state = torch.load(os.path.realpath(args.pretrained_model_path))  # upload the pretrained  MODEL_G 
        model_G.module.load_state_dict(state['model_G_state'])
        optimizer_G.load_state_dict(state['optimizer_G'])
        model_D.module.load_state_dict(state['model_D_state'])            # upload the pretrained  MODEL_D 
        optimizer_D.load_state_dict(state['optimizer_D'])
        curr_epoch = state["epoch"]
        max_miou = state["max_miou"]
        print(str(curr_epoch - 1) + " already trained")
        print("start training from epoch " + str(curr_epoch))
        print('Done!')

    # train
    train (args, model_G, model_D, optimizer_G, optimizer_D, CamVid_dataloader_train, CamVid_dataloader_val, IDDA_dataloader, curr_epoch, max_miou)
Пример #7
0
def main(params):
    # basic parameters
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_epochs',
                        type=int,
                        default=30,
                        help='Number of epochs to train for')
    parser.add_argument('--epoch_start_i',
                        type=int,
                        default=0,
                        help='Start counting epochs from this number')

    parser.add_argument('--checkpoint_step',
                        type=int,
                        default=1,
                        help='How often to save checkpoints (epochs)')
    parser.add_argument('--validation_step',
                        type=int,
                        default=1,
                        help='How often to perform validation (epochs)')
    parser.add_argument('--dataset',
                        type=str,
                        default="CamVid",
                        help='Dataset you are using.')
    parser.add_argument(
        '--crop_height',
        type=int,
        default=640,
        help='Height of cropped/resized input image to network')
    parser.add_argument('--crop_width',
                        type=int,
                        default=640,
                        help='Width of cropped/resized input image to network')
    parser.add_argument('--batch_size',
                        type=int,
                        default=1,
                        help='Number of images in each batch')
    parser.add_argument('--context_path',
                        type=str,
                        default="resnet101",
                        help='The context path model you are using.')
    parser.add_argument('--learning_rate',
                        type=float,
                        default=0.01,
                        help='learning rate used for train')
    parser.add_argument('--data',
                        type=str,
                        default='/path/to/data',
                        help='path of training data')
    parser.add_argument('--num_workers',
                        type=int,
                        default=4,
                        help='num of workers')
    parser.add_argument('--num_classes',
                        type=int,
                        default=2,
                        help='num of object classes (with void)')
    parser.add_argument('--cuda',
                        type=str,
                        default='0',
                        help='GPU ids used for training')
    parser.add_argument('--use_gpu',
                        type=bool,
                        default=True,
                        help='whether to user gpu for training')
    parser.add_argument('--pretrained_model_path',
                        type=str,
                        default=None,
                        help='path to pretrained model')
    parser.add_argument('--save_model_path',
                        type=str,
                        default=None,
                        help='path to save model')

    args = parser.parse_args(params)

    # create dataset and dataloader
    dataloader_train = DataLoader(train(input_transform, target_transform),
                                  num_workers=1,
                                  batch_size=2,
                                  shuffle=True)

    # build model
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
    model = BiSeNet(args.num_classes, args.context_path)
    if torch.cuda.is_available() and args.use_gpu:
        model = torch.nn.DataParallel(model).cuda()

    para = sum([np.prod(list(p.size())) for p in model.parameters()])
    print('Model {} : params: {:4f}M'.format(model._get_name(),
                                             para * 4 / 1000 / 1000))

    # build optimizer
    optimizer = torch.optim.RMSprop(model.parameters(), args.learning_rate)

    # load pretrained model if exists
    if args.pretrained_model_path is not None:
        print('load model from %s ...' % args.pretrained_model_path)
        model.module.load_state_dict(torch.load(args.pretrained_model_path))
        print('Done!')

    # train
    train(args, model, optimizer, dataloader_train)