Example #1
0
    def _select_stream(self, streams, chosen_idx, name):
        if not streams:
            raise SushiError('No {0} streams found in {1}'.format(
                name, self._path))
        if chosen_idx is None:
            if len(streams) > 1:
                default_track = next((s for s in streams if s.default), None)
                if default_track:
                    logging.warning(
                        'Using default track {0} in {1} because there are multiple candidates'
                        .format(self._format_stream(default_track),
                                self._path))
                    return default_track
                raise SushiError(
                    'More than one {0} stream found in {1}.'
                    'You need to specify the exact one to demux. Here are all candidates:\n'
                    '{2}'.format(name, self._path,
                                 self._format_streams_list(streams)))
            return streams[0]

        try:
            return next(x for x in streams if x.id == chosen_idx)
        except StopIteration:
            raise SushiError("Stream with index {0} doesn't exist in {1}.\n"
                             "Here are all that do:\n"
                             "{2}".format(chosen_idx, self._path,
                                          self._format_streams_list(streams)))
Example #2
0
    def from_file(cls, path):
        script_info, styles, events = [], [], []
        other_sections = collections.OrderedDict()

        def parse_script_info_line(line):
            if line.startswith(u'Format:'):
                return
            script_info.append(line)

        def parse_styles_line(line):
            if line.startswith(u'Format:'):
                return
            styles.append(line)

        def parse_event_line(line):
            if line.startswith(u'Format:'):
                return
            events.append(AssEvent(line, position=len(events) + 1))

        def create_generic_parse(section_name):
            if section_name in other_sections:
                raise SushiError("Duplicate section detected, invalid script?")
            other_sections[section_name] = []
            return other_sections[section_name].append

        parse_function = None

        try:
            with codecs.open(path, encoding='utf-8-sig') as script:
                for line_idx, line in enumerate(script):
                    line = line.strip()
                    if not line:
                        continue
                    if line[0] == u';':
                        continue

                    low = line.lower()
                    if low == u'[script info]':
                        parse_function = parse_script_info_line
                    elif low == u'[v4+ styles]':
                        parse_function = parse_styles_line
                    elif low == u'[events]':
                        parse_function = parse_event_line
                    elif re.match(r'\[.+?\]', low):
                        parse_function = create_generic_parse(line)
                    elif not parse_function:
                        raise SushiError(
                            "That's some invalid ASS script: {0} [line {1}]".
                            format(line, line_idx))
                    else:
                        try:
                            parse_function(line)
                        except Exception as e:
                            raise SushiError(
                                "That's some invalid ASS script: {0} [line {1}]"
                                .format(e.message, line_idx))
        except IOError:
            raise SushiError("Script {0} not found".format(path))
        return cls(script_info, styles, events, other_sections)
Example #3
0
    def __init__(self, path, sample_rate=12000, sample_type='uint8'):
        if sample_type not in ('float32', 'uint8'):
            raise SushiError('Unknown sample type of WAV stream, must be uint8 or float32')

        stream = DownmixedWavFile(path)
        total_seconds = stream.frames_count / float(stream.framerate)
        downsample_rate = sample_rate / float(stream.framerate)

        self.sample_count = math.ceil(total_seconds * sample_rate)
        self.sample_rate = sample_rate
        # pre-allocating the data array and some place for padding
        self.data = np.empty((1, int(self.PADDING_SECONDS * 2 * stream.framerate + self.sample_count)), np.float32)
        self.padding_size = 10 * stream.framerate
        before_read = time()
        try:
            seconds_read = 0
            samples_read = self.padding_size
            while seconds_read < total_seconds:
                data = stream.readframes(int(self.READ_CHUNK_SIZE * stream.framerate))
                new_length = int(round(len(data) * downsample_rate))

                dst_view = self.data[0][samples_read:samples_read+new_length]

                if downsample_rate != 1:
                    data = data.reshape((1, len(data)))
                    data = cv2.resize(data, (new_length, 1), interpolation=cv2.INTER_NEAREST)[0]

                np.copyto(dst_view, data, casting='no')
                samples_read += new_length
                seconds_read += self.READ_CHUNK_SIZE

            # padding the audio from both sides
            self.data[0][0:self.padding_size].fill(self.data[0][self.padding_size])
            self.data[0][-self.padding_size:].fill(self.data[0][-self.padding_size-1])

            # normalizing
            # also clipping the stream by 3*median value from both sides of zero
            max_value = np.median(self.data[self.data >= 0], overwrite_input=True) * 3
            min_value = np.median(self.data[self.data <= 0], overwrite_input=True) * 3

            np.clip(self.data, min_value, max_value, out=self.data)

            self.data -= min_value
            self.data /= (max_value - min_value)

            if sample_type == 'uint8':
                self.data *= 255.0
                self.data += 0.5
                self.data = self.data.astype('uint8')

        except Exception as e:
            raise SushiError('Error while loading {0}: {1}'.format(path, e))
        finally:
            stream.close()
        logging.info('Done reading WAV {0} in {1}s'.format(path, time() - before_read))
