Beispiel #1
0
    def from_data(
        cls,
        # Your raw dataset. Supports DataFrames, Hugging Face Datasets, as well as file paths
        # to .csv, .xlsx, .xls, and .jsonl files
        data: Union[pd.DataFrame, Path, str, List[Dict]],
        # The name or path of the pretrained model you want to fine-tune
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        # The language modeling strategy (or objective)
        lm_strategy_cls: BaseLMStrategy = CausalLMStrategy,
        # The attribute in your dataset that contains your raw text
        text_attr: str = "text",
        # A function that will split your Dataset into a training and validation set
        # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters
        dblock_splitter: Optional[Callable] = None,
        # Any kwargs to pass to your `DataLoaders`
        dl_kwargs={},
        # Any kwargs to pass to your task specific `Blearner`
        learner_kwargs={},
    ):
        # if we get a path/str then we're loading something like a .csv file
        if isinstance(data, Path) or isinstance(data, str):
            content_type = mimetypes.guess_type(data)[0]
            if content_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
                data = pd.read_excel(data)
            elif content_type == "text/csv":
                data = pd.read_csv(data)
            elif content_type == "application/json":
                data = pd.read_json(data, orient="records")
            else:
                raise ValueError("'data' must be a .xlsx, .xls, .csv, or .jsonl file")

            data = pd.read_csv(data)

        # infer our datablock splitter if None
        if dblock_splitter is None:
            dblock_splitter = ColSplitter() if hasattr(data, "is_valid") else RandomSplitter()

        # get our hf objects
        lm_type = lm_strategy_cls.get_lm_type()
        model_cls = cls.get_model_cls(lm_type=lm_type)
        hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name_or_path, model_cls=model_cls)

        # not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
        if hf_tokenizer.pad_token is None:
            hf_tokenizer.add_special_tokens({"pad_token": "<pad>"})
            hf_config.pad_token_id = hf_tokenizer.get_vocab()["<pad>"]
            hf_model.resize_token_embeddings(len(hf_tokenizer))

        # define DataBlock and DataLoaders
        bbtfm = LMBatchTokenizeTransform(hf_arch, hf_config, hf_tokenizer, hf_model, lm_strategy_cls=lm_strategy_cls)

        input_return_type = CausalLMTextInput if (lm_type == LMType.CAUSAL) else MLMTextInput
        blocks = (TextBlock(batch_tokenize_tfm=bbtfm, input_return_type=input_return_type), noop)

        dblock = DataBlock(blocks=blocks, get_x=ItemGetter(text_attr), splitter=dblock_splitter)
        dls = dblock.dataloaders(data, **dl_kwargs.copy())

        # return BLearner instance with default metrics (optional)
        learner_kwargs["metrics"] = learner_kwargs.pop("metrics", [perplexity])
        return cls(dls, hf_model, **learner_kwargs.copy())
def create_inference_model(checkpoint: str = None, model='resnet34', path='.'):
    if model == 'resnet34':
        model = resnet34
    elif model == 'resnet18':
        model = resnet18
    elif model == 'mobilenet_v2':
        model = mobilenet_v2

    # Create an inference model instance and load the requested checkpoint
    inf_db = DataBlock(blocks=[ImageBlock, CategoryBlock],
                       get_x=ItemGetter(0),
                       get_y=ItemGetter(1))

    dummy_img = PILImage.create(np.zeros((415, 415, 3), dtype=np.uint8))
    source = [(dummy_img, False), (dummy_img, True)]

    inf_dls = inf_db.dataloaders(source)

    if model == mobilenet_v2:
        learner = cnn_learner(inf_dls,
                              model,
                              cut=-1,
                              splitter=_mobilenetv2_split,
                              pretrained=False)
    else:
        learner = cnn_learner(inf_dls, model, pretrained=False)
    learner.path = Path(path)

    if checkpoint is not None:
        learner.load(checkpoint, with_opt=False, device='cpu')

    return learner
Beispiel #3
0
def fake_dataloaders(a=2, b=3, bs=16, n=10):
    def get_data(n):
        x = torch.randn(bs * n, 1)
        return torch.cat((x, a * x + b + 0.1 * torch.randn(bs * n, 1)), 1)

    ds = get_data(n)
    dblock = DataBlock()
    return dblock.dataloaders(ds)
