示例#1
0
    def checkExportImport(self, model, input, check_value=True):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)
        print(model_code)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']

        with original_state_dict_hooks(converted_model):
            converted_model.load_state_dict(model.state_dict())

        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            self.assertEqual(len(converted_output), len(expected_output))
            for a, b in zip(converted_output, expected_output):
                if hasattr(a, 'dtype') and a.dtype == torch.bool:
                    self.assertEqual((a ^ b), False)
                elif isinstance((a - b), int):
                    self.assertEqual((a - b), 0)
                else:
                    self.assertLess((a - b).abs().max().item(), 1E-4)
        return converted_model
示例#2
0
    def run_test(self, model, input, check_value=True, strict_load=True):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)

        from .inject_nn import remove_inject_pytorch_nn
        remove_inject_pytorch_nn()

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']

        with original_state_dict_hooks(converted_model):
            converted_model.load_state_dict(model.state_dict(),
                                            strict=strict_load)

        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            try:
                self.assertEqual(len(converted_output), len(expected_output))
                for a, b in zip(converted_output, expected_output):
                    torch.eq(a, b)
            except:
                self.assertEqual(converted_output, expected_output)
        return converted_model
示例#3
0
def evaluate_acc(class_cls, criterion, args):
    model = class_cls()
    with original_state_dict_hooks(model):
        model.load_state_dict(load_and_parse_state_dict(args.checkpoint),
                              strict=False)
    model.cuda()

    if args.spos_preprocessing:
        train_trans = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.ColorJitter(brightness=0.4,
                                   contrast=0.4,
                                   saturation=0.4),
            transforms.RandomHorizontalFlip(0.5),
            ToBGRTensor()
        ])
    else:
        train_trans = transforms.Compose(
            [transforms.RandomResizedCrop(224),
             transforms.ToTensor()])
    val_trans = transforms.Compose(
        [transforms.RandomResizedCrop(224),
         ToBGRTensor()])
    train_dataset = datasets.ImageNet(args.imagenet_dir,
                                      split='train',
                                      transform=train_trans)
    val_dataset = datasets.ImageNet(args.imagenet_dir,
                                    split='val',
                                    transform=val_trans)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.train_batch_size,
        num_workers=args.workers,
        shuffle=True)
    test_loader = torch.utils.data.DataLoader(val_dataset,
                                              batch_size=args.test_batch_size,
                                              num_workers=args.workers,
                                              shuffle=True)

    acc_before = test_acc(model, criterion, args.log_frequency, test_loader)
    nni.report_intermediate_result(acc_before)

    retrain_bn(model, criterion, args.train_iters, args.log_frequency,
               train_loader)
    acc = test_acc(model, criterion, args.log_frequency, test_loader)
    assert isinstance(acc, float)
    nni.report_intermediate_result(acc)
    nni.report_final_result(acc)
示例#4
0
    def checkExportImport(self, model, input):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']
        with original_state_dict_hooks(converted_model):
            converted_model.load_state_dict(dict(model.state_dict()))
        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        self.assertEqual(len(converted_output), len(expected_output))
        for a, b in zip(converted_output, expected_output):
            self.assertLess((a - b).abs().max().item(), 1E-4)
        return converted_model
示例#5
0
    def checkExportImport(self, model, input, check_value=True):
        model_ir = self._convert_model(model, input)
        model_code = model_to_pytorch_script(model_ir)
        #print(model_code)

        exec_vars = {}
        exec(model_code + '\n\nconverted_model = _model()', exec_vars)
        converted_model = exec_vars['converted_model']

        with original_state_dict_hooks(converted_model):
            converted_model.load_state_dict(model.state_dict())

        with torch.no_grad():
            expected_output = model.eval()(*input)
            converted_output = converted_model.eval()(*input)
        if check_value:
            try:
                self.assertEqual(len(converted_output), len(expected_output))
                for a, b in zip(converted_output, expected_output):
                    torch.eq(a, b)
            except:
                self.assertEqual(converted_output, expected_output)
        return converted_model