예제 #1
0
    def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1):
        self.para = para
        self.single_object = para['single_object']
        self.local_rank = local_rank

        self.PNet = nn.parallel.DistributedDataParallel(
            PropagationNetwork(self.single_object).cuda(), 
            device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False)

        # Setup logger when local_rank=0
        self.logger = logger
        self.save_path = save_path
        if logger is not None:
            self.last_time = time.time()
        self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size)
        if self.single_object:
            self.train_integrator.add_hook(iou_hooks_so)
        else:
            self.train_integrator.add_hook(iou_hooks_mo)
        self.loss_computer = LossComputer(para)

        self.train()
        self.optimizer = optim.Adam(filter(
            lambda p: p.requires_grad, self.PNet.parameters()), lr=para['lr'], weight_decay=1e-7)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, para['steps'], para['gamma'])

        # Logging info
        self.report_interval = 100
        self.save_im_interval = 800
        self.save_model_interval = 50000
        if para['debug']:
            self.report_interval = self.save_im_interval = 1
예제 #2
0
    def __init__(self,
                 para,
                 logger=None,
                 save_path=None,
                 local_rank=0,
                 world_size=1):
        self.para = para
        self.local_rank = local_rank

        self.S2M = nn.parallel.DistributedDataParallel(
            nn.SyncBatchNorm.convert_sync_batchnorm(
                deeplabv3plus_resnet50(num_classes=1,
                                       output_stride=16,
                                       pretrained_backbone=False)).cuda(),
            device_ids=[local_rank],
            output_device=local_rank,
            broadcast_buffers=False)

        # Setup logger when local_rank=0
        self.logger = logger
        self.save_path = save_path
        if logger is not None:
            self.last_time = time.time()
        self.train_integrator = Integrator(self.logger,
                                           distributed=True,
                                           local_rank=local_rank,
                                           world_size=world_size)
        self.train_integrator.add_hook(iou_hooks_to_be_used)
        self.loss_computer = LossComputer(para)

        self.train()
        self.optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                           self.S2M.parameters()),
                                    lr=para['lr'],
                                    weight_decay=1e-7)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, para['steps'], para['gamma'])

        # Logging info
        self.report_interval = 50
        self.save_im_interval = 800
        self.save_model_interval = 20000
        if para['debug']:
            self.report_interval = self.save_im_interval = 1
예제 #3
0
                          pin_memory=True)

sobel_compute = SobelComputer()

# Learning rate decay scheduling
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, para['steps'],
                                           para['gamma'])

saver = ModelSaver(long_id)
report_interval = 50
save_im_interval = 800

total_epoch = int(para['iterations'] / len(train_loader) + 0.5)
print('Actual training epoch: ', total_epoch)

train_integrator = Integrator(logger)
train_integrator.add_hook(iou_hooks_to_be_used)
total_iter = 0
last_time = 0
for e in range(total_epoch):
    np.random.seed()  # reset seed
    epoch_start_time = time.time()

    # Train loop
    model = model.train()
    for im, seg, gt in train_loader:
        im, seg, gt = im.cuda(), seg.cuda(), gt.cuda()

        total_iter += 1
        if total_iter % 5000 == 0:
            saver.save_model(model, total_iter)
