Ejemplo n.º 1
0
def get_datasets(dataset, dataset_dir, bptt, bs, lang, max_vocab, ds_pct, lm_type):
    tmp_dir = dataset_dir / 'tmp'
    tmp_dir.mkdir(exist_ok=True)
    vocab_file = tmp_dir / f'vocab_{lang}.pkl'
    if not (tmp_dir / f'{TRN}_{lang}_ids.npy').exists():
        print('Reading the data...')
        toks, lbls = read_clas_data(dataset_dir, dataset, lang)
        # create the vocabulary
        counter = Counter(word for example in toks[TRN]+toks[TST]+toks[VAL] for word in example)
        itos = [word for word, count in counter.most_common(n=max_vocab)]
        itos.insert(0, PAD)
        itos.insert(0, UNK)
        vocab = Vocab(itos)
        stoi = vocab.stoi
        with open(vocab_file, 'wb') as f:
            pickle.dump(vocab, f)

        ids = {}
        for split in [TRN, VAL, TST]:
            ids[split] = np.array([([stoi.get(w, stoi[UNK]) for w in s])
                                   for s in toks[split]])
            np.save(tmp_dir / f'{split}_{lang}_ids.npy', ids[split])
            np.save(tmp_dir / f'{split}_{lang}_lbl.npy', lbls[split])
    else:
        print('Loading the pickled data...')
        ids, lbls = {}, {}
        for split in [TRN, VAL, TST]:
            ids[split] = np.load(tmp_dir / f'{split}_{lang}_ids.npy')
            lbls[split] = np.load(tmp_dir / f'{split}_{lang}_lbl.npy')
        with open(vocab_file, 'rb') as f:
            vocab = pickle.load(f)
    print(f'Train size: {len(ids[TRN])}. Valid size: {len(ids[VAL])}. '
          f'Test size: {len(ids[TST])}.')
    if ds_pct < 1.0:
        print(f"Making the dataset smaller {ds_pct}")
    for split in [TRN, VAL, TST]:
        ids[split] = np.array([np.array(e, dtype=np.int) for e in ids[split]])
        #print([lbl for lbl in lbls[split] if not int(lbl) in [0,1,2]])          # debug by ak
        #print(f'First 10 lbls[split] labels: {lbls[split][:11]}') 
        if split == TRN: print("processing TRN labels ... ")
        lbls[split] = np.array([np.array(e, dtype=np.int) for e in lbls[split]])
        if split == TRN: print("Info: Passed the train labels lbls[split] to np.array sucessfully .....")
    data_lm = TextLMDataBunch.from_ids(path=tmp_dir, vocab=vocab, train_ids=np.concatenate([ids[TRN],ids[TST]]),
                                       valid_ids=ids[VAL], bs=bs, bptt=bptt, lm_type=lm_type)
    #  TODO TextClasDataBunch allows tst_ids as input, but not tst_lbls?
    data_clas = TextClasDataBunch.from_ids(
        path=tmp_dir, vocab=vocab, train_ids=ids[TRN], valid_ids=ids[VAL],
        train_lbls=lbls[TRN], valid_lbls=lbls[VAL], bs=bs, classes={l:l for l in lbls[TRN]})

    print(f"Sizes of train_ds {len(data_clas.train_ds)}, valid_ds {len(data_clas.valid_ds)}")
    return data_clas, data_lm
Ejemplo n.º 2
0
 def create_lm_learner(self, data_lm, dps=None, **kwargs):
     learner = super().create_lm_learner(data_lm, dps, **kwargs)
     if self.parallel_data_path is not None:
         src_trn_df = pd.read_csv(self.parallel_data_path / self.src_lang /
                                  'train.csv',
                                  header=None)
         tgt_trn_df = pd.read_csv(self.parallel_data_path / self.tgt_lang /
                                  'train.csv',
                                  header=None)
         bs = self.parallel_data_bs
         data_src = TextClasDataBunch.from_df(path=self.cache_dir,
                                              train_df=src_trn_df,
                                              lm_type=self.lm_type,
                                              bs=bs)
         data_tgt = TextClasDataBunch.from_df(path=self.cache_dir,
                                              train_df=tgt_trn_df,
                                              lm_type=self.lm_type,
                                              bs=bs)
         learner.callback_fns = [
             partial(ParallelAlignmentCallback,
                     data_src=data_src,
                     data_tgt=data_tgt)
         ] + learner.callback_fns
