Пример #1
0
def main():
    """Parameters initialization and starting attention model training """
    # read command line arguments
    args = get_parser().parse_args()

    # set random seed
    seed_everything(args.seed)

    # paths to dataset
    train_path = osp.join(args.dataset_path, 'train')
    test_path = osp.join(args.dataset_path, 'test')

    # declare Unet model with two ouput classes (occluders and their shadows)
    model = smp.Unet(encoder_name=args.encoder, classes=2, activation='sigmoid',)
    # replace the first convolutional layer in model: 4 channels tensor as model input
    model.encoder.conv1 = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), \
                                    padding=(3, 3), bias=False)

    # declare datasets
    train_dataset = ARDataset(train_path, augmentation=get_training_augmentation(args.img_size), \
        preprocessing=get_preprocessing(),)

    valid_dataset = ARDataset(test_path, augmentation=get_validation_augmentation(args.img_size), \
        preprocessing=get_preprocessing(),)

    # declare loaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, \
                              shuffle=True, num_workers=args.num_workers)

    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, \
                              shuffle=False, num_workers=args.num_workers)

    # declare loss function, optimizer and metric
    loss = smp.utils.losses.DiceLoss()
    metric = smp.utils.metrics.IoU(threshold=args.iou_th)
    optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=args.lr),])

    # tensorboard
    writer = SummaryWriter()

    # device
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')

    # start training
    train(
        writer=writer,
        n_epoch=args.n_epoch,
        train_loader=train_loader,
        valid_loader=valid_loader,
        model_path=args.model_path,
        model=model,
        loss=loss,
        metric=metric,
        optimizer=optimizer,
        device=device
    )
Пример #2
0
def get_train_val_loaders(encoder_type, batch_size=16, pseudo_label=False):
    if encoder_type.startswith('myunet'):
        encoder_type = 'resnet50'
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        encoder_type, 'imagenet')
    train, train_ids, valid_ids = prepare_df()
    train['pseudo'] = 0
    pseudo_imgs = set()
    if pseudo_label:
        train_pseudo = prepare_df(
            train_file=f'{settings.DATA_DIR}/sub_blend_1111_1.csv',
            pseudo_label=True)
        train_pseudo['pseudo'] = 1
        pseudo_ids = train_pseudo.im_id.unique().tolist()
        print(pseudo_ids[:10])
        pseudo_imgs = set(pseudo_ids)
        train_ids.extend(pseudo_ids)
        train = pd.concat([train, train_pseudo])
        print(train.head())
        print(train_pseudo.head())
        print(train.shape)
        print(len(train_ids))

    num_workers = 24
    train_dataset = CloudDataset(
        df=train,
        datatype='train',
        img_ids=train_ids,
        transforms=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        pseudo_imgs=pseudo_imgs)
    valid_dataset = CloudDataset(
        df=train,
        datatype='valid',
        img_ids=valid_ids,
        transforms=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn))

    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    train_loader.num = len(train_ids)
    valid_loader.num = len(valid_ids)

    loaders = {"train": train_loader, "valid": valid_loader}
    return loaders
Пример #3
0
def get_data_loaders(bs=8, num_workers=0, shuffle=True, ts=0.2):
    train_df, img_2_ohe_vector = get_df()
    train_imgs, val_imgs = train_test_split(train_df['Image'].values,
                            test_size=ts,
                            stratify=train_df['Class'].map(lambda x: str(sorted(list(x)))),
                            random_state=42)
    print(train_imgs)
    print(val_imgs)
    print(len(train_imgs))
    print(len(val_imgs))
    train_dataset = CloudDataset(img_2_ohe_vector, img_ids=train_imgs,
                                 transforms=get_training_augmentation())
    train_loader = DataLoader(train_dataset, batch_size=bs,
                              shuffle=shuffle, num_workers=num_workers)

    val_dataset = CloudDataset(img_2_ohe_vector, img_ids=val_imgs,
                                 transforms=get_validation_augmentation())
    val_loader = DataLoader(val_dataset, batch_size=bs,
                              shuffle=shuffle, num_workers=num_workers)

    return train_loader, val_loader
