def test(load_model_weight=False):
    if load_model_weight:
      if cfg.model_weight_file != '':
        map_location = (lambda storage, loc: storage)
        sd = torch.load(cfg.model_weight_file, map_location=map_location)
        load_state_dict(model, sd)
        print('Loaded model weights from {}'.format(cfg.model_weight_file))
      else:
        load_ckpt(modules_optims, cfg.ckpt_file)

    use_local_distance = (cfg.l_loss_weight > 0) \
                         and cfg.local_dist_own_hard_sample

    for test_set, name in zip(test_sets, test_set_names):
      test_set.set_feat_func(ExtractFeature(model_w, TVT))
      print('\n=========> Test on dataset: {} <=========\n'.format(name))
      test_set.eval(
        normalize_feat=cfg.normalize_feature,
        use_local_distance=use_local_distance)
Exemplo n.º 2
0
def prepare():
    model = Model(local_conv_out_channels)
    model_w = DataParallel(model)

    map_location = (lambda storage, loc: storage)
    sd = torch.load(model_weight_file, map_location=map_location)
    if 'state_dicts' in sd:
        sd = sd['state_dicts'][0]
    load_state_dict(model, sd)
    print('Loaded model weight from {}'.format(model_weight_file))

    TMO([model])

    extractor = ExtractFeature(model_w, TVT)
    preprocessor = PreProcessIm(resize_h_w,
                                scale=scale_im,
                                im_mean=im_mean,
                                im_std=im_std)

    return extractor, preprocessor
Exemplo n.º 3
0
    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                load_ckpt(modules_optims, cfg.ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            test_set.eval(normalize_feat=cfg.normalize_feature,
                          use_local_distance=use_local_distance)
    def test(load_model_weight=False):
        if load_model_weight:
            if cfg.model_weight_file != '':
                map_location = (lambda storage, loc: storage)
                sd = torch.load(cfg.model_weight_file,
                                map_location=map_location)
                load_state_dict(model, sd)
                print('Loaded model weights from {}'.format(
                    cfg.model_weight_file))
            else:
                # load_ckpt(modules_optims, cfg.ckpt_file)
                ckpt_file = '/home/eric/Disk100G/githubProject/AlignedReID-Re-Production-Pytorch/model/trainSet_conbined/Resnet-50/GL-0.7_LL-0.3_NNF_TWGD_EP-150_LDOHS-true_CP-0.3_CR-0.7_gm-lm-0.4_staircase_warm_up/200_ckpt.pth'
                load_ckpt(modules_optims, ckpt_file)

        use_local_distance = (cfg.l_loss_weight > 0) \
                             and cfg.local_dist_own_hard_sample

        for test_set, name in zip(test_sets, test_set_names):
            test_set.set_feat_func(ExtractFeature(model_w, TVT))
            print('\n=========> Test on dataset: {} <=========\n'.format(name))
            # test_set.volatile = True
            test_set.eval(normalize_feat=cfg.normalize_feature,
                          use_local_distance=False)
Exemplo n.º 5
0
    def __call__(self, ims):
        print(len(ims))
        old_train_eval_model = self.model.training
        # Set eval mode.
        # Force all BN layers to use global mean and variance, also disable
        # dropout.
        self.model.eval()
        ims = Variable(self.TVT(torch.from_numpy(ims).float()))
        global_feat, local_feat = self.model(ims)[:2]
        global_feat = global_feat.data.cpu().numpy()
        local_feat = local_feat.data.cpu().numpy()
        # Restore the model to its old train/eval mode.
        self.model.train(old_train_eval_model)
        return global_feat, local_feat


model = Model(local_conv_out_channels=128, num_classes=751)
# Model wrapper
model_w = DataParallel(model)

weight = './model_weight.pth'
map_location = (lambda storage, loc: storage)
sd = torch.load(weight, map_location=map_location)
load_state_dict(model, sd)
print('Loaded model weights from {}'.format(weight))
TVT, TMO = set_devices((0, ))

FeatureExtractor = ExtractFeature(model_w, TVT)
global_feat, local_feat = FeatureExtractor(XXX)