示例#1
0
def recursiveType(param, type, tensorCache={}):
    from .Criterion import Criterion
    from .Module import Module
    if isinstance(param, list):
        for i, p in enumerate(param):
            param[i] = recursiveType(p, type, tensorCache)
    elif isinstance(param, Module) or isinstance(param, Criterion):
        param.type(type, tensorCache)
    elif isinstance(param, torch.Tensor):
        if param.type() != type:
            key = param._cdata
            if key in tensorCache:
                newparam = tensorCache[key]
            else:
                newparam = torch.Tensor().type(type)
                storageType = type.replace('Tensor', 'Storage')
                param_storage = param.storage()
                if param_storage:
                    storage_key = param_storage._cdata
                    if storage_key not in tensorCache:
                        tensorCache[storage_key] = torch._import_dotted_name(
                            storageType)(param_storage.size()).copy_(param_storage)
                    newparam.set_(
                        tensorCache[storage_key],
                        param.storage_offset(),
                        param.size(),
                        param.stride()
                    )
                tensorCache[key] = newparam
            param = newparam
    return param
示例#2
0
def recursiveType(param, type, tensorCache={}):
    from .Criterion import Criterion
    from .Module import Module
    if isinstance(param, list):
        for i, p in enumerate(param):
            param[i] = recursiveType(p, type, tensorCache)
    elif isinstance(param, Module) or isinstance(param, Criterion):
        param.type(type, tensorCache)
    elif torch.is_tensor(param):
        if param.type() != type:
            key = param._cdata
            if key in tensorCache:
                newparam = tensorCache[key]
            else:
                newparam = torch.Tensor().type(type)
                storageType = type.replace('Tensor', 'Storage')
                param_storage = param.storage()
                if param_storage:
                    storage_key = param_storage._cdata
                    if storage_key not in tensorCache:
                        tensorCache[storage_key] = torch._import_dotted_name(
                            storageType)(
                                param_storage.size()).copy_(param_storage)
                    newparam.set_(tensorCache[storage_key],
                                  param.storage_offset(), param.size(),
                                  param.stride())
                tensorCache[key] = newparam
            param = newparam
    return param
示例#3
0
 def _get_type(self, name):
     module = torch._import_dotted_name(self.data.__module__)
     return getattr(module, name)
示例#4
0
 def _get_type(self, name):
     module = torch._import_dotted_name(self.data.__module__)
     return getattr(module, name)