def train_beta(model): print("Starting initial training (with cropped images)") num_epochs = 100 batch_size = 2 nframes = 14 nframes_val = 32 size = (480, 864) def image_read(path): pic = Image.open(path) transform = tv.transforms.Compose( [tv.transforms.Resize(size, interpolation=Image.BILINEAR), tv.transforms.ToTensor(), tv.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)]) return transform(pic) def label_read(path): if os.path.exists(path): pic = Image.open(path) transform = tv.transforms.Compose( [tv.transforms.Resize(size, interpolation=Image.NEAREST), LabelToLongTensor()]) label = transform(pic) else: label = torch.LongTensor(1,*size).fill_(255) # Put label that will be ignored return label def random_object_sampler(lst): return [random.choice(lst)] def deterministic_object_sampler(lst): return [lst[0]] train_transform = dataset_loaders.JointCompose([dataset_loaders.JointRandomHorizontalFlip()]) train_set = torch.utils.data.ConcatDataset([ DAVIS17V2(config['davis17_path'], '2017', 'train', image_read, label_read, train_transform, nframes, random_object_sampler, start_frame='random'), ]) val_set = YTVOSV2(config['ytvos_path'], 'train', 'val_joakim', 'JPEGImages', image_read, label_read, None, nframes_val, deterministic_object_sampler, start_frame='first') sampler = torch.utils.data.WeightedRandomSampler(len(train_set)*[1,], 118, replacement=True) train_loader = DataLoader(train_set, batch_size=batch_size, sampler=sampler, num_workers=11) val_loader = DataLoader(val_set, shuffle=False, batch_size=batch_size, num_workers=11) print("Sets initiated with {} (train) and {} (val) samples.".format(len(train_set), len(val_set))) objective = nn.NLLLoss(ignore_index=255).cuda() optimizer = torch.optim.Adam([param for param in model.parameters() if param.requires_grad], lr=1e-5, weight_decay=1e-6) lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, .985) trainer = trainers.VOSTrainer( model, optimizer, objective, lr_sched, train_loader, val_loader, use_gpu=True, workspace_dir=config['workspace_path'], save_name=os.path.splitext(os.path.basename(__file__))[0]+"_beta", checkpoint_interval=100, print_interval=25, debug=False) trainer.load_checkpoint() trainer.train(num_epochs)
def train_alpha(model): num_epochs = 160 batch_size = 4 nframes = 8 nframes_val = 32 size = (240, 432) def image_read(path): pic = Image.open(path) transform = tv.transforms.Compose([ tv.transforms.Resize(size, interpolation=Image.BILINEAR), tv.transforms.ToTensor(), tv.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD), ]) return transform(pic) def label_read(path): if os.path.exists(path): pic = Image.open(path) transform = tv.transforms.Compose([ tv.transforms.Resize(size, interpolation=Image.NEAREST), LabelToLongTensor(), ]) label = transform(pic) else: label = torch.LongTensor(1, *size).fill_( 255) # Put label that will be ignored return label def random_object_sampler(lst): return [random.choice(lst)] def deterministic_object_sampler(lst): return [lst[0]] train_transform = dataset_loaders.JointCompose( [dataset_loaders.JointRandomHorizontalFlip()]) train_set = torch.utils.data.ConcatDataset([ DAVIS17V2( config["davis17_path"], "2017", "train", image_read, label_read, train_transform, nframes, random_object_sampler, start_frame="random", ), YTVOSV2( config["ytvos_path"], "train", "train_joakim", "JPEGImages", image_read, label_read, train_transform, nframes, random_object_sampler, start_frame="random", ), ]) val_set = YTVOSV2( config["ytvos_path"], "train", "val_joakim", "JPEGImages", image_read, label_read, None, nframes_val, deterministic_object_sampler, start_frame="first", ) train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size, num_workers=11) val_loader = DataLoader(val_set, shuffle=False, batch_size=batch_size, num_workers=11) print("Sets initiated with {} (train) and {} (val) samples.".format( len(train_set), len(val_set))) objective = nn.NLLLoss(ignore_index=255).cuda() optimizer = torch.optim.Adam( [param for param in model.parameters() if param.requires_grad], lr=1e-4, weight_decay=1e-5, ) lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.975) trainer = trainers.VOSTrainer( model, optimizer, objective, lr_sched, train_loader, val_loader, use_gpu=True, workspace_dir=config["workspace_path"], save_name=os.path.splitext(os.path.basename(__file__))[0] + "_alpha", checkpoint_interval=10, print_interval=25, debug=False, ) trainer.load_checkpoint() trainer.train(num_epochs)