def train_classifier(train_df, lm_dls, config, arch, args, label_list):
    # Train the classifier using the previously fine-tuned LM
    if len(label_list) > 1:
        block_category = MultiCategoryBlock()
        label_delim = args.label_delim
    else:
        block_category = CategoryBlock()
        label_delim = None

    blocks = (TextBlock.from_df(args.text_col,
                                is_lm=False,
                                seq_len=lm_dls.seq_len,
                                vocab=lm_dls.vocab,
                                tok=lm_dls.tok), block_category)

    clf_datablock = DataBlock(blocks=blocks,
                              get_x=ColReader("text"),
                              get_y=ColReader(LABEL_COL_NAME,
                                              label_delim=label_delim),
                              splitter=RandomSplitter(valid_pct=VAL_SIZE,
                                                      seed=RANDOM_STATE))

    clf_dataloaders = clf_datablock.dataloaders(train_df, bs=args.batch_size)

    config_cls = update_classif_config(config)

    learner_clf = text_classifier_learner(clf_dataloaders,
                                          arch,
                                          path=clf_dataloaders.path,
                                          drop_mult=args.drop_mult,
                                          config=config_cls,
                                          pretrained=False).to_fp32()
    learner_clf.load_encoder(ENCODER_FILE_NAME)

    lr = find_best_lr(learner_clf)
    learner_clf = fit_with_gradual_unfreezing(learner_clf, args.epochs, lr,
                                              args)
    learner_clf.export(args.model_filename)
    return learner_clf, args.model_filename
Beispiel #5
0
def get_data(inputs,
             df_all=None,
             batch_tfms=None,
             item_tfms=None,
             verbose=False,
             autoencoder=False):

    if df_all is None:
        df_all = get_dataframe(inputs, verbose)

    if item_tfms is None:
        tfms = [Resize(128, method="squish")]
    else:
        tfms = item_tfms

    if autoencoder:
        blocks = (ImageBlock, ImageBlock)
        y_reader = ColReader("cam/image_array")
    else:
        blocks = (ImageBlock, RegressionBlock(n_out=2))
        y_reader = ColReader(['user/angle', 'user/throttle'])

    pascal = DataBlock(blocks=blocks,
                       splitter=RandomSplitter(),
                       get_x=ColReader("cam/image_array"),
                       get_y=y_reader,
                       item_tfms=tfms,
                       batch_tfms=batch_tfms,
                       n_inp=1)

    dls = pascal.dataloaders(df_all)

    if verbose:
        dls.show_batch()
        dls.one_batch()[0].shape

    return dls
