Exemplo n.º 1
0
def train_beta(model):
    print("Starting initial training (with cropped images)")
    num_epochs = 100
    batch_size = 2
    nframes = 14
    nframes_val = 32

    size = (480, 864)
    def image_read(path):
        pic = Image.open(path)
        transform = tv.transforms.Compose(
            [tv.transforms.Resize(size, interpolation=Image.BILINEAR),
             tv.transforms.ToTensor(),
             tv.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])
        return transform(pic)
    def label_read(path):
        if os.path.exists(path):
            pic = Image.open(path)
            transform = tv.transforms.Compose(
                [tv.transforms.Resize(size, interpolation=Image.NEAREST),
                 LabelToLongTensor()])
            label = transform(pic)
        else:
            label = torch.LongTensor(1,*size).fill_(255) # Put label that will be ignored
        return label
    def random_object_sampler(lst):
        return [random.choice(lst)]
    def deterministic_object_sampler(lst):
        return [lst[0]]
    train_transform = dataset_loaders.JointCompose([dataset_loaders.JointRandomHorizontalFlip()])

    train_set = torch.utils.data.ConcatDataset([
        DAVIS17V2(config['davis17_path'], '2017', 'train', image_read, label_read, train_transform, nframes,
                  random_object_sampler, start_frame='random'),
    ])
    val_set = YTVOSV2(config['ytvos_path'], 'train', 'val_joakim', 'JPEGImages', image_read, label_read, None,
                      nframes_val, deterministic_object_sampler, start_frame='first')

    sampler = torch.utils.data.WeightedRandomSampler(len(train_set)*[1,], 118, replacement=True)
    train_loader = DataLoader(train_set, batch_size=batch_size, sampler=sampler, num_workers=11)
    val_loader = DataLoader(val_set, shuffle=False, batch_size=batch_size, num_workers=11)
    print("Sets initiated with {} (train) and {} (val) samples.".format(len(train_set), len(val_set)))

    objective = nn.NLLLoss(ignore_index=255).cuda()
    optimizer = torch.optim.Adam([param for param in model.parameters() if param.requires_grad],
                                 lr=1e-5, weight_decay=1e-6)
    lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, .985)

    trainer = trainers.VOSTrainer(
        model, optimizer, objective, lr_sched,
        train_loader, val_loader,
        use_gpu=True, workspace_dir=config['workspace_path'],
        save_name=os.path.splitext(os.path.basename(__file__))[0]+"_beta",
        checkpoint_interval=100, print_interval=25, debug=False)
    trainer.load_checkpoint()
    trainer.train(num_epochs)
Exemplo n.º 2
0
def train_alpha(model):
    num_epochs = 160
    batch_size = 4
    nframes = 8
    nframes_val = 32

    size = (240, 432)

    def image_read(path):
        pic = Image.open(path)
        transform = tv.transforms.Compose([
            tv.transforms.Resize(size, interpolation=Image.BILINEAR),
            tv.transforms.ToTensor(),
            tv.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
        ])
        return transform(pic)

    def label_read(path):
        if os.path.exists(path):
            pic = Image.open(path)
            transform = tv.transforms.Compose([
                tv.transforms.Resize(size, interpolation=Image.NEAREST),
                LabelToLongTensor(),
            ])
            label = transform(pic)
        else:
            label = torch.LongTensor(1, *size).fill_(
                255)  # Put label that will be ignored
        return label

    def random_object_sampler(lst):
        return [random.choice(lst)]

    def deterministic_object_sampler(lst):
        return [lst[0]]

    train_transform = dataset_loaders.JointCompose(
        [dataset_loaders.JointRandomHorizontalFlip()])

    train_set = torch.utils.data.ConcatDataset([
        DAVIS17V2(
            config["davis17_path"],
            "2017",
            "train",
            image_read,
            label_read,
            train_transform,
            nframes,
            random_object_sampler,
            start_frame="random",
        ),
        YTVOSV2(
            config["ytvos_path"],
            "train",
            "train_joakim",
            "JPEGImages",
            image_read,
            label_read,
            train_transform,
            nframes,
            random_object_sampler,
            start_frame="random",
        ),
    ])
    val_set = YTVOSV2(
        config["ytvos_path"],
        "train",
        "val_joakim",
        "JPEGImages",
        image_read,
        label_read,
        None,
        nframes_val,
        deterministic_object_sampler,
        start_frame="first",
    )
    train_loader = DataLoader(train_set,
                              shuffle=True,
                              batch_size=batch_size,
                              num_workers=11)
    val_loader = DataLoader(val_set,
                            shuffle=False,
                            batch_size=batch_size,
                            num_workers=11)
    print("Sets initiated with {} (train) and {} (val) samples.".format(
        len(train_set), len(val_set)))

    objective = nn.NLLLoss(ignore_index=255).cuda()
    optimizer = torch.optim.Adam(
        [param for param in model.parameters() if param.requires_grad],
        lr=1e-4,
        weight_decay=1e-5,
    )

    lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.975)

    trainer = trainers.VOSTrainer(
        model,
        optimizer,
        objective,
        lr_sched,
        train_loader,
        val_loader,
        use_gpu=True,
        workspace_dir=config["workspace_path"],
        save_name=os.path.splitext(os.path.basename(__file__))[0] + "_alpha",
        checkpoint_interval=10,
        print_interval=25,
        debug=False,
    )
    trainer.load_checkpoint()
    trainer.train(num_epochs)