def meme_add_name_idx(name: str):
    for idx, meme in enumerate(
            training_db.query(MemeCorrectTrain).filter(
                cast(ClauseElement, MemeCorrectTrain.name == name))):
        meme.name_idx = idx  # type:ignore
    for idx, meme in enumerate(
            training_db.query(MemeCorrectTest).filter(
                cast(ClauseElement, MemeCorrectTest.name == name))):
        meme.name_idx = idx  # type:ignore
    for idx, meme in enumerate(
            training_db.query(MemeIncorrectTrain).filter(
                cast(ClauseElement, MemeIncorrectTrain.name == name))):
        meme.name_idx = idx  # type:ignore
    for idx, meme in enumerate(
            training_db.query(MemeIncorrectTest).filter(
                cast(ClauseElement, MemeIncorrectTest.name == name))):
        meme.name_idx = idx  # type:ignore
    training_db.commit()
def download_imgflip_memes(fresh: bool = False):
    if not training_db.query(Template).count():
        raise Exception("No Templates")
    for (name, ) in training_db.query(Template.name):
        Path(MEMES_REPO + name).mkdir(parents=True, exist_ok=True)
    if fresh:
        name_page = training_db.query(Template.name, Template.page).all()
    else:
        names_filled = list(
            name for name in os.listdir(MEMES_REPO)
            if len(list(os.listdir(MEMES_REPO + name))) >= 1000)
        name_page = (training_db.query(Template.name, Template.page).filter(
            cast(ClauseElement, ~Template.name.in_(names_filled))).all())
    with Pool(cpu_count()) as workers:
        _: List[None] = list(
            tqdm(
                workers.imap_unordered(download_imgs_from_page, name_page),
                total=len(name_page),
            ))
Exemple #3
0
 def get_assortment(self):
     image_name = self.names.pop()
     entity = self.correct_entity
     rand = random.randint(0, self.max_name_idx["correct"][image_name])
     clause = and_(
         cast(ClauseElement, entity.name == image_name),
         cast(ClauseElement, entity.name_idx == rand),
     )
     p = cast(Tuple[str],
              training_db.query(entity.path).filter(clause).first())[0]
     transforms = toTensorOnly if random.random() < 0.5 else self.transforms
     return (transforms(Img.open(p)), 1 if image_name == self.name else 0)
def download_blanks() -> None:
    q = training_db.query(Template.name, Template.blank_url).filter(
        cast(ClauseElement, Template.name.in_(MEMES_TO_USE)))
    path = BLANKS_REPO
    shutil.rmtree(path)
    Path(path).mkdir(parents=True, exist_ok=True)  # type: ignore
    with Pool(cpu_count()) as workers:
        results: List[bool] = list(
            tqdm(
                workers.imap_unordered(download_img_from_url_blank,
                                       zip(q, repeat(path))),
                total=q.count(),
            ))
    print(f"{sum(results)}/{len(results)}")
def miss_match() -> None:
    downloaded_blanks = [
        os.path.splitext(filename)[0] for filename in os.listdir(BLANKS_REPO)
    ]
    display_df(
        pd.read_sql(
            cast(
                str,
                training_db.query(Template).filter(
                    cast(ClauseElement,
                         ~Template.name.in_(downloaded_blanks))).statement,
            ),
            training_db.bind,
        ))
def download_names_not_used():
    names = [
        name for name in MEMES_TO_USE if not Path(MEMES_REPO + name).is_dir()
        or len(list(os.listdir(MEMES_REPO + name))) == 0
    ]
    for name in names:
        Path(MEMES_REPO + name).mkdir(parents=True, exist_ok=True)
    name_page = (training_db.query(Template.name, Template.page).filter(
        cast(ClauseElement, Template.name.in_(names))).all())
    with Pool(cpu_count()) as workers:
        _: List[None] = list(
            tqdm(
                workers.imap_unordered(download_imgs_from_page, name_page),
                total=len(name_page),
            ))
