def tweet_transformer(lang, n_gram, voc=None): """ Get tweet transformer :param lang: :param n_gram: :return: """ if voc is None: token_to_ix = dict() else: token_to_ix = voc # end if if n_gram == 'c1': return transforms.Compose([ ltransforms.RemoveRegex( regex=r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'), ltransforms.ToLower(), ltransforms.Character(), ltransforms.ToIndex(start_ix=1, token_to_ix=token_to_ix), ltransforms.ToLength(length=settings.min_length), ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram][lang] - 1) ]) else: return transforms.Compose([ ltransforms.RemoveRegex( regex=r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'), ltransforms.ToLower(), ltransforms.Character2Gram(), ltransforms.ToIndex(start_ix=1, token_to_ix=token_to_ix), ltransforms.ToLength(length=settings.min_length), ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram][lang] - 1) ])
def text_transformer_cnn(window_size, n_gram, token_to_ix): """ Get text transformer for CNNSCD :param window_size: :param n_gram: :return: """ if n_gram == 'c1': return ltransforms.Compose([ ltransforms.ToLower(), ltransforms.Character(), ltransforms.ToIndex(start_ix=1, token_to_ix=token_to_ix), ltransforms.ToLength(length=window_size), ltransforms.Reshape((-1)), ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram]) ]) else: return ltransforms.Compose([ ltransforms.ToLower(), ltransforms.Character2Gram(), ltransforms.ToIndex(start_ix=1, token_to_ix=token_to_ix), ltransforms.ToLength(length=window_size), ltransforms.Reshape((-1)), ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram]) ])
def text_transformer(n_gram, window_size): """ Get tweet transformer :param lang: :param n_gram: :return: """ if n_gram == 'c1': return transforms.Compose([ ltransforms.ToLower(), ltransforms.Character(), ltransforms.ToIndex(start_ix=0), ltransforms.ToNGram(n=window_size, overlapse=True), ltransforms.Reshape((-1, window_size)), ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram] - 1) ]) else: return transforms.Compose([ ltransforms.ToLower(), ltransforms.Character2Gram(), ltransforms.ToIndex(start_ix=0), ltransforms.ToNGram(n=window_size, overlapse=True), ltransforms.Reshape((-1, window_size)), ltransforms.MaxIndex(max_id=settings.voc_sizes[n_gram] - 1) ])
model.load_state_dict(torch.load(open(args.model, 'rb'))) if args.cuda: model.cuda() # end if voc = torch.load(open(args.voc, 'rb')) # Eval model.eval() if args.n_gram == 'c1': transforms = ltransforms.Compose([ ltransforms.ToLower(), ltransforms.Character(), ltransforms.ToIndex(start_ix=1, token_to_ix=voc), ltransforms.ToLength(length=window_size), ltransforms.MaxIndex(max_id=settings.voc_sizes[args.n_gram]) ]) else: transforms = ltransforms.Compose([ ltransforms.ToLower(), ltransforms.Character2Gram(), ltransforms.ToIndex(start_ix=1, token_to_ix=voc), ltransforms.ToLength(length=window_size), ltransforms.MaxIndex(max_id=settings.voc_sizes[args.n_gram]) ]) # end if # Validation losses validation_total = 0 validation_success = np.zeros((n_levels, n_thresholds)) n_files = 0.0
input_sparsity = 0.1 w_sparsity = 0.1 input_scaling = 0.5 n_test = 10 n_samples = 2 n_epoch = 100 text_length = 20 # Argument args = tools.functions.argument_parser_training_model() # Transforms transform = transforms.Compose([ transforms.Character(), transforms.ToIndex(start_ix=0), transforms.MaxIndex(max_id=83), transforms.ToNGram(n=text_length, overlapse=True), transforms.Reshape((-1, 20)) ]) # Author identification training dataset dataset_train = dataset.AuthorIdentificationDataset(root="./data/", download=True, transform=transform, problem=1, lang='en') # Author identification test dataset dataset_valid = dataset.AuthorIdentificationDataset(root="./data/", download=True, transform=transform, problem=1, train=False, lang='en') # Cross validation dataloader_train = torch.utils.data.DataLoader(torchlanguage.utils.CrossValidation(dataset_train), batch_size=1, shuffle=True) dataloader_valid = torch.utils.data.DataLoader(torchlanguage.utils.CrossValidation(dataset_valid, train=False), batch_size=1, shuffle=True) # Author to idx