Esempio n. 1
0
def main(_):
    if FLAGS.log:
        wandb.init(config=FLAGS, **FLAGS.wandb)

    # Pipeline
    ## Setup
    set_seed()

    ## Data
    X_train, X_test, y_train, y_test = load_train_test_splits()

    ## Train model
    model = load_model()
    model.fit(X_train, y_train)

    ## Evaluate
    y_fit = model.predict(X_train)
    residual_rmse = mean_squared_error(y_train, y_fit, squared=False)
    y_pred = model.predict(X_test)
    test_rmse = mean_squared_error(y_test, y_pred, squared=False)

    # Log
    if FLAGS.log:
        # Log model
        model_artifact = wandb.Artifact('model', type='model')
        model_path = ARTIFACT_DIR / "models" / "model.joblib"
        dump(model, model_path, compress=3)
        model_artifact.add_file(model_path)
        wandb.log_artifact(model_artifact)

        # Log eval
        wandb.log({'residual_rmse': residual_rmse, 'test_rmse': test_rmse})
    else:
        print(f"Residual RMSE: {residual_rmse}\t Test RMSE: {test_rmse}")
Esempio n. 2
0
def save_models_to_artifact(cfg, workers, stage, metadata, filename=None):
    if filename is None: filename = stage
    model_artifact = wandb.Artifact(f"{stage}-{cfg['model_variant']}",
                                    type='model',
                                    metadata=metadata)

    for worker in workers:
        model_artifact.add_file(
            str(worker.cfg['tmp'] / f"{filename}.pth"),
            f"{worker.cfg['rank']}-v{worker.cfg['model_variant']}-m{worker.cfg['model_mapping']}-{stage}.pth"
        )
        model_artifact.add_file(
            str(worker.cfg['tmp'] / f"{filename}_optim.pth"),
            f"{worker.cfg['rank']}-v{worker.cfg['model_variant']}-m{worker.cfg['model_mapping']}-{stage}_optim.pth"
        )

    wandb.log_artifact(model_artifact)
    try:
        model_artifact.wait()  # throws execption in offline mode
        print(
            f'Model: Save "{stage}" models as version {model_artifact.version}'
        )
    except Exception as e:
        print(f'Model: Save "{stage}" models in offline mode')
    return model_artifact
Esempio n. 3
0
    def on_train_end(self, last, best, plots, epoch, results):
        # Callback runs on training end
        if plots:
            plot_results(file=self.save_dir /
                         'results.csv')  # save results.png
        files = [
            'results.png', 'confusion_matrix.png',
            *(f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R'))
        ]
        files = [(self.save_dir / f) for f in files
                 if (self.save_dir / f).exists()]  # filter

        if self.tb:
            for f in files:
                self.tb.add_image(f.stem,
                                  cv2.imread(str(f))[..., ::-1],
                                  epoch,
                                  dataformats='HWC')

        if self.wandb:
            self.wandb.log({k: v
                            for k, v in zip(self.keys[3:10], results)
                            })  # log best.pt val results
            self.wandb.log({
                "Results":
                [wandb.Image(str(f), caption=f.name) for f in files]
            })
            # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
            if not self.opt.evolve:
                wandb.log_artifact(str(best if best.exists() else last),
                                   type='model',
                                   name='run_' + self.wandb.wandb_run.id +
                                   '_model',
                                   aliases=['latest', 'best', 'stripped'])
            self.wandb.finish_run()
Esempio n. 4
0
    def log_model(self, path, opt, epoch, fitness_score, best_model=False):
        """
        Log the model checkpoint as W&B artifact

        arguments:
        path (Path)   -- Path of directory containing the checkpoints
        opt (namespace) -- Command line arguments for this run
        epoch (int)  -- Current epoch number
        fitness_score (float) -- fitness score for current epoch
        best_model (boolean) -- Boolean representing if the current checkpoint is the best yet.
        """
        model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model',
                                        type='model',
                                        metadata={
                                            'original_url': str(path),
                                            'epochs_trained': epoch + 1,
                                            'save period': opt.save_period,
                                            'project': opt.project,
                                            'total_epochs': opt.epochs,
                                            'fitness_score': fitness_score
                                        })
        model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
        wandb.log_artifact(model_artifact,
                           aliases=[
                               'latest', 'last',
                               'epoch ' + str(self.current_epoch),
                               'best' if best_model else ''
                           ])
        LOGGER.info(f"Saving model artifact on epoch {epoch + 1}")