Пример #4
0
def main():
    parser = argparse.ArgumentParser(
        description='Train the Model on images and target masks',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument('-g_num',
                        '--gpu_num',
                        metavar='G',
                        type=str,
                        default='0',
                        help='GPU',
                        dest='gpu_num')
    parser.add_argument('-d',
                        '--device',
                        metavar='G',
                        type=str,
                        default='cpu',
                        help='GPU',
                        dest='device')
    parser.add_argument('-train_dir',
                        '--train_data_dir',
                        type=str,
                        default='/Users/ryeon/Desktop/segmentation_data/',
                        help='dataset dir',
                        dest='train_dir')
    parser.add_argument('-valid_dir',
                        '--valid_data_dir',
                        type=str,
                        default='/Users/ryeon/Desktop/valid_final/result/',
                        help='dataset dir',
                        dest='valid_dir')
    parser.add_argument('-train_epoch',
                        type=int,
                        default=50,
                        help='validation epoch')
    parser.add_argument('-valid_epoch',
                        type=int,
                        default=1,
                        help='train epoch')
    parser.add_argument('-pretrained',
                        type=str,
                        default=None,
                        help='pretrained model.pth')
    parser.add_argument('-encoder',
                        type=str,
                        default='mobilenet_v2',
                        help='Encoder')
    parser.add_argument('-encoder_weight',
                        type=str,
                        default='imagenet',
                        help='Encoder')
    parser.add_argument('-activation',
                        type=str,
                        default='sigmoid',
                        help='Encoder')
    parser.add_argument('-simple', type=bool, default=False, help='')
    parser.add_argument('-trans', type=bool, default=False, help='')
    parser.add_argument('-color', type=bool, default=False, help='')
    parser.add_argument('-trans_color', type=bool, default=False, help='')
    args = parser.parse_args()

    print(args)

    if args.device == 'cuda':
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_num
    ENCODER = args.encoder
    ENCODER_WEIGHTS = args.encoder_weight
    CLASSES = ['bubble']
    ACTIVATION = args.activation
    DEVICE = args.device
    CHANNEL = 3

    model = smp.Unet(
        encoder_name=ENCODER,
        encoder_weights=ENCODER_WEIGHTS,
        in_channels=CHANNEL,
        classes=len(CLASSES),
        activation=ACTIVATION,
    )
    loss = smp.utils.losses.DiceLoss()
    metrics = [
        smp.utils.metrics.IoU(threshold=0.5),
    ]

    optimizer = torch.optim.Adam([
        dict(params=model.parameters(), lr=0.0001),
    ])

    DATA_DIR = args.train_dir
    x_train_dir = os.path.join(DATA_DIR, 'image')
    y_train_dir = os.path.join(DATA_DIR, 'mask')
    VALID_DIR = args.valid_dir
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        ENCODER, ENCODER_WEIGHTS)

    train_dataset = Dataset(
        x_train_dir,
        y_train_dir,
        augmentation=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        classes=CLASSES,
        train=True,
    )

    train_loader = DataLoader(train_dataset,
                              batch_size=128,
                              shuffle=True,
                              num_workers=0)

    train_epoch = smp.utils.train.TrainEpoch(
        model,
        loss=loss,
        metrics=metrics,
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )

    max_score = 0

    if args.pretrained != None:
        if args.device == 'cuda':
            model = torch.load(args.pretrained)
        else:
            model = torch.load(args.pretrained, map_location=args.device)

    for i in range(0, args.train_epoch):
        print(i, args.train_epoch)

        save_model_name = './train_model_mob_trans_'

        if i % args.valid_epoch == 0:
            print('\nValid: {}'.format(i))
            bubble_mask_with_segmentation(VALID_DIR,
                                          model,
                                          i,
                                          device=args.device)

        print('\nEpoch: {}'.format(i))
        train_logs = train_epoch.run(train_loader)

        # do something (save model, change lr, etc.)
        if max_score < train_logs['iou_score']:
            max_score = train_logs['iou_score']
            os.makedirs('./result/model/', exist_ok=True)
            save_model_name = './result/model/' + save_model_name + str(
                i) + '.pth'
            torch.save(model, save_model_name)
            print('Model saved!')

        if i == 25:
            optimizer.param_groups[0]['lr'] = 1e-5
            print('Decrease decoder learning rate to 1e-5!')