Example #4
0
    def demux_file(input_path, **kwargs):
        args = ['ffmpeg', '-hide_banner', '-i', input_path, '-y']

        audio_stream = kwargs.get('audio_stream', None)
        audio_path = kwargs.get('audio_path', None)
        audio_rate = kwargs.get('audio_rate', None)
        if audio_stream is not None:
            args.extend(('-map', '0:{0}'.format(audio_stream)))
            if audio_rate:
                args.extend(('-ar', str(audio_rate)))
            args.extend(('-ac', '1', '-acodec', 'pcm_s16le', audio_path))

        script_stream = kwargs.get('script_stream', None)
        script_path = kwargs.get('script_path', None)
        if script_stream is not None:
            args.extend(('-map', '0:{0}'.format(script_stream), script_path))

        video_stream = kwargs.get('video_stream', None)
        timecodes_path = kwargs.get('timecodes_path', None)
        if timecodes_path is not None:
            args.extend(('-map', '0:{0}'.format(video_stream), '-f',
                         'mkvtimestamp_v2', timecodes_path))

        logging.debug('ffmpeg args: {0}'.format(' '.join(
            ('"{0}"' if ' ' in a else '{0}').format(a) for a in args)))
        try:
            subprocess.call(args)
        except OSError as e:
            if e.errno == 2:
                raise SushiError(
                    "Couldn't invoke ffmpeg, check that it's installed")
            raise
Example #5
0
    def readframes(self, count):
        if not count:
            return ''
        data = self._file.read(count * self.frame_size)
        if self.sample_width == 2:
            unpacked = np.fromstring(data, dtype=np.int16)
        elif self.sample_width == 3:
            bytes = np.ndarray(len(data), 'int8', data)
            unpacked = np.zeros(len(data) / 3, np.int16)
            unpacked.view(dtype='int8')[0::2] = bytes[1::3]
            unpacked.view(dtype='int8')[1::2] = bytes[2::3]
        else:
            raise SushiError('Unsupported sample width: {0}'.format(
                self.sample_width))

        unpacked = unpacked.astype('float32')

        if self.channels_count == 1:
            return unpacked
        else:
            min_length = len(unpacked) // self.channels_count
            real_length = len(unpacked) / float(self.channels_count)
            if min_length != real_length:
                logging.error(
                    "Length of audio channels didn't match. This might result in broken output"
                )

            channels = (unpacked[i::self.channels_count]
                        for i in xrange(self.channels_count))
            data = reduce(lambda a, b: a[:min_length] + b[:min_length],
                          channels)
            data /= float(self.channels_count)
            return data
Example #6
0
def parse_keyframes(path):
    text = read_all_text(path)
    if text.find('# XviD 2pass stat file')>=0:
        frames = parse_scxvid_keyframes(text)
    else:
        raise SushiError('Unsupported keyframes type')
    if 0 not in frames:
        frames.insert(0, 0)
    return frames
