示例#1
0
 def __init__(self):
     self.aug = iaa.Sequential([
         iaa.Scale((224, 224)),
         iaa.Sometimes(0.30, iaa.GaussianBlur(sigma=(0, 3.0))),
         iaa.Sometimes(0.25, iaa.Multiply((0.5, 1.5), per_channel=0.5)),
         iaa.Sometimes(0.20, iaa.Invert(0.25, per_channel=0.5)),
         iaa.Sometimes(
             0.25,
             iaa.ReplaceElementwise(iap.FromLowerResolution(
                 iap.Binomial(0.1), size_px=8),
                                    iap.Normal(128, 0.4 * 128),
                                    per_channel=0.5)),
         iaa.Sometimes(0.30, iaa.AdditivePoissonNoise(40)),
         iaa.Fliplr(0.5),
         iaa.Affine(rotate=(-20, 20), mode='symmetric'),
         iaa.Sometimes(
             0.30,
             iaa.OneOf([
                 iaa.Dropout(p=(0, 0.1)),
                 iaa.CoarseDropout(0.1, size_percent=0.5)
             ])),
         iaa.AddToHueAndSaturation(value=(-10, 10), per_channel=True)
     ])
def chapter_augmenters_replaceelementwise():
    aug_cls = iaa.ReplaceElementwise
    fn_start = "arithmetic/replaceelementwise"

    aug = aug_cls(0.05, [0, 255])
    run_and_save_augseq(fn_start + ".jpg",
                        aug, [ia.quokka(size=(128, 128)) for _ in range(8)],
                        cols=4,
                        rows=2,
                        quality=95,
                        seed=2)

    aug = aug_cls(0.05, [0, 255], per_channel=0.5)
    run_and_save_augseq(fn_start + "_per_channel_050.jpg",
                        aug, [ia.quokka(size=(128, 128)) for _ in range(8)],
                        cols=4,
                        rows=2,
                        quality=95,
                        seed=2)

    aug = aug_cls(0.1, iap.Normal(128, 0.4 * 128), per_channel=0.5)
    run_and_save_augseq(fn_start + "_gaussian_noise.jpg",
                        aug, [ia.quokka(size=(128, 128)) for _ in range(8)],
                        cols=4,
                        rows=2,
                        quality=95,
                        seed=2)

    aug = aug_cls(iap.FromLowerResolution(iap.Binomial(0.1), size_px=8),
                  iap.Normal(128, 0.4 * 128),
                  per_channel=0.5)
    run_and_save_augseq(fn_start + "_gaussian_noise_coarse.jpg",
                        aug, [ia.quokka(size=(128, 128)) for _ in range(8)],
                        cols=4,
                        rows=2,
                        quality=95,
                        seed=2)
示例#3
0
def main():
    """Function that initializes the training (e.g. models)
    and runs the batches."""

    parser = argparse.ArgumentParser(description="Train steering wheel tracker")
    parser.add_argument('--nocontinue', default=False, action="store_true", help="Whether to NOT continue the previous experiment", required=False)
    args = parser.parse_args()

    if os.path.isfile("steering_wheel.tar") and not args.nocontinue:
        checkpoint = torch.load("steering_wheel.tar")
    else:
        checkpoint = None

    if checkpoint is not None:
        history = plotting.History.from_string(checkpoint["history"])
    else:
        history = plotting.History()
        history.add_group("loss", ["train", "val"], increasing=False)
        history.add_group("acc", ["train", "val"], increasing=True)
    loss_plotter = plotting.LossPlotter(
        history.get_group_names(),
        history.get_groups_increasing(),
        save_to_fp="train_plot.jpg"
    )
    loss_plotter.start_batch_idx = 100

    tracker_cnn = models.SteeringWheelTrackerCNNModel()
    tracker_cnn.train()

    optimizer = optim.Adam(tracker_cnn.parameters())

    criterion = nn.CrossEntropyLoss()
    #criterion = nn.BCELoss()
    if checkpoint is not None:
        tracker_cnn.load_state_dict(checkpoint["tracker_cnn_state_dict"])

    if Config.GPU >= 0:
        tracker_cnn.cuda(Config.GPU)
        criterion.cuda(Config.GPU)

    # initialize image augmentation cascade
    rarely = lambda aug: iaa.Sometimes(0.1, aug)
    sometimes = lambda aug: iaa.Sometimes(0.2, aug)
    often = lambda aug: iaa.Sometimes(0.4, aug)
    augseq = iaa.Sequential([
            sometimes(iaa.Crop(percent=(0, 0.025))),
            rarely(iaa.GaussianBlur((0, 1.0))), # blur images with a sigma between 0 and 3.0
            rarely(iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.02*255), per_channel=0.5)), # add gaussian noise to images
            often(iaa.Dropout(
                iap.FromLowerResolution(
                    other_param=iap.Binomial(1 - 0.2),
                    size_px=(2, 16)
                ),
                per_channel=0.2
            )),
            often(iaa.Add((-20, 20), per_channel=0.5)), # change brightness of images (by -10 to 10 of original value)
            often(iaa.Multiply((0.8, 1.2), per_channel=0.25)), # change brightness of images (50-150% of original value)
            often(iaa.ContrastNormalization((0.8, 1.2), per_channel=0.5)), # improve or worsen the contrast
            often(iaa.Affine(
                scale={"x": (0.8, 1.3), "y": (0.8, 1.3)},
                translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
                rotate=(-0, 0),
                shear=(-0, 0),
                order=[0, 1],
                cval=(0, 255),
                mode=["constant", "edge"]
            )),
            rarely(iaa.Grayscale(alpha=(0.0, 1.0)))
        ],
        random_order=True # do all of the above in random order
    )

    #memory = replay_memory.ReplayMemory.get_instance_supervised()
    batch_loader_train = BatchLoader(val=False, augseq=augseq, queue_size=15, nb_workers=4)
    batch_loader_val = BatchLoader(val=True, augseq=iaa.Noop(), queue_size=NB_VAL_BATCHES, nb_workers=2)

    start_batch_idx = 0 if checkpoint is None else checkpoint["batch_idx"] + 1
    for batch_idx in xrange(start_batch_idx, NB_BATCHES):
        run_batch(batch_idx, False, batch_loader_train, tracker_cnn, criterion, optimizer, history, (batch_idx % 20) == 0)

        if (batch_idx+1) % VAL_EVERY == 0:
            for i in xrange(NB_VAL_BATCHES):
                run_batch(batch_idx, True, batch_loader_val, tracker_cnn, criterion, optimizer, history, i == 0)

        if (batch_idx+1) % PLOT_EVERY == 0:
            loss_plotter.plot(history)

        # every N batches, save a checkpoint
        if (batch_idx+1) % SAVE_EVERY == 0:
            torch.save({
                "batch_idx": batch_idx,
                "history": history.to_string(),
                "tracker_cnn_state_dict": tracker_cnn.state_dict()
            }, "steering_wheel.tar")