def forward(self, x, targets=None): num_samples = x.size(0) grid_size = x.size(2) # # If grid size does not match current we compute new offsets # if grid_size != self.grid_size: # self.compute_grid_offsets(grid_size, cuda=x.is_cuda) prediction = x.clone().view(num_samples, self.num_anchors, self.num_classes + 5, grid_size, grid_size).permute(0, 1, 3, 4, 2).contiguous() # Get outputs xy = sigmoid(prediction[..., 0:2]) # Center x wh = prediction[..., 2:4] # Width pred_conf = sigmoid(prediction[..., 4]) # Conf pred_cls = sigmoid(prediction[..., 5:]) # Cls pred. cls_probs = reduce_max(pred_cls, -1, keepdims=True) # Add offset and scale with anchors pred_boxes = zeros_like(prediction[..., :4]) pred_boxes[..., 0:2] = xy + self.grid.to(get_device()) pred_boxes[..., 2:4] = exp(wh) * self.anchor_wh.to(get_device()) output = torch.cat( ( pred_boxes.view(num_samples, -1, 4) * self.stride, pred_conf.view(num_samples, -1, 1), pred_cls.view(num_samples, -1, self.num_classes), ), -1, ) return output
def clear_state(self): self.hidden_state = zeros_like(self.hidden_state, dtype=dtype.float32, requires_grad=False).to(get_device()) self.cell_state = zeros_like(self.cell_state, dtype=dtype.float32, requires_grad=False).to(get_device())
def compute_grid_offsets(self, grid_size): self.stride = self.img_dim / grid_size self.anchor_vec = self.anchors / self.stride self.anchor_wh = self.anchor_vec.view(1, self.num_anchors, 1, 1, 2) yv, xv = torch.meshgrid([ torch.arange(grid_size, device=get_device()), torch.arange(grid_size, device=get_device()) ]) self.grid = torch.stack((xv, yv), 2).view( (1, 1, grid_size, grid_size, 2)).float() self.grid1 = meshgrid(grid_size, grid_size, requires_grad=False).view( [1, 1, grid_size, grid_size, 2])
def attention(self, lstm_output): batch_size, sequence_length, channels = int_shape(lstm_output) if not hasattr(self, 'w_omega') or self.w_omega is None: self.w_omega = Parameter( torch.zeros(channels, self.attention_size).to(get_device())) self.u_omega = Parameter( torch.zeros(self.attention_size).to(get_device())) output_reshape = reshape(lstm_output, (-1, channels)) attn_tanh = torch.tanh(torch.mm(output_reshape, self.w_omega)) attn_hidden_layer = torch.mm(attn_tanh, reshape(self.u_omega, [-1, 1])) exps = reshape(torch.exp(attn_hidden_layer), [-1, sequence_length]) alphas = exps / reshape(torch.sum(exps, 1), [-1, 1]) alphas_reshape = reshape(alphas, [-1, sequence_length, 1]) return lstm_output * alphas_reshape
def load(cls): # 從google drive載入模型 st = datetime.datetime.now() download_model_from_google_drive('13XZPWh8QhEsC8EdIp1niLtZz0ipatSGC', dirname, 'word2vec_chinese.pth') recovery_model = fix_layer( load(os.path.join(dirname, 'word2vec_chinese.pth'))) recovery_model.locale = locale.getdefaultlocale()[0].lower() recovery_model.to(get_device()) download_file_from_google_drive( file_id='16yDlJJ4-O9pHF-ZbXy7XPZZk6vo3aw4e', dirname=os.path.join(_trident_dir, 'download'), filename='vocabs_tw.txt') if not hasattr(recovery_model, 'tw2cn') or recovery_model.tw2cn is None: with open(download_path, 'r', encoding='utf-8-sig') as f: vocabs_tw = f.readlines() vocabs_tw = [ s.replace('\n', '') for s in vocabs_tw if s != '\n' ] recovery_model.tw2cn = OrderedDict() recovery_model.cn2tw = OrderedDict() for i, (w, w_cn) in tqdm( enumerate(zip(vocabs_tw, recovery_model._vocabs.keys()))): if w not in recovery_model.tw2cn: recovery_model.tw2cn[w] = w_cn recovery_model.cn2tw[w_cn] = w et = datetime.datetime.now() print('total loading time:{0}'.format(et - st)) return recovery_model
def build(self, input_shape): if self._built == False: if self.affine: self.weight = Parameter(torch.Tensor(self.input_filters)) self.bias = Parameter(torch.Tensor(self.input_filters)) init.ones_(self.weight) init.zeros_(self.bias) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(self.input_filters)) self.register_buffer('running_var', torch.ones(self.input_filters)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.register_parameter('num_batches_tracked', None) self.reset_running_stats() self.to(get_device()) self._built = True
def load(cls): # 從google drive載入模型 st = datetime.datetime.now() set_device('cpu') dirname = os.path.join(get_trident_dir(), 'models') download_model_from_google_drive('13XZPWh8QhEsC8EdIp1niLtZz0ipatSGC', dirname, 'word2vec_chinese.pth') recovery_model = load(os.path.join(dirname, 'word2vec_chinese.pth')) recovery_weight = recovery_model.state_dict()['weight'] shp = int_shape(recovery_weight) v = cls(pretrained=True, num_embeddings=shp[0], embedding_dim=shp[-1], _weight=recovery_weight, name='word2vec_chinese') v._vocabs = copy.deepcopy(recovery_model._vocabs) v.tw2cn = copy.deepcopy(recovery_model.tw2cn) v.cn2tw = copy.deepcopy(recovery_model.cn2tw) del recovery_model v.locale = ctx.locale v.to(get_device()) et = datetime.datetime.now() print('total loading time:{0}'.format(et - st)) return v
def __init__(self, anchors=None, num_classes=80, grid_size=76, img_dim=608, small_item_enhance=False): super(YoloLayer, self).__init__() if anchors is None: anchors = generate_anchors(grid_size) self.register_buffer( 'anchors', to_tensor(anchors, requires_grad=False).to(get_device())) self.small_item_enhance = small_item_enhance self.num_anchors = len(anchors) self.num_classes = num_classes self.ignore_thres = 0.5 # self.mse_loss = nn.MSELoss() # self.bce_loss = nn.BCELoss() self.obj_scale = 1 self.noobj_scale = 100 self.metrics = {} self.img_dim = img_dim self.grid_size = grid_size #self.grid_size = to_tensor(grid_size) # grid size #yv, xv = torch.meshgrid([torch.arange(grid_size), torch.arange(grid_size)]) #self.register_buffer('grid', torch.stack((xv.detach(), yv.detach()), 2).view((1, 1, grid_size, grid_size, 2)).float().detach()) self.stride = self.img_dim / grid_size self.compute_grid_offsets(grid_size)
def build(self, input_shape: TensorShape): if not self._built: if self.affine: self.register_parameter('weight', Parameter(ones(self.input_filters))) self.register_parameter('bias', Parameter(zeros(self.input_filters))) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', zeros(self.input_filters)) self.register_buffer('running_var', ones(self.input_filters)) self.register_buffer('num_batches_tracked', to_tensor(0, dtype=torch.long), persistent=False) else: self.register_buffer('running_mean', None) self.register_buffer('running_var', None) self.register_buffer('num_batches_tracked', None, persistent=False) self.to(get_device()) self._built = True
def __init__(self, d_model, max_len=512): super().__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model).float().to(get_device()) pe.require_grad = False position = torch.arange(0, max_len).float().unsqueeze(1).to(get_device()) div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp().to(get_device()) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) self.register_buffer('pe', pe)
def build(self, input_shape: TensorShape): if not self._built: for layer in range(self.num_layers): for direction in range(self.num_directions): layer_input_size = input_shape[ -1] if layer == 0 else self.hidden_size * self.num_directions w_ih = Parameter( torch.Tensor(self.gate_size, layer_input_size).to(get_device())) w_hh = Parameter( torch.Tensor(self.gate_size, self.hidden_size).to(get_device())) b_ih = Parameter( torch.Tensor(self.gate_size).to(get_device())) # Second bias vector included for CuDNN compatibility. Only one # bias vector is needed in standard definition. b_hh = Parameter( torch.Tensor(self.gate_size).to(get_device())) layer_params = (w_ih, w_hh, b_ih, b_hh) suffix = '_reverse' if direction == 1 else '' param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] if self.use_bias: param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] param_names = [ x.format(layer, suffix) for x in param_names ] for name, param in zip(param_names, layer_params): if hasattr(self, "_flat_weights_names" ) and name in self._flat_weights_names: # keep self._flat_weights up to date if you do self.weight = ... idx = self._flat_weights_names.index(name) self._flat_weights[idx] = param self.register_parameter(name, param) self._flat_weights_names.extend(param_names) self._all_weights.append(param_names) self._flat_weights = [(lambda wn: getattr(self, wn) if hasattr(self, wn) else None)(wn) for wn in self._flat_weights_names] self.flatten_parameters() self.reset_parameters()
def initial_state(self, input): max_batch_size = input.size(0) if self.batch_first else input.size(1) num_directions = 2 if self.bidirectional else 1 zeros = torch.zeros(self.num_layers * num_directions, max_batch_size, self.hidden_size, dtype=dtype.float32, requires_grad=False).to(get_device()) self.hidden_state = zeros self.cell_state = zeros
def build(self, input_shape: TensorShape): #B, 196,768 if not self._built: B, N, C = input_shape.dims if self.use_cls_token: self.cls_token = Parameter(torch.zeros(1, 1, C)).to(get_device()) trunc_normal(self.cls_token, std=.02) self.pos_embed = Parameter(torch.zeros(1, N + 1, C)).to(get_device()) else: self.pos_embed = Parameter(torch.zeros(1, N, C)).to(get_device()) if self.mode == 'trainable': trunc_normal(self.pos_embed, std=.02) # elif self.mode=='meshgrid': # self.pos_embed = Parameter(expand_dims(meshgrid(N + 1,C,normalized_coordinates=True,requires_grad=True).mean(-1),0).to(get_device())) self._built = True
def build(self, *input_shape: TensorShape): if not self._built: if self.num_filters is None: self.num_filters = minimum(self.input_filters // 2, 64) if self.hidden_filters is None: self.hidden_filters = self.num_filters // 2 self.register_parameter( 'weight', Parameter( random_normal((self.input_filters, self.num_filters, self.hidden_filters)).to(get_device())))
def _make_params(self): w = getattr(self.module, 'weight') height = w.data.shape[0] width = w.view(height, -1).data.shape[1] u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False).to(get_device()) v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False).to(get_device()) u.data = l2_normalize(u.data) v.data = l2_normalize(v.data) w_bar = Parameter(w.data).to(get_device()) del self.module._parameters['weight'] self.module.register_parameter('weight' + "_u", u) self.module.register_parameter('weight' + "_v", v) self.module.register_parameter('weight' + "_bar", w_bar)
def forward(self, x, targets=None): num_batch = x.size(0) grid_size = x.size(-1) if self.training and (self.grid_size != x.size(-1) or self.grid_size != x.size(-2)): self.compute_grid_offsets(grid_size) self.grid.to(get_device()) self.anchors.to(get_device()) anchor_vec = self.anchors / self.stride anchor_wh = anchor_vec.view(1, self.num_anchors, 1, 1, 2) prediction = x.view(num_batch, self.num_anchors, self.num_classes + 5, grid_size, grid_size).permute(0, 1, 3, 4, 2).contiguous() # if self.training: # return reshape(prediction,(num_batch, -1, self.num_classes + 5)) # Get outputs xy = sigmoid(prediction[..., 0:2]) # Center x wh = prediction[..., 2:4] # Width xy = reshape((xy + self.grid) * self.stride, (num_batch, -1, 2)) wh = reshape((exp(wh) * anchor_wh) * self.stride, (num_batch, -1, 2)) pred_conf = sigmoid(prediction[..., 4]) # Conf pred_class = sigmoid(prediction[..., 5:]) # Cls pred. pred_conf = reshape(pred_conf, (num_batch, -1, 1)) pred_class = reshape(pred_class, (num_batch, -1, self.num_classes)) cls_probs = reduce_max(pred_class, -1, keepdims=True) if self.small_item_enhance and self.stride == 8: pred_conf = (pred_conf * cls_probs).sqrt() output = torch.cat([xy, wh, pred_conf, pred_class], -1) return output
def YoLoV4(pretrained=True, freeze_features=False, input_shape=(3, 608, 608), classes=80, **kwargs): detector = YoloDetectionModel(input_shape=input_shape, output=yolo4_body(classes, input_shape[-1])) if pretrained: download_model_from_google_drive('1CcbyinE8gQFjMjt05arSg2W0LLUwjsdt', dirname, 'pretrained_yolov4_mscoco.pth') recovery_model = fix_layer(load(os.path.join(dirname, 'pretrained_yolov4_mscoco.pth'))) detector.model = recovery_model detector.model .input_shape = input_shape detector.model .to(get_device()) return detector
def __init__(self,convert_ratio=0.5,name='random_homomorphic_typo',**kwargs): super().__init__() self.convert_ratio=convert_ratio download_file_from_google_drive('1MDk7eH7nORa16SyzNzqv7fYzBofzxGRI',dirname=os.path.join(get_trident_dir(),'download'),filename='chardict.pkl') self.chardict=unpickle(os.path.join(get_trident_dir(),'download','chardict.pkl')) self.all_embedding =to_tensor(np.stack(self.chardict.value_list, 0)).to(get_device()) self.name=name if not get_session().get_resources('char_freq'): char_freq = get_session().regist_resources('char_freq', OrderedDict()) with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'char_freq.txt'), 'r', encoding='utf-8-sig') as f: for line in f.readlines(): cols = line.strip().split('\t') char_freq[cols[0]] = float(cols[1]) self.char_freq = char_freq else: self.char_freq = get_session().get_resources('char_freq')
def get_similar(self,char): if char in self.chardict and char not in string.digits and char not in string.punctuation and char not in string.ascii_letters and char not in bpmf_phonetic: embedding = to_tensor(expand_dims(self.chardict[char], 0)).to(get_device()) results = element_cosine_distance(embedding, self.all_embedding, -1)[0] top10 = argsort(results, axis=0)[:5] results=to_numpy(results) similar_chars=[self.chardict.key_list[idx.item()] for idx in top10 if self.chardict.key_list[idx.item()] != char and results[idx.item()]>0.8] max_freq=-15.5 return_char=char for similar_char in similar_chars: if similar_char in self.char_freq and self.char_freq[similar_char]>max_freq: max_freq=self.char_freq[similar_char] return_char=similar_char return return_char else: return char
def __init__(self, anchors, num_classes,grid_size, img_dim=608): super(YoloLayer, self).__init__() self.register_buffer('grid', None) self.register_buffer('anchors', to_tensor(anchors, requires_grad=False).to(get_device())) self.num_anchors = len(anchors) self.num_classes = num_classes self.ignore_thres = 0.5 self.mse_loss = nn.MSELoss() self.bce_loss = nn.BCELoss() self.obj_scale = 1 self.noobj_scale = 100 self.metrics = {} self.img_dim = img_dim self.grid_size = grid_size # grid size self.compute_grid_offsets(grid_size)
def generate_priors(feature_map_list, shrinkage_list, image_size, min_boxes, clamp=True) -> torch.Tensor: priors = [] for index in range(0, len(feature_map_list[0])): scale_w = image_size[0] / shrinkage_list[0][index] scale_h = image_size[1] / shrinkage_list[1][index] for j in range(0, feature_map_list[1][index]): for i in range(0, feature_map_list[0][index]): x_center = (i + 0.5) / scale_w y_center = (j + 0.5) / scale_h for min_box in min_boxes[index]: w = min_box / image_size[0] h = min_box / image_size[1] priors.append([x_center, y_center, w, h]) print("priors nums:{}".format(len(priors))) priors = to_tensor(priors).to(get_device()) # .view(-1, 4) if clamp: torch.clamp(priors, 0.0, 1.0, out=priors) return priors
def forward(self, x): if int_shape(x)[1] == 2: x, segments_tensor = split(x, num_splits=2, axis=1) x = x.squeeze(1) segments_tensor = segments_tensor.squeeze(1) else: segments_tensor = zeros_like(x, dtype=x.dtype).to(get_device()) # attention masking for padded token # torch.ByteTensor([batch_size, 1, seq_len, seq_len) mask = (x != self.pad_idx).unsqueeze(1).repeat( 1, x.size(1), 1).unsqueeze(1).detach() # embedding the indexed sequence to sequence of vectors x = self.embedding(x, segments_tensor) # running over multiple transformer blocks for name, transformer in self.named_children(): if 'transformer_block' in name: x = transformer.forward(x, mask) return x
def forward(self, x, segments_tensor=None): if segments_tensor is None: segments_tensor_list = [] B, N = int_shape(x) sep_tuples = (x == self.sep_idx).nonzero()(as_tuple=True) for i in range(B): sep_tuple = sep_tuples[i] if len(sep_tuple) <= 1: segments_tensor_list.append(zeros_like(x[i])) elif sep_tuple == 2: t = zeros_like([i]).detach() sep_tuple[:sep_tuple[0] + 1] = 1 sep_tuple[sep_tuple[0] + 1:sep_tuple[1] + 1] = 2 segments_tensor_list.append(t) segments_tensor = stack(segments_tensor_list, axis=0).to(get_device()) x = self.token(x) + self.position(x) + self.segment(segments_tensor) x = self.norm(x) if self.dropout_rate > 0 and self.training: x = self.dropout(x) return x
def EfficientNetB7(include_top=True, pretrained=True, freeze_features=False, input_shape=(3, 600, 600), classes=1000, **kwargs): if input_shape is not None and len(input_shape) == 3: input_shape = tuple(input_shape) else: input_shape = (3, 600, 600) effb7 = EfficientNet(2.0, 3.1, input_shape, 0.5, model_name='efficientnet-b7', include_top=include_top, num_classes=classes) if pretrained: download_model_from_google_drive('1M2DfvsNPRCWSo_CeXnUCQOR46rvOrhLl', dirname, 'efficientnet-b7.pth') recovery_model = fix_layer( load(os.path.join(dirname, 'efficientnet-b7.pth'))) recovery_model = _make_recovery_model_include_top( recovery_model, input_shape=input_shape, include_top=include_top, classes=classes, freeze_features=freeze_features) effb7.model = recovery_model else: effb7.model = _make_recovery_model_include_top(effb7.model, include_top=include_top, classes=classes, freeze_features=False) effb7.model.input_shape = input_shape effb7.model.to(get_device()) return effb7
def EfficientNetB6(include_top=True, pretrained=True, freeze_features=False, input_shape=(3, 528, 528), classes=1000, **kwargs): if input_shape is not None and len(input_shape) == 3: input_shape = tuple(input_shape) else: input_shape = (3, 528, 528) effb6 = EfficientNet(1.8, 2.6, input_shape, 0.5, model_name='efficientnet-b6', include_top=include_top, num_classes=classes) if pretrained: download_model_from_google_drive('1XJrKmcmMObN_nnjP2Z-YH_BQ3img58qF', dirname, 'efficientnet-b6.pth') recovery_model = fix_layer( load(os.path.join(dirname, 'efficientnet-b6.pth'))) recovery_model = _make_recovery_model_include_top( recovery_model, input_shape=input_shape, include_top=include_top, classes=classes, freeze_features=freeze_features) effb6.model = recovery_model else: effb6.model = _make_recovery_model_include_top(effb6.model, include_top=include_top, classes=classes, freeze_features=False) effb6.model.input_shape = input_shape effb6.model.to(get_device()) return effb6
def EfficientNetB5(include_top=True, pretrained=True, freeze_features=False, input_shape=(3, 456, 456), classes=1000, **kwargs): if input_shape is not None and len(input_shape) == 3: input_shape = tuple(input_shape) else: input_shape = (3, 456, 456) effb5 = EfficientNet(1.6, 2.2, input_shape, 0.4, model_name='efficientnet-b5', include_top=include_top, num_classes=classes) if pretrained: download_model_from_google_drive('17iTD12G9oW3jYAui84MKtdY4gjd9vpgG', dirname, 'efficientnet-b5.pth') recovery_model = fix_layer( load(os.path.join(dirname, 'efficientnet-b5.pth'))) recovery_model = _make_recovery_model_include_top( recovery_model, input_shape=input_shape, include_top=include_top, classes=classes, freeze_features=freeze_features) effb5.model = recovery_model else: effb5.model = _make_recovery_model_include_top(effb5.model, include_top=include_top, classes=classes, freeze_features=False) effb5.model.input_shape = input_shape effb5.model.to(get_device()) return effb5
def EfficientNetB4(include_top=True, pretrained=True, freeze_features=False, input_shape=(3, 380, 380), classes=1000, **kwargs): if input_shape is not None and len(input_shape) == 3: input_shape = tuple(input_shape) else: input_shape = (3, 380, 380) effb4 = EfficientNet(1.4, 1.8, input_shape, 0.4, model_name='efficientnet-b4', include_top=include_top, num_classes=classes) if pretrained: download_model_from_google_drive('1X4ZOBR_ETRHZJeffJHvCmWTTy9_aW8SP', dirname, 'efficientnet-b4.pth') recovery_model = fix_layer( load(sanitize_path(os.path.join(dirname, 'efficientnet-b4.pth')))) recovery_model = _make_recovery_model_include_top( recovery_model, input_shape=input_shape, include_top=include_top, classes=classes, freeze_features=freeze_features) effb4.model = recovery_model else: effb4.model = _make_recovery_model_include_top(effb4.model, include_top=include_top, classes=classes, freeze_features=False) effb4.model.input_shape = input_shape effb4.model.to(get_device()) return effb4
def EfficientNetB3(include_top=True, pretrained=True, freeze_features=False, input_shape=(3, 300, 300), classes=1000, **kwargs): if input_shape is not None and len(input_shape) == 3: input_shape = tuple(input_shape) else: input_shape = (3, 300, 300) effb3 = EfficientNet(1.2, 1.4, input_shape, 0.3, model_name='efficientnet-b3', include_top=include_top, num_classes=classes) if pretrained: download_model_from_google_drive('11tMxdYdFfaEREwnESO4cwjtcoEB42zB_', dirname, 'efficientnet-b3.pth') recovery_model = fix_layer( load(os.path.join(dirname, 'efficientnet-b3.pth'))) recovery_model = _make_recovery_model_include_top( recovery_model, input_shape=input_shape, include_top=include_top, classes=classes, freeze_features=freeze_features) effb3.model = recovery_model else: effb3.model = _make_recovery_model_include_top(effb3.model, include_top=include_top, classes=classes, freeze_features=False) effb3.model.input_shape = input_shape effb3.model.to(get_device()) return effb3
def EfficientNetB2(include_top=True, pretrained=True, freeze_features=False, input_shape=(3, 260, 260), classes=1000, **kwargs): if input_shape is not None and len(input_shape) == 3: input_shape = tuple(input_shape) else: input_shape = (3, 260, 260) effb2 = EfficientNet(1.1, 1.2, input_shape, 0.3, model_name='efficientnet-b2', include_top=include_top, num_classes=classes) if pretrained: download_model_from_google_drive('1PjqhB7WJasF_hqOwYtSBNSXSGBY-cRLU', dirname, 'efficientnet-b2.pth') recovery_model = fix_layer( load(os.path.join(dirname, 'efficientnet-b2.pth'))) recovery_model = _make_recovery_model_include_top( recovery_model, input_shape=input_shape, include_top=include_top, classes=classes, freeze_features=freeze_features) effb2.model = recovery_model else: effb2.model = _make_recovery_model_include_top(effb2.model, include_top=include_top, classes=classes, freeze_features=False) effb2.model.input_shape = input_shape effb2.model.to(get_device()) return effb2
def EfficientNetB1(include_top=True, pretrained=True, freeze_features=False, input_shape=(3, 240, 240), classes=1000, **kwargs): if input_shape is not None and len(input_shape) == 3: input_shape = tuple(input_shape) else: input_shape = (3, 240, 240) effb1 = EfficientNet(1.0, 1.1, input_shape, 0.2, model_name='efficientnet-b1', include_top=include_top, num_classes=classes) if pretrained: download_model_from_google_drive('1F3BtnAjmDz4G9RS9Q0hqU_K7WWXCni1G', dirname, 'efficientnet-b1.pth') recovery_model = fix_layer( load(os.path.join(dirname, 'efficientnet-b1.pth'))) recovery_model = _make_recovery_model_include_top( recovery_model, input_shape=input_shape, include_top=include_top, classes=classes, freeze_features=freeze_features) effb1.model = recovery_model else: effb1.model = _make_recovery_model_include_top(effb1.model, include_top=include_top, classes=classes, freeze_features=False) effb1.model.input_shape = input_shape effb1.model.to(get_device()) return effb1