Exemple #7
0
 def get_not_meme(self):
     return (
         toTensorOnly(
             Img.open(
                 cast(
                     Tuple[str],
                     training_db.query(self.not_a_meme_entity.path).filter(
                         cast(
                             ClauseElement,
                             self.not_a_meme_entity.name_idx
                             == random.randint(
                                 0, self.max_name_idx["not_a_meme"]),
                         ), ).first(),
                 )[0])),
         0,
     )
 def get_assortment(self):
     name = self.names.pop()
     rand = random.randint(0, self.max_name_idx["correct"][name])
     clause = and_(
         cast(ClauseElement, self.meme_entity.name == name),
         cast(ClauseElement, self.meme_entity.name_idx == rand),
     )
     q = training_db.query(self.meme_entity.path).filter(clause)
     path = cast(Tuple[str], q.first())[0]
     if self.is_training and self.use_transforms == None:
         transforms = toTensorOnly if random.random(
         ) < 0.05 else trainingTransforms
     else:
         if self.use_transforms:
             transforms = trainingTransforms
         else:
             transforms = toTensorOnly
     return (transforms(Img.open(path)), self.name_num[name])
Exemple #9
0
 def get_incorrect(self):
     return (
         toTensorOnly(
             Img.open(
                 cast(
                     Tuple[str],
                     training_db.query(self.incorrect_entity.path).filter(
                         and_(
                             cast(
                                 ClauseElement,
                                 self.incorrect_entity.name == self.name,
                             ),
                             cast(
                                 ClauseElement,
                                 self.incorrect_entity.name_idx ==
                                 random.randint(0, self.my_max_incorrect),
                             ),
                         )).first(),
                 )[0])),
         0,
     )