Ejemplo n.º 3
0
 def _fit_class(self, df_train, df_val, data_lm):
     n_data = min(len(df_train), len(df_val))
     # Classifier model data
     data_class = TextClasDataBunch.from_df(
         path="",
         train_df=df_train,
         valid_df=df_val,
         vocab=data_lm.train_ds.vocab,
         bs=self.batch_size if self.batch_size < n_data else n_data // 2,
     )
     # train the learner object
     class_learner = text_classifier_learner(
         data_class, self.arch, drop_mult=self.dropout_lm
     )
     class_learner.load_encoder(self.path_lm.name)
     class_learner.fit_one_cycle(1, self.lr_class)
     class_learner.export(self.path_class)
Ejemplo n.º 4
0
def prepare_clas_dataset(input_path,
                         output_dir=None,
                         valid_split=0.2,
                         tokenizer_lang="xx",
                         min_freq=2,
                         seed=42):
    """
    Reads a CSV file with texts and labels, splits it into training and validation sets,
    tokenizes texts and saves datasets for fine-tuning and for classification.

    Attributes:
        input_path (str): Path to CSV file with texts in the first and labels in second column.
        output_dir (str): Folder where to store the processed dataset.
        valid_split (float): A fraction of data used for validation.
        tokenizer_lang (str): Language setting for tokenizer.
        min_freq (int): Minimal number of occurrences of a word to be conidered for adding to
            vocabulary.
        seed (int): Random seed that determines the training-validation split.
    """
    input_path = Path(input_path)
    output_dir = Path(output_dir or input_path.parent)
    output_dir.mkdir(parents=True, exist_ok=True)

    train_df, valid_df = csv_to_train_valid_df(input_path, valid_split, seed)

    data_finetune_lm = TextLMDataBunch.from_df(
        output_dir,
        train_df,
        valid_df,
        tokenizer=Tokenizer(lang=tokenizer_lang),
        text_cols=0,
        min_freq=min_freq)
    data_clas = TextClasDataBunch.from_df(
        output_dir,
        train_df,
        valid_df,
        tokenizer=Tokenizer(lang=tokenizer_lang),
        text_cols=0,
        label_cols=1,
        vocab=data_finetune_lm.train_ds.vocab,
        bs=32,
        min_freq=min_freq)

    data_finetune_lm.save("data_finetune_lm.pkl")
    data_clas.save("data_clas.pkl")
    def validate_cls(self, save_name='cls_last', bs=40):
        args = self.tokenzier_to_fastai_args(
            trn_data_loading_func=lambda: trn_df[1], add_moses=True)
        data_clas, data_lm = self.load_cls_data(bs,
                                                use_test_for_validation=True)
        data_eval = [
            TextClasDataBunch.from_csv(path=Path(tgt_path),
                                       csv_name=self.csv_name,
                                       **args)
            for tgt_path in self.target_paths
        ]

        for data in [data_clas] + data_eval:
            learn = self.create_cls_learner(data, drop_mult=0.1)
            learn.load(save_name)
            print(
                f"Loss and accuracy using ({save_name}) for dataset at {data.path}:",
                learn.validate())