Beispiel #6
0
    def from_data(
        cls,
        # Your raw dataset. Supports DataFrames, Hugging Face Datasets, as well as file paths
        # to .csv, .xlsx, .xls, and .jsonl files
        data: Union[pd.DataFrame, Path, str, List[Dict]],
        # The name or path of the pretrained model you want to fine-tune
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        # The attribute in your dataset that contains a list of your tokens
        tokens_attr: List[str] = "tokens",
        # The attribute in your dataset that contains the entity labels for each token in your raw text
        token_labels_attr: List[str] = "token_labels",
        # The unique entity labels (or vocab) available in your dataset
        labels: Optional[List[str]] = None,
        # A function that will split your Dataset into a training and validation set
        # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters
        dblock_splitter: Optional[Callable] = None,
        # Any kwargs to pass to your `DataLoaders`
        dl_kwargs: dict = {},
        # Any kwargs to pass to your task specific `Blearner`
        learner_kwargs: dict = {},
    ):
        # if we get a path/str then we're loading something like a .csv file
        if isinstance(data, Path) or isinstance(data, str):
            content_type = mimetypes.guess_type(data)[0]
            if content_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
                data = pd.read_excel(data)
            elif content_type == "text/csv":
                data = pd.read_csv(data)
            elif content_type == "application/json":
                data = pd.read_json(data, orient="records")
            else:
                raise ValueError(
                    "'data' must be a .xlsx, .xls, .csv, or .jsonl file")

            data = pd.read_csv(data)

        # we need to tell transformer how many labels/classes to expect
        if labels is None:
            if isinstance(data, pd.DataFrame):
                labels = sorted(
                    list(
                        set([
                            lbls
                            for sublist in data[token_labels_attr].tolist()
                            for lbls in sublist
                        ])))
            else:
                labels = sorted(
                    list(set([item[token_labels_attr] for item in data])))

        # infer our datablock splitter if None
        if dblock_splitter is None:
            dblock_splitter = ColSplitter() if hasattr(
                data, "is_valid") else RandomSplitter()

        # get our hf objects
        n_labels = len(labels)
        hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(
            pretrained_model_name_or_path,
            model_cls=cls.get_model_cls(),
            config_kwargs={"num_labels": n_labels})

        # not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
        if hf_tokenizer.pad_token is None:
            hf_tokenizer.add_special_tokens({"pad_token": "<pad>"})
            hf_config.pad_token_id = hf_tokenizer.get_vocab()["<pad>"]
            hf_model.resize_token_embeddings(len(hf_tokenizer))

        batch_tok_tfm = TokenClassBatchTokenizeTransform(
            hf_arch, hf_config, hf_tokenizer, hf_model)
        blocks = (
            TextBlock(batch_tokenize_tfm=batch_tok_tfm,
                      input_return_type=TokenClassTextInput),
            TokenCategoryBlock(vocab=labels),
        )

        dblock = DataBlock(blocks=blocks,
                           get_x=ItemGetter(tokens_attr),
                           get_y=ItemGetter(token_labels_attr),
                           splitter=dblock_splitter)
        dls = dblock.dataloaders(data, **dl_kwargs.copy())

        # return BLearner instance
        return cls(dls, hf_model, **learner_kwargs.copy())
Beispiel #7
0
    def from_data(
        cls,
        # Your raw dataset. Supports DataFrames, Hugging Face Datasets, as well as file paths
        # to .csv, .xlsx, .xls, and .jsonl files
        data: Union[pd.DataFrame, Path, str, List[Dict]],
        # The name or path of the pretrained model you want to fine-tune
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        # The attribute in your dataset that contains your raw text
        text_attr: str = "text",
        # The attribute in your dataset that contains your labels/targets
        label_attr: str = "label",
        # The number of labels/classes your model should predict
        n_labels: Optional[int] = None,
        # A function that will split your Dataset into a training and validation set
        # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters
        dblock_splitter: Optional[Callable] = None,
        # Any kwargs to pass to your `DataLoaders`
        dl_kwargs: dict = {},
        # Any kwargs to pass to your task specific `Blearner`
        learner_kwargs: dict = {},
    ):
        # if we get a path/str then we're loading something like a .csv file
        if isinstance(data, Path) or isinstance(data, str):
            content_type = mimetypes.guess_type(data)[0]
            if content_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
                data = pd.read_excel(data)
            elif content_type == "text/csv":
                data = pd.read_csv(data)
            elif content_type == "application/json":
                data = pd.read_json(data, orient="records")
            else:
                raise ValueError("'data' must be a .xlsx, .xls, .csv, or .jsonl file")

            data = pd.read_csv(data)

        # we need to tell transformer how many labels/classes to expect
        if n_labels is None:
            if isinstance(data, pd.DataFrame):
                n_labels = len(label_attr) if (is_listy(label_attr)) else len(data[label_attr].unique())
            else:
                n_labels = len(label_attr) if (is_listy(label_attr)) else len(set([item[label_attr] for item in data]))

        # infer our datablock splitter if None
        if dblock_splitter is None:
            dblock_splitter = ColSplitter() if hasattr(data, "is_valid") else RandomSplitter()

        # get our hf objects
        hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(
            pretrained_model_name_or_path, model_cls=cls.get_model_cls(), config_kwargs={"num_labels": n_labels}
        )

        # not all architectures include a native pad_token (e.g., gpt2, ctrl, etc...), so we add one here
        if hf_tokenizer.pad_token is None:
            hf_tokenizer.add_special_tokens({"pad_token": "<pad>"})
            hf_config.pad_token_id = hf_tokenizer.get_vocab()["<pad>"]
            hf_model.resize_token_embeddings(len(hf_tokenizer))

        # infer loss function and default metrics
        if is_listy(label_attr):
            trg_block = MultiCategoryBlock(encoded=True, vocab=label_attr)
            learner_kwargs["metrics"] = learner_kwargs.get("metrics", [F1ScoreMulti(), accuracy_multi])
        else:
            trg_block = CategoryBlock
            learner_kwargs["metrics"] = learner_kwargs.get("metrics", [F1Score(), accuracy])

        # build our DataBlock and DataLoaders
        blocks = (TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), trg_block)
        dblock = DataBlock(
            blocks=blocks, get_x=partial(cls._get_x, attr=text_attr), get_y=partial(cls._get_y, attr=label_attr), splitter=dblock_splitter
        )

        dls = dblock.dataloaders(data, **dl_kwargs.copy())

        # return BLearner instance
        return cls(dls, hf_model, **learner_kwargs.copy())