def main(args):
    """
    Main code for training a classification model.

    Args:
        args (instance of argparse.ArgumentParser): arguments must be compiled with parse_args
    Returns:
        None
    """
    # Reading the in the .csvs
    train = pd.read_csv(os.path.join(args.dset_path, "train.csv"))
    sub = pd.read_csv(os.path.join(args.dset_path, "sample_submission.csv"))

    # setting up the train/val split with filenames
    train, sub, id_mask_count = setup_train_and_sub_df(args.dset_path)
    # setting up the train/val split with filenames
    seed_everything(args.split_seed)
    train_ids, valid_ids = train_test_split(id_mask_count["im_id"].values,
                                            random_state=args.split_seed,
                                            stratify=id_mask_count["count"],
                                            test_size=args.test_size)
    # setting up the classification model
    ENCODER_WEIGHTS = "imagenet"
    DEVICE = "cuda"
    model = ResNet34(pre=ENCODER_WEIGHTS, num_classes=4, use_simple_head=True)

    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        "resnet34", ENCODER_WEIGHTS)

    # Setting up the I/O
    train_dataset = ClassificationSteelDataset(
        args.dset_path,
        df=train,
        datatype="train",
        im_ids=train_ids,
        transforms=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
    )
    valid_dataset = ClassificationSteelDataset(
        args.dset_path,
        df=train,
        datatype="valid",
        im_ids=valid_ids,
        transforms=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
    )

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)

    loaders = {"train": train_loader, "valid": valid_loader}
    # everything is saved here (i.e. weights + stats)
    logdir = "./logs/segmentation"

    # model, criterion, optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=2)
    criterion = smp.utils.losses.BCEDiceLoss(eps=1.)
    runner = SupervisedRunner()

    runner.train(model=model,
                 criterion=criterion,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 loaders=loaders,
                 callbacks=[
                     DiceCallback(),
                     EarlyStoppingCallback(patience=5, min_delta=0.001)
                 ],
                 logdir=logdir,
                 num_epochs=args.num_epochs,
                 verbose=True)
    utils.plot_metrics(
        logdir=logdir,
        # specify which metrics we want to plot
        metrics=["loss", "dice", "lr", "_base/lr"])
Пример #6
0
def main(args):
    """
    Main code for training for training a U-Net with some user-defined encoder.
    Args:
        args (instance of argparse.ArgumentParser): arguments must be compiled with parse_args
    Returns:
        None
    """
    # setting up the train/val split with filenames
    train, sub, id_mask_count = setup_train_and_sub_df(args.dset_path)
    # setting up the train/val split with filenames
    seed_everything(args.split_seed)
    train_ids, valid_ids = train_test_split(id_mask_count["im_id"].values,
                                            random_state=args.split_seed,
                                            stratify=id_mask_count["count"],
                                            test_size=args.test_size)
    # setting up model (U-Net with ImageNet Encoders)
    ENCODER_WEIGHTS = "imagenet"
    DEVICE = "cuda"

    attention_type = None if args.attention_type == "None" else args.attention_type
    model = smp.Unet(encoder_name=args.encoder,
                     encoder_weights=ENCODER_WEIGHTS,
                     classes=4,
                     activation=None,
                     attention_type=attention_type)
    preprocessing_fn = smp.encoders.get_preprocessing_fn(
        args.encoder, ENCODER_WEIGHTS)

    # Setting up the I/O
    train_dataset = SteelDataset(
        args.dset_path,
        df=train,
        datatype="train",
        im_ids=train_ids,
        transforms=get_training_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        use_resized_dataset=args.use_resized_dataset)
    valid_dataset = SteelDataset(
        args.dset_path,
        df=train,
        datatype="valid",
        im_ids=valid_ids,
        transforms=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        use_resized_dataset=args.use_resized_dataset)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.num_workers)
    valid_loader = DataLoader(valid_dataset,
                              batch_size=args.batch_size,
                              shuffle=False,
                              num_workers=args.num_workers)

    loaders = {"train": train_loader, "valid": valid_loader}
    # everything is saved here (i.e. weights + stats)
    logdir = "./logs/segmentation"

    # model, criterion, optimizer
    optimizer = torch.optim.Adam([
        {
            "params": model.decoder.parameters(),
            "lr": args.encoder_lr
        },
        {
            "params": model.encoder.parameters(),
            "lr": args.decoder_lr
        },
    ])
    scheduler = ReduceLROnPlateau(optimizer, factor=0.15, patience=2)
    criterion = smp.utils.losses.BCEDiceLoss(eps=1.)
    runner = SupervisedRunner()

    callbacks_list = [
        DiceCallback(),
        EarlyStoppingCallback(patience=5, min_delta=0.001),
    ]
    if args.checkpoint_path != "None":  # hacky way to say no checkpoint callback but eh what the heck
        ckpoint_p = Path(args.checkpoint_path)
        fname = ckpoint_p.name
        resume_dir = str(ckpoint_p.parents[0]
                         )  # everything in the path besides the base file name
        print(
            f"Loading {fname} from {resume_dir}. Checkpoints will also be saved in {resume_dir}."
        )
        callbacks_list = callbacks_list + [
            CheckpointCallback(resume=fname, resume_dir=resume_dir),
        ]

    runner.train(model=model,
                 criterion=criterion,
                 optimizer=optimizer,
                 scheduler=scheduler,
                 loaders=loaders,
                 callbacks=callbacks_list,
                 logdir=logdir,
                 num_epochs=args.num_epochs,
                 verbose=True)