Ejemplo n.º 6
0
    def load_cls_data_old_for_xnli(self, bs):
        tmp_dir = self.cache_dir
        tmp_dir.mkdir(exist_ok=True)
        vocab_file = tmp_dir / f'vocab_{self.lang}.pkl'
        if not (tmp_dir / f'{TRN}_{self.lang}_ids.npy').exists():
            print('Reading the data...')
            toks, lbls = read_clas_data(self.dataset_dir,
                                        self.dataset_dir.name, self.lang)
            # create the vocabulary
            counter = Counter(word
                              for example in toks[TRN] + toks[TST] + toks[VAL]
                              for word in example)
            itos = [
                word for word, count in counter.most_common(n=self.max_vocab)
            ]
            itos.insert(0, PAD)
            itos.insert(0, UNK)
            vocab = Vocab(itos)
            stoi = vocab.stoi
            with open(vocab_file, 'wb') as f:
                pickle.dump(vocab, f)
            ids = {}
            for split in [TRN, VAL, TST]:
                ids[split] = np.array([([stoi.get(w, stoi[UNK]) for w in s])
                                       for s in toks[split]])
                np.save(tmp_dir / f'{split}_{self.lang}_ids.npy', ids[split])
                np.save(tmp_dir / f'{split}_{self.lang}_lbl.npy', lbls[split])
        else:
            print('Loading the pickled data...')
            ids, lbls = {}, {}
            for split in [TRN, VAL, TST]:
                ids[split] = np.load(tmp_dir / f'{split}_{self.lang}_ids.npy')
                lbls[split] = np.load(tmp_dir / f'{split}_{self.lang}_lbl.npy')
            with open(vocab_file, 'rb') as f:
                vocab = pickle.load(f)
        print(f'Train size: {len(ids[TRN])}. Valid size: {len(ids[VAL])}. '
              f'Test size: {len(ids[TST])}.')
        for split in [TRN, VAL, TST]:
            ids[split] = np.array(
                [np.array(e, dtype=np.int) for e in ids[split]])
            lbls[split] = np.array(
                [np.array(e, dtype=np.int) for e in lbls[split]])
        data_lm = TextLMDataBunch.from_ids(path=tmp_dir,
                                           vocab=vocab,
                                           train_ids=np.concatenate(
                                               [ids[TRN], ids[TST]]),
                                           valid_ids=ids[VAL],
                                           bs=bs,
                                           bptt=self.bptt,
                                           lm_type=self.lm_type)
        #  TODO TextClasDataBunch allows tst_ids as input, but not tst_lbls?
        data_clas = TextClasDataBunch.from_ids(
            path=tmp_dir,
            vocab=vocab,
            train_ids=ids[TRN],
            valid_ids=ids[VAL],
            train_lbls=lbls[TRN],
            valid_lbls=lbls[VAL],
            bs=bs,
            classes={l: l
                     for l in lbls[TRN]})

        print(
            f"Sizes of train_ds {len(data_clas.train_ds)}, valid_ds {len(data_clas.valid_ds)}"
        )
        return data_clas, data_lm