Beispiel #8
0
    def from_data(
        cls,
        # Your raw dataset. Supports DataFrames, Hugging Face Datasets, as well as file paths
        # to .csv, .xlsx, .xls, and .jsonl files
        data: Union[pd.DataFrame, Path, str, List[Dict]],
        # The name or path of the pretrained model you want to fine-tune
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        # The language of your source (inputs)
        src_lang_name: str = "English",
        # The attribute/column of your source language texts
        src_lang_attr: str = "src_lang",
        # The attribute/column of your target language texts
        trg_lang_name: str = "English",
        # The attribute/column of your target language texts (this is what you want to predict)
        trg_lang_attr: str = "trg_lang",
        # The max length of your raw text to consider for summarization
        max_length: Union[int, str] = None,
        # The max length of your targets (sumamrized) text
        max_target_length: Union[int, str] = None,
        # A function that will split your Dataset into a training and validation set
        # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters
        dblock_splitter: Optional[Callable] = None,
        # Any additional keyword arguments applied during tokenization
        hf_tok_kwargs: dict = {},
        # If you want to override your Blurr transform's `text_gen_kwargs`, do that here
        text_gen_kwargs: dict = {},
        # Any kwargs to pass to your `DataLoaders`
        dl_kwargs: dict = {},
        # Any kwargs to pass to your task specific `Blearner`
        learner_kwargs: dict = {},
    ):
        # if we get a path/str then we're loading something like a .csv file
        if isinstance(data, Path) or isinstance(data, str):
            content_type = mimetypes.guess_type(data)[0]
            if content_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
                data = pd.read_excel(data)
            elif content_type == "text/csv":
                data = pd.read_csv(data)
            elif content_type == "application/json":
                data = pd.read_json(data, orient="records")
            else:
                raise ValueError(
                    "'data' must be a .xlsx, .xls, .csv, or .jsonl file")

            data = pd.read_csv(data)

        # infer our datablock splitter if None
        if dblock_splitter is None:
            dblock_splitter = ColSplitter() if hasattr(
                data, "is_valid") else RandomSplitter()

        # we need to find the architecture to ensure "mbart" specific tokenizer kwargs are included
        model_cls = cls.get_model_cls()
        model = model_cls.from_pretrained(pretrained_model_name_or_path)
        hf_arch = model.__module__.split(".")[2]

        if hf_arch == "mbart":
            hf_tok_kwargs = {
                **{
                    "src_lang": "en_XX",
                    "tgt_lang": "en_XX"
                },
                **hf_tok_kwargs
            }

        # get our hf objects
        hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(
            pretrained_model_name_or_path,
            model_cls=model_cls,
            tokenizer_kwargs=hf_tok_kwargs)

        # update text generation kwargs
        text_gen_kwargs = {
            **text_gen_kwargs,
            **default_text_gen_kwargs(hf_config, hf_model, task="translation")
        }

        # not all "translation" parameters are for the model.generate method ... remove them here
        generate_func_args = list(
            inspect.signature(hf_model.generate).parameters.keys())
        for k in text_gen_kwargs.copy():
            if k not in generate_func_args:
                del text_gen_kwargs[k]

        # update our text generation kwargs for mbart
        if hf_arch == "mbart":
            text_gen_kwargs = {
                **{
                    "decoder_start_token_id": "en_XX"
                },
                **text_gen_kwargs
            }

        # build dblock, dls, and default metrics (optional)
        get_x = Pipeline(funcs=[ColReader(src_lang_attr)])
        get_y = ItemGetter(trg_lang_attr)

        if hf_arch == "t5":
            get_x.add(
                partial(cls._add_t5_prefix,
                        src_lang_name=src_lang_name,
                        trg_lang_name=trg_lang_name))

        batch_tokenize_tfm = Seq2SeqBatchTokenizeTransform(
            hf_arch,
            hf_config,
            hf_tokenizer,
            hf_model,
            max_length=max_length,
            max_target_length=max_target_length,
            text_gen_kwargs=text_gen_kwargs,
        )

        blocks = (Seq2SeqTextBlock(batch_tokenize_tfm=batch_tokenize_tfm),
                  noop)
        dblock = DataBlock(blocks=blocks,
                           get_x=get_x,
                           get_y=get_y,
                           splitter=dblock_splitter)

        dls = dblock.dataloaders(data, **dl_kwargs.copy())

        # return BLearner instance
        learner_kwargs["splitter"] = learner_kwargs.pop(
            "splitter", partial(blurr_seq2seq_splitter, arch=hf_arch))
        learner_kwargs["loss_func"] = learner_kwargs.pop(
            "loss_func", PreCalculatedCrossEntropyLoss())
        return cls(dls, hf_model, **learner_kwargs.copy())
