def get_yolo2_model(model_type, num_anchors, num_classes, input_tensor=None, input_shape=None, model_pruning=False, pruning_end_step=10000): #prepare input tensor if input_shape: input_tensor = Input(shape=input_shape, name='image_input') if input_tensor is None: input_tensor = Input(shape=(None, None, 3), name='image_input') #YOLOv2 model has 5 anchors if model_type in yolo2_model_map: model_function = yolo2_model_map[model_type][0] backbone_len = yolo2_model_map[model_type][1] weights_path = yolo2_model_map[model_type][2] if weights_path: model_body = model_function(input_tensor, num_anchors, num_classes, weights_path=weights_path) else: model_body = model_function(input_tensor, num_anchors, num_classes) else: raise ValueError('model type mismatch anchors') if model_pruning: model_body = get_pruning_model(model_body, begin_step=0, end_step=pruning_end_step) return model_body, backbone_len
def get_yolo3_model(model_type, num_feature_layers, num_anchors, num_classes, input_tensor=None, input_shape=None, model_pruning=False, pruning_end_step=10000): #prepare input tensor if input_shape: input_tensor = Input(shape=input_shape, name='image_input') if input_tensor is None: input_tensor = Input(shape=(None, None, 3), name='image_input') #Tiny YOLOv3 model has 6 anchors and 2 feature layers if num_feature_layers == 2: if model_type in yolo3_tiny_model_map: model_function = yolo3_tiny_model_map[model_type][0] backbone_len = yolo3_tiny_model_map[model_type][1] weights_path = yolo3_tiny_model_map[model_type][2] if weights_path: model_body = model_function(input_tensor, num_anchors // 2, num_classes, weights_path=weights_path) else: model_body = model_function(input_tensor, num_anchors // 2, num_classes) else: raise ValueError('This model type is not supported now') #YOLOv3 model has 9 anchors and 3 feature layers elif num_feature_layers == 3: if model_type in yolo3_model_map: model_function = yolo3_model_map[model_type][0] backbone_len = yolo3_model_map[model_type][1] weights_path = yolo3_model_map[model_type][2] if weights_path: model_body = model_function(input_tensor, num_anchors // 3, num_classes, weights_path=weights_path) else: model_body = model_function(input_tensor, num_anchors // 3, num_classes) else: raise ValueError('This model type is not supported now') else: raise ValueError('model type mismatch anchors') if model_pruning: model_body = get_pruning_model(model_body, begin_step=0, end_step=pruning_end_step) return model_body, backbone_len