Ejemplo n.º 1
0
    def load_pretrained(self,
                        checkpoint=None,
                        fields=None,
                        ignore_fields=None,
                        load_constructor=False):
        """ Loads a pre-trained network parameter from a checkpoint file. """
        from ltr.admin import loading, multigpu

        net = self.actor.net.module if multigpu.is_multi_gpu(
            self.actor.net) else self.actor.net

        net_type = type(net).__name__

        # Load network
        print("load from: {}".format(checkpoint))
        checkpoint_dict = loading.torch_load_legacy(checkpoint)

        assert net_type == checkpoint_dict[
            'net_type'], 'Network is not of correct type.'

        if fields is None:
            fields = checkpoint_dict.keys()

        # Load all fields
        for key in fields:
            if key == 'net':
                net.load_state_dict(checkpoint_dict[key])

        return True
Ejemplo n.º 2
0
 def load_weights(self, checkpoint = None):
     """Loads network weights"""
     net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net
     # Load network
     checkpoint_dict = loading.torch_load_legacy(checkpoint)
     net.load_state_dict(checkpoint_dict['net'], strict=False)
     return True
Ejemplo n.º 3
0
def steepest_descent_learn_filter_resnet50_newiou(filter_size=1, optim_iter=3, optim_init_step=1.0, optim_init_reg=0.01, output_activation=None,
                                 classification_layer='layer3', backbone_pretrained=False, clf_feat_blocks=1,
                                 clf_feat_norm=True, init_filter_norm=False, final_conv=False,
                                 out_feature_dim=256, init_gauss_sigma=1.0, num_dist_bins=5, bin_displacement=1.0, test_loss=None,
                                           mask_init_factor=4.0, iou_input_dim=(256,256), iou_inter_dim=(256,256),
                                                  jitter_sigma_factor=None):
    # backbone
    backbone_net = backbones.resnet50(pretrained=backbone_pretrained)

    norm_scale = math.sqrt(1.0 / (out_feature_dim * filter_size * filter_size))

    
    # classifier
    clf_feature_extractor = clf_features.residual_bottleneck_comb(num_blocks=clf_feat_blocks, l2norm=clf_feat_norm,
                                                              final_conv=final_conv, norm_scale=norm_scale,
                                                              out_dim=out_feature_dim)
    initializer = clf_initializer.FilterInitializerLinear(filter_size=filter_size, filter_norm=init_filter_norm, feature_dim=out_feature_dim)
    optimizer = clf_optimizer.SteepestDescentLearn(num_iter=optim_iter, filter_size=filter_size, init_step_length=optim_init_step,
                                                   init_filter_reg=optim_init_reg, feature_dim=out_feature_dim,
                                                   init_gauss_sigma=init_gauss_sigma, num_dist_bins=num_dist_bins,
                                                   bin_displacement=bin_displacement, test_loss=test_loss, mask_init_factor=mask_init_factor)
    classifier = target_clf.LinearFilter(filter_size=filter_size, filter_initializer=initializer,
                                         filter_optimizer=optimizer, feature_extractor=clf_feature_extractor,
                                         output_activation=output_activation, jitter_sigma_factor=jitter_sigma_factor)    
    # Bounding box regressor
    # combine RGB and TIR by 2*
    bb_regressor = bbmodels.AtomIoUNet(input_dim=(4*128,4*256), pred_input_dim=iou_input_dim, pred_inter_dim=iou_inter_dim)
    # load pretrained model
    pretrainmodel_path='/home/lichao/projects/pytracking_lichao/pytracking/DiMP_nets/sdlearn_300_onlytestloss_lr_causal_mg30_iou_nocf_res50_lfilt512_coco/OptimTracker_ep0040.pth.tar'
    pretrainmodel = loading.torch_load_legacy(pretrainmodel_path)['net']
    usepretrain = True; updback = True; updcls = True; updbb = True

    if usepretrain:
        if updback:
            # update backbone
            backbone_dict = backbone_net.state_dict()
            pretrain_dict = {k[len('feature_extractor.'):]: v for k, v in pretrainmodel.items() if k[len('feature_extractor.'):] in backbone_dict}
            backbone_net.load_state_dict(pretrain_dict)

        if updcls:
            # update classifier
            pretrainmodel['classifier.feature_extractor.0.weight']=torch.cat((pretrainmodel['classifier.feature_extractor.0.weight'],pretrainmodel['classifier.feature_extractor.0.weight']),1)
            classifier_dict = classifier.state_dict()
            pretrain_dict = {k[len('classifier.'):]: v for k, v in pretrainmodel.items() if k[len('classifier.'):] in classifier_dict}
            #classifier_dict.update(pretrain_dict)
            classifier.load_state_dict(pretrain_dict)
        if updbb:
            # update Bounding box regressor
            
            bb_regressor_dict = bb_regressor.state_dict()
            pretrain_dict = {k[len('bb_regressor.'):]: v for k, v in pretrainmodel.items() if k[len('bb_regressor.'):] in bb_regressor_dict}
            bb_regressor.load_state_dict(pretrain_dict)

    net = OptimTracker(feature_extractor=backbone_net, classifier=classifier, bb_regressor=bb_regressor,
                       classification_layer=classification_layer, bb_regressor_layer=['layer2', 'layer3'])
    return net