Ejemplo n.º 7
0
    def load_cls_data_imdb(self,
                           bs,
                           force=False,
                           use_test_for_validation=False):
        trn_df = pd.read_csv(self.dataset_path / 'train.csv', header=None)
        tst_df = pd.read_csv(self.dataset_path / 'test.csv', header=None)
        unsp_df = pd.read_csv(self.dataset_path / 'unsup.csv', header=None)

        lm_trn_df = pd.concat([unsp_df, trn_df, tst_df])
        val_len = max(int(len(lm_trn_df) * 0.1), 2)
        lm_trn_df = lm_trn_df[val_len:]
        lm_val_df = lm_trn_df[:val_len]

        if use_test_for_validation:
            val_df = tst_df
            cls_cache = 'notst'
        else:
            val_len = max(int(len(trn_df) * 0.1), 2)
            trn_len = len(trn_df) - val_len
            trn_df, val_df = trn_df[:trn_len], trn_df[trn_len:]
            cls_cache = '.'

        if self.tokenizer is Tokenizers.SUBWORD:
            args = get_sentencepiece(self.dataset_path,
                                     self.dataset_path / 'train.csv',
                                     self.name,
                                     vocab_size=self.max_vocab,
                                     pre_rules=[],
                                     post_rules=[])
            if self.tokenizer is Tokenizers.SUBWORD:
                args = get_sentencepiece(self.dataset_path,
                                         self.dataset_path / 'train.csv',
                                         self.name,
                                         vocab_size=self.max_vocab,
                                         pre_rules=[],
                                         post_rules=[])
        elif self.tokenizer is Tokenizers.MOSES:
            args = dict(tokenizer=Tokenizer(tok_func=MosesTokenizerFunc,
                                            lang='en',
                                            pre_rules=[],
                                            post_rules=[]))
        elif self.tokenizer is Tokenizers.MOSES_FA:
            args = dict(
                tokenizer=Tokenizer(tok_func=MosesTokenizerFunc,
                                    lang='en'))  # use default pre/post rules
        elif self.tokenizer is Tokenizers.FASTAI:
            args = dict()
        else:
            raise ValueError(
                f"self.tokenizer has wrong value {self.tokenizer}, Allowed values are taken from {Tokenizers}"
            )

        try:
            if force: raise FileNotFoundError("Forcing reloading of caches")
            data_lm = TextLMDataBunch.load(self.cache_dir,
                                           'lm',
                                           lm_type=self.lm_type,
                                           bs=bs)
            print(
                f"Tokenized data loaded, lm.trn {len(data_lm.train_ds)}, lm.val {len(data_lm.valid_ds)}"
            )
        except FileNotFoundError:
            print(f"Running tokenization...")
            data_lm = TextLMDataBunch.from_df(path=self.cache_dir,
                                              train_df=lm_trn_df,
                                              valid_df=lm_val_df,
                                              max_vocab=self.max_vocab,
                                              bs=bs,
                                              lm_type=self.lm_type,
                                              **args)
            print(
                f"Saving tokenized: cls.trn {len(data_lm.train_ds)}, cls.val {len(data_lm.valid_ds)}"
            )
            data_lm.save('lm')

        try:
            if force: raise FileNotFoundError("Forcing reloading of caches")
            data_cls = TextClasDataBunch.load(self.cache_dir, cls_cache, bs=bs)
            print(
                f"Tokenized data loaded, cls.trn {len(data_cls.train_ds)}, cls.val {len(data_cls.valid_ds)}"
            )
        except FileNotFoundError:
            args[
                'vocab'] = data_lm.vocab  # make sure we use the same vocab for classifcation
            print(f"Running tokenization...")
            data_cls = TextClasDataBunch.from_df(path=self.cache_dir,
                                                 train_df=trn_df,
                                                 valid_df=val_df,
                                                 test_df=tst_df,
                                                 max_vocab=self.max_vocab,
                                                 bs=bs,
                                                 **args)
            print(
                f"Saving tokenized: cls.trn {len(data_cls.train_ds)}, cls.val {len(data_cls.valid_ds)}"
            )
            data_cls.save(cls_cache)
        print('Size of vocabulary:', len(data_lm.vocab.itos))
        print('First 20 words in vocab:', data_lm.vocab.itos[:20])
        return data_cls, data_lm
    def load_cls_data(self,
                      bs,
                      force=False,
                      use_test_for_validation=False,
                      **kwargs):
        args = self.tokenzier_to_fastai_args(
            trn_data_loading_func=lambda: trn_df[1], add_moses=True)
        src_path = self.dataset_path
        csv_name = self.csv_name
        tgt_paths = [Path(tgt_path) for tgt_path in self.target_paths]
        mixed_csv = pd.read_csv(src_path / csv_name, header=None)
        for tgt_path in tgt_paths:
            mixed_csv = pd.concat(
                [mixed_csv,
                 pd.read_csv(tgt_path / csv_name, header=None)])

        xcvs_name = ('x_' + csv_name)
        mixed_csv.to_csv(src_path / xcvs_name, header=None, index=False)

        try:
            if force: raise FileNotFoundError("Forcing reloading of caches")
            data_lm = TextLMDataBunch.load(src_path,
                                           'xlm',
                                           lm_type=self.lm_type,
                                           bs=bs)
            print(
                f"Tokenized data loaded, xlm.trn {len(data_lm.train_ds)}, xlm.val {len(data_lm.valid_ds)}"
            )
        except FileNotFoundError:
            print(f"Running tokenization...")
            data_lm = TextLMDataBunch.from_csv(path=src_path,
                                               csv_name=xcvs_name,
                                               bs=bs,
                                               lm_type=self.lm_type,
                                               **kwargs,
                                               **args)
            print(
                f"Saving tokenized: cls.trn {len(data_lm.train_ds)}, cls.val {len(data_lm.valid_ds)}"
            )
            data_lm.save('xlm')

        try:
            if force: raise FileNotFoundError("Forcing reloading of caches")
            data_cls = TextClasDataBunch.load(src_path, 'cls', bs=bs)
            print(
                f"Tokenized data loaded, cls.trn {len(data_cls.train_ds)}, cls.val {len(data_cls.valid_ds)}"
            )
        except FileNotFoundError:
            args[
                'vocab'] = data_lm.vocab  # make sure we use the same vocab for classifcation
            print(f"Running tokenization...")
            data_cls = TextClasDataBunch.from_csv(path=src_path,
                                                  csv_name=csv_name,
                                                  bs=bs,
                                                  **kwargs,
                                                  **args)

            print(
                f"Saving tokenized: cls.trn {len(data_cls.train_ds)}, cls.val {len(data_cls.valid_ds)}"
            )
            data_cls.save('cls')

        print('Size of vocabulary:', len(data_lm.vocab.itos))
        print('First 20 words in vocab:', data_lm.vocab.itos[:20])
        return data_cls, data_lm
Ejemplo n.º 9
0
SAMPLES_PER_CLASS = 12500

print('loading data')
texts = []
target = []

