def __init__(self,
                 image_dir,
                 config_path,
                 weights_path,
                 threshold=0.5,
                 class_names=None):
        super(GuiInferenceViewer, self).__init__('Inference Viewer')
        assert os.path.isdir(image_dir), 'No such directory: {}'.format(
            image_dir)

        self.threshold = threshold
        self.class_names = class_names
        # Get the list of image files
        self.image_files = sorted([
            os.path.join(image_dir, x) for x in os.listdir(image_dir)
            if x.lower().endswith('.jpg') or x.lower().endswith('.png')
            or x.lower().endswith('.bmp')
        ])
        self.num_images = len(self.image_files)

        self.model = get_model_wrapper(load_config(config_path))
        if weights_path:
            self.model.load_weights(weights_path)
        else:
            print(
                'No weights path provided. Will use randomly initialized weights'
            )

        self.create_slider()
        self.create_textbox()

        self.display()
示例#2
0
                        help='Directory of the MS-COCO dataset')
    parser.add_argument('--tag',
                        required=False,
                        default='simple',
                        metavar='<tag>',
                        help='Tag of the KITTI dataset (default=simple)')
    parser.add_argument('-t',
                        '--threshold',
                        required=False,
                        type=float,
                        default=0.5,
                        metavar='Threshold value for inference',
                        help='Must be between 0 and 1.')
    args = parser.parse_args()

    model = get_model_wrapper(load_config(args.model_cfg))
    model.load_weights(args.weights)

    dataset = KittiDataset()
    dataset.load_kitti(args.dataset, 'val', args.tag)

    assert dataset.num_classes == model.config.NUM_CLASSES

    num_classes = dataset.num_classes

    confusion_matrix = np.zeros((num_classes, num_classes))
    for i in tqdm(range(dataset.num_images)):
        img, gt_mask = load_image_gt(dataset, i, model.config.IMAGE_SHAPE)
        pr_mask = model.predict(img.astype(np.float32), args.threshold)
        confusion_matrix += compute_confusion_matrix(gt_mask.astype(np.int),
                                                     pr_mask.astype(np.int),
示例#3
0
        help=
        'Workspace is a parent directory containing log directories (default=logs)'
    )
    args = parser.parse_args()

    print('Model weights: ', args.weights)
    print('Dataset: ', args.dataset)
    print('Tag: ', args.tag)
    print('Workspace: ', args.workspace)

    # Model Configuration
    model_config = load_config(args.model_cfg)

    # Create model
    print('Building model...')
    model_wrapper = get_model_wrapper(model_config)

    # Training Configurations
    train_cfgs = []
    for path in args.train_cfg:
        train_cfgs.append(load_config(path))

    for stage, train_config in enumerate(train_cfgs):
        # Create trainer
        trainer = Trainer(model_wrapper=model_wrapper,
                          train_config=train_config,
                          workspace=args.workspace,
                          stage=stage)

        # Load the weights file designated by the command line argument
        # only on the first stage and continue training on the current
示例#4
0
def main(_run, _config, _log):
    experiment_name = _run.experiment_info['name']
    _log.info(f'Starting run {experiment_name}')
    pprint.pprint(_config)

    # Initialize the random number generators
    random.seed(_config['seed'])
    np.random.seed(_config['seed'])
    tf.set_random_seed(_config['seed'])

    # Load training and validation datasets
    train_dataset, val_dataset = load_train_val_datasets(
        _config['image_size'], _config['batch_size'],
        _config['train_augmentation'], _config['seed'])

    # Initialize the model and model checkpointing
    model_wrapper = get_model_wrapper(_config)
    model_wrapper.model.summary()
    checkpoint = tf.train.Checkpoint(model=model_wrapper.model)
    checkpoint_dir = os.path.join(common.MODEL_DIR, experiment_name)
    checkpoint_manager = tf.contrib.checkpoint.CheckpointManager(
        checkpoint, directory=checkpoint_dir, max_to_keep=1)

    # Initialize the optimizer
    # https://arxiv.org/abs/1608.03983
    batch_size = _config['batch_size']
    examples_per_class = _config['train_examples_per_class']
    batches_per_epoch = (len(common.WORD2LABEL) * examples_per_class) // batch_size
    global_step = tf.train.get_or_create_global_step()
    global_step.assign(0)
    learning_rate = tf.train.cosine_decay_restarts(
        _config['learning_rate'], global_step,
        first_decay_steps=_config['cycle_len'] * batches_per_epoch,
        t_mul=_config['cycle_mult'])
    # https://stackoverflow.com/a/50778921
    optimizer = tf.train.MomentumOptimizer(
        learning_rate, _config['momentum'], use_nesterov=True)

    epoch_num = 0
    cycle_count = 0
    next_cycle_epoch = _config['cycle_len']
    best_val_metric = 0

    while True:
        epoch_num += 1
        _log.info(f'Starting epoch {epoch_num}')

        train_loss, train_metric = model_wrapper.train(optimizer, train_dataset)
        _log.info('train_loss: {:.6f} train_MAP@3: {:.6f}'
                  .format(train_loss, train_metric))
        _run.log_scalar('train_loss', train_loss.numpy(), epoch_num)
        _run.log_scalar('train_MAP@3', train_metric, epoch_num)

        if epoch_num == next_cycle_epoch:
            cycle_count += 1
            _log.info('Finished cycle {}/{}'.format(cycle_count, _config['max_cycles']))

            # Evaluation on the validation dataset
            val_loss, val_metric = model_wrapper.validate(val_dataset)
            _log.info('val_loss: {:.6f} val_MAP@3: {:.6f}'
                      .format(val_loss, val_metric))
            _run.log_scalar('val_loss', val_loss.numpy(), epoch_num)
            _run.log_scalar('val_MAP@3', val_metric, epoch_num)

            if val_metric <= best_val_metric:
                # Early stopping
                _log.warning('val_MAP@3 did not improve {:.6f} <= {:.6f}'
                             .format(val_metric, best_val_metric))
                break
            else:
                _log.info('Model checkpoint. val_MAP@3: {:.6f} > {:.6f}'
                          .format(val_metric, best_val_metric))
                checkpoint_manager.save()
                best_val_metric = val_metric
                if cycle_count < _config['max_cycles']:
                    # Update the target epoch for the new cycle
                    next_cycle_epoch += _config['cycle_mult'] ** cycle_count
                else:
                    # Reached the maximum number of cycles
                    break

    # Generate predictions on the test dataset using the latest checkpoint.
    checkpoint_dir = checkpoint_manager.latest_checkpoint
    status = checkpoint.restore(checkpoint_dir)
    status.assert_consumed()
    test_key_ids, test_scores = generate_predictions(_config, _log, model_wrapper)

    # Generating the submission file.
    test_labels = tf.nn.top_k(test_scores, k=3)[1].numpy().tolist()
    predictions_file = os.path.join(common.PREDICTIONS_DIR, f'{experiment_name}.csv')
    with open(predictions_file, 'w', newline='') as f:
        writer = csv.DictWriter(f, ['key_id', 'word'])
        writer.writeheader()
        for key_id, labels in zip(test_key_ids, test_labels):
            words = [common.LABEL2WORD[i].replace(' ', '_') for i in labels]
            row = {'key_id': key_id.numpy().decode(), 'word': ' '.join(words)}
            writer.writerow(row)
    _run.add_artifact(predictions_file, 'predictions')