Ejemplo n.º 4
0
    def load_checkpoint(self,
                        checkpoint=None,
                        fields=None,
                        ignore_fields=None,
                        load_constructor=False):
        """Loads a network checkpoint file.

        Can be called in three different ways:
            load_checkpoint():
                Loads the latest epoch from the workspace. Use this to continue training.
            load_checkpoint(epoch_num):
                Loads the network at the given epoch number (int).
            load_checkpoint(path_to_checkpoint):
                Loads the file from the given absolute path (str).
        """

        net = self.actor.net.module if multigpu.is_multi_gpu(
            self.actor.net) else self.actor.net

        actor_type = type(self.actor).__name__
        net_type = type(net).__name__

        if checkpoint is None:
            # Load most recent checkpoint
            # checkpoint_list = sorted(glob.glob('{}/{}/{}_ep*.pth.tar'.format(self._checkpoint_dir,
            #                                                                  self.settings.project_path, net_type)))
            checkpoint_list = sorted(
                glob.glob('{}/{}/{}_ep*.pth.tar'.format(
                    self._checkpoint_dir, self.settings.project_path,
                    net_type)))
            print('{}/{}/{}_ep*.pth.tar'.format(self._checkpoint_dir,
                                                self.settings.project_path,
                                                net_type))
            if checkpoint_list:
                checkpoint_path = checkpoint_list[-1]
            else:
                print('No matching checkpoint file found')
                return
        elif isinstance(checkpoint, int):
            # Checkpoint is the epoch number
            checkpoint_path = '{}/{}/{}_ep{:04d}.pth.tar'.format(
                self._checkpoint_dir, self.settings.project_path, net_type,
                checkpoint)
        elif isinstance(checkpoint, str):
            # checkpoint is the path
            if os.path.isdir(checkpoint):
                checkpoint_list = sorted(
                    glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
                if checkpoint_list:
                    checkpoint_path = checkpoint_list[-1]
                else:
                    raise Exception('No checkpoint found')
            else:
                checkpoint_path = os.path.expanduser(checkpoint)
        else:
            raise TypeError

        # Load network
        checkpoint_dict = loading.torch_load_legacy(checkpoint_path)

        # assert net_type == checkpoint_dict['net_type'], 'Network is not of correct type.'

        if fields is None:
            fields = checkpoint_dict.keys()
        if ignore_fields is None:
            ignore_fields = ['settings']

            # Never load the scheduler. It exists in older checkpoints.
        ignore_fields.extend([
            'lr_scheduler', 'constructor', 'net_type', 'actor_type', 'net_info'
        ])

        # Load all fields
        for key in fields:
            if key in ignore_fields:
                continue
            if key == 'net':
                # net.load_state_dict(checkpoint_dict[key])

                # Song for merge layer
                model_dict = net.state_dict()
                pretrained_dict = {
                    k: v
                    for k, v in checkpoint_dict[key].items() if k in model_dict
                }
                model_dict.update(pretrained_dict)
                net.load_state_dict(model_dict)

            elif key == 'optimizer':
                try:
                    self.optimizer.load_state_dict(checkpoint_dict[key])
                except:
                    continue
            else:
                setattr(self, key, checkpoint_dict[key])

        # Set the net info
        if load_constructor and 'constructor' in checkpoint_dict and checkpoint_dict[
                'constructor'] is not None:
            net.constructor = checkpoint_dict['constructor']
        if 'net_info' in checkpoint_dict and checkpoint_dict[
                'net_info'] is not None:
            net.info = checkpoint_dict['net_info']

        # Update the epoch in lr scheduler
        if 'epoch' in fields:
            self.lr_scheduler.last_epoch = self.epoch

        return True
Ejemplo n.º 5
0
import sys

sys.path.append('../../pytracking-rgbd/')
sys.path.append('../../pytracking-rgbd/pytracking')
sys.path.append('../../pytracking-rgbd/ltr')

from ltr.admin.loading import torch_load_legacy
import torch

if __name__ == '__main__':
    ''' We train networks one the Machince with Torch 1.7.1, but we want to test on torch 1.4.0 '''

    net_path = '/home/yan/Data2/DeT-models/DeT_ATOM_Mean.pth.tar'
    checkpoints = torch_load_legacy(net_path)
    torch.save(checkpoints, net_path, _use_new_zipfile_serialization=False)
Ejemplo n.º 6
0
    def load_checkpoint(self, checkpoint = None, fields = None, ignore_fields = None, load_constructor = False):
        """Loads a network checkpoint file.

        Can be called in three different ways:
            load_checkpoint():
                Loads the latest epoch from the workspace. Use this to continue training.
            load_checkpoint(epoch_num):
                Loads the network at the given epoch number (int).
            load_checkpoint(path_to_checkpoint):
                Loads the file from the given absolute path (str).
        """

        net = self.actor.net.module if multigpu.is_multi_gpu(self.actor.net) else self.actor.net

        actor_type = type(self.actor).__name__
        net_type = type(net).__name__

        if checkpoint is None:
            # Load most recent checkpoint
            checkpoint_list = sorted(glob.glob('{}/{}/{}_ep*.pth.tar'.format(self._checkpoint_dir,
                                                                             self.settings.project_path, net_type)))
            print(checkpoint_list)

            if checkpoint_list:
                checkpoint_path = checkpoint_list[-1]
            else:
                print('No matching checkpoint file found')
                return
        elif isinstance(checkpoint, int):
            # Checkpoint is the epoch number
            checkpoint_path = '{}/{}/{}_ep{:04d}.pth.tar'.format(self._checkpoint_dir, self.settings.project_path,
                                                                 net_type, checkpoint)
        elif isinstance(checkpoint, str):
            # checkpoint is the path
            if os.path.isdir(checkpoint):
                checkpoint_list = sorted(glob.glob('{}/*_ep*.pth.tar'.format(checkpoint)))
                if checkpoint_list:
                    checkpoint_path = checkpoint_list[-1]
                else:
                    raise Exception('No checkpoint found')
            else:
                checkpoint_path = os.path.expanduser(checkpoint)
        else:
            raise TypeError

        # Load network
        checkpoint_dict = loading.torch_load_legacy(checkpoint_path)

        #print([net_type, checkpoint_dict['net_type']])['DiMPnet_rgbd', 'DiMPnet']

        assert checkpoint_dict['net_type'] in net_type, 'Network is not of correct type.'

        if fields is None:
            fields = checkpoint_dict.keys()
        if ignore_fields is None:
            ignore_fields = ['settings']

            # Never load the scheduler. It exists in older checkpoints.
        ignore_fields.extend(['lr_scheduler', 'constructor', 'net_type', 'actor_type', 'net_info'])

        # Load all fields
        for key in fields:
            if key in ignore_fields:
                continue
            if key == 'net':
                #print(checkpoint_dict[key].keys())
                net.load_state_dict(checkpoint_dict[key], strict=False)
            elif key == 'optimizer':# and 'dimp50.pth' not in checkpoint_path:
                # param_lens = (len(g['params']) for g in self.optimizer.param_groups)
                # saved_lens = (len(g['params']) for g in checkpoint_dict[key]['param_groups'])
                # # for p_len,s_len in zip(param_lens, saved_lens):
                # #     print([p_len, s_len])
                # print(self.optimizer.param_groups[3]['params'][0].shape)
                # if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
                #     raise ValueError("loaded state dict contains a parameter group "
                #                      "that doesn't match the size of optimizer's group")
                self.optimizer.load_state_dict(checkpoint_dict[key])
            else:
                setattr(self, key, checkpoint_dict[key])

        # Set the net info
        if load_constructor and 'constructor' in checkpoint_dict and checkpoint_dict['constructor'] is not None:
            net.constructor = checkpoint_dict['constructor']
        if 'net_info' in checkpoint_dict and checkpoint_dict['net_info'] is not None:
            net.info = checkpoint_dict['net_info']

        # Update the epoch in lr scheduler
        if 'epoch' in fields:
            self.lr_scheduler.last_epoch = self.epoch

        return True