import pandas as pd
import spacy
import torch
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

from config.data import (
    DATASET_FOLDER,
    PROCESSED_DATASET,
    PROCESSED_DATASET_FOLDER,
    RAW_DATASET,
)
from config.root import LOGGING_FORMAT, LOGGING_LEVEL, SEED, seed_all

nlp = spacy.load("en")
seed_all(SEED)

# Initialize logger for this file
logger = logging.getLogger(__name__)
logging.basicConfig(level=LOGGING_LEVEL, format=LOGGING_FORMAT)


class PreProcessDataset:
    """
    Class to preprocess dataset takes input location as input
    otherwise will use configuration location
    """
    def __init__(self, location):
        if location:
            self.dataset_location = location
        else:
Example #2
0
        default="RNNHiddenClassifier",
        choices=["RNNHiddenClassifier"],
        help="select the classifier to train on",
    )

    parser.add_argument(
        "-lhd",
        "--linear-hidden-dim",
        default=LINEAR_HIDDEN_DIM,
        help="Freeze Embeddings of Model",
        type=int,
    )

    args = parser.parse_args()

    seed_all(args.seed)
    logger.debug(args)
    logger.debug("Custom seed set with: {}".format(args.seed))

    logger.info("Loading Dataset")

    dataset = GrammarDasetAnswerKey.get_iterators(args.batch_size)

    logger.info("Dataset Loaded Successfully")

    if args.model_location:
        model = torch.load(args.model_location)
    else:
        model = initialize_new_model(
            args.model,
            dataset,
import torch.nn as nn
import torch.nn.functional as F
import torchtext.data as data
import tqdm
from torchtext.datasets import TranslationDataset
from torchtext.data import Field, BucketIterator


from config.data import DATA_FOLDER, DATA_FOLDER_PROCESSED, DATASETS, SQUAD_NAME
from config.root import LOGGING_FORMAT, LOGGING_LEVEL, seed_all, device
from config.hyperparameters import VANILLA_SEQ2SEQ

from utils import word_tokenizer

# TODO: Move this to main menu to seed in the starting of application
seed_all()

# Initialize logger for this file
logger = logging.getLogger(__name__)
logging.basicConfig(level=LOGGING_LEVEL, format=LOGGING_FORMAT)


FILE_PATH = os.path.join(DATA_FOLDER, DATA_FOLDER_PROCESSED)


def load_dataset(
    dataset_name="SQUAD",
    tokenizer=word_tokenizer,
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,