def build_template_db() -> None:
    print(f"num memes to use - {len(MEMES_TO_USE)}")
    _: Any = training_db.query(Template).delete()
    training_db.commit()
    step_size = 50
    pages = range(0, step_size)
    with Pool(cpu_count() // 2) as workers:
        df = pd.DataFrame.from_records(
            list(
                chain.from_iterable(
                    cast(
                        Iterable[List[Dict[str, str]]],
                        tqdm(
                            workers.imap_unordered(get_template_data, pages),
                            total=step_size,
                        ),
                    )))).drop_duplicates(ignore_index=True)
    df.to_sql("templates",
              training_db.bind,
              if_exists="replace",
              index_label="id")
    print(f"Total Templates - {training_db.query(Template).count()}")
def build_db_from_imgdir():
    _: Any = training_db.query(MemeCorrectTrain).delete()
    _: Any = training_db.query(MemeCorrectTest).delete()
    _: Any = training_db.query(NotAMemeTrain).delete()
    _: Any = training_db.query(NotAMemeTest).delete()
    _: Any = training_db.query(NotATemplateTrain).delete()
    _: Any = training_db.query(NotATemplateTest).delete()
    training_db.commit()
    template_names = list(
        cast(Set[str],
             set.intersection(set(os.listdir(MEMES_REPO)), MEMES_TO_USE)))
    files = list(os.listdir(NOT_MEME_REPO))
    with Pool(cpu_count()) as workers:
        _ = list(
            tqdm(workers.imap_unordered(add_not_meme, enumerate(files)),
                 total=len(files)))
    files = list(os.listdir(NOT_TEMPLATE_REPO))
    with Pool(cpu_count()) as workers:
        _ = list(
            tqdm(
                workers.imap_unordered(add_not_template, enumerate(files)),
                total=len(files),
            ))
    with Pool(cpu_count()) as workers:
        _ = list(
            tqdm(
                workers.imap_unordered(name_imgs_to_db, template_names),
                total=len(template_names),
            ))
    q_list: List[Any] = [
        NotAMemeTrain,
        NotAMemeTest,
        NotATemplateTrain,
        NotATemplateTest,
    ]
    with Pool(cpu_count()) as workers:
        _ = list(
            tqdm(workers.imap_unordered(add_name_idx, q_list),
                 total=len(q_list)))
    with Pool(cpu_count()) as workers:
        _ = list(
            tqdm(
                workers.imap_unordered(meme_add_name_idx, template_names),
                total=len(template_names),
            ))
Exemple #12
0
 def get_correct(self) -> Tuple[Tensor, int]:
     entity = self.correct_entity
     rand = random.randint(0, self.my_max_correct)
     clause = and_(
         cast(ClauseElement, entity.name == self.name),
         cast(ClauseElement, entity.name_idx == rand),
     )
     if self.is_training and self.use_transforms == None:
         transforms = toTensorOnly if random.random(
         ) < 0.15 else trainingTransforms
     else:
         if self.use_transforms:
             transforms = trainingTransforms
         else:
             transforms = toTensorOnly
     return (
         transforms(
             Img.open(
                 cast(
                     Tuple[str],
                     training_db.query(entity.path).filter(clause).first(),
                 )[0])),
         1,
     )
def init_static() -> Static:
    names = [
        name for name, in cast(
            Iterator[str],
            training_db.query(MemeCorrectTrain.name).distinct(
                MemeCorrectTrain.name),
        )
    ]
    names_to_shuffle = deepcopy(names)
    name_num = {name: idx for idx, name in enumerate(names)}
    num_name = {str(v): k for k, v in name_num.items()}
    max_name_idx: TestTrainToMax = {
        "train": {
            "not_a_meme":
            cast(
                int,
                training_db.query(func.max(NotAMemeTrain.name_idx)).scalar(),
            ),
            "not_a_template":
            cast(
                int,
                training_db.query(func.max(
                    NotATemplateTrain.name_idx)).scalar(),
            ),
            "correct": {
                name: cast(
                    int,
                    training_db.query(func.max(
                        MemeCorrectTrain.name_idx)).filter(
                            cast(ClauseElement,
                                 MemeCorrectTrain.name == name)).scalar(),
                )
                for name in names
                if name not in ["not_a_meme", "not_a_template"]
            },
            "incorrect": {
                name: cast(
                    int,
                    training_db.query(func.max(
                        MemeIncorrectTrain.name_idx)).filter(
                            cast(ClauseElement,
                                 MemeIncorrectTrain.name == name)).scalar(),
                )
                for name in names
                if name not in ["not_a_meme", "not_a_template"]
            },
        },
        "test": {
            "not_a_meme":
            cast(
                int,
                training_db.query(func.max(NotAMemeTest.name_idx)).scalar(),
            ),
            "not_a_template":
            cast(
                int,
                training_db.query(func.max(
                    NotATemplateTest.name_idx)).scalar(),
            ),
            "correct": {
                name: cast(
                    int,
                    training_db.query(func.max(
                        MemeCorrectTest.name_idx)).filter(
                            cast(ClauseElement,
                                 MemeCorrectTest.name == name)).scalar(),
                )
                for name in names
                if name not in ["not_a_meme", "not_a_template"]
            },
            "incorrect": {
                name: cast(
                    int,
                    training_db.query(func.max(
                        MemeIncorrectTest.name_idx)).filter(
                            cast(ClauseElement,
                                 MemeIncorrectTest.name == name)).scalar(),
                )
                for name in names
                if name not in ["not_a_meme", "not_a_template"]
            },
        },
    }
    static: Static = {
        "names": names,
        "names_to_shuffle": names_to_shuffle,
        "name_num": name_num,
        "num_name": num_name,
        "folder_count": {
            "not_a_meme": len(os.listdir(NOT_MEME_REPO)),
            "not_a_template": len(os.listdir(NOT_TEMPLATE_REPO)),
            **{
                name: len(os.listdir(MEMES_REPO + name))
                for name in os.listdir(MEMES_REPO)
            },
        },
        "max_name_idx": max_name_idx,
    }
    return static
def add_name_idx(entity: Any):
    for idx, meme in enumerate(training_db.query(entity)):
        meme.name_idx = idx
    training_db.commit()
 def get_not_template(self):
     rand = random.randint(0, self.max_name_idx["not_a_template"])
     clause = cast(ClauseElement, self.not_a_template.name_idx == rand)
     q = training_db.query(self.not_a_template.path).filter(clause)
     path = cast(Tuple[str], q.first())[0]
     return (toTensorOnly(Img.open(path)), self.name_size)