def init_boto(self): if self._s3 is not None: return self._s3 boto3 = util.get_module( "boto3", required= "s3:// references requires the boto3 library, run pip install wandb[aws]", ) self._s3 = boto3.session.Session().resource( "s3", endpoint_url=os.getenv("AWS_S3_ENDPOINT_URL"), region_name=os.getenv("AWS_REGION"), ) self._botocore = util.get_module("botocore") return self._s3
def _prepare_video(self, V): """This logic was mostly taken from tensorboardX""" np = util.get_module( "numpy", required= 'wandb.Video requires numpy when passing raw data. To get it, run "pip install numpy".' ) if V.ndim < 4: raise ValueError( "Video must be atleast 4 dimensions: time, channels, height, width" ) if V.ndim == 4: V = V.reshape(1, *V.shape) b, t, c, h, w = V.shape if V.dtype == np.uint8: V = np.float32(V) / 255. def is_power2(num): return num != 0 and ((num & (num - 1)) == 0) # pad to nearest power of 2, all at once if not is_power2(V.shape[0]): len_addition = int(2**V.shape[0].bit_length() - V.shape[0]) V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0) n_rows = 2**((b.bit_length() - 1) // 2) n_cols = V.shape[0] // n_rows V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w)) V = np.transpose(V, axes=(2, 0, 4, 1, 5, 3)) V = np.reshape(V, newshape=(t, n_rows * h, n_cols * w, c)) return V
def history(self, samples=500, pandas=True, stream="default"): """Return history metrics for a run Args: samples (int, optional): The number of samples to return pandas (bool, optional): Return a pandas dataframe stream (str, optional): "default" for metrics, "system" for machine metrics """ node = "history" if stream == "default" else "events" query = gql(''' query Run($project: String!, $entity: String!, $name: String!, $samples: Int!) { project(name: $project, entityName: $entity) { run(name: $name) { %s(samples: $samples) } } } ''' % node) response = self._exec(query, samples=samples) lines = [json.loads(line) for line in response['project']['run'][node]] if pandas: pandas = util.get_module("pandas") if pandas: lines = pandas.DataFrame.from_records(lines) else: print("Unable to load pandas, call history with pandas=False") return lines
def transform(audio_list, out_dir, key, step): if len(audio_list) > Audio.MAX_AUDIO_COUNT: logging.warn( "The maximum number of audio files to store per step is %i." % Audio.MAX_AUDIO_COUNT) sf = util.get_module( "soundfile", required= "wandb.Audio requires the soundfile package. To get it, run: pip install soundfile" ) base_path = os.path.join(out_dir, "media", "audio") util.mkdir_exists_ok(base_path) for i, audio in enumerate(audio_list[:Audio.MAX_AUDIO_COUNT]): sf.write( os.path.join(base_path, "{}_{}_{}.wav".format(key, step, i)), audio.audio_data, audio.sample_rate) meta = { "_type": "audio", "count": min(len(audio_list), Audio.MAX_AUDIO_COUNT) } sample_rates = Audio.sample_rates(audio_list[:Audio.MAX_AUDIO_COUNT]) if sample_rates: meta["sampleRates"] = sample_rates durations = Audio.durations(audio_list[:Audio.MAX_AUDIO_COUNT]) if durations: meta["durations"] = durations captions = Audio.captions(audio_list[:Audio.MAX_AUDIO_COUNT]) if captions: meta["captions"] = captions return meta
def __init__(self, data_or_path): super(Bokeh, self).__init__() bokeh = util.get_module("bokeh", required=True) if isinstance(data_or_path, str) and os.path.exists(data_or_path): with open(data_or_path, "r") as file: b_json = json.load(file) self.b_obj = bokeh.document.Document.from_json(b_json) self._set_file(data_or_path, is_tmp=False, extension=".bokeh.json") elif isinstance(data_or_path, bokeh.model.Model): _data = bokeh.document.Document() _data.add_root(data_or_path) # serialize/deserialize pairing followed by sorting attributes ensures # that the file's shas are equivalent in subsequent calls self.b_obj = bokeh.document.Document.from_json(_data.to_json()) b_json = self.b_obj.to_json() if "references" in b_json["roots"]: b_json["roots"]["references"].sort(key=lambda x: x["id"]) tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + ".bokeh.json") util.json_dump_safer(b_json, codecs.open(tmp_path, "w", encoding="utf-8")) self._set_file(tmp_path, is_tmp=True, extension=".bokeh.json") elif not isinstance(data_or_path, bokeh.document.Document): raise TypeError( "Bokeh constructor accepts Bokeh document/model or path to Bokeh json file" )
def nest(thing): """Use tensorflows nest function if available, otherwise just wrap object in an array""" tfutil = util.get_module('tensorflow.python.util') if tfutil: return tfutil.nest.flatten(thing) else: return [thing]
def image_segmentation_dataframe(x, y_true, y_pred, labels=None, example_ids=None, class_colors=None): np = util.get_module('numpy', required='dataframes require numpy') y_pred = np.array(y_pred) if y_pred[0].shape[-1] == 1: return image_segmentation_binary_dataframe(x, y_true, y_pred, example_ids=example_ids) else: return image_segmentation_multiclass_dataframe(x, y_true, y_pred, labels=labels, example_ids=example_ids, class_colors=class_colors)
def __init__(self, sequence=None, np_histogram=None, num_bins=64): """Accepts a sequence to be converted into a histogram or np_histogram can be set to a tuple of (values, bins_edges) as np.histogram returns i.e. wandb.log({"histogram": wandb.Histogram( np_histogram=np.histogram(data))}) The maximum number of bins currently supported is 512 """ if np_histogram: if len(np_histogram) == 2: self.histogram = np_histogram[0] self.bins = np_histogram[1] else: raise ValueError( 'Expected np_histogram to be a tuple of (values, bin_edges) or sequence to be specified' ) else: np = util.get_module( "numpy", required="Auto creation of histograms requires numpy") self.histogram, self.bins = np.histogram(sequence, bins=num_bins) self.histogram = self.histogram.tolist() self.bins = self.bins.tolist() if len(self.histogram) > self.MAX_LENGTH: raise ValueError("The maximum length of a histogram is %i" % self.MAX_LENGTH) if len(self.histogram) + 1 != len(self.bins): raise ValueError("len(bins) must be len(histogram) + 1")
def explain_text(text, probas, target_names=None): """ ExplainText adds support for eli5's LIME based TextExplainer. Arguments: text (str): Text to explain probas (black-box classification pipeline): A function which takes a list of strings (documents) and returns a matrix of shape (n_samples, n_classes) with probability values, i.e. a row per document and a column per output label. Returns: Nothing. To see plots, go to your W&B run page. Example: wandb.log({'roc': wandb.plots.ExplainText(text, probas)}) """ eli5 = util.get_module( "eli5", required= "explain_text requires the eli5 library, install with `pip install eli5`" ) if (test_missing(text=text, probas=probas)): #and test_types(proba=proba)): wandb.termlog('Visualizing TextExplainer.') te = eli5.lime.TextExplainer(random_state=42) te.fit(text, probas) html = te.show_prediction(target_names=target_names) return wandb.Html(html.data)
def grpc_server(project=None, entity=None): _ = util.get_module( "grpc", required="grpc-server requires the grpcio library, run pip install wandb[grpc]", ) from wandb.server.grpc_server import main as grpc_server grpc_server()
def upload_dataset(dataset_name): """ Uploads dataset from local database to Weights & Biases. Args: dataset_name: The name of the dataset in the Prodigy database. """ # Check if wandb.init has been called if wandb.run is None: raise ValueError("You must call wandb.init() before upload_dataset()") with wb_telemetry.context(run=wandb.run) as tel: tel.feature.prodigy = True prodigy_db = util.get_module( "prodigy.components.db", required= "`prodigy` library is required but not installed. Please see https://prodi.gy/docs/install", ) # Retrieve and upload prodigy dataset database = prodigy_db.connect() data = database.get_dataset(dataset_name) array_dict_types = [] schema = get_schema(data, {}, array_dict_types) for i, _d in enumerate(data): standardize(data[i], schema, array_dict_types) table = create_table(data) wandb.log({dataset_name: table}) print("Prodigy dataset `" + dataset_name + "` uploaded.")
def __init__(self, data_or_path, sample_rate=None, caption=None): """Accepts a path to an audio file or a numpy array of audio data.""" super(Audio, self).__init__() self._duration = None self._sample_rate = sample_rate self._caption = caption if isinstance(data_or_path, six.string_types): if Audio.path_is_reference(data_or_path): self._path = data_or_path self._sha256 = hashlib.sha256( data_or_path.encode("utf-8")).hexdigest() self._is_tmp = False else: self._set_file(data_or_path, is_tmp=False) else: if sample_rate is None: raise ValueError( 'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.' ) soundfile = util.get_module( "soundfile", required= 'Raw audio requires the soundfile package. To get it, run "pip install soundfile"', ) tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + ".wav") soundfile.write(tmp_path, data_or_path, sample_rate) self._duration = len(data_or_path) / float(sample_rate) self._set_file(tmp_path, is_tmp=True)
def history(self, samples=500, keys=None, x_axis="_step", pandas=True, stream="default"): """Return history metrics for a run Args: samples (int, optional): The number of samples to return pandas (bool, optional): Return a pandas dataframe keys (list, optional): Only return metrics for specific keys x_axis (str, optional): Use this metric as the xAxis defaults to _step stream (str, optional): "default" for metrics, "system" for machine metrics """ if keys and stream != "default": wandb.termerror("stream must be default when specifying keys") return [] elif keys: lines = self._sampled_history(keys=keys, x_axis=x_axis, samples=samples) else: lines = self._full_history(samples=samples, stream=stream) if pandas: pandas = util.get_module("pandas") if pandas: lines = pandas.DataFrame.from_records(lines) else: print("Unable to load pandas, call history with pandas=False") return lines
def __init__(self, data_or_path, sample_rate=None, caption=None): """Accepts a path to an audio file or a numpy array of audio data. """ self._duration = None self._sample_rate = sample_rate self._caption = caption if isinstance(data_or_path, six.string_types): super(Audio, self).__init__(data_or_path, is_tmp=False) else: if sample_rate == None: raise ValueError( 'Argument "sample_rate" is required when instantiating wandb.Audio with raw data.' ) soundfile = util.get_module( "soundfile", required= 'Raw audio requires the soundfile package. To get it, run "pip install soundfile"' ) tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + '.wav') soundfile.write(tmp_path, data_or_path, sample_rate) self._duration = len(data_or_path) / float(sample_rate) super(Audio, self).__init__(tmp_path, is_tmp=True)
def encode(self): mpy = util.get_module( "moviepy.editor", required= 'wandb.Video requires moviepy and imageio when passing raw data. Install with "pip install moviepy imageio"' ) tensor = self._prepare_video(self.data) _, self._height, self._width, self._channels = tensor.shape # encode sequence of images into gif string clip = mpy.ImageSequenceClip(list(tensor), fps=self._fps) filename = os.path.join(MEDIA_TMP.name, util.generate_id() + '.' + self._format) try: # older version of moviepy does not support progress_bar argument. if self._format == "gif": clip.write_gif(filename, verbose=False, progress_bar=False) else: clip.write_videofile(filename, verbose=False, progress_bar=False) except TypeError: if self._format == "gif": clip.write_gif(filename, verbose=False) else: clip.write_videofile(filename, verbose=False) super(Video, self).__init__(filename, is_tmp=True)
def seq_to_json(cls, seq, run, key, step): audio_list = list(seq) for audio in audio_list: if not audio.is_bound(): audio.bind_to_run(run, key, step) sf = util.get_module( "soundfile", required= "wandb.Audio requires the soundfile package. To get it, run: pip install soundfile" ) base_path = os.path.join(run.dir, "media", "audio") util.mkdir_exists_ok(base_path) meta = { "_type": "audio", "count": len(audio_list), 'audio': [a.to_json(run) for a in audio_list], } sample_rates = cls.sample_rates(audio_list) if sample_rates: meta["sampleRates"] = sample_rates durations = cls.durations(audio_list) if durations: meta["durations"] = durations captions = cls.captions(audio_list) if captions: meta["captions"] = captions return meta
def test_fitted(model): np = util.get_module("numpy", required="Logging plots requires numpy") pd = util.get_module("pandas", required="Logging dataframes requires pandas") scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy") scikit_utils = util.get_module( "sklearn.utils", required="roc requires the scikit utils submodule, install with `pip install scikit-learn`", ) scikit_exceptions = util.get_module( "sklearn.exceptions", "roc requires the scikit preprocessing submodule, install with `pip install scikit-learn`", ) try: model.predict(np.zeros((7, 3))) except scikit_exceptions.NotFittedError: wandb.termerror("Please fit the model before passing it in.") return False except AttributeError: # Some clustering models (LDA, PCA, Agglomerative) don't implement ``predict`` try: scikit_utils.validation.check_is_fitted( model, [ "coef_", "estimator_", "labels_", "n_clusters_", "children_", "components_", "n_components_", "n_iter_", "n_batch_iter_", "explained_variance_", "singular_values_", "mean_", ], all_or_any=any, ) return True except scikit_exceptions.NotFittedError: wandb.termerror("Please fit the model before passing it in.") return False except Exception: # Assume it's fitted, since ``NotFittedError`` wasn't raised return True
def test_missing(**kwargs): np = util.get_module("numpy", required="Logging plots requires numpy") pd = util.get_module("pandas", required="Logging dataframes requires pandas") scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy") scikit = util.get_module( "sklearn", required="Logging plots matrices requires scikit-learn") test_passed = True for k, v in kwargs.items(): # Missing/empty params/datapoint arrays if v is None: wandb.termerror("%s is None. Please try again." % (k)) test_passed = False if ((k == 'X') or (k == 'X_test')): if isinstance(v, scipy.sparse.csr.csr_matrix): v = v.toarray() elif isinstance(v, (pd.DataFrame, pd.Series)): v = v.to_numpy() elif isinstance(v, list): v = np.asarray(v) # Warn the user about missing values missing = 0 missing = np.count_nonzero(pd.isnull(v)) if missing > 0: wandb.termwarn("%s contains %d missing values. " % (k, missing)) test_passed = False # Ensure the dataset contains only integers non_nums = 0 if v.ndim == 1: non_nums = sum(1 for val in v if (not isinstance(val, (int, float, complex)) and not isinstance(val, np.number))) else: non_nums = sum(1 for sl in v for val in sl if (not isinstance(val, (int, float, complex)) and not isinstance(val, np.number))) if non_nums > 0: wandb.termerror( "%s contains values that are not numbers. Please vectorize, label encode or one hot encode %s and call the plotting function again." % (k, k)) test_passed = False return test_passed
def init_gcs(self): if self._client is not None: return self._client storage = util.get_module( "google.cloud.storage", required="gs:// references requires the google-cloud-storage library, run pip install wandb[gcp]", ) self._client = storage.Client() return self._client
def __init__(self, data_or_path, mode=None, caption=None, grouping=None): # TODO: We should remove grouping, it's a terrible name and I don't # think anyone uses it. self._grouping = grouping self._caption = caption self._width = None self._height = None self._image = None if isinstance(data_or_path, six.string_types): super(Image, self).__init__(data_or_path, is_tmp=False) else: data = data_or_path PILImage = util.get_module( "PIL.Image", required='wandb.Image needs the PIL package. To get it, run "pip install pillow".') if util.is_matplotlib_typename(util.get_full_typename(data)): buf = six.BytesIO() util.ensure_matplotlib_figure(data).savefig(buf) self._image = PILImage.open(buf) elif isinstance(data, PILImage.Image): self._image = data elif util.is_pytorch_tensor_typename(util.get_full_typename(data)): vis_util = util.get_module( "torchvision.utils", "torchvision is required to render images") if hasattr(data, "requires_grad") and data.requires_grad: data = data.detach() data = vis_util.make_grid(data, normalize=True) self._image = PILImage.fromarray(data.mul(255).clamp( 0, 255).byte().permute(1, 2, 0).cpu().numpy()) else: if hasattr(data, "numpy"): # TF data eager tensors data = data.numpy() if data.ndim > 2: data = data.squeeze() # get rid of trivial dimensions as a convenience self._image = PILImage.fromarray( self.to_uint8(data), mode=mode or self.guess_mode(data)) self._width, self._height = self._image.size tmp_path = os.path.join(MEDIA_TMP.name, util.generate_id() + '.png') self._image.save(tmp_path, transparency=None) super(Image, self).__init__(tmp_path, is_tmp=True)
def part_of_speech(docs): """ Adds support for spaCy's dependency visualizer which shows part-of-speech tags and syntactic dependencies. Arguments: docs (list, Doc, Span): Document(s) to visualize. Returns: Nothing. To see plots, go to your W&B run page. Example: wandb.log({'part_of_speech': wandb.plots.POS(docs=doc)}) """ deprecation_notice() spacy = util.get_module( "spacy", required= "part_of_speech requires the spacy library, install with `pip install spacy`" ) en_core_web_md = util.get_module( "en_core_web_md", required= "part_of_speech requires the en_core_web_md library, install with `python -m spacy download en_core_web_md`" ) nlp = en_core_web_md.load() if (test_missing(docs=docs)): #and test_types(docs=docs)): wandb.termlog('Visualizing part of speech.') options = { "compact": True, "color": "#1a1c1f", "font": "Source Sans Pro", "collapse_punct": True, "collapse_phrases": True } html = spacy.displacy.render(nlp(str(docs)), style='dep', minify=True, options=options, page=True) return wandb.Html(html)
def hook_torch_modules( self, module, criterion=None, prefix=None, graph_idx=0, parent=None ): torch = util.get_module("torch", "Could not import torch") layers = 0 graph = self if hasattr(module, "_wandb_watch_called") and module._wandb_watch_called: raise ValueError( "You can only call `wandb.watch` once per model. Pass a new instance of the model if you need to call wandb.watch again in your code." ) module._wandb_watch_called = True if criterion: graph.criterion = criterion graph.criterion_passed = True for name, sub_module in module.named_children(): name = name or str(layers) if prefix: name = prefix + "." + name layers += 1 if not isinstance(sub_module, torch.nn.Module): # TODO: Why does this happen? break # Trying to support torch >0.3 making this code complicated # We want a list of types that we should recurse into # Torch 0.3 uses containers # 0.4 has ModuleList # 0.4.1 has ModuleDict module_types = [ getattr(torch.nn, module_classname) for module_classname in ( "Container", "Sequential", "ModuleList", "ModuleDict", ) if hasattr(torch.nn, module_classname) ] if parent is None: parent = module if isinstance(sub_module, tuple(module_types)): self.hook_torch_modules(sub_module, prefix=name, parent=parent) else: self._graph_hooks |= {id(sub_module)} graph_hook = sub_module.register_forward_hook( self.create_forward_hook(name, graph_idx) ) wandb.run.history.torch._hook_handles[ "topology/" + str(id(graph_hook)) ] = graph_hook if not hasattr(parent, "_wandb_hook_names"): # should never happen but let's be extra safe parent._wandb_hook_names = [] parent._wandb_hook_names.append("topology/" + str(id(graph_hook)))
def __init__(self): super(TorchGraph, self).__init__("torch") # When we changed to register_full_backward_hook a regression test running fastai v1 # started failing. To maximize compatability we don't use full backward hooks # when we detect fastai v1 has been imported :( self._should_use_full_hooks = True if "fastai" in sys.modules: fastai = util.get_module("fastai") if fastai.__version__.startswith("1."): self._should_use_full_hooks = False
def test_types(**kwargs): np = util.get_module("numpy", required="Logging plots requires numpy") pd = util.get_module("pandas", required="Logging dataframes requires pandas") scipy = util.get_module("scipy", required="Logging scipy matrices requires scipy") scikit = util.get_module( "sklearn", required="Logging plots matrices requires scikit-learn") test_passed = True for k, v in kwargs.items(): # check for incorrect types if ((k == 'X') or (k == 'X_test') or (k == 'y') or (k == 'y_test') or (k == 'y_true') or (k == 'y_probas') or (k == 'x_labels') or (k == 'y_labels') or (k == 'matrix_values')): # FIXME: do this individually if not isinstance(v, (Sequence, Iterable, np.ndarray, np.generic, pd.DataFrame, pd.Series, list)): wandb.termerror("%s is not an array. Please try again." % (k)) test_passed = False # check for classifier types if (k == 'model'): if ((not scikit.base.is_classifier(v)) and (not scikit.base.is_regressor(v))): wandb.termerror( "%s is not a classifier or regressor. Please try again." % (k)) test_passed = False elif (k == 'clf' or k == 'binary_clf'): if (not (scikit.base.is_classifier(v))): wandb.termerror("%s is not a classifier. Please try again." % (k)) test_passed = False elif (k == 'regressor'): if (not scikit.base.is_regressor(v)): wandb.termerror("%s is not a regressor. Please try again." % (k)) test_passed = False elif (k == 'clusterer'): if (not (getattr(v, "_estimator_type", None) == "clusterer")): wandb.termerror("%s is not a clusterer. Please try again." % (k)) test_passed = False return test_passed
def mock_boto(artifact, path=False): class S3Object(object): def __init__(self, name="my_object.pb", metadata=None): self.metadata = metadata or {"md5": "1234567890abcde"} self.e_tag = '"1234567890abcde"' self.version_id = "1" self.name = name self.key = name self.content_length = 10 def load(self): if path: raise util.get_module("botocore").exceptions.ClientError( {"Error": { "Code": "404" }}, "HeadObject") class Filtered(object): def limit(self, *args, **kwargs): return [S3Object(), S3Object(name="my_other_object.pb")] class S3Objects(object): def filter(self, **kwargs): return Filtered() class S3Bucket(object): def __init__(self, *args, **kwargs): self.objects = S3Objects() class S3Resource(object): def Object(self, bucket, key): return S3Object() def Bucket(self, bucket): return S3Bucket() mock = S3Resource() handler = artifact._storage_policy._handler._handlers["s3"] handler._s3 = mock handler._botocore = util.get_module("botocore") handler._botocore.exceptions = util.get_module("botocore.exceptions") return mock
def confusion_matrix(preds=None, y_true=None, class_names=None): """ Computes a multi-run confusion matrix. Arguments: preds (arr): Array of predicted label indices. y_true (arr): Array of label indices. class_names (arr): Array of class names. Returns: Nothing. To see plots, go to your W&B run page then expand the 'media' tab under 'auto visualizations'. Example: wandb.log({'pr': wandb.plot.confusion_matrix(preds, y_true, labels)}) """ np = util.get_module( "numpy", required= "confusion matrix requires the numpy library, install with `pip install numpy`", ) assert len(preds) == len( y_true), "Number of predictions and label indices must match" if class_names is not None: n_classes = len(class_names) assert max(preds) <= len( class_names), "Higher predicted index than number of classes" assert max(y_true) <= len( class_names), "Higher label class index than number of classes" else: n_classes = max(max(preds), max(y_true)) class_names = ["Class_{}".format(i) for i in range(1, n_classes + 1)] counts = np.zeros((n_classes, n_classes)) for i in range(len(preds)): counts[y_true[i], preds[i]] += 1 data = [] for i in range(n_classes): data.extend([class_names[i], class_names[j], counts[i, j]] for j in range(n_classes)) fields = { "Actual": "Actual", "Predicted": "Predicted", "nPredicted": "Count" } return wandb.plot_table( "wandb/confusion_matrix/v0", wandb.Table(columns=["Actual", "Predicted", "Count"], data=data), fields, )
def node_from_module(cls, nid, module): numpy = util.get_module("numpy", "Could not import numpy") node = wandb.Node() node.id = nid node.child_parameters = 0 for parameter in module.parameters(): node.child_parameters += numpy.prod(parameter.size()) node.class_name = type(module).__name__ return node
def __init__(self, data, mode=None, caption=None, grouping=None): """ Accepts numpy array of image data, or a PIL image. The class attempts to infer the data format and converts it. If grouping is set to a number the interface combines N images. """ PILImage = util.get_module( "PIL.Image", required= "wandb.Image requires the PIL package, to get it run: pip install pillow" ) if util.is_matplotlib_typename(util.get_full_typename(data)): buf = six.BytesIO() util.ensure_matplotlib_figure(data).savefig(buf) self.image = PILImage.open(buf) elif isinstance(data, PILImage.Image): self.image = data elif util.is_pytorch_tensor_typename(util.get_full_typename(data)): vis_util = util.get_module( "torchvision.utils", "torchvision is required to render images") if hasattr(data, "requires_grad") and data.requires_grad: data = data.detach() data = vis_util.make_grid(data, normalize=True) self.image = PILImage.fromarray( data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()) else: # Handle TF eager tensors if hasattr(data, "numpy"): data = data.numpy() data = data.squeeze( ) # get rid of trivial dimensions as a convenience self.image = PILImage.fromarray(self.to_uint8(data), mode=mode or self.guess_mode(data)) self.grouping = grouping self.caption = caption
def plot_to_json(obj): if util.is_matplotlib_typename(util.get_full_typename(obj)): tools = util.get_module( "plotly.tools", required= "plotly is required to log interactive plots, install with: pip install plotly or convert the plot to an image with `wandb.Image(plt)`" ) obj = tools.mpl_to_plotly(obj) if util.is_plotly_typename(util.get_full_typename(obj)): return {"_type": "plotly", "plot": obj.to_plotly_json()} else: return obj
def plot_to_json(obj): """Converts a matplotlib or plotly object to json so that we can pass it the the wandb server and display it nicely there""" if util.is_matplotlib_typename(util.get_full_typename(obj)): tools = util.get_module( "plotly.tools", required="plotly is required to log interactive plots, install with: pip install plotly or convert the plot to an image with `wandb.Image(plt)`") obj = tools.mpl_to_plotly(obj) if util.is_plotly_typename(util.get_full_typename(obj)): return {"_type": "plotly", "plot": numpy_arrays_to_lists(obj.to_plotly_json())} else: return obj