def RetinaNet(input_shape=None): # pylint: disable=W0613 NNCF_ROOT = TEST_ROOT.parent path_to_config = NNCF_ROOT.joinpath('examples', 'tensorflow', 'object_detection', 'configs', 'retinanet_coco.json') config_from_json = SampleConfig.from_json(path_to_config) predefined_config = get_predefined_config(config_from_json.model) predefined_config.update(config_from_json) model_builder = get_model_builder(predefined_config) model = model_builder.build_model() return model
def export(config): model_builder = get_model_builder(config) model = model_builder.build_model(weights=config.get('weights', None)) compression_ctrl, compress_model = create_compressed_model( model, config.nncf_config) if config.ckpt_path: checkpoint = tf.train.Checkpoint(model=compress_model) load_checkpoint(checkpoint, config.ckpt_path) save_path, save_format = get_saving_parameters(config) compression_ctrl.export_model(save_path, save_format) logger.info("Saved to {}".format(save_path))
def checkpoint_saver(config): """ Load checkpoint and re-save it without optimizer (memory footprint is reduced) """ model_builder = get_model_builder(config) model = model_builder.build_model() _, compress_model = create_compressed_model(model, config.nncf_config) checkpoint = tf.train.Checkpoint(model=compress_model) load_checkpoint(checkpoint, config.ckpt_path) checkpoint_manager = tf.train.CheckpointManager(checkpoint, config.checkpoint_save_dir, max_to_keep=None) save_path = checkpoint_manager.save() logger.info('Saved checkpoint: {}'.format(save_path))
def run(config): strategy = get_distribution_strategy(config) if config.metrics_dump is not None: write_metrics(0, config.metrics_dump) # Create dataset builders = get_dataset_builders(config, strategy.num_replicas_in_sync) datasets = [builder.build() for builder in builders] train_builder, test_builder = builders train_dataset, test_dataset = datasets train_dist_dataset = strategy.experimental_distribute_dataset( train_dataset) test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset) # Training parameters epochs = config.epochs steps_per_epoch = train_builder.steps_per_epoch num_test_batches = test_builder.steps_per_epoch # Create model builder model_builder = get_model_builder(config) with TFOriginalModelManager(model_builder.build_model, weights=config.get('weights', None)) as model: with strategy.scope(): compression_ctrl, compress_model = create_compressed_model( model, config.nncf_config) scheduler = build_scheduler(config=config, steps_per_epoch=steps_per_epoch) optimizer = build_optimizer(config=config, scheduler=scheduler) eval_metric = model_builder.eval_metrics() loss_fn = model_builder.build_loss_fn(compress_model, compression_ctrl.loss) predict_post_process_fn = model_builder.post_processing checkpoint = tf.train.Checkpoint(model=compress_model, optimizer=optimizer) checkpoint_manager = tf.train.CheckpointManager( checkpoint, config.checkpoint_save_dir, max_to_keep=None) initial_epoch = initial_step = 0 if config.ckpt_path: initial_epoch, initial_step = resume_from_checkpoint( checkpoint_manager, compression_ctrl, config.ckpt_path, steps_per_epoch, config) else: logger.info('Initialization...') compression_ctrl.initialize(dataset=train_dataset) train_step = create_train_step_fn(strategy, compress_model, loss_fn, optimizer) test_step = create_test_step_fn(strategy, compress_model, predict_post_process_fn) if 'train' in config.mode: train(train_step, test_step, eval_metric, train_dist_dataset, test_dist_dataset, initial_epoch, initial_step, epochs, steps_per_epoch, checkpoint_manager, compression_ctrl, config.log_dir, optimizer, num_test_batches, config.print_freq) print_statistics(compression_ctrl.statistics()) metric_result = evaluate(test_step, eval_metric, test_dist_dataset, num_test_batches, config.print_freq) logger.info('Validation metric = {}'.format(metric_result)) if config.metrics_dump is not None: write_metrics(metric_result['AP'], config.metrics_dump) if 'export' in config.mode: save_path, save_format = get_saving_parameters(config) compression_ctrl.export_model(save_path, save_format) logger.info("Saved to {}".format(save_path))