Esempio n. 5
0
    def end_epoch(self, best_result=False):
        """
        commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.

        arguments:
        best_result (boolean): Boolean representing if the result of this evaluation is best or not
        """
        if self.wandb_run:
            with all_logging_disabled():
                if self.bbox_media_panel_images:
                    self.log_dict["BoundingBoxDebugger"] = self.bbox_media_panel_images
                try:
                    wandb.log(self.log_dict)
                except BaseException as e:
                    LOGGER.info(
                        f"An error occurred in wandb logger. The training will proceed without interruption. More info\n{e}"
                    )
                    self.wandb_run.finish()
                    self.wandb_run = None

                self.log_dict = {}
                self.bbox_media_panel_images = []
            if self.result_artifact:
                self.result_artifact.add(self.result_table, 'result')
                wandb.log_artifact(self.result_artifact,
                                   aliases=[
                                       'latest', 'last', 'epoch ' + str(self.current_epoch),
                                       ('best' if best_result else '')])

                wandb.log({"evaluation": self.result_table})
                columns = ["epoch", "id", "ground truth", "prediction"]
                columns.extend(self.data_dict['names'])
                self.result_table = wandb.Table(columns)
                self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
Esempio n. 6
0
def main():
    parser = ArgumentParser(description='''
        Generate train/validation splits for the MNIST and FashionMNIST datasets.
    ''',
                            parents=[MNISTData.add_data_args()])
    args = parser.parse_args()

    splits_sizes = {'train': 50000, 'valid': 10000}

    wandb.init('dataset_splits')

    # MNIST
    artifact = wandb.Artifact('mnist', type='dataset_split')

    with artifact.new_file('splits.pt', mode='wb') as f:
        dataset = MNIST(args.data_dir, train=True, download=True)
        generate_splits(dataset, splits_sizes, f)

    wandb.log_artifact(artifact)

    # FashionMNIST
    artifact = wandb.Artifact('fashion_mnist', type='dataset_split')

    with artifact.new_file('splits.pt', mode='wb') as f:
        dataset = FashionMNIST(args.data_dir, train=True, download=True)
        generate_splits(dataset, splits_sizes, f)

    wandb.log_artifact(artifact)
Esempio n. 7
0
    def on_train_end(self, last, best, plots, epoch):
        # Callback runs on training end
        if plots:
            plot_results(dir=self.save_dir)  # save results.png
        files = [
            'results.png', 'confusion_matrix.png',
            *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]
        ]
        files = [(self.save_dir / f) for f in files
                 if (self.save_dir / f).exists()]  # filter

        if self.tb:
            from PIL import Image
            import numpy as np
            for f in files:
                self.tb.add_image(f.stem,
                                  np.asarray(Image.open(f)),
                                  epoch,
                                  dataformats='HWC')

        if self.wandb:
            self.wandb.log({
                "Results":
                [wandb.Image(str(f), caption=f.name) for f in files]
            })
            # Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
            wandb.log_artifact(str(best if best.exists() else last),
                               type='model',
                               name='run_' + self.wandb.wandb_run.id +
                               '_model',
                               aliases=['latest', 'best', 'stripped'])
            self.wandb.finish_run()
Esempio n. 8
0
    def end_epoch(self, best_result=False):
        if self.wandb_run:
            with all_logging_disabled():
                if self.bbox_media_panel_images:
                    self.log_dict[
                        "Bounding Box Debugger/Images"] = self.bbox_media_panel_images
                wandb.log(self.log_dict)
                self.log_dict = {}
                self.bbox_media_panel_images = []
            if self.result_artifact:
                self.result_artifact.add(self.result_table, 'result')
                wandb.log_artifact(self.result_artifact,
                                   aliases=[
                                       'latest', 'last',
                                       'epoch ' + str(self.current_epoch),
                                       ('best' if best_result else '')
                                   ])

                wandb.log({"evaluation": self.result_table})
                self.result_table = wandb.Table([
                    "epoch", "id", "ground truth", "prediction",
                    "avg_confidence"
                ])
                self.result_artifact = wandb.Artifact(
                    "run_" + wandb.run.id + "_progress", "evaluation")