Example #7
0
 def _read_fmt_chunk(self, chunk):
     wFormatTag, self.channels_count, self.framerate, dwAvgBytesPerSec, wBlockAlign = struct.unpack(
         '<HHLLH', chunk.read(14))
     if wFormatTag == WAVE_FORMAT_PCM or wFormatTag == WAVE_FORMAT_EXTENSIBLE:  # ignore the rest
         bits_per_sample = struct.unpack('<H', chunk.read(2))[0]
         self.sample_width = (bits_per_sample + 7) // 8
     else:
         raise SushiError('unknown format: {0}'.format(wFormatTag))
     self.frame_size = self.channels_count * self.sample_width
Example #8
0
 def get_info(path):
     try:
         process = subprocess.Popen(['ffmpeg', '-hide_banner', '-i', path], stderr=subprocess.PIPE)
         out, err = process.communicate()
         process.wait()
         return err
     except OSError as e:
         if e.errno == 2:
             raise SushiError("Couldn't invoke ffmpeg, check that it's installed")
         raise
Example #9
0
    def make_keyframes(cls, video_path, log_path):
        try:
            ffmpeg_process = subprocess.Popen(['ffmpeg', '-i', video_path,
                            '-f', 'yuv4mpegpipe',
                            '-vf', 'scale=640:360',
                            '-pix_fmt', 'yuv420p',
                            '-vsync', 'drop', '-'], stdout=subprocess.PIPE)
        except OSError as e:
            if e.errno == 2:
                raise SushiError("Couldn't invoke ffmpeg, check that it's installed")
            raise

        try:
            scxvid_process = subprocess.Popen(['SCXvid', log_path], stdin=ffmpeg_process.stdout)
        except OSError as e:
            ffmpeg_process.kill()
            if e.errno == 2:
                raise SushiError("Couldn't invoke scxvid, check that it's installed")
            raise
        scxvid_process.wait()
Example #10
0
 def select_timecodes(external_file, fps_arg, demuxer):
     if external_file:
         return external_file
     elif fps_arg:
         return None
     elif demuxer.has_video:
         path = format_full_path(args.temp_dir, demuxer.path, '.sushi.timecodes.txt')
         demuxer.set_timecodes(output_path=path)
         return path
     else:
         raise SushiError('Fps, timecodes or video files must be provided if keyframes are used')
Example #11
0
 def select_keyframes(file_arg, demuxer):
     auto_file = format_full_path(args.temp_dir, demuxer.path, '.sushi.keyframes.txt')
     if file_arg in ('auto', 'make'):
         if file_arg == 'make' or not os.path.exists(auto_file):
             if not demuxer.has_video:
                 raise SushiError("Cannot make keyframes for {0} because it doesn't have any video!"
                                  .format(demuxer.path))
             demuxer.set_keyframes(output_path=auto_file)
         return auto_file
     else:
         return file_arg
Example #12
0
def running_median(values, window_size):
    if window_size % 2 != 1:
        raise SushiError('Median window size should be odd')
    half_window = window_size // 2
    medians = []
    items_count = len(values)
    for idx in xrange(items_count):
        radius = min(half_window, idx, items_count - idx - 1)
        med = np.median(values[idx - radius:idx + radius + 1])
        medians.append(med)
    return medians
Example #13
0
    def __init__(self, path):
        super(DownmixedWavFile, self).__init__()
        self._file = None
        self._file = open(path, 'rb')
        try:
            riff = Chunk(self._file, bigendian=False)
            if riff.getname() != 'RIFF':
                raise SushiError('File does not start with RIFF id')
            if riff.read(4) != 'WAVE':
                raise SushiError('Not a WAVE file')

            fmt_chunk_read = False
            data_chink_read = False
            file_size = os.path.getsize(path)

            while True:
                try:
                    chunk = Chunk(self._file, bigendian=False)
                except EOFError:
                    break

                if chunk.getname() == 'fmt ':
                    self._read_fmt_chunk(chunk)
                    fmt_chunk_read = True
                elif chunk.getname() == 'data':
                    if file_size > 0xFFFFFFFF:
                        # large broken wav
                        self.frames_count = (
                            file_size - self._file.tell()) // self.frame_size
                    else:
                        self.frames_count = chunk.chunksize // self.frame_size
                    data_chink_read = True
                    break
                chunk.skip()
            if not fmt_chunk_read or not data_chink_read:
                raise SushiError('Invalid WAV file')
        except:
            if self._file:
                self._file.close()
            raise
