コード例 #1
0
def tweet_transformer(lang, n_gram, voc=None):
    """
    Get tweet transformer
    :param lang:
    :param n_gram:
    :return:
    """
    if voc is None:
        token_to_ix = dict()
    else:
        token_to_ix = voc
    # end if
    if n_gram == 'c1':
        return transforms.Compose([
            ltransforms.RemoveRegex(
                regex=r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'),
            ltransforms.ToLower(),
            ltransforms.Character(),
            ltransforms.ToIndex(start_ix=1, token_to_ix=token_to_ix),
            ltransforms.ToLength(length=settings.min_length),
            ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram][lang] - 1)
        ])
    else:
        return transforms.Compose([
            ltransforms.RemoveRegex(
                regex=r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'),
            ltransforms.ToLower(),
            ltransforms.Character2Gram(),
            ltransforms.ToIndex(start_ix=1, token_to_ix=token_to_ix),
            ltransforms.ToLength(length=settings.min_length),
            ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram][lang] - 1)
        ])
コード例 #2
0
ファイル: functions.py プロジェクト: pan-webis-de/schaetti18c
def text_transformer_cnn(window_size, n_gram, token_to_ix):
    """
    Get text transformer for CNNSCD
    :param window_size:
    :param n_gram:
    :return:
    """
    if n_gram == 'c1':
        return ltransforms.Compose([
            ltransforms.ToLower(),
            ltransforms.Character(),
            ltransforms.ToIndex(start_ix=1, token_to_ix=token_to_ix),
            ltransforms.ToLength(length=window_size),
            ltransforms.Reshape((-1)),
            ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram])
        ])
    else:
        return ltransforms.Compose([
            ltransforms.ToLower(),
            ltransforms.Character2Gram(),
            ltransforms.ToIndex(start_ix=1, token_to_ix=token_to_ix),
            ltransforms.ToLength(length=window_size),
            ltransforms.Reshape((-1)),
            ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram])
        ])
コード例 #3
0
ファイル: functions.py プロジェクト: pan-webis-de/schaetti18c
def text_transformer(n_gram, window_size):
    """
    Get tweet transformer
    :param lang:
    :param n_gram:
    :return:
    """
    if n_gram == 'c1':
        return transforms.Compose([
            ltransforms.ToLower(),
            ltransforms.Character(),
            ltransforms.ToIndex(start_ix=0),
            ltransforms.ToNGram(n=window_size, overlapse=True),
            ltransforms.Reshape((-1, window_size)),
            ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram] - 1)
        ])
    else:
        return transforms.Compose([
            ltransforms.ToLower(),
            ltransforms.Character2Gram(),
            ltransforms.ToIndex(start_ix=0),
            ltransforms.ToNGram(n=window_size, overlapse=True),
            ltransforms.Reshape((-1, window_size)),
            ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram] - 1)
        ])
model.load_state_dict(torch.load(open(args.model, 'rb')))
if args.cuda:
    model.cuda()
# end if
voc = torch.load(open(args.voc, 'rb'))

# Eval
model.eval()

if args.n_gram == 'c1':
    transforms = ltransforms.Compose([
        ltransforms.ToLower(),
        ltransforms.Character(),
        ltransforms.ToIndex(start_ix=1, token_to_ix=voc),
        ltransforms.ToLength(length=window_size),
        ltransforms.MaxIndex(max_id=settings.voc_sizes[args.n_gram])
    ])
else:
    transforms = ltransforms.Compose([
        ltransforms.ToLower(),
        ltransforms.Character2Gram(),
        ltransforms.ToIndex(start_ix=1, token_to_ix=voc),
        ltransforms.ToLength(length=window_size),
        ltransforms.MaxIndex(max_id=settings.voc_sizes[args.n_gram])
    ])
# end if

# Validation losses
validation_total = 0
validation_success = np.zeros((n_levels, n_thresholds))
n_files = 0.0
input_sparsity = 0.1
w_sparsity = 0.1
input_scaling = 0.5
n_test = 10
n_samples = 2
n_epoch = 100
text_length = 20

# Argument
args = tools.functions.argument_parser_training_model()

# Transforms
transform = transforms.Compose([
    transforms.Character(),
    transforms.ToIndex(start_ix=0),
    transforms.MaxIndex(max_id=83),
    transforms.ToNGram(n=text_length, overlapse=True),
    transforms.Reshape((-1, 20))
])

# Author identification training dataset
dataset_train = dataset.AuthorIdentificationDataset(root="./data/", download=True, transform=transform, problem=1, lang='en')

# Author identification test dataset
dataset_valid = dataset.AuthorIdentificationDataset(root="./data/", download=True, transform=transform, problem=1, train=False, lang='en')

# Cross validation
dataloader_train = torch.utils.data.DataLoader(torchlanguage.utils.CrossValidation(dataset_train), batch_size=1, shuffle=True)
dataloader_valid = torch.utils.data.DataLoader(torchlanguage.utils.CrossValidation(dataset_valid, train=False), batch_size=1, shuffle=True)

# Author to idx