Esempio n. 9
0
    def end_epoch(self, best_result=False):
        """
        commit the log_dict, model artifacts and Tables to W&B and flush the log_dict.

        arguments:
        best_result (boolean): Boolean representing if the result of this evaluation is best or not
        """
        if self.wandb_run:
            with all_logging_disabled():
                if self.bbox_media_panel_images:
                    self.log_dict[
                        "Bounding Box Debugger/Images"] = self.bbox_media_panel_images
                wandb.log(self.log_dict)
                self.log_dict = {}
                self.bbox_media_panel_images = []
            if self.result_artifact:
                self.result_artifact.add(self.result_table, 'result')
                wandb.log_artifact(self.result_artifact,
                                   aliases=[
                                       'latest', 'last',
                                       'epoch ' + str(self.current_epoch),
                                       ('best' if best_result else '')
                                   ])

                wandb.log({"evaluation": self.result_table})
                self.result_table = wandb.Table([
                    "epoch", "id", "ground truth", "prediction",
                    "avg_confidence"
                ])
                self.result_artifact = wandb.Artifact(
                    "run_" + wandb.run.id + "_progress", "evaluation")
Esempio n. 10
0
def write_file(kind, metrics, output_dir, save_artifact=False):
    output_file = output_dir / f"{kind}_results.txt"
    headers = []
    label_headers = []
    data = []
    label_data = []
    with open(output_file, "w") as writer:
        printm(f"**{kind.capitalize()} results**")
        for key, value in metrics.items():
            printm(f"\t{key} = {value}")
            writer.write(f"{key} = {value}\n")
            title = key.replace("eval_", "", 1)
            if title.startswith("label_"):
                label_headers.append(title.replace("label_", "", 1))
                label_data.append(value)
            else:
                headers.append(title)
                data.append(value)
            wandb.log({f"{kind}:{title}": value})
    wandb.log({kind: wandb.Table(data=[data], columns=headers)})
    if label_headers:
        wandb.log({
            f"{kind}:labels":
            wandb.Table(data=[label_data], columns=label_headers)
        })
    if save_artifact:
        artifact = wandb.Artifact(kind, type="result")
        artifact.add_file(str(output_file))
        wandb.log_artifact(artifact)
Esempio n. 11
0
 def log_dataset_artifact(self, dataset, class_to_id, name='dataset'):
     artifact = wandb.Artifact(name=name, type="dataset")
     image_path = dataset.path
     artifact.add_dir(image_path, name='data/images')
     table = wandb.Table(columns=["id", "train_image", "Classes"])
     class_set = wandb.Classes([{'id': id, 'name': name} for id, name in class_to_id.items()])
     for si, (img, labels, paths, shapes) in enumerate(dataset):
         height, width = shapes[0]
         labels[:, 2:] = (xywh2xyxy(labels[:, 2:].view(-1, 4)))
         labels[:, 2:] *= torch.Tensor([width, height, width, height])
         box_data = []
         img_classes = {}
         for cls, *xyxy in labels[:, 1:].tolist():
             cls = int(cls)
             box_data.append({"position": {"minX": xyxy[0], "minY": xyxy[1], "maxX": xyxy[2], "maxY": xyxy[3]},
                              "class_id": cls,
                              "box_caption": "%s" % (class_to_id[cls]),
                              "scores": {"acc": 1},
                              "domain": "pixel"})
             img_classes[cls] = class_to_id[cls]
         boxes = {"ground_truth": {"box_data": box_data, "class_labels": class_to_id}}  # inference-space
         table.add_data(si, wandb.Image(paths, classes=class_set, boxes=boxes), json.dumps(img_classes))
     artifact.add(table, name)
     labels_path = 'labels'.join(image_path.rsplit('images', 1))
     zip_path = Path(labels_path).parent / (name + '_labels.zip')
     if not zip_path.is_file():  # make_archive won't check if file exists
         shutil.make_archive(zip_path.with_suffix(''), 'zip', labels_path)
     artifact.add_file(str(zip_path), name='data/labels.zip')
     wandb.log_artifact(artifact)
     print("Saving data to W&B...")
Esempio n. 12
0
def test_log_artifact_simple(runner, wandb_init_run):
    util.mkdir_exists_ok("artsy")
    open("artsy/file1.txt", "w").write("hello")
    open("artsy/file2.txt", "w").write("goodbye")
    with pytest.raises(ValueError):
        wandb.log_artifact("artsy")
    art = wandb.log_artifact("artsy", type="dataset")
    assert art.name == "run-" + wandb_init_run.id + "-artsy"
Esempio n. 13
0
    def _log_model_as_artifact(self, model):
        model_name = f"{wandb.run.id}_model.json"
        model_path = Path(wandb.run.dir) / model_name
        model.save_model(str(model_path))

        model_artifact = wandb.Artifact(name=model_name, type="model")
        model_artifact.add_file(model_path)
        wandb.log_artifact(model_artifact)