Example #14
0
 def from_file(cls, path):
     try:
         with codecs.open(path, encoding='utf-8-sig') as script:
             text = script.read()
             events_list = [SrtEvent(
                 source_index=int(match.group(1)),
                 start=SrtEvent.parse_time(match.group(2)),
                 end=SrtEvent.parse_time(match.group(3)),
                 text=match.group(4).strip()
             ) for match in SrtEvent.EVENT_REGEX.finditer(text)]
             return cls(events_list)
     except IOError:
         raise SushiError("Script {0} not found".format(path))
Example #15
0
 def parse(cls, text):
     lines = text.splitlines()
     if not lines:
         return []
     first = lines[0].lower().lstrip()
     if first.startswith('# timecode format v2'):
         tcs = [float(x) / 1000.0 for x in lines[1:]]
         return Timecodes(tcs, None)
     elif first.startswith('# timecode format v1'):
         default = float(lines[1].lower().replace('assume ', ""))
         overrides = (x.split(',') for x in lines[2:])
         return Timecodes(cls._convert_v1_to_v2(default, overrides), default)
     else:
         raise SushiError('This timecodes format is not supported')
Example #16
0
 def create_generic_parse(section_name):
     if section_name in other_sections:
         raise SushiError("Duplicate section detected, invalid script?")
     other_sections[section_name] = []
     return other_sections[section_name].append
Example #17
0
 def from_file(cls, path):
     try:
         with codecs.open(path, encoding='utf-8-sig') as script:
             return cls([SrtEvent(x) for x in script.read().replace(os.linesep, '\n').split('\n\n') if x])
     except IOError:
         raise SushiError("Script {0} not found".format(path))
Example #18
0
def check_file_exists(path, file_title):
    if path and not os.path.exists(path):
        raise SushiError("{0} file doesn't exist".format(file_title))