Beispiel #9
0
    def from_data(
        cls,
        # Your raw dataset. Supports DataFrames, Hugging Face Datasets, as well as file paths
        # to .csv, .xlsx, .xls, and .jsonl files
        data: Union[pd.DataFrame, Path, str, List[Dict]],
        # The name or path of the pretrained model you want to fine-tune
        pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
        # The maximum sequence length to constrain our data
        max_seq_len: int = None,
        # The unique identifier in the dataset. If not specified and "return_overflowing_tokens": True, an "_id" attribute
        # will be added to your dataset with its value a unique, sequential integer, assigned to each record
        id_attr: Optional[str] = None,
        # The attribute in your dataset that contains the context (where the answer is included) (default: 'context')
        context_attr: str = "context",
        # The attribute in your dataset that contains the question being asked (default: 'question')
        question_attr: str = "question",
        # The attribute in your dataset that contains the tokenized answer start (default: 'tok_answer_start')
        tok_ans_start_attr: str = "ans_start_token_idx",
        # The attribute in your dataset that contains the tokenized answer end(default: 'tok_answer_end')
        tok_ans_end_attr: str = "ans_end_token_idx",
        # A function that will split your Dataset into a training and validation set
        # See [here](https://docs.fast.ai/data.transforms.html#Split) for a list of fast.ai splitters
        dblock_splitter: Optional[Callable] = None,
        # Any kwargs to pass to your `DataLoaders`
        dl_kwargs={},
        # Any kwargs to pass to your task specific `Blearner`
        learner_kwargs={},
    ):
        # if we get a path/str then we're loading something like a .csv file
        if isinstance(data, Path) or isinstance(data, str):
            content_type = mimetypes.guess_type(data)[0]
            if content_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet":
                data = pd.read_excel(data)
            elif content_type == "text/csv":
                data = pd.read_csv(data)
            elif content_type == "application/json":
                data = pd.read_json(data, orient="records")
            else:
                raise ValueError("'data' must be a .xlsx, .xls, .csv, or .jsonl file")

            data = pd.read_csv(data)

        # infer our datablock splitter if None
        if dblock_splitter is None:
            dblock_splitter = ColSplitter() if hasattr(data, "is_valid") else RandomSplitter()

        hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(pretrained_model_name_or_path, model_cls=cls.get_model_cls())

        # potentially used by our preprocess_func, it is the basis for our CategoryBlock vocab
        if max_seq_len is None:
            max_seq_len = hf_config.get("max_position_embeddings", 128)

        # bits required by our "before_batch_tfm" and DataBlock
        vocab = list(range(max_seq_len))
        padding_side = hf_tokenizer.padding_side

        # define DataBlock and DataLoaders
        before_batch_tfm = QABatchTokenizeTransform(hf_arch, hf_config, hf_tokenizer, hf_model, max_length=max_seq_len)
        blocks = (
            TextBlock(batch_tokenize_tfm=before_batch_tfm, input_return_type=QATextInput),
            CategoryBlock(vocab=vocab),
            CategoryBlock(vocab=vocab),
        )
        dblock = DataBlock(
            blocks=blocks,
            get_x=partial(cls._get_x, qst=question_attr, ctx=context_attr, id=id_attr, padding_side=padding_side),
            get_y=[ItemGetter(tok_ans_start_attr), ItemGetter(tok_ans_end_attr)],
            splitter=dblock_splitter,
            n_inp=1,
        )

        dls = dblock.dataloaders(data, **dl_kwargs.copy())

        # return BLearner instance
        return cls(dls, hf_model, **learner_kwargs.copy())
