def efficientnet_edgetpu(width_coefficient=None,
                         depth_coefficient=None,
                         dropout_rate=0.2,
                         survival_prob=0.8):
  """Creates an efficientnet-edgetpu model."""
  blocks_args = [
      'r1_k3_s11_e4_i24_o24_c1_noskip',
      'r2_k3_s22_e8_i24_o32_c1',
      'r4_k3_s22_e8_i32_o48_c1',
      'r5_k5_s22_e8_i48_o96',
      'r4_k5_s11_e8_i96_o144',
      'r2_k5_s22_e8_i144_o192',
  ]
  global_params = efficientnet_model.GlobalParams(
      batch_norm_momentum=0.99,
      batch_norm_epsilon=1e-3,
      dropout_rate=dropout_rate,
      survival_prob=survival_prob,
      data_format='channels_last',
      num_classes=1001,
      width_coefficient=width_coefficient,
      depth_coefficient=depth_coefficient,
      depth_divisor=8,
      min_depth=None,
      relu_fn=tf.nn.relu,
      # The default is TPU-specific batch norm.
      # The alternative is tf.layers.BatchNormalization.
      batch_norm=utils.TpuBatchNormalization,  # TPU-specific requirement.
      use_se=False)
  decoder = efficientnet_builder.BlockDecoder()
  return decoder.decode(blocks_args), global_params
def efficientnet_edgetpu(width_coefficient=None,
                         depth_coefficient=None,
                         dropout_rate=0.2,
                         drop_connect_rate=0.2):
    """Creates an efficientnet-edgetpu model."""
    blocks_args = [
        'r1_k3_s11_e4_i24_o24_c1_noskip',
        'r2_k3_s22_e8_i24_o32_c1',
        'r4_k3_s22_e8_i32_o48_c1',
        'r5_k5_s22_e8_i48_o96',
        'r4_k5_s11_e8_i96_o144',
        'r2_k5_s22_e8_i144_o192',
    ]
    global_params = efficientnet_model.GlobalParams(
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-3,
        dropout_rate=dropout_rate,
        drop_connect_rate=drop_connect_rate,
        data_format='channels_last',
        num_classes=1001,
        width_coefficient=width_coefficient,
        depth_coefficient=depth_coefficient,
        depth_divisor=8,
        min_depth=None,
        relu_fn=tf.nn.relu)
    decoder = efficientnet_builder.BlockDecoder()
    return decoder.decode(blocks_args), global_params
def efficientnet_condconv(width_coefficient=None,
                          depth_coefficient=None,
                          dropout_rate=0.2,
                          drop_connect_rate=0.2,
                          condconv_num_experts=None):
    """Creates an efficientnet-condconv model."""
    blocks_args = [
        'r1_k3_s11_e1_i32_o16_se0.25',
        'r2_k3_s22_e6_i16_o24_se0.25',
        'r2_k5_s22_e6_i24_o40_se0.25',
        'r3_k3_s22_e6_i40_o80_se0.25',
        'r3_k5_s11_e6_i80_o112_se0.25_cc',
        'r4_k5_s22_e6_i112_o192_se0.25_cc',
        'r1_k3_s11_e6_i192_o320_se0.25_cc',
    ]
    global_params = efficientnet_model.GlobalParams(
        batch_norm_momentum=0.99,
        batch_norm_epsilon=1e-3,
        dropout_rate=dropout_rate,
        drop_connect_rate=drop_connect_rate,
        data_format='channels_last',
        num_classes=1000,
        width_coefficient=width_coefficient,
        depth_coefficient=depth_coefficient,
        depth_divisor=8,
        min_depth=None,
        relu_fn=tf.nn.swish,
        # The default is TPU-specific batch norm.
        # The alternative is tf.layers.BatchNormalization.
        batch_norm=utils.TpuBatchNormalization,  # TPU-specific requirement.
        use_se=True,
        condconv_num_experts=condconv_num_experts)
    decoder = efficientnet_builder.BlockDecoder()
    return decoder.decode(blocks_args), global_params
def get_model_params(model_name, override_params):
  """Get the block args and global params for a given model."""
  if model_name.startswith('efficientnet-lite'):
    width_coefficient, depth_coefficient, _, dropout_rate = (
        efficientnet_lite_params(model_name))
    global_params = efficientnet_lite(
        width_coefficient, depth_coefficient, dropout_rate)
  else:
    raise NotImplementedError('model name is not pre-defined: %s' % model_name)

  if override_params:
    # ValueError will be raised here if override_params has fields not included
    # in global_params.
    global_params = global_params._replace(**override_params)

  decoder = efficientnet_builder.BlockDecoder()
  blocks_args = decoder.decode(global_params.blocks_args)

  logging.info('global_params= %s', global_params)
  return blocks_args, global_params