for class_index, classname in enumerate(CLASS_NAMES):

    for n, line in enumerate(open(DATA_FOLDER+classname+'.txt')):

        texts.append(preprocess_string(line,False))
        target.append(class_index)

        if n > SAMPLES_PER_CLASS:
            break

df = DataFrame({'label':target,'text':texts})
df_train, df_val = train_test_split(df, stratify = df['label'], test_size = 0.4, random_state = 12)

data_lm = TextLMDataBunch.from_df(train_df = df_train, valid_df = df_val, path = "")
data_clas = TextClasDataBunch.from_df(path = "", train_df = df_train, valid_df = df_val, vocab=data_lm.train_ds.vocab, bs=32)

learn = language_model_learner(data_lm, pretrained_model=URLs.WT103, drop_mult=0.7)
learn.fit_one_cycle(1, 1e-2)
learn.save_encoder('ft_enc')

learn = text_classifier_learner(data_clas, drop_mult=0.7)
learn.load_encoder('ft_enc')
learn.fit_one_cycle(1, 1e-2)

    lm_learner = language_model_learner(
        text_lm,
        arch=AWD_LSTM,
        drop_mult=0.2,
    )

    lm_learner.lr_find()
    lm_learner.recorder.plot(suggestion=True)

    lm_learner.fit_one_cycle(1, lm_learner.recorder.min_grad_lr)

    lm_learner.save_encoder(model)

    text_clas = TextClasDataBunch.from_df(
        train_df=train_df,
        valid_df=valid_df,
        vocab=text_lm.train_ds.vocab,
        path="",
    )

    clf = text_classifier_learner(
        text_clas,
        arch=AWD_LSTM,
        drop_mult=0.2,
    )
    clf.load_encoder(model)

    clf.lr_find()
    clf.recorder.plot(suggestion=True)

    clf.fit_one_cycle(1, clf.recorder.min_grad_lr)