Beispiel #10
0
    # print(learn.lr_find())
    # learn.fine_tune(5, 1e-2, cbd=TensorBoardCallback(PATH_TENSORBOARD, trace_model=True))

    #%%
    # Prepare IMDB data
    path = untar_data(URLs.IMDB)
    bs = 32

    # Fine-tune pretrained language model (based on wikitext) to the IMDB corpus
    get_imdb = partial(get_text_files, folders=["train", "test", "unsup"])
    dls_lm = DataBlock(blocks=TextBlock.from_folder(path,
                                                    is_lm=True,
                                                    n_workers=4),
                       get_items=get_imdb,
                       splitter=RandomSplitter(0.1))
    dls_lm = dls_lm.dataloaders(path, path=path, bs=bs, seq_len=80)
    print(dls_lm.show_batch(max_n=3))

    # #%%
    # learn = language_model_learner(
    #     dls_lm, AWD_LSTM, drop_mult=0.3,
    #     metrics=[accuracy, Perplexity()]).to_fp16()
    # learn.lr_find()
    # print(learn.model)
    # learn.fit_one_cycle(1, 2e-2, moms=(0.8,0.7,0.8), cbs=cbs)
    # learn.save("1epoch")

    # #%%
    learn = language_model_learner(dls_lm,
                                   AWD_LSTM,
                                   drop_mult=0.3,
Beispiel #11
0
    train="training",
    valid="testing",
    device=device,
)

# %%
mnist_block = DataBlock(
    blocks=(ImageBlock(cls=PILImageBW), CategoryBlock),
    get_items=get_image_files,
    splitter=GrandparentSplitter(train_name="training", valid_name="testing"),
    get_y=parent_label,
    # batch_tfms=aug_transforms(mult=1.2, do_flip=False)
)

# %%
mnist_dls = mnist_block.dataloaders(mnist_dir)

# %%
mnist_dls.train.one_batch()[0].shape

# %%
mnist_dls.show_batch()

# %% [markdown]
#
# ## Multi-Layer Perceptron (Feed-Forward Network)

# %%
mlp_model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 128),
Beispiel #12
0
import matplotlib.pyplot as plt
from fastai.distributed import *
# In[]:
products = DataBlock(blocks=(ImageBlock, CategoryBlock),
                     get_items=get_image_files,
                     splitter=GrandparentSplitter(train_name="train",
                                                  valid_name="validation"),
                     get_y=parent_label,
                     item_tfms=Resize(192))

products = products.new(item_tfms=RandomResizedCrop(168, min_scale=0.8),
                        batch_tfms=aug_transforms())

project_path = Path("/home/yaro/Workspace/fastai/")
dataset_path = project_path.joinpath("for_test")
dls = products.dataloaders(dataset_path)

gpu = None
if torch.cuda.is_available():
    if gpu is not None: torch.cuda.set_device(gpu)
    n_gpu = torch.cuda.device_count()
else:
    n_gpu = None

learn = cnn_learner(dls, resnet18, metrics=error_rate).to_fp16()

# The context manager way of dp/ddp, both can handle single GPU base case.
if gpu is None and n_gpu is not None:
    ctx = learn.parallel_ctx
    with partial(ctx, gpu)():
        print(