def __init__(self, image_dir, config_path, weights_path, threshold=0.5, class_names=None): super(GuiInferenceViewer, self).__init__('Inference Viewer') assert os.path.isdir(image_dir), 'No such directory: {}'.format( image_dir) self.threshold = threshold self.class_names = class_names # Get the list of image files self.image_files = sorted([ os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.lower().endswith('.jpg') or x.lower().endswith('.png') or x.lower().endswith('.bmp') ]) self.num_images = len(self.image_files) self.model = get_model_wrapper(load_config(config_path)) if weights_path: self.model.load_weights(weights_path) else: print( 'No weights path provided. Will use randomly initialized weights' ) self.create_slider() self.create_textbox() self.display()
help='Directory of the MS-COCO dataset') parser.add_argument('--tag', required=False, default='simple', metavar='<tag>', help='Tag of the KITTI dataset (default=simple)') parser.add_argument('-t', '--threshold', required=False, type=float, default=0.5, metavar='Threshold value for inference', help='Must be between 0 and 1.') args = parser.parse_args() model = get_model_wrapper(load_config(args.model_cfg)) model.load_weights(args.weights) dataset = KittiDataset() dataset.load_kitti(args.dataset, 'val', args.tag) assert dataset.num_classes == model.config.NUM_CLASSES num_classes = dataset.num_classes confusion_matrix = np.zeros((num_classes, num_classes)) for i in tqdm(range(dataset.num_images)): img, gt_mask = load_image_gt(dataset, i, model.config.IMAGE_SHAPE) pr_mask = model.predict(img.astype(np.float32), args.threshold) confusion_matrix += compute_confusion_matrix(gt_mask.astype(np.int), pr_mask.astype(np.int),
help= 'Workspace is a parent directory containing log directories (default=logs)' ) args = parser.parse_args() print('Model weights: ', args.weights) print('Dataset: ', args.dataset) print('Tag: ', args.tag) print('Workspace: ', args.workspace) # Model Configuration model_config = load_config(args.model_cfg) # Create model print('Building model...') model_wrapper = get_model_wrapper(model_config) # Training Configurations train_cfgs = [] for path in args.train_cfg: train_cfgs.append(load_config(path)) for stage, train_config in enumerate(train_cfgs): # Create trainer trainer = Trainer(model_wrapper=model_wrapper, train_config=train_config, workspace=args.workspace, stage=stage) # Load the weights file designated by the command line argument # only on the first stage and continue training on the current
def main(_run, _config, _log): experiment_name = _run.experiment_info['name'] _log.info(f'Starting run {experiment_name}') pprint.pprint(_config) # Initialize the random number generators random.seed(_config['seed']) np.random.seed(_config['seed']) tf.set_random_seed(_config['seed']) # Load training and validation datasets train_dataset, val_dataset = load_train_val_datasets( _config['image_size'], _config['batch_size'], _config['train_augmentation'], _config['seed']) # Initialize the model and model checkpointing model_wrapper = get_model_wrapper(_config) model_wrapper.model.summary() checkpoint = tf.train.Checkpoint(model=model_wrapper.model) checkpoint_dir = os.path.join(common.MODEL_DIR, experiment_name) checkpoint_manager = tf.contrib.checkpoint.CheckpointManager( checkpoint, directory=checkpoint_dir, max_to_keep=1) # Initialize the optimizer # https://arxiv.org/abs/1608.03983 batch_size = _config['batch_size'] examples_per_class = _config['train_examples_per_class'] batches_per_epoch = (len(common.WORD2LABEL) * examples_per_class) // batch_size global_step = tf.train.get_or_create_global_step() global_step.assign(0) learning_rate = tf.train.cosine_decay_restarts( _config['learning_rate'], global_step, first_decay_steps=_config['cycle_len'] * batches_per_epoch, t_mul=_config['cycle_mult']) # https://stackoverflow.com/a/50778921 optimizer = tf.train.MomentumOptimizer( learning_rate, _config['momentum'], use_nesterov=True) epoch_num = 0 cycle_count = 0 next_cycle_epoch = _config['cycle_len'] best_val_metric = 0 while True: epoch_num += 1 _log.info(f'Starting epoch {epoch_num}') train_loss, train_metric = model_wrapper.train(optimizer, train_dataset) _log.info('train_loss: {:.6f} train_MAP@3: {:.6f}' .format(train_loss, train_metric)) _run.log_scalar('train_loss', train_loss.numpy(), epoch_num) _run.log_scalar('train_MAP@3', train_metric, epoch_num) if epoch_num == next_cycle_epoch: cycle_count += 1 _log.info('Finished cycle {}/{}'.format(cycle_count, _config['max_cycles'])) # Evaluation on the validation dataset val_loss, val_metric = model_wrapper.validate(val_dataset) _log.info('val_loss: {:.6f} val_MAP@3: {:.6f}' .format(val_loss, val_metric)) _run.log_scalar('val_loss', val_loss.numpy(), epoch_num) _run.log_scalar('val_MAP@3', val_metric, epoch_num) if val_metric <= best_val_metric: # Early stopping _log.warning('val_MAP@3 did not improve {:.6f} <= {:.6f}' .format(val_metric, best_val_metric)) break else: _log.info('Model checkpoint. val_MAP@3: {:.6f} > {:.6f}' .format(val_metric, best_val_metric)) checkpoint_manager.save() best_val_metric = val_metric if cycle_count < _config['max_cycles']: # Update the target epoch for the new cycle next_cycle_epoch += _config['cycle_mult'] ** cycle_count else: # Reached the maximum number of cycles break # Generate predictions on the test dataset using the latest checkpoint. checkpoint_dir = checkpoint_manager.latest_checkpoint status = checkpoint.restore(checkpoint_dir) status.assert_consumed() test_key_ids, test_scores = generate_predictions(_config, _log, model_wrapper) # Generating the submission file. test_labels = tf.nn.top_k(test_scores, k=3)[1].numpy().tolist() predictions_file = os.path.join(common.PREDICTIONS_DIR, f'{experiment_name}.csv') with open(predictions_file, 'w', newline='') as f: writer = csv.DictWriter(f, ['key_id', 'word']) writer.writeheader() for key_id, labels in zip(test_key_ids, test_labels): words = [common.LABEL2WORD[i].replace(' ', '_') for i in labels] row = {'key_id': key_id.numpy().decode(), 'word': ' '.join(words)} writer.writerow(row) _run.add_artifact(predictions_file, 'predictions')