예제 #1
0
  def __init__(self, app_path):
    
    self.namespace = app_path
    self.name = app_path
    self.verbose_name = util.humanize(self.name)
    
    self._module = util.import_from_string(app_path)
    self._urls_module = util.import_from_string(app_path + '.urls', True)
    self._models_module = util.import_from_string(app_path + '.models', True)
    self._views_module = util.import_from_string(app_path + '.views', True)
    
    self.models = RozModel.get_models(self, self._models_module)
    self.routes = RozEngine.Engine.app_routes


    if hasattr(self._module, 'RozMeta'):
      meta = getattr(self._module, 'RozMeta')
      RozMetaDefault.set_defaults(meta)
    else:
      meta = RozMetaDefault
      
    self._module.RozMeta = meta
    
    for key, value in vars(meta).iteritems():
      util.safe_setattr(self, key, value)
    
    self.register_views()  
예제 #2
0
  def get_parsed_value(self, field_name, field_value, field_class, field=None):

    value = None
    
    if field_class in Edit.FieldMapping:
      value = Edit.FieldMapping[field_class](field_value)
    
    elif field_class == "ReferenceProperty":
      path = (field.reference_class.__module__ + "." + field.reference_class.__name__).split('.')
      
      ref_app_name = ".".join(path[:-2])
      ref_model_name = path[-1]
      ref_app = self.rozengine.apps[ref_app_name]
      ref_model = ref_app.models[ref_model_name]
      value = ref_model.definition.get_by_id(int(field_value))
    
    elif field_class == "ListProperty" or field_class == "StringListProperty":
      value = self.request.POST.getall(field_name)
      
      if field.item_type.__name__ == 'Key':
        result = []
        value_ref = self.request.POST.getall(field_name+"_ref")
        for i in range(len(value)):
          v = value[i]
          r = value_ref[i] #TODO: dont relay on order
          path = v.split('_')
          #TODO: move this RozModel and model using namespace
          app = util.import_from_string(".".join(path[:-1]) + '.models', True)
          model = getattr(app, path[-1])
          ref = model.get_by_id(int(r))
          result.append(ref.key())
          
        value = result
      else:
        class_name = Edit.ListFieldMapping[field.item_type.__name__]
        value = [self.get_parsed_value(field_name, v, class_name) for v in value if len(v) > 0]
      
    elif field_class == "UserProperty":
      # Right now it supports only current user to choose
      if field_value == users.get_current_user().user_id():
        value = users.get_current_user() 
      else:
        value = None
        
    else:
      value = field_value  

        
    return value
예제 #3
0
    def __init__(self,
                 model_name_or_path: str = None,
                 modules: Iterable[nn.Module] = None,
                 device: str = None):
        if modules is not None and not isinstance(modules, OrderedDict):
            modules = OrderedDict([(str(idx), module)
                                   for idx, module in enumerate(modules)])

        if model_name_or_path is not None and model_name_or_path != "":
            logging.info("Load pretrained SentenceTransformer: {}".format(
                model_name_or_path))

            if '/' not in model_name_or_path and '\\' not in model_name_or_path and not os.path.isdir(
                    model_name_or_path):
                logging.info(
                    "Did not find a / or \\ in the name. Assume to download model from server"
                )
                model_name_or_path = __DOWNLOAD_SERVER__ + model_name_or_path + '.zip'

            if model_name_or_path.startswith(
                    'http://') or model_name_or_path.startswith('https://'):
                model_url = model_name_or_path
                folder_name = model_url.replace("https://", "").replace(
                    "http://", "").replace("/", "_")[:250]

                try:
                    from torch.hub import _get_torch_home
                    torch_cache_home = _get_torch_home()
                except ImportError:
                    torch_cache_home = os.path.expanduser(
                        os.getenv(
                            'TORCH_HOME',
                            os.path.join(
                                os.getenv('XDG_CACHE_HOME', '~/.cache'),
                                'torch')))
                default_cache_path = os.path.join(torch_cache_home,
                                                  'sentence_transformers')
                model_path = os.path.join(default_cache_path, folder_name)
                os.makedirs(model_path, exist_ok=True)

                if not os.listdir(model_path):
                    if model_url[-1] is "/":
                        model_url = model_url[:-1]
                    logging.info(
                        "Downloading sentence transformer model from {} and saving it at {}"
                        .format(model_url, model_path))
                    try:
                        zip_save_path = os.path.join(model_path, 'model.zip')
                        http_get(model_url, zip_save_path)
                        with ZipFile(zip_save_path, 'r') as zip:
                            zip.extractall(model_path)
                    except Exception as e:
                        shutil.rmtree(model_path)
                        raise e
            else:
                model_path = model_name_or_path

            #### Load from disk
            if model_path is not None:
                logging.info("Load SentenceTransformer from folder: {}".format(
                    model_path))
                with open(os.path.join(model_path, 'modules.json')) as fIn:
                    contained_modules = json.load(fIn)

                modules = OrderedDict()
                for module_config in contained_modules:
                    module_class = import_from_string(module_config['type'])
                    module = module_class.load(
                        os.path.join(model_path, module_config['path']))
                    modules[module_config['name']] = module

        super().__init__(modules)
        if device is None:

            device = 'cuda:{}'.format(
                args.gpu) if torch.cuda.is_available() else "cpu"
            logging.info("Use pytorch device: {}".format(device))
        self.device = torch.device(device)
        self.to(device)