Exemplo n.º 1
0
    def objective(trial):
        global writer
        writer = SummaryWriter()
        device = "cuda" if torch.cuda.is_available() else "cpu"

        model = Net(trial).to(device)
        optimizer = get_optimizer(trial, model)
        writer.add_hparams_start(h_params)
        for step in range(EPOCH):
            train(model, device, train_loader, optimizer)
            error_rate = test(model, device, test_loader)
            writer.add_scalar('test/loss', error_rate, step)
            trial.report(error_rate, step)
            if trial.should_prune(step):
                pbar.update()
                raise optuna.structs.TrialPruned()

        pbar.update()
        writer.add_hparams_end()  # save hyper parameter
        return error_rate
Exemplo n.º 2
0
    def emit(self, record: logging.LogRecord) -> None:
        """Save to tensorboard logging directory

        Overrides `logging.Handler.emit`

        Parameters
        ----------
        record : logging.LogRecord
            LogRecord with data relevant to Tensorboard

        Returns
        -------
        None

        """
        # Handler relies on access to raw objects which flambe logging
        # provides
        if not hasattr(record, "raw_msg_obj"):
            return
        message = record.raw_msg_obj  # type: ignore
        # Check for a log directory from the logging context
        # This will be prepended to the final tag before saving to
        # Tensorboard

        if hasattr(record, "_tf_log_dir"):
            log_dir = record._tf_log_dir  # type: ignore
            if log_dir in self.writers:
                writer = self.writers[log_dir]
            else:
                writer = SummaryWriter(log_dir=log_dir)
                hparams = getattr(record, "_tf_hparams", dict())
                if len(hparams):
                    writer.add_hparams_start(hparams=hparams)

                self.writers[log_dir] = writer
        else:
            return
        # Datatypes with a standard `tag` field
        if isinstance(message, (ScalarT, HistogramT, TextT, EmbeddingT, ImageT, PRCurveT)):
            kwargs = message._replace(tag=message.tag)._asdict()
            fn = {
                ScalarT: writer.add_scalar,
                HistogramT: writer.add_histogram,
                TextT: writer.add_text,
                EmbeddingT: writer.add_embedding,
                ImageT: writer.add_image,
                PRCurveT: writer.add_pr_curve
            }
            fn[message.__class__](**kwargs)
        # Datatypes with a special tag field
        elif isinstance(message, ScalarsT):
            kwargs = message._replace(main_tag=message.main_tag)._asdict()
            writer.add_scalars(**kwargs)
        # Datatypes without a tag field
        elif isinstance(message, GraphT):
            kwargs = message._asdict()
            for k, v in kwargs['kwargs']:
                kwargs[k] = v
            del kwargs['kwargs']
            writer.add_model(**kwargs)
        writer.file_writer.flush()