Esempio n. 14
0
    def log_artifact(self, name, type, file_or_dir, metadata={}, aliases=[]):
        artifact = wandb.Artifact(name=name, type=type, metadata=metadata)
        path_to_log = Path(file_or_dir)
        if path_to_log.is_file():
            artifact.add_file(file_or_dir)
        if path_to_log.is_dir():
            artifact.add_dir(file_or_dir)

        wandb.log_artifact(artifact, aliases=aliases)
Esempio n. 15
0
File: loggers.py Progetto: csr/spaCy
 def log_dir_artifact(
     path: str,
     name: str,
     type: str,
     metadata: Optional[Dict[str, Any]] = {},
     aliases: Optional[List[str]] = [],
 ):
     dataset_artifact = wandb.Artifact(name, type=type, metadata=metadata)
     dataset_artifact.add_dir(path, name=name)
     wandb.log_artifact(dataset_artifact, aliases=aliases)
 def end_epoch(self, best_result=False):
     if self.wandb_run:
         wandb.log(self.log_dict)
         self.log_dict = {}
         if self.result_artifact:
             train_results = wandb.JoinedTable(self.val_table, self.result_table, "id")
             self.result_artifact.add(train_results, 'result')
             wandb.log_artifact(self.result_artifact, aliases=['latest', 'epoch ' + str(self.current_epoch),
                                                               ('best' if best_result else '')])
             self.result_table = wandb.Table(["epoch", "id", "prediction", "avg_confidence"])
             self.result_artifact = wandb.Artifact("run_" + wandb.run.id + "_progress", "evaluation")
Esempio n. 17
0
 def finish_run(self):
     if self.wandb_run:
         if self.result_artifact:
             print("Add Training Progress Artifact")
             self.result_artifact.add(self.result_table, 'result')
             train_results = wandb.JoinedTable(self.testset_artifact.get("val"), self.result_table, "id")
             self.result_artifact.add(train_results, 'joined_result')
             wandb.log_artifact(self.result_artifact)
         if self.log_dict:
             wandb.log(self.log_dict)
         wandb.run.finish()
Esempio n. 18
0
 def log_model(self, path, opt, epoch):
     datetime_suffix = datetime.today().strftime('%Y-%m-%d-%H-%M-%S')
     model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
         'original_url': str(path),
         'epoch': epoch + 1,
         'save period': opt.save_period,
         'project': opt.project,
         'datetime': datetime_suffix
     })
     model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
     model_artifact.add_file(str(path / 'best.pt'), name='best.pt')
     wandb.log_artifact(model_artifact)
     print("Saving model artifact on epoch ", epoch + 1)
Esempio n. 19
0
def store_model_artifact(path: str, name: str):
    """
    Store a tf model directory as a WandB artifact

    Args:
        path: Path to tensorflow saved model (e.g., /path/to/model.tf/)
        name: name for the WandB artifact.  If it already exists a new
            version is stored
    """

    model_artifact = wandb.Artifact(name, type="model")
    model_artifact.add_dir(path)
    wandb.log_artifact(model_artifact)
Esempio n. 20
0
def _checkpoint_artifact(model: "Booster", iteration: int,
                         aliases: "List[str]") -> None:
    """Upload model checkpoint as W&B artifact."""
    # NOTE: type ignore required because wandb.run is improperly inferred as None type
    model_name = f"model_{wandb.run.id}"  # type: ignore
    model_path = Path(
        wandb.run.dir) / f"model_ckpt_{iteration}.txt"  # type: ignore

    model.save_model(model_path, num_iteration=iteration)

    model_artifact = wandb.Artifact(name=model_name, type="model")
    model_artifact.add_file(model_path)
    wandb.log_artifact(model_artifact, aliases=aliases)
 def log_model(self, path, opt, epoch, fitness_score, best_model=False):
     model_artifact = wandb.Artifact('run_' + wandb.run.id + '_model', type='model', metadata={
         'original_url': str(path),
         'epochs_trained': epoch + 1,
         'save period': opt.save_period,
         'project': opt.project,
         'total_epochs': opt.epochs,
         'fitness_score': fitness_score
     })
     model_artifact.add_file(str(path / 'last.pt'), name='last.pt')
     wandb.log_artifact(model_artifact,
                        aliases=['latest', 'epoch ' + str(self.current_epoch), 'best' if best_model else ''])
     print("Saving model artifact on epoch ", epoch + 1)
