Exemple #1
0
def compress_segments(data_dir, segments, ontology, labels, output_file):
    raw_dir = os.path.join(data_dir, 'raw')
    segments = SegmentsWrapper(segments, raw_dir)
    ontology = Ontology(ontology, os.path.join(data_dir, 'videos'))

    def transform_segment(s):
        return (s.ytid, str(float(s.start_seconds)), str(float(s.end_seconds)),
                '"{}"'.format(','.join(s.positive_labels)))

    def segment_in_ontology(o):
        def decorator(s):
            return any(map(o.__contains__, s.positive_labels))

        return decorator

    ontologies = ontology.retrieve(*labels)

    available_segments = os.listdir(raw_dir)
    available_segments = filter(segments.__contains__, available_segments)
    available_segments = map(segments.__getitem__, available_segments)
    available_segments = filter(attrgetter('is_available'), available_segments)
    available_segments = filter(segment_in_ontology(ontologies),
                                available_segments)
    available_segments = map(transform_segment, available_segments)
    available_segments = map(', '.join, available_segments)
    available_segments = '\n'.join(available_segments)

    with open(output_file, 'w') as outfile:
        outfile.write(available_segments)

    print('Segments file saved to', output_file)
Exemple #2
0
def test_compress_segments(data_dir, segments, ontology):
    with temp_dir(data_dir):
        outfile = os.path.join(data_dir, 'compressed_segments.csv')
        commands.dataset.download(('animal', ),
                                  data_dir,
                                  segments.filename,
                                  ontology.filename,
                                  limit=1,
                                  seed=1)
        commands.dataset.compress_segments(data_dir, segments.filename,
                                           ontology.filename, ('animal', ),
                                           outfile)
        segments = SegmentsWrapper(outfile, os.path.join(data_dir, 'raw'))
        segments_length = len(segments)
    assert segments_length == 1
Exemple #3
0
def cleanup(data_dir, segments, audio, frames, spectrograms):
    segments = SegmentsWrapper(segments, os.path.join(data_dir, 'raw'))
    segments = list(filter(attrgetter('is_available'), segments))

    for s in segments:
        if os.path.exists(s.wav) and audio:
            print(f'{s.ytid}: Deleting audio...', end='\r')
            os.remove(s.wav)
            print(f'{s.ytid}: Deleting audio OK')

        if os.path.exists(s.frames_dir) and frames:
            print(f'{s.ytid}: Deleting frames...', end='\r')
            shutil.rmtree(s.frames_dir)
            print(f'{s.ytid}: Deleting frames OK')

        if os.path.exists(s.spectrograms_dir) and spectrograms:
            print(f'{s.ytid}: Deleting spectrograms...', end='\r')
            shutil.rmtree(s.spectrograms_dir)
            print(f'{s.ytid}: Deleting spectrograms OK')
Exemple #4
0
def segments(data_dir):
    return SegmentsWrapper('tests/data/segments/test.csv',
                           os.path.join(data_dir, 'raw'))
def test_create_with_invalid_filename_type():
    with pytest.raises(TypeError):
        assert SegmentsWrapper(0, None)
def test_create_with_non_existing_segments_file(non_existing_segments_file):
    with pytest.raises(FileNotFoundError):
        assert SegmentsWrapper(non_existing_segments_file, None)
def segments_with_lot_of_comments():
    return SegmentsWrapper('tests/data/segments/lot_of_comments.csv',
                           'tests/.temp/segments')
def segments():
    return SegmentsWrapper('tests/data/segments/test.csv',
                           'tests/.temp/segments')