예제 #4
0
class FusionModel:
    def __init__(self,
                 para,
                 logger=None,
                 save_path=None,
                 local_rank=0,
                 world_size=1,
                 distributed=True):
        self.para = para
        self.local_rank = local_rank

        if distributed:
            self.net = nn.parallel.DistributedDataParallel(
                FusionNet().cuda(),
                device_ids=[local_rank],
                output_device=local_rank,
                broadcast_buffers=False)
        else:
            self.net = nn.DataParallel(FusionNet().cuda(),
                                       device_ids=[local_rank],
                                       output_device=local_rank)

        self.prop_net = AttentionReadNetwork().eval().cuda()

        # Setup logger when local_rank=0
        self.logger = logger
        self.save_path = save_path
        if logger is not None:
            self.last_time = time.time()
        self.train_integrator = Integrator(self.logger,
                                           distributed=distributed,
                                           local_rank=local_rank,
                                           world_size=world_size)
        self.train_integrator.add_hook(iou_hooks)
        self.val_integrator = Integrator(self.logger,
                                         distributed=distributed,
                                         local_rank=local_rank,
                                         world_size=world_size)
        self.loss_computer = LossComputer(para)

        self.train()
        self.optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                           self.net.parameters()),
                                    lr=para['lr'],
                                    weight_decay=1e-7)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, para['steps'], para['gamma'])

        # Logging info
        self.report_interval = 100
        self.save_im_interval = 500
        self.save_model_interval = 5000
        if para['debug']:
            self.report_interval = self.save_im_interval = 1

    def do_pass(self, data, it=0):
        # No need to store the gradient outside training
        torch.set_grad_enabled(self._is_train)

        for k, v in data.items():
            if type(v) != list and type(v) != dict and type(v) != int:
                data[k] = v.cuda(non_blocking=True)

        # See fusion_dataset.py for variable definitions
        im = data['rgb']

        seg1 = data['seg1']
        seg2 = data['seg2']
        src2_ref = data['src2_ref']
        src2_ref_gt = data['src2_ref_gt']

        seg12 = data['seg12']
        seg22 = data['seg22']
        src2_ref2 = data['src2_ref2']
        src2_ref_gt2 = data['src2_ref_gt2']

        src2_ref_im = data['src2_ref_im']
        selector = data['selector']
        dist = data['dist']

        out = {}
        # Get kernelized memory
        with torch.no_grad():
            attn1, attn2 = self.prop_net(src2_ref_im, src2_ref, src2_ref_gt,
                                         src2_ref2, src2_ref_gt2, im)

        prob1 = torch.sigmoid(self.net(im, seg1, seg2, attn1, dist))
        prob2 = torch.sigmoid(self.net(im, seg12, seg22, attn2, dist))
        prob = torch.cat([prob1, prob2],
                         1) * selector.unsqueeze(2).unsqueeze(2)
        logits, prob = aggregate_wbg_channel(prob, True)

        out['logits'] = logits
        out['mask'] = prob
        out['attn1'] = attn1
        out['attn2'] = attn2

        if self._do_log or self._is_train:
            losses = self.loss_computer.compute({**data, **out}, it)

            # Logging
            if self._do_log:
                self.integrator.add_dict(losses)
                if self._is_train:
                    if it % self.save_im_interval == 0 and it != 0:
                        if self.logger is not None:
                            images = {**data, **out}
                            size = (320, 320)
                            self.logger.log_cv2('train/pairs',
                                                pool_fusion(images, size=size),
                                                it)
                else:
                    # Validation save
                    if data['val_iter'] % 10 == 0:
                        if self.logger is not None:
                            images = {**data, **out}
                            size = (320, 320)
                            self.logger.log_cv2('val/pairs',
                                                pool_fusion(images, size=size),
                                                it)

        if self._is_train:
            if (it) % self.report_interval == 0 and it != 0:
                if self.logger is not None:
                    self.logger.log_scalar('train/lr',
                                           self.scheduler.get_last_lr()[0], it)
                    self.logger.log_metrics('train', 'time',
                                            (time.time() - self.last_time) /
                                            self.report_interval, it)
                self.last_time = time.time()
                self.train_integrator.finalize('train', it)
                self.train_integrator.reset_except_hooks()

            if it % self.save_model_interval == 0 and it != 0:
                if self.logger is not None:
                    self.save(it)

            # Backward pass
            self.optimizer.zero_grad(set_to_none=True)
            losses['total_loss'].backward()
            self.optimizer.step()
            self.scheduler.step()

    def save(self, it):
        if self.save_path is None:
            print('Saving has been disabled.')
            return

        os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
        model_path = self.save_path + ('_%s.pth' % it)
        torch.save(self.net.module.state_dict(), model_path)
        print('Model saved to %s.' % model_path)

        self.save_checkpoint(it)

    def save_checkpoint(self, it):
        if self.save_path is None:
            print('Saving has been disabled.')
            return

        os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
        checkpoint_path = self.save_path + '_checkpoint.pth'
        checkpoint = {
            'it': it,
            'network': self.net.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)

        print('Checkpoint saved to %s.' % checkpoint_path)

    def load_model(self, path):
        map_location = 'cuda:%d' % self.local_rank
        checkpoint = torch.load(path, map_location={'cuda:0': map_location})

        it = checkpoint['it']
        network = checkpoint['network']
        optimizer = checkpoint['optimizer']
        scheduler = checkpoint['scheduler']

        map_location = 'cuda:%d' % self.local_rank
        self.net.module.load_state_dict(network)
        self.optimizer.load_state_dict(optimizer)
        self.scheduler.load_state_dict(scheduler)

        print('Model loaded.')

        return it

    def load_network(self, path):
        map_location = 'cuda:%d' % self.local_rank
        self.net.module.load_state_dict(
            torch.load(path, map_location={'cuda:0': map_location}))
        # self.net.load_state_dict(torch.load(path))
        print('Network weight loaded:', path)

    def load_prop(self, path):
        map_location = 'cuda:%d' % self.local_rank
        self.prop_net.load_state_dict(torch.load(
            path, map_location={'cuda:0': map_location}),
                                      strict=False)
        print('Propagation network weight loaded:', path)

    def finalize_val(self, it):
        self.val_integrator.finalize('val', it)
        self.val_integrator.reset_except_hooks()

    def train(self):
        self._is_train = True
        self._do_log = True
        self.integrator = self.train_integrator
        # Also skip BN
        self.net.eval()
        self.prop_net.eval()
        return self

    def val(self):
        self._is_train = False
        self.integrator = self.val_integrator
        self._do_log = True
        self.net.eval()
        self.prop_net.eval()
        return self

    def test(self):
        self._is_train = False
        self._do_log = False
        self.net.eval()
        self.prop_net.eval()
        return self
