Exemplo n.º 1
0
         num_layers_in_each_block=None,
         data_format='channels_last',
         bottleneck=True,
         compression=0.5,
         weight_decay=1e-4,
         dropout_rate=0.,
         pool_initial=False,
         include_top=True,
         train_mode='custom_loop',
         data_dir=None):

    model = densenet.DenseNet(mode, growth_rate, output_classes,
                              depth_of_model, num_of_blocks,
                              num_layers_in_each_block, data_format,
                              bottleneck, compression, weight_decay,
                              dropout_rate, pool_initial, include_top)
    train_obj = Train(epochs, enable_function, model)
    train_dataset, test_dataset, _ = utils.create_dataset(
        buffer_size, batch_size, data_format, data_dir)

    print('Training...')
    if train_mode == 'custom_loop':
        return train_obj.custom_loop(train_dataset, test_dataset)
    elif train_mode == 'keras_fit':
        return train_obj.keras_fit(train_dataset, test_dataset)


if __name__ == '__main__':
    utils.define_densenet_flags()
    app.run(run_main)
Exemplo n.º 2
0
    train_dataset, test_dataset, metadata = utils.create_dataset(
        buffer_size, batch_size, data_format, data_dir)

    num_train_steps_per_epoch = metadata.splits[
        'train'].num_examples // batch_size
    num_test_steps_per_epoch = metadata.splits[
        'test'].num_examples // batch_size

    train_iterator = strategy.make_dataset_iterator(train_dataset)
    test_iterator = strategy.make_dataset_iterator(test_dataset)

    print('Training...')
    if train_mode == 'custom_loop':
      return trainer.custom_loop(train_iterator,
                                 test_iterator,
                                 num_train_steps_per_epoch,
                                 num_test_steps_per_epoch,
                                 strategy)
    elif train_mode == 'keras_fit':
      raise ValueError(
          '`tf.distribute.Strategy` does not support subclassed models yet.')
    else:
      raise ValueError(
          'Please enter either "keras_fit" or "custom_loop" as the argument.')


if __name__ == '__main__':
  utils.define_densenet_flags()
  app.run(run_main)