Exemple #9
0
def train(data_dir, train_segments, negative_segments, valid_segments,
          ontology, labels, seed, network, resume_training, epochs,
          initial_epoch, checkpoints_period, logdir, checkpoints, modeldir,
          output):
    random.seed(seed)
    tf.random.set_seed(seed)

    raw_dir = os.path.join(data_dir, 'raw')
    train_segments = SegmentsWrapper(train_segments, raw_dir)
    valid_segments = SegmentsWrapper(valid_segments, raw_dir)
    negative_segments = SegmentsWrapper(negative_segments, raw_dir)

    def segment_in_ontology(o):
        def decorator(s):
            return any(map(o.__contains__, s.positive_labels))

        return decorator

    videos_dir = os.path.join(data_dir, 'videos')
    ontology = Ontology(ontology, videos_dir)
    ontologies = ontology.retrieve(*labels)

    train_segments = filter(segment_in_ontology(ontologies), train_segments)
    train_segments = list(filter(attrgetter('is_available'), train_segments))

    valid_segments = filter(segment_in_ontology(ontologies), valid_segments)
    valid_segments = list(filter(attrgetter('is_available'), valid_segments))

    negative_segments = filter(segment_in_ontology(ontologies),
                               negative_segments)
    negative_segments = list(
        filter(attrgetter('is_available'), negative_segments))

    os.makedirs(logdir, exist_ok=True)

    with open(os.path.join(logdir, 'train_segments.txt'), 'w') as outfile:
        outfile.writelines(list(map(attrgetter('ytid'), train_segments)))

    with open(os.path.join(logdir, 'valid_segments.txt'), 'w') as outfile:
        outfile.writelines(list(map(attrgetter('ytid'), valid_segments)))

    print(len(train_segments), len(valid_segments))
    model = models.retrieve_model(network)()

    train_generator = SegmentsGenerator(train_segments, negative_segments,
                                        model, 55)
    valid_generator = SegmentsGenerator(valid_segments, negative_segments,
                                        model, 34)

    def decayer(epoch):
        return 1e-5 * math.pow((94. / 100), ((1 + epoch) // 16))

    numpyz_board = NumpyzBoard(logdir,
                               period=checkpoints_period,
                               resume_training=resume_training)
    model_checkpoint = ModelCheckpoint(checkpoints, period=checkpoints_period)
    lr_scheduler = LearningRateScheduler(decayer)

    callbacks = [numpyz_board, model_checkpoint, lr_scheduler]

    if resume_training:
        checkpoints_dir = os.path.dirname(checkpoints)
        checkpoint_models = os.listdir(checkpoints_dir)
        checkpoint_models = {
            int(x.split('-')[0]): x
            for x in checkpoint_models
        }
        initial_epoch = max(checkpoint_models.keys())
        latest_model = checkpoint_models[initial_epoch]
        model = keras_models.load_model(
            os.path.join(checkpoints_dir, latest_model))
    else:
        model: Model = model.compile()

    model.fit_generator(train_generator,
                        epochs=epochs,
                        callbacks=callbacks,
                        validation_data=valid_generator,
                        workers=34,
                        max_queue_size=21,
                        initial_epoch=initial_epoch)

    model_filepath = os.path.join(modeldir, '{}.h5'.format(output))
    model.save(model_filepath)
    print('Model save to', model_filepath)
Exemple #10
0
def download(labels,
             data_dir,
             segments,
             ontology,
             limit=None,
             min_size=None,
             max_size=None,
             blacklist=None,
             seed=None):
    random.seed(seed)
    segments = SegmentsWrapper(segments, os.path.join(data_dir, 'raw'))
    ontology = Ontology(ontology, os.path.join(data_dir, 'videos'))

    if blacklist is None:
        blacklist = pd.DataFrame(columns=['YTID', 'reason'])
    else:
        blacklist = pd.read_csv(blacklist)

    def segment_in_ontology(o):
        def decorator(s):
            return any(map(o.__contains__, s.positive_labels))

        return decorator

    def filter_by_ontology(s):
        def decorator(o):
            return list(filter(segment_in_ontology(o), s))

        return decorator

    ontologies = ontology.retrieve(*labels)
    segments = list(filter(segment_in_ontology(ontologies), segments))

    ontologies = list(map(ontology.retrieve, labels))
    downloaded = list(filter(attrgetter('is_available'), segments))
    downloaded = map(filter_by_ontology(downloaded), ontologies)
    downloaded = zip(map(attrgetter('name'), ontologies), downloaded)

    counter = {name: s for name, s in downloaded}
    pprint.pprint({name: len(s) for name, s in counter.items()})

    random.shuffle(segments)

    for segment in segments:
        finished = limit is not None and all(
            map(limit.__le__, map(len, counter.values())))
        print(list(map(len, counter.values())))

        if finished:
            break

        if not any(map(lambda x: segment.ytid in x, counter.values())):

            if limit is not None:
                ok = True
                exceeded = list()

                for ont in ontologies:
                    if segment_in_ontology(ont)(segment) and len(
                            counter[ont.name]) >= limit:
                        ok = False
                        exceeded.append(ont.proper_name)

                if not ok:
                    print('[{}] "{}" has reached limit.'.format(
                        segment.ytid, exceeded))
                    continue

            blacklisted = blacklist[blacklist['YTID'] == segment.ytid]
            if not blacklisted.empty:
                print(
                    '[{}] is blacklisted. {}.'.format(*blacklisted.values[0]))
                continue

            info = yt.info(segment.ytid)
            if info == -1:
                continue

            formats = filter(lambda x: 'filesize' in x, info['formats'])
            filesizes = map(itemgetter('filesize'), formats)
            filesizes = list(filter(lambda x: x is not None, filesizes))
            filesize = int(max(filesizes) / 1024 / 1024) if filesizes else None

            if filesize is None:
                print('[{}] cannot retrieve filesize from youtube info'.format(
                    segment.ytid))
                continue

            if min_size is not None and filesize < min_size:
                print('[{}] smaller than min_size ({} MiB).'.format(
                    segment.ytid, filesize))
                continue

            if max_size is not None and filesize > max_size:
                print('[{}] exceeds max_size ({} MiB).'.format(
                    segment.ytid, filesize))
                continue

            yt.dl(segment.ytid, outtmpl=segment.ydl_outtmpl)

            for ont in ontologies:
                if segment_in_ontology(ont)(segment):
                    counter[ont.name].append(segment)
Exemple #11
0
def preprocess(data_dir, segments, workers=1):
    if not isinstance(workers, int):
        raise TypeError('WORKERS can\'t be of type {}'.format(
            type(workers).__name__))

    if workers < 0:
        raise ValueError('WORKERS must be positive (not {}).'.format(workers))

    def thread_print_function(thread_id):
        def decorator(*args, wait=False, **kwargs):
            end = '\r' if wait and workers == 1 else '\n'
            if workers > 1:
                args = ('[Thread {}]'.format(thread_id), ) + args
            print(*args, **kwargs, end=end)

        return decorator

    def thread_function(thread_id, thread_segments):
        print_function = thread_print_function(thread_id)

        for i, segment in enumerate(thread_segments):
            print_function('{}: ({} / {})'.format(segment.ytid, i + 1,
                                                  len(thread_segments)))
            print_function('{}: Extracting frames'.format(segment.ytid),
                           wait=True)

            ops.extract_frames(segment.raw, segment.frames_dir,
                               segment.start_seconds)

            print_function('{}: Extracting frames (finished)'.format(
                segment.ytid))
            print_function('{}: Computing spectrograms'.format(segment.ytid),
                           wait=True)

            waveform, sr = segment.waveform
            for j in range(segment.start_frames,
                           min(segment.end_frames, len(segment))):
                if workers == 1:
                    print_function('{}: Computing spectrograms ({})'.format(
                        segment.ytid, j),
                                   wait=True)

                if not os.path.exists(segment.spectrogram(j)):
                    start_samples = segment.get_sample_index(j)
                    samples_slice = slice(start_samples,
                                          start_samples + segment.sample_rate)
                    ops.compute_spectrogram(waveform[samples_slice],
                                            segment.spectrogram(j))

            print_function('{}: Computing spectrograms (finished)'.format(
                segment.ytid))

    segments = SegmentsWrapper(segments, os.path.join(data_dir, 'raw'))
    segments = list(filter(attrgetter('is_available'), segments))

    thread_args = list()

    for idx in range(workers):
        thread_size = math.ceil(len(segments) / workers)
        thread_start = idx * thread_size
        thread_args.append(
            (idx, segments[thread_start:thread_start + thread_size]))

    fork(workers, thread_function, *thread_args)