예제 #5
0
class S2MModel:
    def __init__(self,
                 para,
                 logger=None,
                 save_path=None,
                 local_rank=0,
                 world_size=1):
        self.para = para
        self.local_rank = local_rank

        self.S2M = nn.parallel.DistributedDataParallel(
            nn.SyncBatchNorm.convert_sync_batchnorm(
                deeplabv3plus_resnet50(num_classes=1,
                                       output_stride=16,
                                       pretrained_backbone=False)).cuda(),
            device_ids=[local_rank],
            output_device=local_rank,
            broadcast_buffers=False)

        # Setup logger when local_rank=0
        self.logger = logger
        self.save_path = save_path
        if logger is not None:
            self.last_time = time.time()
        self.train_integrator = Integrator(self.logger,
                                           distributed=True,
                                           local_rank=local_rank,
                                           world_size=world_size)
        self.train_integrator.add_hook(iou_hooks_to_be_used)
        self.loss_computer = LossComputer(para)

        self.train()
        self.optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                           self.S2M.parameters()),
                                    lr=para['lr'],
                                    weight_decay=1e-7)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer, para['steps'], para['gamma'])

        # Logging info
        self.report_interval = 50
        self.save_im_interval = 800
        self.save_model_interval = 20000
        if para['debug']:
            self.report_interval = self.save_im_interval = 1

    def do_pass(self, data, it=0):
        # No need to store the gradient outside training
        torch.set_grad_enabled(self._is_train)

        for k, v in data.items():
            if type(v) != list and type(v) != dict and type(v) != int:
                data[k] = v.cuda(non_blocking=True)

        out = {}
        Fs = data['rgb']
        Ss = data['seg']
        Rs = data['srb']
        Ms = data['gt']

        inputs = torch.cat([Fs, Ss, Rs], 1)
        prob = torch.sigmoid(self.S2M(inputs))
        logits, mask = aggregate(prob)

        out['logits'] = logits
        out['mask'] = mask

        if self._do_log or self._is_train:
            losses = self.loss_computer.compute({**data, **out}, it)

            # Logging
            if self._do_log:
                self.integrator.add_dict(losses)
                if self._is_train:
                    if it % self.save_im_interval == 0 and it != 0:
                        if self.logger is not None:
                            images = {**data, **out}
                            size = (384, 384)
                            self.logger.log_cv2('train/pairs',
                                                pool_pairs(images, size=size),
                                                it)

        if self._is_train:
            if (it) % self.report_interval == 0 and it != 0:
                if self.logger is not None:
                    self.logger.log_scalar('train/lr',
                                           self.scheduler.get_last_lr()[0], it)
                    self.logger.log_metrics('train', 'time',
                                            (time.time() - self.last_time) /
                                            self.report_interval, it)
                self.last_time = time.time()
                self.train_integrator.finalize('train', it)
                self.train_integrator.reset_except_hooks()

            if it % self.save_model_interval == 0 and it != 0:
                if self.logger is not None:
                    self.save(it)

            # Backward pass
            self.optimizer.zero_grad()
            losses['total_loss'].backward()
            self.optimizer.step()
            self.scheduler.step()

    def save(self, it):
        if self.save_path is None:
            print('Saving has been disabled.')
            return

        os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
        model_path = self.save_path + ('_%s.pth' % it)
        torch.save(self.S2M.module.state_dict(), model_path)
        print('Model saved to %s.' % model_path)

        self.save_checkpoint(it)

    def save_checkpoint(self, it):
        if self.save_path is None:
            print('Saving has been disabled.')
            return

        os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
        checkpoint_path = self.save_path + '_checkpoint.pth'
        checkpoint = {
            'it': it,
            'network': self.S2M.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)

        print('Checkpoint saved to %s.' % checkpoint_path)

    def load_model(self, path):
        map_location = 'cuda:%d' % self.local_rank
        checkpoint = torch.load(path, map_location={'cuda:0': map_location})

        it = checkpoint['it']
        network = checkpoint['network']
        optimizer = checkpoint['optimizer']
        scheduler = checkpoint['scheduler']

        map_location = 'cuda:%d' % self.local_rank
        self.S2M.module.load_state_dict(network)
        self.optimizer.load_state_dict(optimizer)
        self.scheduler.load_state_dict(scheduler)

        print('Model loaded.')

        return it

    def load_network(self, path):
        map_location = 'cuda:%d' % self.local_rank
        self.S2M.module.load_state_dict(
            torch.load(path, map_location={'cuda:0': map_location}))
        print('Network weight loaded:', path)

    def load_deeplab(self, path):
        map_location = 'cuda:%d' % self.local_rank

        cur_dict = self.S2M.module.state_dict()
        src_dict = torch.load(path, map_location={'cuda:0':
                                                  map_location})['model_state']

        for k in list(src_dict.keys()):
            if type(src_dict[k]) is not int:
                if src_dict[k].shape != cur_dict[k].shape:
                    print('Reloading: ', k)
                    if 'bias' in k:
                        # Reseting the class prob bias
                        src_dict[k] = torch.zeros_like((src_dict[k][0:1]))
                    elif src_dict[k].shape[1] != 3:
                        # Reseting the class prob weight
                        src_dict[k] = torch.zeros_like((src_dict[k][0:1]))
                        nn.init.orthogonal_(src_dict[k])
                    else:
                        # Adding the mask and scribbles channel
                        pads = torch.zeros((64, 3, 7, 7),
                                           device=src_dict[k].device)
                        nn.init.orthogonal_(pads)
                        src_dict[k] = torch.cat([src_dict[k], pads], 1)

        self.S2M.module.load_state_dict(src_dict)
        print('Network weight loaded:', path)

    def train(self):
        self._is_train = True
        self._do_log = True
        self.integrator = self.train_integrator
        self.S2M.train()
        return self

    def val(self):
        self._is_train = False
        self._do_log = True
        self.S2M.eval()
        return self

    def test(self):
        self._is_train = False
        self._do_log = False
        self.S2M.eval()
        return self
