def create_transforms(relax_crop, zero_crop): # Preparation of the data loaders first = [ tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_crop), tr.FixedResize(resolutions={ 'crop_image': (512, 512), 'crop_gt': (512, 512) }) ] second = [ tr.ToImage(norm_elem='extreme_points'), tr.ConcatInputs(elems=('crop_image', 'extreme_points')), tr.ToTensor() ] train_tf = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), *first, tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'), *second ]) test_tf = transforms.Compose( [*first, tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'), *second]) return train_tf, test_tf
def transform_tr(self, sample): composed_transforms_tr = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), tr.CropFromMask(crop_elems=('image', 'gt'), relax=20, zero_pad=True), tr.FixedResize(resolutions={'crop_image': (256, 256), 'crop_gt': (256, 256)}), tr.Normalize(elems='crop_image'), tr.ToTensor() ]) return composed_transforms_tr(sample)
net.to(device) # Training the network if resume_epoch != nEpochs: # Logging into Tensorboard log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) # writer = SummaryWriter(log_dir=log_dir) # Use the following optimizer optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd']) p['optimizer'] = str(optimizer) # Preparation of the data loaders composed_transforms_tr = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop), tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}), tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'), tr.ToImage(norm_elem='extreme_points'), tr.ConcatInputs(elems=('crop_image', 'extreme_points')), tr.ToTensor()]) composed_transforms_ts = transforms.Compose([ tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop), tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}), tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'), tr.ToImage(norm_elem='extreme_points'), tr.ConcatInputs(elems=('crop_image', 'extreme_points')), tr.ToTensor()]) voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
'lr': lr / 100, 'weight_decay': wd }, { 'params': net.fuse.bias, 'lr': 2 * lr / 100 }, ], lr=lr, momentum=0.9) # Preparation of the data loaders # Define augmentation transformations as a composition composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)), tr.ToTensor() ]) # Training dataset and its iterator db_train = db.DAVIS2016(train=True, db_root_dir=db_root_dir, transform=composed_transforms, seq_name=seq_name) trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=1) # Testing dataset and its iterator db_test = db.DAVIS2016(train=False, db_root_dir=db_root_dir,
img = cv2.resize(img, (self.inputRes[1], self.inputRes[0])) gt = cv2.resize(gt, (self.inputRes[1], self.inputRes[0]), interpolation=cv2.INTER_NEAREST) sample = {'images': img, 'gts': gt} if self.transform is not None: sample = self.transform(sample) return sample def __len__(self): return len(self.imgs) if __name__ == '__main__': composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-10, 10), scales=(.75, 1.25)), tr.ToTensor() ]) train_set = VOC('train', inputRes=(512, 512), transform=composed_transforms) train_loader = DataLoader(train_set, batch_size=1, num_workers=8, shuffle=True) for ii, sample_batched in enumerate(train_loader): img, mask = sample_batched break
return _img, _target def __str__(self): return 'VOC2012(split=' + str(self.split) + ')' if __name__ == '__main__': from dataloaders import custom_transforms as tr from dataloaders.utils import decode_segmap from torch.utils.data import DataLoader from torchvision import transforms import matplotlib.pyplot as plt composed_transforms_tr = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-15, 15), scales=(.75, 1.5)), tr.FixedResize(size=512), tr.ToTensor() ]) voc_train = VOCSegmentation(split='train', transform=composed_transforms_tr) dataloader = DataLoader(voc_train, batch_size=2, shuffle=True, num_workers=2) for ii, sample in enumerate(dataloader): for jj in range(sample["image"].size()[0]): img = sample['image'].numpy()
save_dir = os.path.join(Path.save_root_dir(), 'lr_' + str(base_lr) + '_wd_' + str(weight_decay)) if not os.path.exists(save_dir): os.makedirs(os.path.join(save_dir)) davis17loader = db17.DAVISLoader(year=cfg.YEAR, phase=cfg.PHASE) seq_data = davis17loader[seq_name] images = seq_data.images anno = seq_data.annotations composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-30, 30), scales=(.5, 1.3)), tr.ToTensor() ]) alreadyTrained = False file_name = os.path.join( save_dir, 'online_training_' + seq_name + '_object_id_' + str(1) + 'epoch_' + str(nEpochs) + '.pth') if os.path.exists(file_name): print('Training already completed! Not doing it again.') alreadyTrained = True if not alreadyTrained: if os.path.exists(os.path.join(save_dir, 'logs')) == False: os.mkdir(os.path.join(save_dir, 'logs'))
def train(self, first_frame, n_interaction, obj_id, scribbles_data, scribble_iter, subset, use_previous_mask=False): nAveGrad = 1 num_workers = 4 train_batch = min(n_interaction, self.train_batch) frames_list = interactive_utils.scribbles.annotated_frames_object( scribbles_data, obj_id) scribbles_list = scribbles_data['scribbles'] seq_name = scribbles_data['sequence'] if obj_id == 1 and n_interaction == 1: self.prev_models = {} # Network definition if n_interaction == 1: print('Loading weights from: {}'.format(self.parent_model)) self.net.load_state_dict(self.parent_model_state) self.prev_models[obj_id] = None else: print( 'Loading weights from previous network: objId-{}_interaction-{}_scribble-{}.pth' .format(obj_id, n_interaction - 1, scribble_iter)) self.net.load_state_dict(self.prev_models[obj_id]) lr = 1e-8 wd = 0.0002 optimizer = optim.SGD([ { 'params': [ pr[1] for pr in self.net.stages.named_parameters() if 'weight' in pr[0] ], 'weight_decay': wd }, { 'params': [ pr[1] for pr in self.net.stages.named_parameters() if 'bias' in pr[0] ], 'lr': lr * 2 }, { 'params': [ pr[1] for pr in self.net.side_prep.named_parameters() if 'weight' in pr[0] ], 'weight_decay': wd }, { 'params': [ pr[1] for pr in self.net.side_prep.named_parameters() if 'bias' in pr[0] ], 'lr': lr * 2 }, { 'params': [ pr[1] for pr in self.net.upscale.named_parameters() if 'weight' in pr[0] ], 'lr': 0 }, { 'params': [ pr[1] for pr in self.net.upscale_.named_parameters() if 'weight' in pr[0] ], 'lr': 0 }, { 'params': self.net.fuse.weight, 'lr': lr / 100, 'weight_decay': wd }, { 'params': self.net.fuse.bias, 'lr': 2 * lr / 100 }, ], lr=lr, momentum=0.9) prev_mask_path = os.path.join( self.save_res_dir, 'interaction-{}'.format(n_interaction - 1), 'scribble-{}'.format(scribble_iter)) composed_transforms_tr = transforms.Compose([ tr.SubtractMeanImage(self.meanval), tr.CustomScribbleInteractive(scribbles_list, first_frame, use_previous_mask=use_previous_mask, previous_mask_path=prev_mask_path), tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)), tr.ToTensor() ]) # Training dataset and its iterator db_train = db.DAVIS2017(split=subset, transform=composed_transforms_tr, custom_frames=frames_list, seq_name=seq_name, obj_id=obj_id, no_gt=True, retname=True) trainloader = DataLoader(db_train, batch_size=train_batch, shuffle=True, num_workers=num_workers) num_img_tr = len(trainloader) loss_tr = [] aveGrad = 0 start_time = timeit.default_timer() # Main Training and Testing Loop epoch = 0 while 1: # One training epoch running_loss_tr = 0 for ii, sample_batched in enumerate(trainloader): inputs, gts, void = sample_batched['image'], sample_batched[ 'scribble_gt'], sample_batched['scribble_void_pixels'] # Forward-Backward of the mini-batch inputs, gts, void = Variable(inputs), Variable(gts), Variable( void) if self.gpu_id >= 0: inputs, gts, void = inputs.cuda(), gts.cuda(), void.cuda() outputs = self.net.forward(inputs) # Compute the fuse loss loss = class_balanced_cross_entropy_loss(outputs[-1], gts, size_average=False, void_pixels=void) running_loss_tr += loss.item() # Print stuff if epoch % 10 == 0: running_loss_tr /= num_img_tr loss_tr.append(running_loss_tr) print('[Epoch: %d, numImages: %5d]' % (epoch + 1, ii + 1)) print('Loss: %f' % running_loss_tr) # writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch) # Backward the averaged gradient loss /= nAveGrad loss.backward() aveGrad += 1 # Update the weights once in nAveGrad forward passes if aveGrad % nAveGrad == 0: # writer.add_scalar('data/total_loss_iter', loss.data[0], ii + num_img_tr * epoch) optimizer.step() optimizer.zero_grad() aveGrad = 0 epoch += train_batch stop_time = timeit.default_timer() if stop_time - start_time > self.time_budget: break # Save the model into dictionary self.prev_models[obj_id] = copy.deepcopy(self.net.state_dict())
def make_data_loader(args, **kwargs): crop_size = args.crop_size gt_size = args.gt_size if args.dataset == 'pascal' or args.dataset == 'click': composed_transforms_tr = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), tr.CropFromMask(crop_elems=('image', 'gt'), relax=20, zero_pad=True, jitters_bound=(40, 70)), tr.FixedResize( resolutions={ 'crop_image': (crop_size, crop_size), 'crop_gt': (gt_size, gt_size) }), tr.Normalize(elems='crop_image'), tr.ToTensor() ]) composed_transforms_val = transforms.Compose([ tr.CropFromMask(crop_elems=('image', 'gt'), relax=20, zero_pad=True, jitters_bound=(50, 51)), tr.FixedResize( resolutions={ 'crop_image': (crop_size, crop_size), 'crop_gt': (gt_size, gt_size) }), tr.Normalize(elems='crop_image'), tr.ToTensor() ]) train_set = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr) if args.dataset == 'click': train_set.reset_target_list(args) val_set = pascal.VOCSegmentation(split='val', transform=composed_transforms_val) if args.use_sbd: sbd_train = sbd.SBDSegmentation(args, split=['train', 'val']) train_set = combine_dbs.CombineDBs([train_set, sbd_train], excluded=[val_set]) train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=True, **kwargs) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = None NUM_CLASSES = 2 return train_loader, val_loader, test_loader, NUM_CLASSES elif args.dataset == 'grabcut': composed_transforms_val = transforms.Compose([ tr.CropFromMask(crop_elems=('image', 'gt'), relax=20, zero_pad=True, jitters_bound=(50, 51)), tr.FixedResize( resolutions={ 'crop_image': (crop_size, crop_size), 'crop_gt': (gt_size, gt_size) }), tr.Normalize(elems='crop_image'), tr.ToTensor() ]) val_set = grab_berkeley_eval.GrabBerkely( which='grabcut', transform=composed_transforms_val) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = None train_loader = None NUM_CLASSES = 2 return train_loader, val_loader, test_loader, NUM_CLASSES elif args.dataset == 'bekeley': composed_transforms_val = transforms.Compose([ tr.CropFromMask(crop_elems=('image', 'gt'), relax=20, zero_pad=True, jitters_bound=(50, 51)), tr.FixedResize( resolutions={ 'crop_image': (crop_size, crop_size), 'crop_gt': (gt_size, gt_size) }), tr.Normalize(elems='crop_image'), tr.ToTensor() ]) val_set = grab_berkeley_eval.GrabBerkely( which='bekeley', transform=composed_transforms_val) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = None train_loader = None NUM_CLASSES = 2 return train_loader, val_loader, test_loader, NUM_CLASSES elif args.dataset == 'cityscapes': train_set = cityscapes.CityscapesSegmentation(args, split='train') val_set = cityscapes.CityscapesSegmentation(args, split='val') test_set = cityscapes.CityscapesSegmentation(args, split='test') num_class = train_set.NUM_CLASSES train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, **kwargs) return train_loader, val_loader, test_loader, num_class elif args.dataset == 'coco': val_set = coco_eval.COCOSegmentation(split='val', cat=args.coco_part) num_class = 2 train_loader = None val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = None return train_loader, val_loader, test_loader, num_class # elif args.dataset == 'click': # train_set = click_dataset.ClickDataset(split='train') # val_set = click_dataset.ClickDataset(split='val') # num_class = 2 # train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) # val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) # test_loader = None # return train_loader, val_loader, test_loader, num_class else: raise NotImplementedError
def main(args): # # Select which GPU, -1 if CPU gpu_id = 0 device = torch.device("cuda:" + str(gpu_id) if torch.cuda.is_available() else "cpu") # # Setting other parameters resume_epoch = 0 # Default is 0, change if want to resume nEpochs = 10 # Number of epochs for training (500.000/2079) batch_size = 1 snapshot = 1 # Store a model every snapshot epochs pred_lr = 1e-8 seg_lr = 1e-4 lr_D = 1e-4 wd = 5e-4 beta = 0.001 margin = 0.3 updateD = True updateG = False num_frame = args.frame_nums modelName = 'STCNN_frame_' + str(num_frame) save_dir = Path.save_root_dir() if not os.path.exists(save_dir): os.makedirs(os.path.join(save_dir)) save_model_dir = os.path.join(save_dir, modelName) if not os.path.exists(save_model_dir): os.makedirs(os.path.join(save_model_dir)) # Network definition netD = Inception3(num_classes=1, aux_logits=False, transform_input=True) initialize_netD( netD, os.path.join(save_dir, 'FramePredModels', 'frame_nums_' + str(num_frame), 'NetD_epoch-90.pth')) seg_enc = SegEncoder() pred_enc = FramePredEncoder(frame_nums=num_frame) pred_dec = FramePredDecoder() j_seg_dec = JointSegDecoder() if resume_epoch == 0: initialize_model(pred_enc, seg_enc, pred_dec, j_seg_dec, save_dir, num_frame=num_frame) net = STCNN(pred_enc, seg_enc, pred_dec, j_seg_dec) else: net = STCNN(pred_enc, seg_enc, pred_dec, j_seg_dec) print("Updating weights from: {}".format( os.path.join( save_model_dir, modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'))) net.load_state_dict( torch.load(os.path.join( save_model_dir, modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'), map_location=lambda storage, loc: storage)) # Logging into Tensorboard log_dir = os.path.join( save_dir, 'JointPredSegNet_runs', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname()) writer = SummaryWriter(log_dir=log_dir, comment='-parent') # PyTorch 0.4.0 style net.to(device) netD.to(device) lp_function = nn.MSELoss().to(device) criterion = nn.BCELoss().to(device) seg_criterion = nn.BCEWithLogitsLoss().to(device) # Use the following optimizer optimizer = optim.SGD([ { 'params': [param for name, param in net.seg_encoder.named_parameters()], 'lr': seg_lr }, { 'params': [param for name, param in net.seg_decoder.named_parameters()], 'lr': seg_lr }, ], weight_decay=wd, momentum=0.9) optimizerG = optim.Adam([ { 'params': [param for name, param in net.pred_encoder.named_parameters()], 'lr': pred_lr }, { 'params': [param for name, param in net.pred_decoder.named_parameters()], 'lr': pred_lr }, ], lr=pred_lr, weight_decay=wd) optimizerD = optim.Adam(netD.parameters(), lr=lr_D, weight_decay=wd) # Preparation of the data loaders # Define augmentation transformations as a composition composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-30, 30), scales=(0.75, 1.25)) ]) # Training dataset and its iterator db_train = db.DAVISDataset(inputRes=(400, 710), samples_list_file=os.path.join( Path.data_dir(), 'DAVIS16_samples_list_' + str(num_frame) + '.txt'), transform=composed_transforms, num_frame=num_frame) trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=4) num_img_tr = len(trainloader) iter_num = nEpochs * num_img_tr curr_iter = resume_epoch * num_img_tr print("Training Network") real_label = torch.ones(batch_size).float().to(device) fake_label = torch.zeros(batch_size).float().to(device) for epoch in range(resume_epoch, nEpochs): start_time = timeit.default_timer() for ii, sample_batched in enumerate(trainloader): seqs, frames, gts, pred_gts = sample_batched['images'], sample_batched['frame'],sample_batched['seg_gt'], \ sample_batched['pred_gt'] # Forward-Backward of the mini-batch seqs.requires_grad_() frames.requires_grad_() seqs, frames, gts, pred_gts = seqs.to(device), frames.to( device), gts.to(device), pred_gts.to(device) pred_gts = F.upsample(pred_gts, size=(100, 178), mode='bilinear', align_corners=False) pred_gts = pred_gts.detach() seg_res, pred = net.forward(seqs, frames) D_real = netD(pred_gts) errD_real = criterion(D_real, real_label) D_fake = netD(pred.detach()) errD_fake = criterion(D_fake, fake_label) optimizer.zero_grad() seg_loss = seg_criterion(seg_res[-1], gts) for i in reversed(range(len(seg_res) - 1)): seg_loss = seg_loss + ( 1 - curr_iter / iter_num) * seg_criterion(seg_res[i], gts) seg_loss.backward() optimizer.step() curr_iter += 1 if updateD: ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real netD.zero_grad() # train with fake d_loss = errD_fake + errD_real d_loss.backward() optimizerD.step() if updateG: ############################ # (2) Update G network: maximize log(D(G(z))) ########################### optimizerG.zero_grad() D_fake = netD(pred) errG = criterion(D_fake, real_label) lp_loss = lp_function(pred, pred_gts) total_loss = lp_loss + beta * errG total_loss.backward() optimizerG.step() if (errD_fake.data < margin).all() or (errD_real.data < margin).all(): updateD = False if (errD_fake.data > (1. - margin)).all() or (errD_real.data > (1. - margin)).all(): updateG = False if not updateD and not updateG: updateD = True updateG = True if (ii + num_img_tr * epoch) % 5 == 4: print( "Iters: [%2d] time: %4.4f, lp_loss: %.8f, G_loss: %.8f,seg_loss: %.8f" % (ii + num_img_tr * epoch, timeit.default_timer() - start_time, lp_loss.item(), errG.item(), seg_loss.item())) print('updateD:', updateD, 'updateG:', updateG) if (ii + num_img_tr * epoch) % 10 == 9: writer.add_scalar('data/loss_iter', total_loss.item(), ii + num_img_tr * epoch) writer.add_scalar('data/lp_loss_iter', lp_loss.item(), ii + num_img_tr * epoch) writer.add_scalar('data/G_loss_iter', errG.item(), ii + num_img_tr * epoch) writer.add_scalar('data/seg_loss_iter', seg_loss.item(), ii + num_img_tr * epoch) if (ii + num_img_tr * epoch) % 500 == 0: seg_pred = seg_res[-1][0, :, :, :].data.cpu().numpy() seg_pred = 1 / (1 + np.exp(-seg_pred)) gt_sample = gts[0, :, :, :].data.cpu().numpy().transpose( [1, 2, 0]) * 255 seg_pred = seg_pred.transpose([1, 2, 0]) * 255 frame_sample = frames[0, :, :, :].data.cpu().numpy().transpose( [1, 2, 0]) frame_sample = inverse_transform(frame_sample) * 255 gt_sample3 = np.concatenate([gt_sample, gt_sample, gt_sample], axis=2) seg_pred3 = np.concatenate([seg_pred, seg_pred, seg_pred], axis=2) samples1 = np.concatenate( (seg_pred3, gt_sample3, frame_sample), axis=0) pred_sample = pred[0, :, :, :].data.cpu().numpy().transpose( [1, 2, 0]) frame_sample = pred_gts[ 0, :, :, :].data.cpu().numpy().transpose([1, 2, 0]) samples2 = np.concatenate((pred_sample, frame_sample), axis=0) samples2 = inverse_transform(samples2) * 255 print("Saving sample ...") running_res_dir = os.path.join(save_dir, modelName + '_results') if not os.path.exists(running_res_dir): os.makedirs(running_res_dir) imageio.imwrite( os.path.join(running_res_dir, "train_%s_s.png" % (ii + num_img_tr * epoch)), np.uint8(samples1)) imageio.imwrite( os.path.join(running_res_dir, "train_%s_p.png" % (ii + num_img_tr * epoch)), np.uint8(samples2)) # Print stuff print('[Epoch: %d, numImages: %5d]' % (epoch, (ii + 1) * batch_size)) stop_time = timeit.default_timer() print("Execution time: " + str(stop_time - start_time)) # Save the model if (epoch % snapshot) == snapshot - 1 and epoch != 0: torch.save( net.state_dict(), os.path.join(save_model_dir, modelName + '_epoch-' + str(epoch) + '.pth')) writer.close()
def train(epochs_wo_avegrad): # Setting of parameters if 'SEQ_NAME' not in os.environ.keys(): seq_name = 'blackswan' else: seq_name = str(os.environ['SEQ_NAME']) db_root_dir = Path.db_root_dir() save_dir = Path.save_root_dir() if not os.path.exists(save_dir): os.makedirs(os.path.join(save_dir)) vis_net = 0 # Visualize the network? vis_res = 0 # Visualize the results? nAveGrad = 5 # Average the gradient every nAveGrad iterations nEpochs = epochs_wo_avegrad * nAveGrad # Number of epochs for training #CHANGED from 2000 snapshot = nEpochs # Store a model every snapshot epochs parentEpoch = 240 # Parameters in p are used for the name of the model p = { 'trainBatch': 1, # Number of Images in each mini-batch } seed = 0 parentModelName = 'parent' # Select which GPU, -1 if CPU gpu_id = 0 device = torch.device("cuda:" + str(gpu_id) if torch.cuda.is_available() else "cpu") # Network definition net = vo.OSVOS(pretrained=0) net.load_state_dict( torch.load(os.path.join( save_dir, parentModelName + '_epoch-' + str(parentEpoch - 1) + '.pth'), map_location=lambda storage, loc: storage)) # Logging into Tensorboard log_dir = os.path.join( save_dir, 'runs', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname() + '-' + seq_name) writer = SummaryWriter(logdir=log_dir) net.to(device) # PyTorch 0.4.0 style # Visualize the network if vis_net: x = torch.randn(1, 3, 480, 854) x.requires_grad_() x = x.to(device) y = net.forward(x) g = viz.make_dot(y, net.state_dict()) g.view() # Use the following optimizer lr = 1e-8 wd = 0.0002 optimizer = optim.SGD([ { 'params': [ pr[1] for pr in net.stages.named_parameters() if 'weight' in pr[0] ], 'weight_decay': wd }, { 'params': [pr[1] for pr in net.stages.named_parameters() if 'bias' in pr[0]], 'lr': lr * 2 }, { 'params': [ pr[1] for pr in net.side_prep.named_parameters() if 'weight' in pr[0] ], 'weight_decay': wd }, { 'params': [ pr[1] for pr in net.side_prep.named_parameters() if 'bias' in pr[0] ], 'lr': lr * 2 }, { 'params': [ pr[1] for pr in net.upscale.named_parameters() if 'weight' in pr[0] ], 'lr': 0 }, { 'params': [ pr[1] for pr in net.upscale_.named_parameters() if 'weight' in pr[0] ], 'lr': 0 }, { 'params': net.fuse.weight, 'lr': lr / 100, 'weight_decay': wd }, { 'params': net.fuse.bias, 'lr': 2 * lr / 100 }, ], lr=lr, momentum=0.9) # Preparation of the data loaders # Define augmentation transformations as a composition composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)), tr.ToTensor() ]) # Training dataset and its iterator db_train = db.DAVIS2016(train=True, db_root_dir=db_root_dir, transform=composed_transforms, seq_name=seq_name) trainloader = DataLoader(db_train, batch_size=p['trainBatch'], shuffle=True, num_workers=1) # Testing dataset and its iterator db_test = db.DAVIS2016(train=False, db_root_dir=db_root_dir, transform=tr.ToTensor(), seq_name=seq_name) testloader = DataLoader(db_test, batch_size=1, shuffle=False, num_workers=1) num_img_tr = len(trainloader) num_img_ts = len(testloader) loss_tr = [] aveGrad = 0 print("Start of Online Training, sequence: " + seq_name) start_time = timeit.default_timer() # Main Training and Testing Loop for epoch in range(0, nEpochs): # One training epoch running_loss_tr = 0 np.random.seed(seed + epoch) for ii, sample_batched in enumerate(trainloader): inputs, gts = sample_batched['image'], sample_batched['gt'] # Forward-Backward of the mini-batch inputs.requires_grad_() inputs, gts = inputs.to(device), gts.to(device) outputs = net.forward(inputs) # Compute the fuse loss loss = class_balanced_cross_entropy_loss(outputs[-1], gts, size_average=False) running_loss_tr += loss.item() # PyTorch 0.4.0 style # Print stuff if epoch % (nEpochs // 20) == (nEpochs // 20 - 1): running_loss_tr /= num_img_tr loss_tr.append(running_loss_tr) print('[Epoch: %d, numImages: %5d]' % (epoch + 1, ii + 1)) print('Loss: %f' % running_loss_tr) writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch) # Backward the averaged gradient loss /= nAveGrad loss.backward() aveGrad += 1 # Update the weights once in nAveGrad forward passes if aveGrad % nAveGrad == 0: writer.add_scalar('data/total_loss_iter', loss.item(), ii + num_img_tr * epoch) optimizer.step() optimizer.zero_grad() aveGrad = 0 # Save the model if (epoch % snapshot) == snapshot - 1 and epoch != 0: torch.save( net.state_dict(), os.path.join(save_dir, seq_name + '_epoch-' + str(epoch) + '.pth')) stop_time = timeit.default_timer() print('Online training time: ' + str(stop_time - start_time)) # Testing Phase if vis_res: import matplotlib.pyplot as plt plt.close("all") plt.ion() f, ax_arr = plt.subplots(1, 3) save_dir_res = os.path.join(save_dir, 'Results', seq_name) if not os.path.exists(save_dir_res): os.makedirs(save_dir_res) print('Testing Network') with torch.no_grad(): # PyTorch 0.4.0 style # Main Testing Loop for ii, sample_batched in enumerate(testloader): img, gt, fname = sample_batched['image'], sample_batched[ 'gt'], sample_batched['fname'] # Forward of the mini-batch inputs, gts = img.to(device), gt.to(device) outputs = net.forward(inputs) for jj in range(int(inputs.size()[0])): pred = np.transpose( outputs[-1].cpu().data.numpy()[jj, :, :, :], (1, 2, 0)) pred = 1 / (1 + np.exp(-pred)) pred = np.squeeze(pred) # Save the result, attention to the index jj sm.imsave( os.path.join(save_dir_res, os.path.basename(fname[jj]) + '.png'), pred) if vis_res: img_ = np.transpose(img.numpy()[jj, :, :, :], (1, 2, 0)) gt_ = np.transpose(gt.numpy()[jj, :, :, :], (1, 2, 0)) gt_ = np.squeeze(gt) # Plot the particular example ax_arr[0].cla() ax_arr[1].cla() ax_arr[2].cla() ax_arr[0].set_title('Input Image') ax_arr[1].set_title('Ground Truth') ax_arr[2].set_title('Detection') ax_arr[0].imshow(im_normalize(img_)) ax_arr[1].imshow(gt_) ax_arr[2].imshow(im_normalize(pred)) plt.pause(0.001) writer.close()
def train(self, first_frame, n_interaction, obj_id, scribbles_data, scribble_iter, subset, use_previous_mask=False): nAveGrad = 1 num_workers = 4 train_batch = min(n_interaction, self.train_batch) frames_list = interactive_utils.scribbles.annotated_frames_object(scribbles_data, obj_id) scribbles_list = scribbles_data['scribbles'] seq_name = scribbles_data['sequence'] if obj_id == 1 and n_interaction == 1: self.prev_models = {} # # Network definition # if n_interaction == 1: # print('Loading weights from: {}'.format(self.parent_model)) # self.net.load_state_dict(self.parent_model_state) # self.prev_models[obj_id] = None # else: # print('Loading weights from previous network: objId-{}_interaction-{}_scribble-{}.pth' # .format(obj_id, n_interaction-1, scribble_iter)) # self.net.load_state_dict(self.prev_models[obj_id]) lr = 1e-5 wd = 0.0002 # optimizer = optim.SGD([ # {'params': [pr[1] for pr in self.net.stages.named_parameters() if 'weight' in pr[0]], 'weight_decay': wd}, # {'params': [pr[1] for pr in self.net.stages.named_parameters() if 'bias' in pr[0]], 'lr': lr * 2}, # {'params': [pr[1] for pr in self.net.side_prep.named_parameters() if 'weight' in pr[0]], 'weight_decay': wd}, # {'params': [pr[1] for pr in self.net.side_prep.named_parameters() if 'bias' in pr[0]], 'lr': lr * 2}, # {'params': [pr[1] for pr in self.net.upscale.named_parameters() if 'weight' in pr[0]], 'lr': 0}, # {'params': [pr[1] for pr in self.net.upscale_.named_parameters() if 'weight' in pr[0]], 'lr': 0}, # {'params': self.net.fuse.weight, 'lr': lr / 100, 'weight_decay': wd}, # {'params': self.net.fuse.bias, 'lr': 2 * lr / 100}, # ], lr=lr, momentum=0.9) optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=lr, momentum=0.9) prev_mask_path = os.path.join(self.save_res_dir, 'interaction-{}'.format(n_interaction-1), 'scribble-{}'.format(scribble_iter)) composed_transforms_tr = transforms.Compose([tr.CenterCrop((480,832)),tr.SubtractMeanImage(self.meanval), tr.CustomScribbleInteractive(scribbles_list, first_frame, use_previous_mask=use_previous_mask, previous_mask_path=prev_mask_path), tr.RandomHorizontalFlip(), tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)), tr.ToTensor()]) # Training dataset and its iterator # db_train = db.DAVIS2017(split=subset, transform=composed_transforms_tr, # custom_frames=frames_list, seq_name=seq_name, # obj_id=obj_id, no_gt=True, retname=True) db_train = db_scribblenet.DAVIS2017(split=subset, transform=composed_transforms_tr, custom_frames=frames_list, seq_name=seq_name, obj_id=obj_id, no_gt=True, retname=True) trainloader = DataLoader(db_train, batch_size=train_batch, shuffle=True, num_workers=num_workers) num_img_tr = len(trainloader) loss_tr = [] aveGrad = 0 # List of all previous masks and aggregated features prev_masks = [] prev_aggs = [] start_time = timeit.default_timer() # Main Training and Testing Loop epoch = 0 while 1: # One training epoch running_loss_tr = 0 for ii, sample_batched in enumerate(trainloader): optimizer.zero_grad() # Parse from dataset loader inputs = sample_batched['images'].cuda() gts = sample_batched['scribble_gt'].cuda() scribbles = sample_batched['scribble_raw'].cuda() scribbles_idx = sample_batched['scribble_idx'].cuda() # Forward-Backward of the mini-batch # prev_masks = torch.tensor(prev_masks).unsqueeze(0) # prev_aggs = torch.tensor(prev_aggs).unsqueeze(0) masks, agg = self.net.forward(inputs, scribbles, scribble_idx, prev_masks, prev_agg) # Compute the fuse loss loss = class_balanced_cross_entropy_loss(masks, gts, scribble_idx) running_loss_tr += loss.item() # Print stuff if epoch % 10 == 0: running_loss_tr /= num_img_tr loss_tr.append(running_loss_tr) print('[Epoch: %d, numImages: %5d]' % (epoch + 1, ii + 1)) print('Loss: %f' % running_loss_tr) # Backward the averaged gradient loss.backward() optimizer.step() # Update the current round data prev_masks.append(masks.detach()) prev_aggs.append(agg.detach()) epoch += train_batch stop_time = timeit.default_timer() if stop_time - start_time > self.time_budget: break # Save the model into dictionary self.prev_models[obj_id] = copy.deepcopy(self.net.state_dict())