Ejemplo n.º 11
0
def new_train_clas(data_dir,
                   lang='en',
                   cuda_id=0,
                   pretrain_name='wt103',
                   model_dir='models',
                   qrnn=False,
                   fine_tune=True,
                   max_vocab=30000,
                   bs=20,
                   bptt=70,
                   name='imdb-clas',
                   dataset='imdb',
                   ds_pct=1.0):
    """
    :param data_dir: The path to the `data` directory
    :param lang: the language unicode
    :param cuda_id: The id of the GPU. Uses GPU 0 by default or no GPU when
                    run on CPU.
    :param pretrain_name: name of the pretrained model
    :param model_dir: The path to the directory where the pretrained model is saved
    :param qrrn: Use a QRNN. Requires installing cupy.
    :param fine_tune: Fine-tune the pretrained language model
    :param max_vocab: The maximum size of the vocabulary.
    :param bs: The batch size.
    :param bptt: The back-propagation-through-time sequence length.
    :param name: The name used for both the model and the vocabulary.
    :param dataset: The dataset used for evaluation. Currently only IMDb and
                    XNLI are implemented. Assumes dataset is located in `data`
                    folder and that name of folder is the same as dataset name.
    """
    results = {}
    if not torch.cuda.is_available():
        print('CUDA not available. Setting device=-1.')
        cuda_id = -1
    torch.cuda.set_device(cuda_id)

    print(f'Dataset: {dataset}. Language: {lang}.')
    assert dataset in DATASETS, f'Error: {dataset} processing is not implemented.'
    assert (dataset == 'imdb' and lang == 'en') or not dataset == 'imdb',\
        'Error: IMDb is only available in English.'

    data_dir = Path(data_dir)
    assert data_dir.name == 'data',\
        f'Error: Name of data directory should be data, not {data_dir.name}.'
    dataset_dir = data_dir / dataset
    model_dir = Path(model_dir)

    if qrnn:
        print('Using QRNNs...')
    model_name = 'qrnn' if qrnn else 'lstm'
    lm_name = f'{model_name}_{pretrain_name}'
    pretrained_fname = (lm_name, f'itos_{pretrain_name}')

    ensure_paths_exists(data_dir, dataset_dir, model_dir,
                        model_dir / f"{pretrained_fname[0]}.pth",
                        model_dir / f"{pretrained_fname[1]}.pkl")

    tmp_dir = dataset_dir / 'tmp'
    tmp_dir.mkdir(exist_ok=True)
    vocab_file = tmp_dir / f'vocab_{lang}.pkl'

    if not (tmp_dir / f'{TRN}_{lang}_ids.npy').exists():
        print('Reading the data...')
        toks, lbls = read_clas_data(dataset_dir, dataset, lang)

        # create the vocabulary
        counter = Counter(word for example in toks[TRN] for word in example)
        itos = [word for word, count in counter.most_common(n=max_vocab)]
        itos.insert(0, PAD)
        itos.insert(0, UNK)
        vocab = Vocab(itos)
        stoi = vocab.stoi
        with open(vocab_file, 'wb') as f:
            pickle.dump(vocab, f)

        ids = {}
        for split in [TRN, VAL, TST]:
            ids[split] = np.array([([stoi.get(w, stoi[UNK]) for w in s])
                                   for s in toks[split]])
            np.save(tmp_dir / f'{split}_{lang}_ids.npy', ids[split])
            np.save(tmp_dir / f'{split}_{lang}_lbl.npy', lbls[split])
    else:
        print('Loading the pickled data...')
        ids, lbls = {}, {}
        for split in [TRN, VAL, TST]:
            ids[split] = np.load(tmp_dir / f'{split}_{lang}_ids.npy')
            lbls[split] = np.load(tmp_dir / f'{split}_{lang}_lbl.npy')
        with open(vocab_file, 'rb') as f:
            vocab = pickle.load(f)

    print(f'Train size: {len(ids[TRN])}. Valid size: {len(ids[VAL])}. '
          f'Test size: {len(ids[TST])}.')

    if ds_pct < 1.0:
        print(f"Makeing the dataset smaller {ds_pct}")
        for split in [TRN, VAL, TST]:
            ids[split] = ids[split][:int(len(ids[split]) * ds_pct)]

    data_lm = TextLMDataBunch.from_ids(path=tmp_dir,
                                       vocab=vocab,
                                       train_ids=ids[TRN],
                                       valid_ids=ids[VAL],
                                       bs=bs,
                                       bptt=bptt)

    # TODO TextClasDataBunch allows tst_ids as input, but not tst_lbls?
    data_clas = TextClasDataBunch.from_ids(path=tmp_dir,
                                           vocab=vocab,
                                           train_ids=ids[TRN],
                                           valid_ids=ids[VAL],
                                           train_lbls=lbls[TRN],
                                           valid_lbls=lbls[VAL],
                                           bs=bs)

    if qrnn:
        emb_sz, nh, nl = 400, 1550, 3
    else:
        emb_sz, nh, nl = 400, 1150, 3
    learn = language_model_learner(data_lm,
                                   bptt=bptt,
                                   emb_sz=emb_sz,
                                   nh=nh,
                                   nl=nl,
                                   qrnn=qrnn,
                                   pad_token=PAD_TOKEN_ID,
                                   pretrained_fnames=pretrained_fname,
                                   path=model_dir.parent,
                                   model_dir=model_dir.name)
    lm_enc_finetuned = f"{lm_name}_{dataset}_enc"
    if fine_tune and not (model_dir / f"lm_enc_finetuned.pth").exists():
        print('Fine-tuning the language model...')
        learn.unfreeze()
        learn.fit(2, slice(1e-4, 1e-2))

        # save encoder
        learn.save_encoder(lm_enc_finetuned)

    print("Starting classifier training")
    learn = text_classifier_learner(data_clas,
                                    bptt=bptt,
                                    pad_token=PAD_TOKEN_ID,
                                    path=model_dir.parent,
                                    model_dir=model_dir.name,
                                    qrnn=qrnn,
                                    emb_sz=emb_sz,
                                    nh=nh,
                                    nl=nl)

    learn.load_encoder(lm_enc_finetuned)

    learn.fit_one_cycle(1, 2e-2, moms=(0.8, 0.7), wd=1e-7)

    learn.freeze_to(-2)
    learn.fit_one_cycle(1,
                        slice(1e-2 / (2.6**4), 1e-2),
                        moms=(0.8, 0.7),
                        wd=1e-7)

    learn.freeze_to(-3)
    learn.fit_one_cycle(1,
                        slice(5e-3 / (2.6**4), 5e-3),
                        moms=(0.8, 0.7),
                        wd=1e-7)

    learn.unfreeze()
    learn.fit_one_cycle(2,
                        slice(1e-3 / (2.6**4), 1e-3),
                        moms=(0.8, 0.7),
                        wd=1e-7)
    results['accuracy'] = learn.validate()[1]
    print(f"Saving models at {learn.path / learn.model_dir}")
    learn.save(f'{model_name}_{name}')
    return results