Esempio n. 22
0
    def save_state_dict(self, epoch: int = None):
        meta_config = dict(self.config)
        meta_config.update({"save_state_epoch": epoch} if epoch else
                           {"save_state_epoch": self.config.max_epochs - 1})

        weight_artifact = wandb.Artifact(name=wandb.run.id + ".pth",
                                         type="pretrained_weight",
                                         metadata=meta_config)
        for agent_id, agent in enumerate(self.agents):
            model_path = f"agent{agent_id}.pth"
            torch.save(agent.brain.network.to("cpu").state_dict(), model_path)
            weight_artifact.add_file(model_path)
            agent.brain.network.to(agent.brain.device)

        wandb.log_artifact(weight_artifact)
Esempio n. 23
0
def log_checkpoints(trainer, save=False, log=True):
    for callback in trainer.callbacks:
        if not isinstance(callback, pl.callbacks.ModelCheckpoint):
            continue

        file_path = callback.best_model_path

        # callback.best_model_path is an empty string until the first checkpoint
        # has been saved.
        if not file_path:
            continue

        file_name = os.path.relpath(file_path, callback.dirpath)

        matches = re.match(r"^epoch=(\d+)(-.+)?\.ckpt$", file_name)
        epoch = matches.group(1) if matches else None

        if callback.monitor:
            metric_name = callback.monitor
            metric_value = callback.best_model_score
        else:
            metric_name = 'latest_epoch'
            metric_value = epoch

        if isinstance(metric_value, torch.Tensor):
            metric_value = metric_value.item()

        metadata = dict(
            file_name=file_name,
            metric_name=metric_name,
            metric_value=metric_value,
            epoch=epoch
        )

        # Handle metrics with a slash in the name
        metric_slug = metric_name.replace('/', '_')

        if save:
            artifact_name = f'{wandb.run.id}'
            artifact = wandb.Artifact(name=artifact_name, type='checkpoint', metadata=metadata)
            artifact.add_file(file_path, name='checkpoint.ckpt')

            wandb.log_artifact(artifact, aliases=[metric_slug])

        if log and callback.monitor:
            wandb.summary[f'{metric_name}/best_value'] = metric_value
            wandb.summary[f'{metric_name}/best_epoch'] = epoch
Esempio n. 24
0
def add_checkpoint_artifact(run, api: wandb.Api, dry_run):
    with tempfile.TemporaryDirectory() as tmp_dir:
        # Download checkpoints from Google Drive
        cmd = ['rclone', 'copy', '--progress', f'drive:data/runs/{run.id}/checkpoints', tmp_dir]
        subprocess.run(cmd, check=True)

        artifacts = []
        artifact_name = run.id
        file_paths = glob.iglob(os.path.join(tmp_dir, '**/*.ckpt'), recursive=True)

        for file_path in file_paths:
            file_name = os.path.relpath(file_path, tmp_dir)

            matches = re.match(r"^epoch=(\d+)(-.+)?\.ckpt$", file_name)
            epoch = matches.group(1) if matches else None

            if matches.group(2) is None:
                metric_name = 'latest_epoch'
                metric_value = epoch
            else:
                warnings.warn("Support for checkpoints tracking a metric has not been implemented yet, skipping.")
                continue

            metadata = dict(
                file_name=file_name,
                metric_name=metric_name,
                metric_value=metric_value,
                epoch=epoch
            )

            # Handle metrics with a slash in the name
            metric_slug = metric_name.replace('/', '_')

            artifact = wandb.Artifact(artifact_name, type='checkpoint', metadata=metadata)
            artifact.add_file(file_path, name='checkpoint.ckpt')
            wandb.log_artifact(artifact, aliases=[metric_slug])

            artifacts.append((artifact, metric_slug))

        # Wait until each artifact has been uploaded and link it to its run.
        for artifact, metric_slug in artifacts:
            name = f'{artifact.name}:{metric_slug}'
            manifest = wait_pending_artifact(api, name, type='checkpoint')

            if not dry_run:
                run.log_artifact(manifest)
