def test_augment_audio_should_match_model_audio_input_shape_with_augmentor(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, vision_augmentor=augmentor)
            target_shape = model.audio_input_shape
            spectrogram = np.ones((target_shape[0], 10, 1))
            assert all(map(lambda x: x.shape == target_shape, generator.augment_audio(spectrogram)))
def test_augment_audio_with_invalid_spectrogram_height(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model)
            with pytest.raises(ValueError):
                spectrogram = np.ones((10, 10, 1))
                generator.augment_audio(spectrogram)
def test_segments_should_be_randomized_on_each_epoch_end(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16)
            epoch1 = generator[0]
            generator.on_epoch_end()
            epoch2 = generator[0]
            assert not np.array_equal(epoch1, epoch2)
def test_augment_audio_with_invalid_spectrogram_length(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model)
            target_shape = model.audio_input_shape
            with pytest.raises(ValueError):
                spectrogram = np.ones((target_shape[0], 1000, 1))
                generator.augment_audio(spectrogram)
def test_zip_samples_with_augmentor(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16, augmentor)
            frames = np.ones((2 * generator.batch_size * len(augmentor), *model.vision_input_shape))
            spectrograms = np.ones((2 * generator.batch_size * len(augmentor), *model.audio_input_shape))
            labels = np.ones((2 * generator.batch_size, *model.output_shape, 1))
            samples = generator.zip_samples(frames, spectrograms, labels)
            assert len(samples) == generator.sample_size
def test_getitem(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16, augmentor, augmentor)
            batch = generator[0]
            zipped = list(zip(*batch[0], *batch[1]))
            assert len(zipped) == generator.sample_size
def test_frames_should_be_normalized(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16)
            batch = generator[0]
            frames = batch[0][0]
            assert np.max(frames) <= 1.
def test_getitem_should_has_same_number_of_sample_for_all_inputs(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16, augmentor, augmentor)
            batch = generator[0]
            zipped = list(zip(*batch[0], *batch[1]))
            assert zipped is not None
def test_getitem_output_shape_with_augmentor(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16, augmentor, augmentor)
            batch_x, batch_y = generator[0]
            frames, spectrograms = batch_x
            labels, = batch_y
            assert all(map(lambda x: x.shape == model.vision_input_shape, frames))
            assert all(map(lambda x: x.shape == model.audio_input_shape, spectrograms))
            assert all(map(lambda x: x.shape == model.output_shape, labels))
def test_augment_vision_should_match_model_vision_input_shape_with_augmentor(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, vision_augmentor=augmentor)
            target_shape = model.vision_input_shape

            def assert_vision_input_shape(shape):
                frame = np.ones((shape))
                assert all(map(lambda x: x.shape == target_shape, generator.augment_vision(frame)))

            assert_vision_input_shape((10, 10, 3))
            assert_vision_input_shape((10, 1000, 3))
            assert_vision_input_shape((1000, 10, 3))
            assert_vision_input_shape((1000, 1000, 3))
def test_model_with_invalid_type(test_raw_file, segments):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            with pytest.raises(TypeError):
                assert SegmentsGenerator(segments, 0)
def test_audio_augmentor(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, audio_augmentor=augmentor)
            assert generator.audio_augmentor is not None
def test_with_empty_segments(model):
    with pytest.raises(ValueError):
        assert SegmentsGenerator([], model)
def test_augmentation_factor(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16)
            assert generator.augmentation_factor == 1
