def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1): self.para = para self.single_object = para['single_object'] self.local_rank = local_rank self.PNet = nn.parallel.DistributedDataParallel( PropagationNetwork(self.single_object).cuda(), device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) # Setup logger when local_rank=0 self.logger = logger self.save_path = save_path if logger is not None: self.last_time = time.time() self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) if self.single_object: self.train_integrator.add_hook(iou_hooks_so) else: self.train_integrator.add_hook(iou_hooks_mo) self.loss_computer = LossComputer(para) self.train() self.optimizer = optim.Adam(filter( lambda p: p.requires_grad, self.PNet.parameters()), lr=para['lr'], weight_decay=1e-7) self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, para['steps'], para['gamma']) # Logging info self.report_interval = 100 self.save_im_interval = 800 self.save_model_interval = 50000 if para['debug']: self.report_interval = self.save_im_interval = 1
def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1): self.para = para self.local_rank = local_rank self.S2M = nn.parallel.DistributedDataParallel( nn.SyncBatchNorm.convert_sync_batchnorm( deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False)).cuda(), device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) # Setup logger when local_rank=0 self.logger = logger self.save_path = save_path if logger is not None: self.last_time = time.time() self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) self.train_integrator.add_hook(iou_hooks_to_be_used) self.loss_computer = LossComputer(para) self.train() self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.S2M.parameters()), lr=para['lr'], weight_decay=1e-7) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, para['steps'], para['gamma']) # Logging info self.report_interval = 50 self.save_im_interval = 800 self.save_model_interval = 20000 if para['debug']: self.report_interval = self.save_im_interval = 1
pin_memory=True) sobel_compute = SobelComputer() # Learning rate decay scheduling scheduler = optim.lr_scheduler.MultiStepLR(optimizer, para['steps'], para['gamma']) saver = ModelSaver(long_id) report_interval = 50 save_im_interval = 800 total_epoch = int(para['iterations'] / len(train_loader) + 0.5) print('Actual training epoch: ', total_epoch) train_integrator = Integrator(logger) train_integrator.add_hook(iou_hooks_to_be_used) total_iter = 0 last_time = 0 for e in range(total_epoch): np.random.seed() # reset seed epoch_start_time = time.time() # Train loop model = model.train() for im, seg, gt in train_loader: im, seg, gt = im.cuda(), seg.cuda(), gt.cuda() total_iter += 1 if total_iter % 5000 == 0: saver.save_model(model, total_iter)
class FusionModel: def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1, distributed=True): self.para = para self.local_rank = local_rank if distributed: self.net = nn.parallel.DistributedDataParallel( FusionNet().cuda(), device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) else: self.net = nn.DataParallel(FusionNet().cuda(), device_ids=[local_rank], output_device=local_rank) self.prop_net = AttentionReadNetwork().eval().cuda() # Setup logger when local_rank=0 self.logger = logger self.save_path = save_path if logger is not None: self.last_time = time.time() self.train_integrator = Integrator(self.logger, distributed=distributed, local_rank=local_rank, world_size=world_size) self.train_integrator.add_hook(iou_hooks) self.val_integrator = Integrator(self.logger, distributed=distributed, local_rank=local_rank, world_size=world_size) self.loss_computer = LossComputer(para) self.train() self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=para['lr'], weight_decay=1e-7) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, para['steps'], para['gamma']) # Logging info self.report_interval = 100 self.save_im_interval = 500 self.save_model_interval = 5000 if para['debug']: self.report_interval = self.save_im_interval = 1 def do_pass(self, data, it=0): # No need to store the gradient outside training torch.set_grad_enabled(self._is_train) for k, v in data.items(): if type(v) != list and type(v) != dict and type(v) != int: data[k] = v.cuda(non_blocking=True) # See fusion_dataset.py for variable definitions im = data['rgb'] seg1 = data['seg1'] seg2 = data['seg2'] src2_ref = data['src2_ref'] src2_ref_gt = data['src2_ref_gt'] seg12 = data['seg12'] seg22 = data['seg22'] src2_ref2 = data['src2_ref2'] src2_ref_gt2 = data['src2_ref_gt2'] src2_ref_im = data['src2_ref_im'] selector = data['selector'] dist = data['dist'] out = {} # Get kernelized memory with torch.no_grad(): attn1, attn2 = self.prop_net(src2_ref_im, src2_ref, src2_ref_gt, src2_ref2, src2_ref_gt2, im) prob1 = torch.sigmoid(self.net(im, seg1, seg2, attn1, dist)) prob2 = torch.sigmoid(self.net(im, seg12, seg22, attn2, dist)) prob = torch.cat([prob1, prob2], 1) * selector.unsqueeze(2).unsqueeze(2) logits, prob = aggregate_wbg_channel(prob, True) out['logits'] = logits out['mask'] = prob out['attn1'] = attn1 out['attn2'] = attn2 if self._do_log or self._is_train: losses = self.loss_computer.compute({**data, **out}, it) # Logging if self._do_log: self.integrator.add_dict(losses) if self._is_train: if it % self.save_im_interval == 0 and it != 0: if self.logger is not None: images = {**data, **out} size = (320, 320) self.logger.log_cv2('train/pairs', pool_fusion(images, size=size), it) else: # Validation save if data['val_iter'] % 10 == 0: if self.logger is not None: images = {**data, **out} size = (320, 320) self.logger.log_cv2('val/pairs', pool_fusion(images, size=size), it) if self._is_train: if (it) % self.report_interval == 0 and it != 0: if self.logger is not None: self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) self.logger.log_metrics('train', 'time', (time.time() - self.last_time) / self.report_interval, it) self.last_time = time.time() self.train_integrator.finalize('train', it) self.train_integrator.reset_except_hooks() if it % self.save_model_interval == 0 and it != 0: if self.logger is not None: self.save(it) # Backward pass self.optimizer.zero_grad(set_to_none=True) losses['total_loss'].backward() self.optimizer.step() self.scheduler.step() def save(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) model_path = self.save_path + ('_%s.pth' % it) torch.save(self.net.module.state_dict(), model_path) print('Model saved to %s.' % model_path) self.save_checkpoint(it) def save_checkpoint(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) checkpoint_path = self.save_path + '_checkpoint.pth' checkpoint = { 'it': it, 'network': self.net.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() } torch.save(checkpoint, checkpoint_path) print('Checkpoint saved to %s.' % checkpoint_path) def load_model(self, path): map_location = 'cuda:%d' % self.local_rank checkpoint = torch.load(path, map_location={'cuda:0': map_location}) it = checkpoint['it'] network = checkpoint['network'] optimizer = checkpoint['optimizer'] scheduler = checkpoint['scheduler'] map_location = 'cuda:%d' % self.local_rank self.net.module.load_state_dict(network) self.optimizer.load_state_dict(optimizer) self.scheduler.load_state_dict(scheduler) print('Model loaded.') return it def load_network(self, path): map_location = 'cuda:%d' % self.local_rank self.net.module.load_state_dict( torch.load(path, map_location={'cuda:0': map_location})) # self.net.load_state_dict(torch.load(path)) print('Network weight loaded:', path) def load_prop(self, path): map_location = 'cuda:%d' % self.local_rank self.prop_net.load_state_dict(torch.load( path, map_location={'cuda:0': map_location}), strict=False) print('Propagation network weight loaded:', path) def finalize_val(self, it): self.val_integrator.finalize('val', it) self.val_integrator.reset_except_hooks() def train(self): self._is_train = True self._do_log = True self.integrator = self.train_integrator # Also skip BN self.net.eval() self.prop_net.eval() return self def val(self): self._is_train = False self.integrator = self.val_integrator self._do_log = True self.net.eval() self.prop_net.eval() return self def test(self): self._is_train = False self._do_log = False self.net.eval() self.prop_net.eval() return self
class S2MModel: def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1): self.para = para self.local_rank = local_rank self.S2M = nn.parallel.DistributedDataParallel( nn.SyncBatchNorm.convert_sync_batchnorm( deeplabv3plus_resnet50(num_classes=1, output_stride=16, pretrained_backbone=False)).cuda(), device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) # Setup logger when local_rank=0 self.logger = logger self.save_path = save_path if logger is not None: self.last_time = time.time() self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) self.train_integrator.add_hook(iou_hooks_to_be_used) self.loss_computer = LossComputer(para) self.train() self.optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.S2M.parameters()), lr=para['lr'], weight_decay=1e-7) self.scheduler = optim.lr_scheduler.MultiStepLR( self.optimizer, para['steps'], para['gamma']) # Logging info self.report_interval = 50 self.save_im_interval = 800 self.save_model_interval = 20000 if para['debug']: self.report_interval = self.save_im_interval = 1 def do_pass(self, data, it=0): # No need to store the gradient outside training torch.set_grad_enabled(self._is_train) for k, v in data.items(): if type(v) != list and type(v) != dict and type(v) != int: data[k] = v.cuda(non_blocking=True) out = {} Fs = data['rgb'] Ss = data['seg'] Rs = data['srb'] Ms = data['gt'] inputs = torch.cat([Fs, Ss, Rs], 1) prob = torch.sigmoid(self.S2M(inputs)) logits, mask = aggregate(prob) out['logits'] = logits out['mask'] = mask if self._do_log or self._is_train: losses = self.loss_computer.compute({**data, **out}, it) # Logging if self._do_log: self.integrator.add_dict(losses) if self._is_train: if it % self.save_im_interval == 0 and it != 0: if self.logger is not None: images = {**data, **out} size = (384, 384) self.logger.log_cv2('train/pairs', pool_pairs(images, size=size), it) if self._is_train: if (it) % self.report_interval == 0 and it != 0: if self.logger is not None: self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) self.logger.log_metrics('train', 'time', (time.time() - self.last_time) / self.report_interval, it) self.last_time = time.time() self.train_integrator.finalize('train', it) self.train_integrator.reset_except_hooks() if it % self.save_model_interval == 0 and it != 0: if self.logger is not None: self.save(it) # Backward pass self.optimizer.zero_grad() losses['total_loss'].backward() self.optimizer.step() self.scheduler.step() def save(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) model_path = self.save_path + ('_%s.pth' % it) torch.save(self.S2M.module.state_dict(), model_path) print('Model saved to %s.' % model_path) self.save_checkpoint(it) def save_checkpoint(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) checkpoint_path = self.save_path + '_checkpoint.pth' checkpoint = { 'it': it, 'network': self.S2M.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict() } torch.save(checkpoint, checkpoint_path) print('Checkpoint saved to %s.' % checkpoint_path) def load_model(self, path): map_location = 'cuda:%d' % self.local_rank checkpoint = torch.load(path, map_location={'cuda:0': map_location}) it = checkpoint['it'] network = checkpoint['network'] optimizer = checkpoint['optimizer'] scheduler = checkpoint['scheduler'] map_location = 'cuda:%d' % self.local_rank self.S2M.module.load_state_dict(network) self.optimizer.load_state_dict(optimizer) self.scheduler.load_state_dict(scheduler) print('Model loaded.') return it def load_network(self, path): map_location = 'cuda:%d' % self.local_rank self.S2M.module.load_state_dict( torch.load(path, map_location={'cuda:0': map_location})) print('Network weight loaded:', path) def load_deeplab(self, path): map_location = 'cuda:%d' % self.local_rank cur_dict = self.S2M.module.state_dict() src_dict = torch.load(path, map_location={'cuda:0': map_location})['model_state'] for k in list(src_dict.keys()): if type(src_dict[k]) is not int: if src_dict[k].shape != cur_dict[k].shape: print('Reloading: ', k) if 'bias' in k: # Reseting the class prob bias src_dict[k] = torch.zeros_like((src_dict[k][0:1])) elif src_dict[k].shape[1] != 3: # Reseting the class prob weight src_dict[k] = torch.zeros_like((src_dict[k][0:1])) nn.init.orthogonal_(src_dict[k]) else: # Adding the mask and scribbles channel pads = torch.zeros((64, 3, 7, 7), device=src_dict[k].device) nn.init.orthogonal_(pads) src_dict[k] = torch.cat([src_dict[k], pads], 1) self.S2M.module.load_state_dict(src_dict) print('Network weight loaded:', path) def train(self): self._is_train = True self._do_log = True self.integrator = self.train_integrator self.S2M.train() return self def val(self): self._is_train = False self._do_log = True self.S2M.eval() return self def test(self): self._is_train = False self._do_log = False self.S2M.eval() return self
class PropagationModel: def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1): self.para = para self.single_object = para['single_object'] self.local_rank = local_rank self.PNet = nn.parallel.DistributedDataParallel( PropagationNetwork(self.single_object).cuda(), device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) # Setup logger when local_rank=0 self.logger = logger self.save_path = save_path if logger is not None: self.last_time = time.time() self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) if self.single_object: self.train_integrator.add_hook(iou_hooks_so) else: self.train_integrator.add_hook(iou_hooks_mo) self.loss_computer = LossComputer(para) self.train() self.optimizer = optim.Adam(filter( lambda p: p.requires_grad, self.PNet.parameters()), lr=para['lr'], weight_decay=1e-7) self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, para['steps'], para['gamma']) # Logging info self.report_interval = 100 self.save_im_interval = 800 self.save_model_interval = 50000 if para['debug']: self.report_interval = self.save_im_interval = 1 def do_pass(self, data, it=0): # No need to store the gradient outside training torch.set_grad_enabled(self._is_train) for k, v in data.items(): if type(v) != list and type(v) != dict and type(v) != int: data[k] = v.cuda(non_blocking=True) out = {} Fs = data['rgb'] Ms = data['gt'] if self.single_object: key_k, key_v = self.PNet(Fs[:,0], Ms[:,0]) prev_logits, prev_mask = self.PNet(Fs[:,1], key_k, key_v) prev_k, prev_v = self.PNet(Fs[:,1], prev_mask) keys = torch.cat([key_k, prev_k], 2) values = torch.cat([key_v, prev_v], 2) this_logits, this_mask = self.PNet(Fs[:,2], keys, values) out['mask_1'] = prev_mask out['mask_2'] = this_mask out['logits_1'] = prev_logits out['logits_2'] = this_logits else: sec_Ms = data['sec_gt'] selector = data['selector'] key_k1, key_v1 = self.PNet(Fs[:,0], Ms[:,0], sec_Ms[:,0]) key_k2, key_v2 = self.PNet(Fs[:,0], sec_Ms[:,0], Ms[:,0]) key_k = torch.stack([key_k1, key_k2], 1) key_v = torch.stack([key_v1, key_v2], 1) prev_logits, prev_mask = self.PNet(Fs[:,1], key_k, key_v, selector) prev_k1, prev_v1 = self.PNet(Fs[:,1], prev_mask[:,0:1], prev_mask[:,1:2]) prev_k2, prev_v2 = self.PNet(Fs[:,1], prev_mask[:,1:2], prev_mask[:,0:1]) prev_k = torch.stack([prev_k1, prev_k2], 1) prev_v = torch.stack([prev_v1, prev_v2], 1) keys = torch.cat([key_k, prev_k], 3) values = torch.cat([key_v, prev_v], 3) this_logits, this_mask = self.PNet(Fs[:,2], keys, values, selector) out['mask_1'] = prev_mask[:,0:1] out['mask_2'] = this_mask[:,0:1] out['sec_mask_1'] = prev_mask[:,1:2] out['sec_mask_2'] = this_mask[:,1:2] out['logits_1'] = prev_logits out['logits_2'] = this_logits if self._do_log or self._is_train: losses = self.loss_computer.compute({**data, **out}, it) # Logging if self._do_log: self.integrator.add_dict(losses) if self._is_train: if it % self.save_im_interval == 0 and it != 0: if self.logger is not None: images = {**data, **out} size = (384, 384) self.logger.log_cv2('train/pairs', pool_pairs(images, size, self.single_object), it) if self._is_train: if (it) % self.report_interval == 0 and it != 0: if self.logger is not None: self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) self.logger.log_metrics('train', 'time', (time.time()-self.last_time)/self.report_interval, it) self.last_time = time.time() self.train_integrator.finalize('train', it) self.train_integrator.reset_except_hooks() if it % self.save_model_interval == 0 and it != 0: if self.logger is not None: self.save(it) # Backward pass for param_group in self.optimizer.param_groups: for p in param_group['params']: p.grad = None losses['total_loss'].backward() self.optimizer.step() self.scheduler.step() def save(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) model_path = self.save_path + ('_%s.pth' % it) torch.save(self.PNet.module.state_dict(), model_path) print('Model saved to %s.' % model_path) self.save_checkpoint(it) def save_checkpoint(self, it): if self.save_path is None: print('Saving has been disabled.') return os.makedirs(os.path.dirname(self.save_path), exist_ok=True) checkpoint_path = self.save_path + '_checkpoint.pth' checkpoint = { 'it': it, 'network': self.PNet.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scheduler': self.scheduler.state_dict()} torch.save(checkpoint, checkpoint_path) print('Checkpoint saved to %s.' % checkpoint_path) def load_model(self, path): map_location = 'cuda:%d' % self.local_rank checkpoint = torch.load(path, map_location={'cuda:0': map_location}) it = checkpoint['it'] network = checkpoint['network'] optimizer = checkpoint['optimizer'] scheduler = checkpoint['scheduler'] map_location = 'cuda:%d' % self.local_rank self.PNet.module.load_state_dict(network) self.optimizer.load_state_dict(optimizer) self.scheduler.load_state_dict(scheduler) print('Model loaded.') return it def load_network(self, path): map_location = 'cuda:%d' % self.local_rank src_dict = torch.load(path, map_location={'cuda:0': map_location}) # Maps SO weight (without other_mask) to MO weight (with other_mask) for k in list(src_dict.keys()): if k == 'mask_rgb_encoder.conv1.weight': if src_dict[k].shape[1] == 4: pads = torch.zeros((64,1,7,7), device=src_dict[k].device) nn.init.orthogonal_(pads) src_dict[k] = torch.cat([src_dict[k], pads], 1) self.PNet.module.load_state_dict(src_dict) print('Network weight loaded:', path) def train(self): self._is_train = True self._do_log = True self.integrator = self.train_integrator # Shall be in eval() mode to freeze BN parameters self.PNet.eval() return self def val(self): self._is_train = False self._do_log = True self.PNet.eval() return self def test(self): self._is_train = False self._do_log = False self.PNet.eval() return self