예제 #6
0
class PropagationModel:
    def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1):
        self.para = para
        self.single_object = para['single_object']
        self.local_rank = local_rank

        self.PNet = nn.parallel.DistributedDataParallel(
            PropagationNetwork(self.single_object).cuda(), 
            device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False)

        # Setup logger when local_rank=0
        self.logger = logger
        self.save_path = save_path
        if logger is not None:
            self.last_time = time.time()
        self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size)
        if self.single_object:
            self.train_integrator.add_hook(iou_hooks_so)
        else:
            self.train_integrator.add_hook(iou_hooks_mo)
        self.loss_computer = LossComputer(para)

        self.train()
        self.optimizer = optim.Adam(filter(
            lambda p: p.requires_grad, self.PNet.parameters()), lr=para['lr'], weight_decay=1e-7)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, para['steps'], para['gamma'])

        # Logging info
        self.report_interval = 100
        self.save_im_interval = 800
        self.save_model_interval = 50000
        if para['debug']:
            self.report_interval = self.save_im_interval = 1

    def do_pass(self, data, it=0):
        # No need to store the gradient outside training
        torch.set_grad_enabled(self._is_train)

        for k, v in data.items():
            if type(v) != list and type(v) != dict and type(v) != int:
                data[k] = v.cuda(non_blocking=True)

        out = {}
        Fs = data['rgb']
        Ms = data['gt']
        
        if self.single_object:
            key_k, key_v = self.PNet(Fs[:,0], Ms[:,0])
            prev_logits, prev_mask = self.PNet(Fs[:,1], key_k, key_v)
            prev_k, prev_v = self.PNet(Fs[:,1], prev_mask)

            keys = torch.cat([key_k, prev_k], 2)
            values = torch.cat([key_v, prev_v], 2)
            this_logits, this_mask = self.PNet(Fs[:,2], keys, values)

            out['mask_1'] = prev_mask
            out['mask_2'] = this_mask
            out['logits_1'] = prev_logits
            out['logits_2'] = this_logits
        else:
            sec_Ms = data['sec_gt']
            selector = data['selector']

            key_k1, key_v1 = self.PNet(Fs[:,0], Ms[:,0], sec_Ms[:,0])
            key_k2, key_v2 = self.PNet(Fs[:,0], sec_Ms[:,0], Ms[:,0])
            key_k = torch.stack([key_k1, key_k2], 1)
            key_v = torch.stack([key_v1, key_v2], 1)

            prev_logits, prev_mask = self.PNet(Fs[:,1], key_k, key_v, selector)
            
            prev_k1, prev_v1 = self.PNet(Fs[:,1], prev_mask[:,0:1], prev_mask[:,1:2])
            prev_k2, prev_v2 = self.PNet(Fs[:,1], prev_mask[:,1:2], prev_mask[:,0:1])
            prev_k = torch.stack([prev_k1, prev_k2], 1)
            prev_v = torch.stack([prev_v1, prev_v2], 1)

            keys = torch.cat([key_k, prev_k], 3)
            values = torch.cat([key_v, prev_v], 3)

            this_logits, this_mask = self.PNet(Fs[:,2], keys, values, selector)

            out['mask_1'] = prev_mask[:,0:1]
            out['mask_2'] = this_mask[:,0:1]
            out['sec_mask_1'] = prev_mask[:,1:2]
            out['sec_mask_2'] = this_mask[:,1:2]

            out['logits_1'] = prev_logits
            out['logits_2'] = this_logits

        if self._do_log or self._is_train:
            losses = self.loss_computer.compute({**data, **out}, it)

            # Logging
            if self._do_log:
                self.integrator.add_dict(losses)
                if self._is_train:
                    if it % self.save_im_interval == 0 and it != 0:
                        if self.logger is not None:
                            images = {**data, **out}
                            size = (384, 384)
                            self.logger.log_cv2('train/pairs', pool_pairs(images, size, self.single_object), it)

        if self._is_train:
            if (it) % self.report_interval == 0 and it != 0:
                if self.logger is not None:
                    self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it)
                    self.logger.log_metrics('train', 'time', (time.time()-self.last_time)/self.report_interval, it)
                self.last_time = time.time()
                self.train_integrator.finalize('train', it)
                self.train_integrator.reset_except_hooks()

            if it % self.save_model_interval == 0 and it != 0:
                if self.logger is not None:
                    self.save(it)

            # Backward pass
            for param_group in self.optimizer.param_groups:
                for p in param_group['params']:
                    p.grad = None
            losses['total_loss'].backward() 
            self.optimizer.step()
            self.scheduler.step()

    def save(self, it):
        if self.save_path is None:
            print('Saving has been disabled.')
            return
        
        os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
        model_path = self.save_path + ('_%s.pth' % it)
        torch.save(self.PNet.module.state_dict(), model_path)
        print('Model saved to %s.' % model_path)

        self.save_checkpoint(it)

    def save_checkpoint(self, it):
        if self.save_path is None:
            print('Saving has been disabled.')
            return

        os.makedirs(os.path.dirname(self.save_path), exist_ok=True)
        checkpoint_path = self.save_path + '_checkpoint.pth'
        checkpoint = { 
            'it': it,
            'network': self.PNet.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict()}
        torch.save(checkpoint, checkpoint_path)

        print('Checkpoint saved to %s.' % checkpoint_path)

    def load_model(self, path):
        map_location = 'cuda:%d' % self.local_rank
        checkpoint = torch.load(path, map_location={'cuda:0': map_location})

        it = checkpoint['it']
        network = checkpoint['network']
        optimizer = checkpoint['optimizer']
        scheduler = checkpoint['scheduler']

        map_location = 'cuda:%d' % self.local_rank
        self.PNet.module.load_state_dict(network)
        self.optimizer.load_state_dict(optimizer)
        self.scheduler.load_state_dict(scheduler)

        print('Model loaded.')

        return it

    def load_network(self, path):
        map_location = 'cuda:%d' % self.local_rank
        src_dict = torch.load(path, map_location={'cuda:0': map_location})

        # Maps SO weight (without other_mask) to MO weight (with other_mask)
        for k in list(src_dict.keys()):
            if k == 'mask_rgb_encoder.conv1.weight':
                if src_dict[k].shape[1] == 4:
                    pads = torch.zeros((64,1,7,7), device=src_dict[k].device)
                    nn.init.orthogonal_(pads)
                    src_dict[k] = torch.cat([src_dict[k], pads], 1)

        self.PNet.module.load_state_dict(src_dict)
        print('Network weight loaded:', path)

    def train(self):
        self._is_train = True
        self._do_log = True
        self.integrator = self.train_integrator
        # Shall be in eval() mode to freeze BN parameters
        self.PNet.eval()
        return self

    def val(self):
        self._is_train = False
        self._do_log = True
        self.PNet.eval()
        return self

    def test(self):
        self._is_train = False
        self._do_log = False
        self.PNet.eval()
        return self