Пример #7
0
reset_index().rename(columns={'index': 'img_id', 'Image_Label': 'count'})
train_ids, valid_ids = train_test_split(id_mask_count['img_id'].values, random_state=42, stratify=id_mask_count['count'], test_size=0.1)
test_ids = sub['Image_Label'].apply(lambda x: x.split('_')[0]).drop_duplicates().values

print("creating preprocessing module...")

ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

print("creating data loader...")

num_workers = 4
bs = 16

train_dataset = CloudDataset(path = path, df=train, datatype='train', img_ids=train_ids, transforms = utils.get_training_augmentation(), preprocessing = utils.get_preprocessing(preprocessing_fn))
valid_dataset = CloudDataset(path = path, df=train, datatype='valid', img_ids=valid_ids, transforms = utils.get_validation_augmentation(), preprocessing = utils.get_preprocessing(preprocessing_fn))

train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True, num_workers=num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=bs, shuffle=False, num_workers=num_workers)

loaders = {
    "train": train_loader,
    "valid": valid_loader
}

print("setting for training...")

ACTIVATION = None
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER, 
Пример #8
0
                                                 factor=0.15)
criteria = {'dice': DiceLoss(), 'bce': torch.nn.BCEWithLogitsLoss()}

train = pd.read_csv(f'train_preprocessed.csv')
train_ids = pd.read_csv(f'./folds/fold_{args.fold}_train.csv').values.ravel()
valid_ids = pd.read_csv(f'./folds/fold_{args.fold}_val.csv').values.ravel()
num_workers = 4
bs = args.bs
train_dataset = CloudDataset(df=train,
                             image_size=(args.size, args.size * 2),
                             path=path,
                             datatype='train',
                             preload=False,
                             img_ids=train_ids,
                             filter_bad_images=True,
                             transforms=get_training_augmentation(
                                 size=(args.size, args.size * 2), p=0.5),
                             preprocessing=get_preprocessing(preprocessing_fn))
valid_dataset = CloudDataset(df=train,
                             image_size=(args.size, args.size * 2),
                             path=path,
                             datatype='valid',
                             preload=False,
                             img_ids=valid_ids,
                             filter_bad_images=True,
                             transforms=get_validation_augmentation(
                                 (args.size, args.size * 2)),
                             preprocessing=get_preprocessing(preprocessing_fn))

train_loader = DataLoader(train_dataset,
                          batch_size=bs,
                          shuffle=True,
Пример #9
0
def get_train_val_loaders(encoder_type, batch_size=16, ifold=0):
    if encoder_type.startswith('myunet'):
        encoder_type = 'resnet50'
    preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder_type, 'imagenet')
    train, train_ids, valid_ids = prepare_df(ifold=ifold)
    print('val:', valid_ids[:10])
    num_workers = 24
    train_dataset = CloudDataset(df=train, datatype='train', img_ids=train_ids, transforms = get_training_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))
    valid_dataset = CloudDataset(df=train, datatype='valid', img_ids=valid_ids, transforms = get_validation_augmentation(), preprocessing=get_preprocessing(preprocessing_fn))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    train_loader.num = len(train_ids)
    valid_loader.num = len(valid_ids)

    loaders = {
        "train": train_loader,
        "valid": valid_loader
    }
    return loaders