def test_audio_augmentor_with_invalid_type(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            with pytest.raises(TypeError):
                assert SegmentsGenerator(segments, model, audio_augmentor=1)
def test_sample_size_with_augmentor(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16, augmentor, augmentor)
            assert generator.sample_size == 2 * generator.batch_size * generator.augmentation_factor
def test_sample_size(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16)
            assert generator.sample_size == 2 * generator.batch_size
def test_augmentation_factor_with_augmentor(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16, augmentor, augmentor)
            assert generator.augmentation_factor == len(augmentor) * len(augmentor)
def test_with_invalid_segments_item_type(model):
    with pytest.raises(TypeError):
        assert SegmentsGenerator([0], model)
def test_default_augmentor(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model)
            assert generator.vision_augmentor is not None
            assert generator.audio_augmentor is not None
def test_with_invalid_segments_type():
    with pytest.raises(TypeError):
        assert SegmentsGenerator(0, None)
def test_len(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 25)
            assert len(generator) == math.ceil(len(segments) / 25)
def test_with_single_segment(test_raw_file, segment, model):
    with temp_dir(segment.dir):
        with temp_copy(test_raw_file, segment.dir):
            assert len(SegmentsGenerator(segment, model).segments) == 1
def test_len_should_ceil_the_batch_count(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, 16)
            assert len(generator) == math.ceil(len(segments) / 16)
def test_with_unavailable_segments(segments, model):
    with pytest.raises(ValueError):
        assert SegmentsGenerator(segments, model)
Esempio n. 26
0
def train(data_dir, train_segments, negative_segments, valid_segments,
          ontology, labels, seed, network, resume_training, epochs,
          initial_epoch, checkpoints_period, logdir, checkpoints, modeldir,
          output):
    random.seed(seed)
    tf.random.set_seed(seed)

    raw_dir = os.path.join(data_dir, 'raw')
    train_segments = SegmentsWrapper(train_segments, raw_dir)
    valid_segments = SegmentsWrapper(valid_segments, raw_dir)
    negative_segments = SegmentsWrapper(negative_segments, raw_dir)

    def segment_in_ontology(o):
        def decorator(s):
            return any(map(o.__contains__, s.positive_labels))

        return decorator

    videos_dir = os.path.join(data_dir, 'videos')
    ontology = Ontology(ontology, videos_dir)
    ontologies = ontology.retrieve(*labels)

    train_segments = filter(segment_in_ontology(ontologies), train_segments)
    train_segments = list(filter(attrgetter('is_available'), train_segments))

    valid_segments = filter(segment_in_ontology(ontologies), valid_segments)
    valid_segments = list(filter(attrgetter('is_available'), valid_segments))

    negative_segments = filter(segment_in_ontology(ontologies),
                               negative_segments)
    negative_segments = list(
        filter(attrgetter('is_available'), negative_segments))

    os.makedirs(logdir, exist_ok=True)

    with open(os.path.join(logdir, 'train_segments.txt'), 'w') as outfile:
        outfile.writelines(list(map(attrgetter('ytid'), train_segments)))

    with open(os.path.join(logdir, 'valid_segments.txt'), 'w') as outfile:
        outfile.writelines(list(map(attrgetter('ytid'), valid_segments)))

    print(len(train_segments), len(valid_segments))
    model = models.retrieve_model(network)()

    train_generator = SegmentsGenerator(train_segments, negative_segments,
                                        model, 55)
    valid_generator = SegmentsGenerator(valid_segments, negative_segments,
                                        model, 34)

    def decayer(epoch):
        return 1e-5 * math.pow((94. / 100), ((1 + epoch) // 16))

    numpyz_board = NumpyzBoard(logdir,
                               period=checkpoints_period,
                               resume_training=resume_training)
    model_checkpoint = ModelCheckpoint(checkpoints, period=checkpoints_period)
    lr_scheduler = LearningRateScheduler(decayer)

    callbacks = [numpyz_board, model_checkpoint, lr_scheduler]

    if resume_training:
        checkpoints_dir = os.path.dirname(checkpoints)
        checkpoint_models = os.listdir(checkpoints_dir)
        checkpoint_models = {
            int(x.split('-')[0]): x
            for x in checkpoint_models
        }
        initial_epoch = max(checkpoint_models.keys())
        latest_model = checkpoint_models[initial_epoch]
        model = keras_models.load_model(
            os.path.join(checkpoints_dir, latest_model))
    else:
        model: Model = model.compile()

    model.fit_generator(train_generator,
                        epochs=epochs,
                        callbacks=callbacks,
                        validation_data=valid_generator,
                        workers=34,
                        max_queue_size=21,
                        initial_epoch=initial_epoch)

    model_filepath = os.path.join(modeldir, '{}.h5'.format(output))
    model.save(model_filepath)
    print('Model save to', model_filepath)