Example #19
0
def run(args):
    ignore_chapters = args.chapters_file is not None and args.chapters_file.lower(
    ) == 'none'
    write_plot = plot_enabled and args.plot_path
    if write_plot:
        plt.clf()
        plt.ylabel('Shift, seconds')
        plt.xlabel('Event index')

    # first part should do all possible validation and should NOT take significant amount of time
    check_file_exists(args.source, 'Source')
    check_file_exists(args.destination, 'Destination')
    check_file_exists(args.src_timecodes, 'Source timecodes')
    check_file_exists(args.dst_timecodes, 'Source timecodes')
    check_file_exists(args.script_file, 'Script')

    if not ignore_chapters:
        check_file_exists(args.chapters_file, 'Chapters')
    if args.src_keyframes not in ('auto', 'make'):
        check_file_exists(args.src_keyframes, 'Source keyframes')
    if args.dst_keyframes not in ('auto', 'make'):
        check_file_exists(args.dst_keyframes, 'Destination keyframes')

    if (args.src_timecodes and args.src_fps) or (args.dst_timecodes
                                                 and args.dst_fps):
        raise SushiError(
            'Both fps and timecodes file cannot be specified at the same time')

    src_demuxer = Demuxer(args.source)
    dst_demuxer = Demuxer(args.destination)

    if src_demuxer.is_wav and not args.script_file:
        raise SushiError("Script file isn't specified")

    if (args.src_keyframes
            and not args.dst_keyframes) or (args.dst_keyframes
                                            and not args.src_keyframes):
        raise SushiError(
            'Either none or both of src and dst keyframes should be provided')

    create_directory_if_not_exists(args.temp_dir)

    # selecting source audio
    if src_demuxer.is_wav:
        src_audio_path = args.source
    else:
        src_audio_path = format_full_path(args.temp_dir, args.source,
                                          '.sushi.wav')
        src_demuxer.set_audio(stream_idx=args.src_audio_idx,
                              output_path=src_audio_path,
                              sample_rate=args.sample_rate)

    # selecting destination audio
    if dst_demuxer.is_wav:
        dst_audio_path = args.destination
    else:
        dst_audio_path = format_full_path(args.temp_dir, args.destination,
                                          '.sushi.wav')
        dst_demuxer.set_audio(stream_idx=args.dst_audio_idx,
                              output_path=dst_audio_path,
                              sample_rate=args.sample_rate)

    # selecting source subtitles
    if args.script_file:
        src_script_path = args.script_file
    else:
        stype = src_demuxer.get_subs_type(args.src_script_idx)
        src_script_path = format_full_path(args.temp_dir, args.source,
                                           '.sushi' + stype)
        src_demuxer.set_script(stream_idx=args.src_script_idx,
                               output_path=src_script_path)

    script_extension = get_extension(src_script_path)
    if script_extension not in ('.ass', '.srt'):
        raise SushiError('Unknown script type')

    # selection destination subtitles
    if args.output_script:
        dst_script_path = args.output_script
        dst_script_extension = get_extension(args.output_script)
        if dst_script_extension != script_extension:
            raise SushiError(
                "Source and destination script file types don't match ({0} vs {1})"
                .format(script_extension, dst_script_extension))
    else:
        dst_script_path = format_full_path(args.temp_dir, args.destination,
                                           '.sushi' + script_extension)

    # selecting chapters
    if args.grouping and not ignore_chapters:
        if args.chapters_file:
            if get_extension(args.chapters_file) == '.xml':
                chapter_times = chapters.get_xml_start_times(
                    args.chapters_file)
            else:
                chapter_times = chapters.get_ogm_start_times(
                    args.chapters_file)
        elif not src_demuxer.is_wav:
            chapter_times = src_demuxer.chapters
            output_path = format_full_path(args.temp_dir, src_demuxer.path,
                                           ".sushi.chapters.txt")
            src_demuxer.set_chapters(output_path)
        else:
            chapter_times = []
    else:
        chapter_times = []

    # selecting keyframes and timecodes
    if args.src_keyframes:

        def select_keyframes(file_arg, demuxer):
            auto_file = format_full_path(args.temp_dir, demuxer.path,
                                         '.sushi.keyframes.txt')
            if file_arg in ('auto', 'make'):
                if file_arg == 'make' or not os.path.exists(auto_file):
                    if not demuxer.has_video:
                        raise SushiError(
                            "Cannot make keyframes for {0} because it doesn't have any video!"
                            .format(demuxer.path))
                    demuxer.set_keyframes(output_path=auto_file)
                return auto_file
            else:
                return file_arg

        def select_timecodes(external_file, fps_arg, demuxer):
            if external_file:
                return external_file
            elif fps_arg:
                return None
            elif demuxer.has_video:
                path = format_full_path(args.temp_dir, demuxer.path,
                                        '.sushi.timecodes.txt')
                demuxer.set_timecodes(output_path=path)
                return path
            else:
                raise SushiError(
                    'Fps, timecodes or video files must be provided if keyframes are used'
                )

        src_keyframes_file = select_keyframes(args.src_keyframes, src_demuxer)
        dst_keyframes_file = select_keyframes(args.dst_keyframes, dst_demuxer)
        src_timecodes_file = select_timecodes(args.src_timecodes, args.src_fps,
                                              src_demuxer)
        dst_timecodes_file = select_timecodes(args.dst_timecodes, args.dst_fps,
                                              dst_demuxer)

    # after this point nothing should fail so it's safe to start slow operations
    # like running the actual demuxing
    src_demuxer.demux()
    dst_demuxer.demux()

    try:
        if args.src_keyframes:
            src_timecodes = Timecodes.cfr(
                args.src_fps) if args.src_fps else Timecodes.from_file(
                    src_timecodes_file)
            src_keytimes = [
                src_timecodes.get_frame_time(f)
                for f in parse_keyframes(src_keyframes_file)
            ]

            dst_timecodes = Timecodes.cfr(
                args.dst_fps) if args.dst_fps else Timecodes.from_file(
                    dst_timecodes_file)
            dst_keytimes = [
                dst_timecodes.get_frame_time(f)
                for f in parse_keyframes(dst_keyframes_file)
            ]

        script = AssScript.from_file(
            src_script_path
        ) if script_extension == '.ass' else SrtScript.from_file(
            src_script_path)
        script.sort_by_time()

        src_stream = WavStream(src_audio_path,
                               sample_rate=args.sample_rate,
                               sample_type=args.sample_type)
        dst_stream = WavStream(dst_audio_path,
                               sample_rate=args.sample_rate,
                               sample_type=args.sample_type)

        calculate_shifts(
            src_stream,
            dst_stream,
            script.events,
            chapter_times=chapter_times,
            window=args.window,
            max_window=args.max_window,
            rewind_thresh=args.rewind_thresh if args.grouping else 0,
            max_ts_duration=args.max_ts_duration,
            max_ts_distance=args.max_ts_distance)

        events = script.events

        if write_plot:
            plt.plot([x.shift for x in events], label='From audio')

        if args.grouping:
            if not ignore_chapters and chapter_times:
                groups = groups_from_chapters(events, chapter_times)
                for g in groups:
                    fix_near_borders(g)
                    smooth_events([x for x in g if not x.linked],
                                  args.smooth_radius)
                groups = split_broken_groups(groups, args.min_group_size)
            else:
                fix_near_borders(events)
                smooth_events([x for x in events if not x.linked],
                              args.smooth_radius)
                groups = detect_groups(events, args.min_group_size)

            if write_plot:
                plt.plot([x.shift for x in events], label='Borders fixed')

            for g in groups:
                start_shift = g[0].shift
                end_shift = g[-1].shift
                avg_shift = average_shifts(g)
                logging.info(
                    u'Group (start: {0}, end: {1}, lines: {2}), '
                    u'shifts (start: {3}, end: {4}, average: {5})'.format(
                        format_time(g[0].start), format_time(g[-1].end),
                        len(g), start_shift, end_shift, avg_shift))

            if args.src_keyframes:
                for e in (x for x in events if x.linked):
                    e.resolve_link()
                for g in groups:
                    snap_groups_to_keyframes(
                        g, chapter_times, args.max_ts_duration,
                        args.max_ts_distance, src_keytimes, dst_keytimes,
                        src_timecodes, dst_timecodes, args.max_kf_distance,
                        args.kf_mode)

            if args.write_avs:
                write_shift_avs(dst_script_path + '.avs', groups,
                                src_audio_path, dst_audio_path)
        else:
            fix_near_borders(events)
            if write_plot:
                plt.plot([x.shift for x in events], label='Borders fixed')

            if args.src_keyframes:
                for e in (x for x in events if x.linked):
                    e.resolve_link()
                snap_groups_to_keyframes(events, chapter_times,
                                         args.max_ts_duration,
                                         args.max_ts_distance, src_keytimes,
                                         dst_keytimes, src_timecodes,
                                         dst_timecodes, args.max_kf_distance,
                                         args.kf_mode)

        for event in events:
            event.apply_shift()

        script.save_to_file(dst_script_path)

        if write_plot:
            plt.plot([
                x.shift + (x._start_shift + x._end_shift) / 2.0 for x in events
            ],
                     label='After correction')
            plt.legend(fontsize=5, frameon=False, fancybox=False)
            plt.savefig(args.plot_path, dpi=300)

    finally:
        if args.cleanup:
            src_demuxer.cleanup()
            dst_demuxer.cleanup()