Пример #10
0
    def Setup(self):
        '''
        User function: Setup all the parameters

        Args:
            None

        Returns:
            None
        '''
        preprocess_input = sm.get_preprocessing(
            self.system_dict["params"]["backbone"])
        # define network parameters
        self.system_dict["local"]["n_classes"] = 1 if len(
            self.system_dict["dataset"]["train"]
            ["classes_to_train"]) == 1 else (
                len(self.system_dict["dataset"]["train"]["classes_to_train"]) +
                1)  # case for binary and multiclass segmentation
        activation = 'sigmoid' if self.system_dict["local"][
            "n_classes"] == 1 else 'softmax'

        #create model
        if (self.system_dict["params"]["model"] == "Unet"):
            self.system_dict["local"]["model"] = sm.Unet(
                self.system_dict["params"]["backbone"],
                classes=self.system_dict["local"]["n_classes"],
                activation=activation)
        elif (self.system_dict["params"]["model"] == "FPN"):
            self.system_dict["local"]["model"] = sm.FPN(
                self.system_dict["params"]["backbone"],
                classes=self.system_dict["local"]["n_classes"],
                activation=activation)
        elif (self.system_dict["params"]["model"] == "Linknet"):
            self.system_dict["local"]["model"] = sm.Linknet(
                self.system_dict["params"]["backbone"],
                classes=self.system_dict["local"]["n_classes"],
                activation=activation)
        elif (self.system_dict["params"]["model"] == "PSPNet"):
            self.system_dict["local"]["model"] = sm.PSPNet(
                self.system_dict["params"]["backbone"],
                classes=self.system_dict["local"]["n_classes"],
                activation=activation)

        # define optomizer
        optim = keras.optimizers.Adam(self.system_dict["params"]["lr"])

        # Segmentation models losses can be combined together by '+' and scaled by integer or float factor
        dice_loss = sm.losses.DiceLoss()
        focal_loss = sm.losses.BinaryFocalLoss() if self.system_dict["local"][
            "n_classes"] == 1 else sm.losses.CategoricalFocalLoss()
        total_loss = dice_loss + (1 * focal_loss)

        # actulally total_loss can be imported directly from library, above example just show you how to manipulate with losses
        # total_loss = sm.losses.binary_focal_dice_loss # or sm.losses.categorical_focal_dice_loss

        metrics = [
            sm.metrics.IOUScore(threshold=0.5),
            sm.metrics.FScore(threshold=0.5)
        ]

        # compile keras model with defined optimozer, loss and metrics
        self.system_dict["local"]["model"].compile(optim, total_loss, metrics)

        # Dataset for train images
        train_dataset = Dataset(
            self.system_dict["dataset"]["train"]["img_dir"],
            self.system_dict["dataset"]["train"]["mask_dir"],
            self.system_dict["dataset"]["train"]["classes_dict"],
            classes_to_train=self.system_dict["dataset"]["train"]
            ["classes_to_train"],
            augmentation=get_training_augmentation(),
            preprocessing=get_preprocessing(preprocess_input),
        )

        if (self.system_dict["params"]["image_shape"][0] % 32 != 0):
            self.system_dict["params"]["image_shape"][0] += (
                32 - self.system_dict["params"]["image_shape"][0] % 32)

        if (self.system_dict["params"]["image_shape"][1] % 32 != 0):
            self.system_dict["params"]["image_shape"][1] += (
                32 - self.system_dict["params"]["image_shape"][1] % 32)

        # Dataset for validation images
        if (self.system_dict["dataset"]["val"]["status"]):
            valid_dataset = Dataset(
                self.system_dict["dataset"]["val"]["img_dir"],
                self.system_dict["dataset"]["val"]["mask_dir"],
                self.system_dict["dataset"]["train"]["classes_dict"],
                classes_to_train=self.system_dict["dataset"]["train"]
                ["classes_to_train"],
                augmentation=get_validation_augmentation(
                    self.system_dict["params"]["image_shape"][0],
                    self.system_dict["params"]["image_shape"][1]),
                preprocessing=get_preprocessing(preprocess_input),
            )
        else:
            valid_dataset = Dataset(
                self.system_dict["dataset"]["train"]["img_dir"],
                self.system_dict["dataset"]["train"]["mask_dir"],
                self.system_dict["dataset"]["train"]["classes_dict"],
                classes_to_train=self.system_dict["dataset"]["train"]
                ["classes_to_train"],
                augmentation=get_validation_augmentation(
                    self.system_dict["params"]["image_shape"][0],
                    self.system_dict["params"]["image_shape"][1]),
                preprocessing=get_preprocessing(preprocess_input),
            )

        self.system_dict["local"]["train_dataloader"] = Dataloder(
            train_dataset,
            batch_size=self.system_dict["params"]["batch_size"],
            shuffle=True)
        self.system_dict["local"]["valid_dataloader"] = Dataloder(
            valid_dataset, batch_size=1, shuffle=False)