Esempio n. 25
0
def save_idx_to_artifact(cfg, idxs, counts, test_idxs):
    idx_artifact_name = get_idx_artifact_name(cfg)
    idx_artifact = wandb.Artifact(idx_artifact_name, type='private_indices',
                                    metadata={
                                        'parties': cfg['parties'],
                                        'normalize': cfg['partition_normalize'],
                                        'samples': cfg['samples'],
                                        'classes': cfg['classes'],
                                        'concentration': cfg['concentration'],
                                        'distributions': counts,
                                        'class_total': counts.sum(axis=0),
                                        'party_total': counts.sum(axis=1)})
    with idx_artifact.new_file('idxs.npy', 'xb') as f:
        np.save(f, np.array(idxs, dtype=object))
    with idx_artifact.new_file('test_idxs.npy', 'xb') as f:
        np.save(f, np.array(test_idxs, dtype=object))
    wandb.log_artifact(idx_artifact)
    return idx_artifact
Esempio n. 26
0
def _checkpoint_artifact(model: Union[CatBoostClassifier, CatBoostRegressor],
                         aliases: List[str]) -> None:
    """
    Upload model checkpoint as W&B artifact
    """
    if wandb.run is None:
        raise wandb.Error(
            "You must call `wandb.init()` before `_checkpoint_artifact()`")

    model_name = f"model_{wandb.run.id}"
    # save the model in the default `cbm` format
    model_path = Path(wandb.run.dir) / "model"

    model.save_model(model_path)

    model_artifact = wandb.Artifact(name=model_name, type="model")
    model_artifact.add_file(str(model_path))
    wandb.log_artifact(model_artifact, aliases=aliases)
Esempio n. 27
0
def init_wandb(cfg: DictConfig):
    config = {
        f"cfg/{k}": v
        for k, v in OmegaConf.to_container(cfg, resolve=True).items()
    }

    wandb.init(
        project=cfg.logger.project,
        config=config,
        tags=cfg.logger.tags,
        group=cfg.logger.group,
    )

    # upload code
    code = wandb.Artifact("project-source", type="code")
    for path in glob.glob(os.path.join(cfg.work_dir, "*.py")):
        code.add_file(path)
    wandb.log_artifact(code)
Esempio n. 28
0
 def on_train_end(self, last, best, plots):
     # Callback runs on training end
     if plots:
         plot_results(dir=self.save_dir)  # save results.png
     files = [
         'results.png', 'confusion_matrix.png',
         *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]
     ]
     files = [(self.save_dir / f) for f in files
              if (self.save_dir / f).exists()]  # filter
     if self.wandb:
         wandb.log({
             "Results":
             [wandb.Image(str(f), caption=f.name) for f in files]
         })
         wandb.log_artifact(str(best if best.exists() else last),
                            type='model',
                            name='run_' + self.wandb.wandb_run.id +
                            '_model',
                            aliases=['latest', 'best', 'stripped'])
         self.wandb.finish_run()
Esempio n. 29
0
    def on_fit_end(self):

        ckpt_dir = os.path.join(self.hparams.log_dir, self.hparams.exp_name,
                                "ckpts")
        wandb.save(ckpt_dir + "/*")
        print("saving checkpoint at: " + str(ckpt_dir))

        self.hparams.scene_name = self.hparams.exp_name
        self.hparams.N_importance = 64

        ckpts = [f for f in os.listdir(ckpt_dir) if "epoch" in f]
        if len(ckpts) != 0:
            ckpts.sort()

            self.hparams.eval_ckpt_path = os.path.join(ckpt_dir, ckpts[-1])
            img_gif, depth_gif = eval(self.hparams)

            wandb.log({
                "eval/depth_gif":
                wandb.Video(depth_gif, fps=30, format="gif")
            })
            # else:
            wandb.log(
                {"eval/out_gif": wandb.Video(img_gif, fps=30, format="gif")})

        if not self.hparams.debug:
            if self.hparams.save_dataset:
                print("saving dataset artifact...")
                dataset_name = hparams.root_dir.split("/")[-1]
                artifact = wandb.Artifact(dataset_name, type="dataset")
                artifact.add_dir(hparams.root_dir)
                artifact.description = hparams.exp_name
                wandb.log_artifact(artifact)
            print("saving dataset artifact...")
            ckpt_dir = os.path.join(hparams.log_dir, hparams.exp_name, "ckpts")
            artifact = wandb.Artifact(hparams.exp_name, type="model")
            artifact.add_dir(ckpt_dir)
            artifact.description = hparams.exp_name
            wandb.log_artifact(artifact)
Esempio n. 30
0
File: wandb.py Progetto: smorad/ray
 def _handle_checkpoint(self, checkpoint_path: str):
     artifact = wandb.Artifact(name=f"checkpoint_{self._trial_name}", type="model")
     artifact.add_dir(checkpoint_path)
     wandb.log_artifact(artifact)