示例#1
0
    def __init__(self, config: BatchProcessorConfigType, *args, **kwargs):
        extra_params = {"data_dir": get_mmf_env(key="data_dir")}
        processors_dict = config.get("processors", {})

        # Since build_processors also imports processor, import it at runtime to
        # avoid circulat dependencies
        from mmf.utils.build import build_processors

        self.processors = build_processors(processors_dict, **extra_params)
示例#2
0
    def _build_model(self):
        self.model_items = load_pretrained_model(self.checkpoint)
        self.config = OmegaConf.create(self.model_items["full_config"])
        dataset_name = list(self.config.dataset_config.keys())[0]
        processor = build_processors(
            self.config.dataset_config[dataset_name].processors)
        feature_extractor = build_encoder(
            self.model_items["config"].image_feature_encodings)
        ckpt = self.model_items["checkpoint"]
        model = build_model(self.model_items["config"])
        model.load_state_dict(ckpt)

        return processor, feature_extractor, model
示例#3
0
    def init_processors(self):
        if not hasattr(self.config, "processors"):
            return

        from mmf.utils.build import build_processors

        extra_params = {"data_dir": self.config.data_dir}
        reg_key = f"{self._dataset_name}_{{}}"
        processor_dict = build_processors(self.config.processors, reg_key,
                                          **extra_params)
        for processor_key, processor_instance in processor_dict.items():
            setattr(self, processor_key, processor_instance)
            full_key = reg_key.format(processor_key)
            registry.register(full_key, processor_instance)
示例#4
0
文件: handler.py 项目: nskool/serve
    def initialize(self, ctx):
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.map_location + ":" +
                                   str(properties.get("gpu_id")) if torch.cuda.
                                   is_available() else self.map_location)

        # reading the csv file which include all the labels in the dataset to make the class/index mapping
        # and matching the output of the model with num labels from dataset
        df = pd.read_csv('./charades_action_lables.csv')
        label_set = set()
        df['action_labels'] = df['action_labels'].str.replace('"', '')
        labels_initial = df['action_labels'].tolist()
        labels = []
        for sublist in labels_initial:
            new_sublist = ast.literal_eval(sublist)
            labels.append(new_sublist)
            for item in new_sublist:
                label_set.add(item)
        classes = sorted(list(label_set))
        self.class_to_idx = {classes[i]: i for i in range(len(classes))}
        self.classes = classes
        self.labels = labels
        self.idx_to_class = classes
        config = OmegaConf.load('config.yaml')
        print("*********** config keyssss **********", config.keys())
        setup_very_basic_config()
        setup_imports()
        self.model = MMFTransformer(config.model_config.mmf_transformer)
        self.model.build()
        self.model.init_losses()
        self.processor = build_processors(
            config.dataset_config["charades"].processors)
        state_dict = torch.load(serialized_file, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()
        self.initialized = True
        print(
            "********* files in temp direcotry that .mar file got extracted *********",
            os.listdir(model_dir))
示例#5
0
 def init_processors(self):
     config = self.config.dataset_config.hateful_memes
     extra_params = {"data_dir": config.data_dir}
     self.processor_dict = build_processors(config.processors,
                                            **extra_params)