Пример #11
0
def main():
    """Parameters initialization and starting SG model training """
    # read command line arguments
    args = get_parser().parse_args()

    # set random seed
    seed_everything(args.seed)

    # paths to dataset
    train_path = osp.join(args.dataset_path, 'train')
    test_path = osp.join(args.dataset_path, 'test')

    # declare generator and discriminator models
    generator = Generator_with_Refin(args.encoder)
    discriminator = Discriminator(input_shape=(3,args.img_size,args.img_size))

    # load weights
    if args.gen_weights != '':
        generator.load_state_dict(torch.load(args.gen_weights))
        print('Generator weights loaded!')

    if args.discr_weights != '':
        discriminator.load_state_dict(torch.load(args.discr_weights))
        print('Discriminator weights loaded!')

    # declare datasets
    train_dataset = ARDataset(train_path,
                              augmentation=get_training_augmentation(args.img_size),
                              augmentation_images=get_image_augmentation(),
                              preprocessing=get_preprocessing(),)

    valid_dataset = ARDataset(test_path,
                              augmentation=get_validation_augmentation(args.img_size),
                              preprocessing=get_preprocessing(),)

    # declare loaders
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    # declare loss functions, optimizers and scheduler
    l2loss = nn.MSELoss()
    perloss = ContentLoss(feature_extractor="vgg16", layers=("relu3_3", ))
    GANloss = nn.MSELoss()

    optimizer_G = torch.optim.Adam([dict(params=generator.parameters(), lr=args.lr_G),])
    optimizer_D = torch.optim.Adam([dict(params=discriminator.parameters(), lr=args.lr_D),])

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_G, mode='min', factor=0.9, patience=args.patience)

    # device
    device = torch.device(args.device if torch.cuda.is_available() else 'cpu')

    # tensorboard
    writer = SummaryWriter()

    # start training
    train(
        generator=generator,
        discriminator=discriminator,
        device=device,
        n_epoch=args.n_epoch,
        optimizer_G=optimizer_G,
        optimizer_D=optimizer_D,
        train_loader=train_loader,
        valid_loader=valid_loader,
        scheduler=scheduler,
        losses=[l2loss, perloss, GANloss],
        models_paths=[args.Gmodel_path, args.Dmodel_path],
        bettas=[args.betta1, args.betta2, args.betta3],
        writer=writer,
    )
Пример #12
0
    train_df, val_df = dataset_utils.get_landcover_train_val_df(
        LANDCOVER_ROOT, random_state=cfg.SEED)
    dataset_info = dataset_utils.get_landcover_info(LANDCOVER_ROOT,
                                                    include_unknow=False)
    class_names = dataset_info['class_names']
    class_rgb_values = dataset_info['class_rgb_values']
    select_class_rgb_values = dataset_info['select_class_rgb_values']

    num_classes = len(select_class_rgb_values)

    model, preprocessing_fn = get_deeplab_model(num_classes, cfg.MODEL.encoder)

    # Get train and val dataset instances
    train_dataset = LandCoverDataset(
        train_df,
        augmentation=get_training_augmentation(cfg.TRAIN.augment_type),
        preprocessing=get_preprocessing(preprocessing_fn),
        class_rgb_values=select_class_rgb_values,
    )

    valid_dataset = LandCoverDataset(
        val_df,
        augmentation=get_validation_augmentation(),
        preprocessing=get_preprocessing(preprocessing_fn),
        class_rgb_values=select_class_rgb_values,
    )

    if cfg.DEBUG:
        # if I only want to debug code, train and val only for 10 samples
        train_dataset = Subset(train_dataset, [n for n in range(10)])
        valid_dataset = Subset(valid_dataset, [n for n in range(10)])