예제 #1
0
def main():
    speaker_jon = lambda text: re.match(r">> (Jon:|JON:|jon:) (.*)$",
                                        text)
    speaker_stephen = lambda text: re.match(
        r">> (Stephen:|STEPHEN:|stephen:) (.*)$", text)
    speaker = r">> (Stephen|STEPHEN|stephen)*(Jon|JON|jon)*:?(.*)$"
    vtt = webvtt.WebVTT()
    vtt_files = glob.iglob(f"{DATA_DIR}/*.vtt")
    txt_file = os.path.join(DATA_DIR, "../", "captions.txt")
    with open(txt_file, "w+") as f:
        for vtt_file in vtt_files:
            text = []
            stephen_speaking = True
            for caption in vtt.read(vtt_file):
                t = caption.text
                if speaker_jon(t):
                    stephen_speaking = False
                    continue
                if speaker_stephen(t):
                    stephen_speaking = True
                    t = "".join(x[1] for x in re.findall(
                        r">> (Stephen:|STEPHEN:|stephen:) (.*)$", t))
                    # print(text)
                if not stephen_speaking:
                    continue
                # print(text)
                t = re.sub(speaker, r"\3", t, flags=re.M)
                text += [t.replace("\n", " ")]
                # print(text)
            f.writelines([" ".join(text)] + ["\n", "<|endoftext|>", "\n"])
예제 #2
0
def test_webvtt():
	vtt = webvtt.WebVTT()
	vtt.save('test.vtt')
	caption = webvtt.Caption('00:00:00.500','00:00:07.000',['Caption line 1\nCaption line 2'])
	vtt.captions.append(caption)

	with open('test.vtt','w') as fd:
		vtt.write(fd)
예제 #3
0
def write2vtt(dict_list,output_file):
	vtt = webvtt.WebVTT()
	vtt.save(output_file)
	for i in range(len(dict_list)):
		caption = webvtt.Caption(float2string(dict_list[i]["start"]),float2string(dict_list[i]["end"]), [dict_list[i]["content"]])
		vtt.captions.append(caption)

	with open(output_file,'w') as fd:
		vtt.write(fd)

	return
def write_vtt(df, filename):
    """Output to VTT format"""
    logging.info("Writing VTT")

    # Initialize vtt
    vtt = webvtt.WebVTT()

    # Iterate through df
    for index, row in df.iterrows():

        # If the segment has 80 or less characters
        if len(row["comment"]) <= 80:

            caption = webvtt.Caption(
                start=row["start_time"] + ".000",
                end=row["end_time"] + ".000",
                text=row["comment"],
            )

        # If the segment has more than 80 characters, use lines
        else:

            lines = []
            text = row["comment"]

            while len(text) > 80:
                text = text.lstrip()
                last_space = text[:80].rindex(" ")
                lines.append(text[:last_space])
                text = text[last_space:]

            caption = webvtt.Caption(
                row["start_time"] + ".000", row["end_time"] + ".000", lines
            )

        if row["speaker"]:
            caption.identifier = row["speaker"]

        vtt.captions.append(caption)

    vtt.save(filename)
    logging.info(f"VTT saved to {filename}")
예제 #5
0
 def test_save_no_filename(self):
     vtt = webvtt.WebVTT()
     self.assertRaises(webvtt.errors.MissingFilenameError, vtt.save)
예제 #6
0
    def download(self,
                 filepath,
                 img_filepath=None,
                 annotation_format=entities.ViewAnnotationOptions.MASK,
                 height=None,
                 width=None,
                 thickness=1,
                 with_text=False):
        """
            Save annotations to file

        :param filepath: path to save annotation
        :param img_filepath: img file path - needed for img_mask
        :param annotation_format:
        :param height:
        :param width:
        :param thickness:
        :param with_text:
        :return:
        """
        dir_name, ex = os.path.splitext(filepath)

        if annotation_format == entities.ViewAnnotationOptions.JSON:
            if not ex:
                filepath = '{}/{}.json'.format(
                    dir_name,
                    os.path.splitext(self.item.name)[0])
            _json = {'_id': self.item.id, 'filename': self.item.filename}
            annotations = list()
            for ann in self.annotations:
                annotations.append(ann.to_json())
            _json['annotations'] = annotations
            with open(filepath, 'w+') as f:
                json.dump(_json, f, indent=2)
        elif annotation_format in [
                entities.ViewAnnotationOptions.MASK,
                entities.ViewAnnotationOptions.INSTANCE,
                entities.ViewAnnotationOptions.ANNOTATION_ON_IMAGE
        ]:
            if not ex:
                filepath = '{}/{}.png'.format(
                    dir_name,
                    os.path.splitext(self.item.name)[0])
            image = None
            if annotation_format == entities.ViewAnnotationOptions.ANNOTATION_ON_IMAGE:
                annotation_format = entities.ViewAnnotationOptions.MASK
                image = np.asarray(Image.open(img_filepath))
            mask = self.show(image=image,
                             thickness=thickness,
                             with_text=with_text,
                             height=height,
                             width=width,
                             annotation_format=annotation_format)
            img = Image.fromarray(mask.astype(np.uint8))
            img.save(filepath)
        elif annotation_format == entities.ViewAnnotationOptions.VTT:
            if not ex:
                filepath = '{}/{}.vtt'.format(
                    dir_name,
                    os.path.splitext(self.item.name)[0])
            annotations_dict = [{
                'start_time': annotation.start_time,
                'end_time': annotation.end_time,
                'text': annotation.coordinates['text']
            } for annotation in self.annotations
                                if annotation.type in ['subtitle']]
            sorted_by_start_time = sorted(annotations_dict,
                                          key=lambda i: i['start_time'])
            vtt = webvtt.WebVTT()
            for ann in sorted_by_start_time:
                s = str(datetime.timedelta(seconds=ann['start_time']))
                if len(s.split('.')) == 1:
                    s += '.000'
                e = str(datetime.timedelta(seconds=ann['end_time']))
                if len(e.split('.')) == 1:
                    e += '.000'
                caption = webvtt.Caption('{}'.format(s), '{}'.format(e),
                                         '{}'.format(ann['text']))
                vtt.captions.append(caption)
            vtt.save(filepath)
        else:
            raise PlatformException(
                error="400",
                message="Unknown annotation option: {}".format(
                    annotation_format))
        return filepath
예제 #7
0
 def test_captions_attribute(self):
     self.assertListEqual([], webvtt.WebVTT().captions)
예제 #8
0
 def test_webvtt_total_length_no_parser(self):
     self.assertEqual(webvtt.WebVTT().total_length, 0)