Ejemplo n.º 1
0
def fit_validate(exp_params, k, data_path, write_path, others=None, custom_tag=''):
    """Fit model and compute metrics on train and validation set. Intended for hyperparameter search.

    Only logs final metrics and scatter plot of final embedding.

    Args:
        exp_params(dict): Parameter dict. Should at least have keys model_name, dataset_name & random_state. Other
        keys are assumed to be model parameters.
        k(int): Fold identifier.
        data_path(str): Data directory.
        write_path(str): Where to write temp files.
        others(dict): Other things to log to Comet experiment.
        custom_tag(str): Custom tag for comet experiment.

    """
    # Comet experiment
    exp = Experiment(parse_args=False)
    exp.disable_mp()
    custom_tag += '_validate'
    exp.add_tag(custom_tag)
    exp.log_parameters(exp_params)

    if others is not None:
        exp.log_others(others)

    # Parse experiment parameters
    model_name, dataset_name, random_state, model_params = parse_params(exp_params)

    # Fetch and split dataset.
    data_train = getattr(grae.data, dataset_name)(split='train', random_state=random_state, data_path=data_path)
    data_train, data_val = data_train.validation_split(random_state=FOLD_SEEDS[k])

    # Model
    m = getattr(grae.models, model_name)(random_state=FOLD_SEEDS[k], **model_params)
    m.write_path = write_path
    m.data_val = data_val

    with exp.train():
        m.fit(data_train)

        # Log plot
        m.comet_exp = exp
        m.plot(data_train, data_val, title=f'{model_name} : {dataset_name}')

        # Probe embedding
        prober = EmbeddingProber()
        prober.fit(model=m, dataset=data_train, mse_only=True)
        train_z, train_metrics = prober.score(data_train, is_train=True)

        # Log train metrics
        exp.log_metrics(train_metrics)

    with exp.validate():
        val_z, val_metrics = prober.score(data_val)

        # Log train metrics
        exp.log_metrics(val_metrics)

    # Log marker to mark successful experiment
    exp.log_other('success', 1)
    def init_callbacks(self):
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(self.config.checkpoint_dir, '%s-{epoch:03d}-{val_nme:.5f}.hdf5' % self.config.exp_name),
                monitor=self.config.checkpoint_monitor,
                mode=self.config.checkpoint_mode,
                save_best_only=self.config.checkpoint_save_best_only,
                save_weights_only=self.config.checkpoint_save_weights_only,
                verbose=self.config.checkpoint_verbose,
            )
        )

        self.callbacks.append(
                TensorBoard(
                    log_dir=self.config.tensorboard_log_dir,
                    write_graph=self.config.tensorboard_write_graph,
                )
            )

        # self.callbacks.append(
        #     LearningRateScheduler(self.lr_scheduler)
        # )

        if hasattr(self.config,"comet_api_key"):
            from comet_ml import Experiment
            experiment = Experiment(api_key=self.config.comet_api_key, project_name=self.config.exp_name)
            experiment.disable_mp()
            experiment.log_multiple_params(self.config)
            self.callbacks.append(experiment.get_keras_callback())
Ejemplo n.º 3
0
    def init_callbacks(self):
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(
                    self.checkpoint_dir, '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
                    self.config['exp_name']),
                monitor=self.config['checkpoint_monitor'],
                mode=self.config['checkpoint_mode'],
                save_best_only=self.config['checkpoint_save_best_only'],
                save_weights_only=self.config['checkpoint_save_weights_only'],
                verbose=self.config['checkpoint_verbose'],
            ))

        self.callbacks.append(
            TensorBoard(
                log_dir=self.tensorboard_log_dir,
                write_graph=self.config['tensorboard_write_graph'],
                histogram_freq=0,  # don't compute histograms
                write_images=
                False  # don't write model weights to visualize as image in TensorBoard
            ))

        if hasattr(self.config, "comet_api_key"):
            from comet_ml import Experiment
            experiment = Experiment(api_key=self.config['comet_api_key'],
                                    project_name=self.config['exp_name'])
            experiment.disable_mp()
            experiment.log_multiple_params(self.config)
            self.callbacks.append(experiment.get_keras_callback())
Ejemplo n.º 4
0
    def init_callbacks(self):
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(
                    self.config.callbacks.checkpoint_dir,
                    '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
                    self.config.exp.name),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                save_best_only=self.config.callbacks.checkpoint_save_best_only,
                save_weights_only=self.config.callbacks.
                checkpoint_save_weights_only,
                verbose=self.config.callbacks.checkpoint_verbose,
            ))

        self.callbacks.append(
            TensorBoard(
                log_dir=self.config.callbacks.tensorboard_log_dir,
                write_graph=self.config.callbacks.tensorboard_write_graph,
            ))

        if hasattr(self.config, "comet_api_key"):
            from comet_ml import Experiment
            experiment = Experiment(api_key=self.config.comet_api_key,
                                    project_name=self.config.exp_name)
            experiment.disable_mp()
            experiment.log_multiple_params(self.config)
            self.callbacks.append(experiment.get_keras_callback())
    def init_callbacks(self):
        if (self.config.model.name == "encoder"):
            import keras
            from keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau, EarlyStopping
        else:
            import tensorflow.keras as keras
            from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau, EarlyStopping
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(
                    self.config.callbacks.checkpoint_dir,
                    '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
                    self.config.exp.name),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                save_best_only=self.config.callbacks.checkpoint_save_best_only,
                save_weights_only=self.config.callbacks.
                checkpoint_save_weights_only,
                verbose=self.config.callbacks.checkpoint_verbose,
            ))
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(
                    self.config.callbacks.checkpoint_dir,
                    'best_model-%s.hdf5' %
                    self.config.callbacks.checkpoint_monitor),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                save_best_only=self.config.callbacks.checkpoint_save_best_only,
            ))
        self.callbacks.append(
            ReduceLROnPlateau(monitor='val_loss',
                              factor=0.5,
                              patience=10,
                              min_lr=0.0001))
        self.callbacks.append(
            EarlyStopping(monitor='val_loss', patience=10, verbose=1), )
        # 在TCN中使用了tensorflow_addson中的WeightNormalization层,与tensorboard不兼容
        # if (self.config.model.name != "tcn"):
        #     self.callbacks.append(
        #         TensorBoard(
        #             log_dir=self.config.callbacks.tensorboard_log_dir,
        #             write_graph=self.config.callbacks.tensorboard_write_graph,
        #             histogram_freq=1,
        #         )
        #     )
        # if self.config.dataset.name == "ptbdb":
        #     self.callbacks.append(
        #         AdvancedLearnignRateScheduler(monitor='val_main_output_loss', patience=6, verbose=1, mode='auto',
        #                                       decayRatio=0.1),
        #     )

        if ("comet_api_key" in self.config):
            from comet_ml import Experiment
            experiment = Experiment(api_key=self.config.comet_api_key,
                                    project_name=self.config.exp_name)
            experiment.disable_mp()
            experiment.log_parameters(self.config["trainer"])
            self.callbacks.append(experiment.get_callback('keras'))
Ejemplo n.º 6
0
class Logger:
    def __init__(self, sess, config):
        self.sess = sess
        self.config = config
        self.summary_placeholders = {}
        self.summary_ops = {}
        self.train_summary_writer = tf.summary.FileWriter(os.path.join(self.config.summary_dir, "train"),
                                                          self.sess.graph)
        self.test_summary_writer = tf.summary.FileWriter(
            os.path.join(self.config.summary_dir, "test"))

        if "comet_api_key" in config:
            from comet_ml import Experiment
            self.experiment = Experiment(
                api_key=config['comet_api_key'], project_name=config['exp_name'])
            self.experiment.disable_mp()
            self.experiment.log_multiple_params(config)

    # it can summarize scalars and images.
    def summarize(self, step, summarizer="train", scope="", summaries_dict=None):
        """
        :param step: the step of the summary
        :param summarizer: use the train summary writer or the test one
        :param scope: variable scope
        :param summaries_dict: the dict of the summaries values (tag,value)
        :return:
        """
        summary_writer = self.train_summary_writer if summarizer == "train" else self.test_summary_writer
        with tf.variable_scope(scope):

            if summaries_dict is not None:
                summary_list = []
                for tag, value in summaries_dict.items():
                    if tag not in self.summary_ops:
                        if len(value.shape) <= 1:
                            self.summary_placeholders[tag] = tf.placeholder(
                                'float32', value.shape, name=tag)
                        else:
                            self.summary_placeholders[tag] = tf.placeholder('float32', [None] + list(value.shape[1:]),
                                                                            name=tag)
                        if len(value.shape) <= 1:
                            self.summary_ops[tag] = tf.summary.scalar(
                                tag, self.summary_placeholders[tag])
                        else:
                            self.summary_ops[tag] = tf.summary.image(
                                tag, self.summary_placeholders[tag])

                    summary_list.append(self.sess.run(self.summary_ops[tag], {
                                        self.summary_placeholders[tag]: value}))

                for summary in summary_list:
                    summary_writer.add_summary(summary, step)

                if hasattr(self, 'experiment') and self.experiment is not None:
                    self.experiment.log_multiple_metrics(
                        summaries_dict, step=step)

                summary_writer.flush()
Ejemplo n.º 7
0
 def _build_experiment(self):
     exp = Experiment(self.api_key,
                      self.project_name,
                      self.workspace,
                      auto_metric_logging=False,
                      auto_param_logging=False,
                      log_graph=False,
                      disabled=False)
     exp.disable_mp()
     return exp
Ejemplo n.º 8
0
    def init_callbacks(self):
        if (self.config.model.name == "encoder"):
            import keras
        else:
            import tensorflow.keras as keras
        from keras.callbacks import ModelCheckpoint, TensorBoard, ReduceLROnPlateau
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(
                    self.config.callbacks.checkpoint_dir,
                    '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
                    self.config.exp.name),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                save_best_only=self.config.callbacks.checkpoint_save_best_only,
                save_weights_only=self.config.callbacks.
                checkpoint_save_weights_only,
                verbose=self.config.callbacks.checkpoint_verbose,
            ))
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(
                    self.config.callbacks.checkpoint_dir,
                    'best_model-%s.hdf5' %
                    self.config.callbacks.checkpoint_monitor),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                save_best_only=self.config.callbacks.checkpoint_save_best_only,
            ))
        self.callbacks.append(
            ReduceLROnPlateau(monitor='val_loss',
                              factor=0.5,
                              patience=50,
                              min_lr=0.0001))
        self.callbacks.append(
            TensorBoard(
                log_dir=self.config.callbacks.tensorboard_log_dir,
                write_graph=self.config.callbacks.tensorboard_write_graph,
                histogram_freq=1,
            ))

        # if hasattr(self.config,"comet_api_key"):
        if ("comet_api_key" in self.config):
            from comet_ml import Experiment
            experiment = Experiment(api_key=self.config.comet_api_key,
                                    project_name=self.config.exp_name)
            experiment.disable_mp()
            experiment.log_parameters(self.config["trainer"])
            self.callbacks.append(experiment.get_callback('keras'))
Ejemplo n.º 9
0
    def init_callbacks(self):
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(
                    self.config.callbacks.checkpoint_dir,
                    '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
                    self.config.exp.name),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                save_best_only=self.config.callbacks.checkpoint_save_best_only,
                save_weights_only=self.config.callbacks.
                checkpoint_save_weights_only,
                verbose=self.config.callbacks.checkpoint_verbose,
            ))

        self.callbacks.append(
            TensorBoard(
                log_dir=self.config.callbacks.tensorboard_log_dir,
                write_graph=self.config.callbacks.tensorboard_write_graph,
            ))

        # if the config has the debug flag on, turn on tfdbg (TODO: make it work)
        if hasattr(self.config, "debug"):
            if (self.config.debug == True):
                import keras.backend
                from tensorflow.python import debug as tf_debug
                print("#=========== DEBUG MODE ===========#")
                sess = keras.backend.get_session()
                sess = tf_debug.LocalCLIDebugWrapperSession(sess)
                keras.backend.set_session(sess)

        # if the config file has a comet_ml key, log on comet
        if hasattr(self.config, "comet_api_key"):
            from comet_ml import Experiment  # PUT the import in main
            experiment = Experiment(api_key=self.config.exp.comet_api_key,
                                    project_name=self.config.exp.name)
            experiment.disable_mp()
            experiment.log_multiple_params(self.config)
            self.callbacks.append(experiment.get_keras_callback())
Ejemplo n.º 10
0
    def init_callbacks(self):
        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(self.config.callbacks.checkpoint_dir,
                                      'best_model.hdf5'),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                save_best_only=self.config.callbacks.checkpoint_save_best_only,
                save_weights_only=self.config.callbacks.
                checkpoint_save_weights_only,
                verbose=self.config.callbacks.checkpoint_verbose,
            ))

        self.callbacks.append(
            EarlyStopping(monitor='val_loss', patience=10, verbose=1))

        self.callbacks.append(
            AdvancedLearnignRateScheduler(monitor='val_loss',
                                          patience=5,
                                          verbose=1,
                                          mode='auto',
                                          warmup_batches=10,
                                          decayRatio=0.1))

        self.callbacks.append(
            TensorBoard(
                log_dir=self.config.callbacks.tensorboard_log_dir,
                write_graph=self.config.callbacks.tensorboard_write_graph,
            ))

        # if hasattr(self.config,"comet_api_key"):
        if ("comet_api_key" in self.config):
            from comet_ml import Experiment
            experiment = Experiment(api_key=self.config.comet_api_key,
                                    project_name=self.config.exp_name)
            experiment.disable_mp()
            experiment.log_parameters(self.config["args"])
            self.callbacks.append(experiment.get_callback('keras'))
Ejemplo n.º 11
0
class SingletonObject:
    __instance = None

    @staticmethod
    def getInstance():
        """ Static access method. """
        if SingletonObject.__instance == None:
            SingletonObject()
        return SingletonObject.__instance

    def __init__(self, disabled=False):
        """ Virtually private constructor. """
        if SingletonObject.__instance != None:
            raise Exception("This class is a singleton!")
        else:
            SingletonObject.__instance = self
            self.init_comet(disabled)

    def init_comet(self, disabled):
        """
    init comet object
    :param disabled:
    :return:
    """
        self.comet_ml_experiment = Experiment(
            api_key="S3mM1eMq6NumMxk2QJAXASkUM",
            project_name="nmt",
            workspace="ttpro1995",
            auto_output_logging="simple",
            disabled=disabled)

    def get_comet_ml_experiment(self):
        return self.comet_ml_experiment

    def disable_comet(self):
        self.comet_ml_experiment.disable_mp()
Ejemplo n.º 12
0
    def init_callbacks(self):
        # Stops training if accuracy does not change at least 0.005 over 10 epochs
        # self.callbacks.append(
        #     EarlyStopping(monitor='acc', min_delta=.005, patience=10, verbose=1, mode='auto')
        # )

        self.callbacks.append(
            TensorBoard(
                log_dir=self.config.callbacks.tensorboard_log_dir,
                write_graph=self.config.callbacks.tensorboard_write_graph,
            ))

        self.callbacks.append(
            ModelCheckpoint(
                filepath=os.path.join(
                    self.config.callbacks.checkpoint_dir,
                    '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %
                    self.config.exp.name),
                monitor=self.config.callbacks.checkpoint_monitor,
                mode=self.config.callbacks.checkpoint_mode,
                save_best_only=self.config.callbacks.checkpoint_save_best_only,
                save_weights_only=self.config.callbacks.
                checkpoint_save_weights_only,
                verbose=self.config.callbacks.checkpoint_verbose,
            ))

        # log experiments to comet.ml
        if hasattr(self.config.api, "comet"):
            from comet_ml import Experiment
            experiment = Experiment(
                api_key=self.config.api.comet.api_key,
                project_name=self.config.api.comet.exp_name)
            experiment.disable_mp()
            experiment.log_parameters(self.config.toDict())
            self.experiment_id = experiment.id
            self.callbacks.append(experiment.get_callback('keras'))
class CometML:
    def __init__(self,
                 api_key,
                 project_name,
                 workspace,
                 debug=True,
                 tags=None):
        self._exp = Experiment(
            api_key=api_key,
            project_name=project_name,
            workspace=workspace,
            disabled=debug,
        )
        if not (self._exp.alive or debug):
            raise RuntimeError("Cannot connect to Comet ML")
        self._exp.disable_mp()

        if tags is not None:
            self._exp.add_tags(tags)

    @property
    def run_name(self):
        return self._exp.get_key()

    def args(self, arg_text):
        self._exp.log_parameter("cmd args", arg_text)

    def meta(self, params):
        self._exp.log_parameters(params)

    def log(self, name, value, step):
        self._exp.log_metric(
            name=name,
            value=value,
            step=step,
        )
Ejemplo n.º 14
0
class Logger:
    """
    Logs/plots results to comet.

    Args:
        exp_config (dict): experiment configuration hyperparameters
        model_config (dict): model configuration hyperparameters
        data_config (dict): data configuration hyperparameters
    """
    def __init__(self, exp_config, model_config, data_config):
        self.exp_config = exp_config
        self.experiment = Experiment(**exp_config['comet_config'])
        self.experiment.disable_mp()
        self._log_hyper_params(exp_config, model_config, data_config)
        self._epoch = 0

    def _log_hyper_params(self, exp_config, model_config, data_config):
        """
        Log the hyper-parameters for the experiment.

        Args:
            exp_config (dict): experiment configuration hyperparameters
            model_config (dict): model configuration hyperparameters
            data_config (dict): data configuration hyperparameters
        """
        def flatten_arg_dict(arg_dict):
            flat_dict = {}
            for k, v in arg_dict.items():
                if type(v) == dict:
                    flat_v = flatten_arg_dict(v)
                    for kk, vv in flat_v.items():
                        flat_dict[k + '_' + kk] = vv
                else:
                    flat_dict[k] = v
            return flat_dict

        self.experiment.log_parameters(flatten_arg_dict(exp_config))
        self.experiment.log_parameters(flatten_arg_dict(model_config))
        self.experiment.log_parameters(flatten_arg_dict(data_config))

    def log(self, results, train_val):
        """
        Plot the results in comet.

        Args:
            results (dict): dictionary of metrics to plot
            train_val (str): either 'train' or 'val'
        """
        objectives, grads, params, images, metrics = results
        for metric_name, metric in objectives.items():
            self.experiment.log_metric(metric_name + '_' + train_val, metric,
                                       self._epoch)
            print(metric_name, ':', metric.item())
        if train_val == 'train':
            for grad_metric_name, grad_metric in grads.items():
                self.experiment.log_metric('grads_' + grad_metric_name,
                                           grad_metric, self._epoch)
        for param_name, param in params.items():
            self.experiment.log_metric(param_name + '_' + train_val, param,
                                       self._epoch)
        for image_name, imgs in images.items():
            self.plot_images(imgs, image_name, train_val)
        for metric_name, metric in metrics.items():
            self.experiment.log_metric(metric_name + '_' + train_val, metric,
                                       self._epoch)
        if train_val == 'val':
            self._epoch += 1

    def plot_images(self, images, title, train_val):
        """
        Plot a tensor of images.

        Args:
            images (torch.Tensor): a tensor of shape [steps, b, c, h, w]
            title (str): title for the images, e.g. reconstructions
            train_val (str): either 'train' or 'val'
        """
        # add a channel dimension if necessary
        if len(images.shape) == 4:
            s, b, h, w = images.shape
            images = images.view(s, b, 1, h, w)
        s, b, c, h, w = images.shape
        if b > 10:
            images = images[:, :10]
        # swap the steps and batch dimensions
        images = images.transpose(0, 1).contiguous()
        images = images.view(-1, c, h, w)
        # grid = make_grid(images.clamp(0, 1), nrow=s).numpy()
        grid = make_grid(images, nrow=s).numpy()
        if c == 1:
            grid = grid[0]
            cmap = 'gray'
        else:
            grid = np.transpose(grid, (1, 2, 0))
            cmap = None
        plt.imshow(grid, cmap=cmap)
        plt.axis('off')
        self.experiment.log_figure(figure=plt,
                                   figure_name=title + '_' + train_val)
        plt.close()

    def save(self, model):
        """
        Save the model weights in comet.

        Args:
            model (nn.Module): the model to be saved
        """
        if self._epoch % self.exp_config['checkpoint_interval'] == 0:
            print('Checkpointing the model...')
            state_dict = model.state_dict()
            cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
            # save the state dictionary
            ckpt_path = os.path.join('./ckpt_epoch_' + str(self._epoch) +
                                     '.ckpt')
            torch.save(cpu_state_dict, ckpt_path)
            self.experiment.log_asset(ckpt_path)
            os.remove(ckpt_path)
            print('Done.')

    def load(self, model):
        """
        Load the model weights.
        """
        assert self.exp_config[
            'checkpoint_exp_key'] is not None, 'Checkpoint experiment key must be set.'
        print('Loading checkpoint from ' +
              self.exp_config['checkpoint_exp_key'] + '...')
        comet_api = comet_ml.papi.API(
            rest_api_key=self.exp_config['rest_api_key'])
        exp = comet_api.get_experiment(
            workspace=self.exp_config['comet_config']['workspace'],
            project_name=self.exp_config['comet_config']['project_name'],
            experiment=self.exp_config['checkpoint_exp_key'])
        # asset_list = comet_api.get_experiment_asset_list(self.exp_config['checkpoint_exp_key'])
        asset_list = exp.get_asset_list()
        # get most recent checkpoint
        ckpt_assets = [
            asset for asset in asset_list if 'ckpt' in asset['fileName']
        ]
        asset_times = [asset['createdAt'] for asset in ckpt_assets]
        asset = asset_list[asset_times.index(max(asset_times))]
        print('Checkpoint Name:', asset['fileName'])
        ckpt = exp.get_asset(asset['assetId'])
        state_dict = torch.load(io.BytesIO(ckpt))
        model.load(state_dict)
        print('Done.')
Ejemplo n.º 15
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
    )
    parser.add_argument(
        "--model_type",
        default=None,
        type=str,
        required=True,
        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
    )
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
    )
    parser.add_argument(
        "--task_name",
        default=None,
        type=str,
        required=True,
        help="The name of the task to train selected in the list: " + ", ".join(processors.keys()),
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    # Other parameters
    parser.add_argument(
        "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
    )
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name",
    )
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from s3",
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.",
    )
    parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
    parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step.",
    )
    parser.add_argument(
        "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
    )

    parser.add_argument(
        "--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument(
        "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")

    parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
    parser.add_argument("--save_steps", type=int, default=50, help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action="store_true",
        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
    parser.add_argument(
        "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
    )
    parser.add_argument(
        "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
    )
    parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")

    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
    parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
    parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
    parser.add_argument('--adv-lr', type=float, default=0)
    parser.add_argument('--adv-steps', type=int, default=1, help="should be at least 1")
    parser.add_argument('--adv-init-mag', type=float, default=0)
    parser.add_argument('--norm-type', type=str, default="l2", choices=["l2", "linf"])
    parser.add_argument('--adv-max-norm', type=float, default=0, help="set to 0 to be unlimited")
    parser.add_argument('--gpu', type=str, default="0")
    parser.add_argument('--expname', type=str, default="default")
    parser.add_argument('--comet', default=False, action="store_true")
    parser.add_argument('--comet_key', default="", type=str)
    parser.add_argument('--hidden_dropout_prob', type=float, default=0.1)
    parser.add_argument('--attention_probs_dropout_prob', type=float, default=0)
    args = parser.parse_args()

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    if args.comet:
        experiment = Experiment(api_key=args.comet_key,
                                project_name="pytorch-freelb", workspace="NLP",
                                auto_param_logging=False, auto_metric_logging=False,
                                parse_args=True, auto_output_logging=True
                                )
        experiment.disable_mp()  # Turn off monkey patching
        experiment.log_parameters(vars(args))
        experiment.set_name(args.expname)
    else:
        experiment = None


    assert args.adv_steps >= 1

    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    # Setup distant debugging if needed
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd

        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    if args.task_name not in processors:
        raise ValueError("Task not found: %s" % (args.task_name))
    processor = processors[args.task_name]()
    args.output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None,
        attention_probs_dropout_prob=args.attention_probs_dropout_prob,
        hidden_dropout_prob=args.hidden_dropout_prob
    )
    tokenizer = tokenizer_class.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    model = model_class.from_pretrained(
        args.model_name_or_path,
        from_tf=bool(".ckpt" in args.model_name_or_path),
        config=config,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    if args.local_rank == 0:
        torch.distributed.barrier()  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)

    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
        global_step, tr_loss = train(args, train_dataset, model, tokenizer, experiment=experiment)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # Create output directory if needed
        if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        # Load a trained model and vocabulary that you have fine-tuned
        model = model_class.from_pretrained(args.output_dir)
        tokenizer = tokenizer_class.from_pretrained(args.output_dir)
        model.to(args.device)

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            model = model_class.from_pretrained(checkpoint)
            model.to(args.device)
            result = evaluate(args, model, tokenizer, prefix=prefix, global_step=global_step, experiment=experiment)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

    return results
Ejemplo n.º 16
0
                        default=1,
                        help="interval evaluations on validation set")
    parser.add_argument("--compute_map",
                        default=False,
                        help="if True computes mAP every tenth batch")
    parser.add_argument("--multiscale_training",
                        default=True,
                        help="allow for multi-scale training")
    opt = parser.parse_args()
    print(opt)

    logger = Logger("logs")
    experiment = Experiment(api_key="hs2nruoKow2CnUKisoeHccvh7",
                            project_name="yolo",
                            workspace="terbed")
    experiment.disable_mp()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    os.makedirs("output", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    # Get data configuration
    data_config = parse_data_config(opt.data_config)
    train_path = data_config["train"]
    valid_path = data_config["valid"]
    class_names = load_classes(data_config["names"])

    # Initiate model
    model = Darknet(opt.model_def).to(device)
    model.apply(weights_init_normal)
Ejemplo n.º 17
0
class Logger:
    """
    Logs/plots results to comet.

    Args:
        exp_config (dict): experiment configuration hyperparameters
        model_config (dict): model configuration hyperparameters
        data_config (dict): data configuration hyperparameters
    """
    def __init__(self, exp_config, model_config, data_config):
        self.experiment = Experiment(**exp_config['comet_config'])
        self.experiment.disable_mp()
        self._log_hyper_params(exp_config, model_config, data_config)
        self._epoch = 0

    def _log_hyper_params(self, exp_config, model_config, data_config):
        """
        Log the hyper-parameters for the experiment.

        Args:
            exp_config (dict): experiment configuration hyperparameters
            model_config (dict): model configuration hyperparameters
            data_config (dict): data configuration hyperparameters
        """
        def flatten_arg_dict(arg_dict):
            flat_dict = {}
            for k, v in arg_dict.items():
                if type(v) == dict:
                    flat_v = flatten_arg_dict(v)
                    for kk, vv in flat_v.items():
                        flat_dict[k + '_' + kk] = vv
                else:
                    flat_dict[k] = v
            return flat_dict

        self.experiment.log_parameters(flatten_arg_dict(exp_config))
        self.experiment.log_parameters(flatten_arg_dict(model_config))
        self.experiment.log_parameters(flatten_arg_dict(data_config))

    def log(self, results, train_val):
        """
        Plot the results in comet.

        Args:
            results (dict): dictionary of metrics to plot
            train_val (str): either 'train' or 'val'
        """
        objectives, grads, params, images, metrics = results
        for metric_name, metric in objectives.items():
            self.experiment.log_metric(metric_name + '_' + train_val, metric,
                                       self._epoch)
            print(metric_name, ':', metric.item())
        if train_val == 'train':
            for grad_metric_name, grad_metric in grads.items():
                self.experiment.log_metric('grads_' + grad_metric_name,
                                           grad_metric, self._epoch)
        for param_name, param in params.items():
            self.experiment.log_metric(param_name + '_' + train_val, param,
                                       self._epoch)
        for image_name, imgs in images.items():
            self.plot_images(imgs, image_name, train_val)
        for metric_name, metric in metrics.items():
            self.experiment.log_metric(metric_name + '_' + train_val, metric,
                                       self._epoch)
        if train_val == 'val':
            self._epoch += 1

    def plot_images(self, images, title, train_val):
        """
        Plot a tensor of images.

        Args:
            images (torch.Tensor): a tensor of shape [steps, b, c, h, w]
            title (str): title for the images, e.g. reconstructions
            train_val (str): either 'train' or 'val'
        """
        # add a channel dimension if necessary
        if len(images.shape) == 4:
            s, b, h, w = images.shape
            images = images.view(s, b, 1, h, w)
        s, b, c, h, w = images.shape
        if b > 10:
            images = images[:, :10]
        # swap the steps and batch dimensions
        images = images.transpose(0, 1).contiguous()
        images = images.view(-1, c, h, w)
        grid = make_grid(images.clamp(0, 1), nrow=s).numpy()
        if c == 1:
            grid = grid[0]
            cmap = 'gray'
        else:
            grid = np.transpose(grid, (1, 2, 0))
            cmap = None
        plt.imshow(grid, cmap=cmap)
        self.experiment.log_figure(figure=plt,
                                   figure_name=title + '_' + train_val)
        plt.close()

    def save(self, model):
        """
        Save the model weights in comet.

        Args:
            model (nn.Module): the model to be saved
        """
        pass
Ejemplo n.º 18
0
def fit_test(exp_params, data_path, k, write_path, others=None, custom_tag=''):
    """Fit model and compute metrics on both train and test sets.

    Also log plot and embeddings to comet.

    Args:
        exp_params(dict): Parameter dict. Should at least have keys model_name, dataset_name & random_state. Other
        keys are assumed to be model parameters.
        k(int): Fold identifier.
        data_path(str): Data directory.
        write_path(str): Where temp files can be written.
        others(dict): Other things to log to Comet experiment.
        custom_tag(str): Custom tag for Comet experiment.

    """
    # Increment fold to avoid reusing validation seeds
    k += 10

    # Comet experiment
    exp = Experiment(parse_args=False)
    exp.disable_mp()
    custom_tag += '_test'
    exp.add_tag(custom_tag)
    exp.log_parameters(exp_params)

    if others is not None:
        exp.log_others(others)

    # Parse experiment parameters
    model_name, dataset_name, random_state, model_params = parse_params(exp_params)

    # Fetch and split dataset.
    data_train_full = getattr(grae.data, dataset_name)(split='train', random_state=random_state, data_path=data_path)
    data_test = getattr(grae.data, dataset_name)(split='test', random_state=random_state, data_path=data_path)

    if model_name == 'PCA':
        # No validation split on PCA
        data_train, data_val = data_train_full, None
    else:
        data_train, data_val = data_train_full.validation_split(random_state=FOLD_SEEDS[k])

    # Model
    m = getattr(grae.models, model_name)(random_state=FOLD_SEEDS[k], **model_params)
    m.comet_exp = exp  # Used by DL models to log metrics between epochs
    m.write_path = write_path
    m.data_val = data_val  # For early stopping

    # Benchmark fit time
    fit_start = time.time()

    m.fit(data_train)

    fit_stop = time.time()

    fit_time = fit_stop - fit_start

    # Log plots
    m.plot(data_train, data_test, title=f'{model_name}_{dataset_name}')
    if dataset_name in ['Faces', 'RotatedDigits', 'UMIST', 'Tracking', 'COIL100', 'Teapot']:
        m.view_img_rec(data_train, choice='random', title=f'{model_name}_{dataset_name}_train_rec')
        m.view_img_rec(data_test, choice='best', title=f'{model_name}_{dataset_name}_test_rec_best')
        m.view_img_rec(data_test, choice='worst', title=f'{model_name}_{dataset_name}_test_rec_worst')
    elif dataset_name in ['ToroidalHelices', 'Mammoth'] or 'SwissRoll' in dataset_name:
        m.view_surface_rec(data_train, title=f'{model_name}_{dataset_name}_train_rec', dataset_name=dataset_name)
        m.view_surface_rec(data_test, title=f'{model_name}_{dataset_name}_test_rec', dataset_name=dataset_name)

    # Score models
    prober = EmbeddingProber()
    prober.fit(model=m, dataset=data_train_full)

    with exp.train():
        train_z, train_metrics = prober.score(data_train_full)
        _, train_y = data_train_full.numpy()

        # Log train metrics
        exp.log_metric(name='fit_time', value=fit_time)
        exp.log_metrics(train_metrics)

    with exp.test():
        test_z, test_metrics = prober.score(data_test)
        _, test_y = data_test.numpy()

        # Log train metrics
        exp.log_metrics(test_metrics)

    # Log embedding as .npy file
    file_name = os.path.join(write_path, f'emb_{model_name}_{dataset_name}.npy')
    save_dict(dict(train_z=train_z,
                   train_y=train_y,
                   test_z=test_z,
                   test_y=test_y,
                   random_state=random_state,
                   dataset_name=dataset_name,
                   model_name=model_name),
              file_name)
    file = open(file_name, 'rb')
    exp.log_asset(file, file_name=file_name)
    file.close()
    os.remove(file_name)

    # Log marker to mark successful experiment
    exp.log_other('success', 1)
Ejemplo n.º 19
0
class Plotter:
    """
    Handles plotting and logging to comet.

    Args:
        exp_args (args.parse_args): arguments for the experiment
        agent_args (dict): arguments for the agent
        agent (Agent): the agent
    """
    def __init__(self, exp_args, agent_args, agent):
        self.exp_args = exp_args
        self.agent_args = agent_args
        self.agent = agent
        self.experiment = None
        if self.exp_args.plotting:
            self.experiment = Experiment(api_key=LOGGING_API_KEY,
                                         project_name=PROJECT_NAME,
                                         workspace=WORKSPACE)
            self.experiment.disable_mp()
            self.experiment.log_parameters(get_arg_dict(exp_args))
            self.experiment.log_parameters(flatten_arg_dict(agent_args))
            self.experiment.log_asset_data(json.dumps(get_arg_dict(exp_args)), name='exp_args')
            self.experiment.log_asset_data(json.dumps(agent_args), name='agent_args')
            if self.exp_args.checkpoint_exp_key is not None:
                self.load_checkpoint()
        self.result_dict = None
        # keep a hard-coded list of returns in case Comet fails
        self.returns = []

    def _plot_ts(self, key, observations, statistics, label, color):
        dim_obs = min(observations.shape[1], 9)
        k = 1
        for i in range(dim_obs):
            plt.subplot(int(str(dim_obs) + '1' + str(k)))
            observations_i = observations[:-1, i].cpu().numpy()
            if key == 'action' and self.agent.postprocess_action:
                observations_i = np.tanh(observations_i)
            plt.plot(observations_i.squeeze(), 'o', label='observation', color='k', markersize=2)
            if len(statistics) == 1:  # Bernoulli distribution
                probs = statistics['probs']
                probs = probs.cpu().numpy()
                plt.plot(probs, label=label, color=color)
            elif len(statistics) == 2:
                if 'loc' in statistics:
                    # Normal distribution
                    mean = statistics['loc']
                    std = statistics['scale']
                    mean = mean[:, i].cpu().numpy()
                    std = std[:, i].cpu().numpy()
                    mean = mean.squeeze()
                    std = std.squeeze()
                    x, plus, minus = mean, mean + std, mean - std
                    if key == 'action' and label == 'approx_post' and self.agent_args['approx_post_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']:
                        # Tanh Normal distribution
                        x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus)
                    if key == 'action' and label == 'direct_approx_post' and self.agent_args['approx_post_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']:
                        # Tanh Normal distribution
                        x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus)
                    if key == 'action' and label == 'prior' and self.agent_args['prior_args']['dist_type'] in ['TanhNormal', 'TanhARNormal']:
                        # Tanh Normal distribution
                        x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus)
                    if key == 'action' and self.agent.postprocess_action:
                        x, plus, minus = np.tanh(x), np.tanh(plus), np.tanh(minus)
                    if key == 'action' and label == 'prior' and self.agent_args['prior_args']['dist_type'] == 'NormalUniform':
                        # Normal + Uniform distribution
                        x, plus, minus = x, np.minimum(plus, 1.), np.maximum(minus, -1)
                elif 'low' in statistics:
                    # Uniform distribution
                    low = statistics['low'][:, i].cpu().numpy()
                    high = statistics['high'][:, i].cpu().numpy()
                    x = low + (high - low) / 2
                    plus, minus = x + high, x + low
                else:
                    raise NotImplementedError
                plt.plot(x, label=label, color=color)
                plt.fill_between(np.arange(len(x)), plus, minus, color=color, alpha=0.2, label=label)
            else:
                NotImplementedError
            k += 1

    def plot_states_and_rewards(self, states, rewards, step):
        """
        Plots the states and rewards for a collected episode.
        """
        # states
        plt.figure()
        dim_obs = states.shape[1]
        for i in range(dim_obs):
            plt.subplot(dim_obs, 1, i+1)
            states_i = states[:-1, i].cpu().numpy()
            plt.plot(states_i.squeeze(), 'o', label='state', color='k', markersize=2)
        self.experiment.log_figure(figure=plt, figure_name='states_ts_'+str(step))
        plt.close()

        # rewards
        plt.figure()
        rewards = rewards[:-1, 0].cpu().numpy()
        plt.plot(rewards.squeeze(), 'o', label='reward', color='k', markersize=2)
        self.experiment.log_figure(figure=plt, figure_name='rewards_ts_'+str(step))
        plt.close()

    def plot_episode(self, episode, step):
        """
        Plots a newly collected episode.
        """
        if self.exp_args.plotting:
            self.experiment.log_metric('cumulative_reward', episode['reward'].sum(), step)

            def merge_legends():
                handles, labels = plt.gca().get_legend_handles_labels()
                newLabels, newHandles = [], []
                for handle, label in zip(handles, labels):
                    if label not in newLabels:
                        newLabels.append(label)
                        newHandles.append(handle)

                plt.legend(newHandles, newLabels)

            for k in episode['distributions'].keys():
                for i, l in enumerate(episode['distributions'][k].keys()):
                    color = COLORS[i]
                    self._plot_ts(k, episode[k], episode['distributions'][k][l], l, color)
                plt.suptitle(k)
                merge_legends()
                self.experiment.log_figure(figure=plt, figure_name=k + '_ts_'+str(step))
                plt.close()

            self.plot_states_and_rewards(episode['state'], episode['reward'], step)

    def log_eval(self, episode, eval_states, step):
        """
        Plots an evaluation episode performance. Logs the episode.

        Args:
            episode (dict): dictionary containing agent's collected episode
            eval_states (dict): dictionary of MuJoCo simulator states
            step (int): the current step number in training
        """
        # plot and log eval returns
        eval_return = episode['reward'].sum()
        print(' Eval. Return at Step ' + str(step) + ': ' + str(eval_return.item()))
        self.returns.append(eval_return.item())
        if self.exp_args.plotting:
            self.experiment.log_metric('eval_cumulative_reward', eval_return, step)
            json_str = json.dumps(self.returns)
            self.experiment.log_asset_data(json_str, name='eval_returns', overwrite=True)

            # log the episode itself
            for ep_item_str in ['state', 'action', 'reward']:
                ep_item = episode[ep_item_str].tolist()
                json_str = json.dumps(ep_item)
                item_name = 'episode_step_' + str(step) + '_' + ep_item_str
                self.experiment.log_asset_data(json_str, name=item_name)

            # log the MuJoCo simulator states
            for sim_item_str in ['qpos', 'qvel']:
                if len(eval_states[sim_item_str]) > 0:
                    sim_item = eval_states[sim_item_str].tolist()
                    json_str = json.dumps(sim_item)
                    item_name = 'episode_step_' + str(step) + '_' + sim_item_str
                    self.experiment.log_asset_data(json_str, name=item_name)

    def plot_agent_kl(self, agent_kl, step):
        if self.exp_args.plotting:
            self.experiment.log_metric('agent_kl', agent_kl, step)

    def log_results(self, results):
        """
        Log the results dictionary.
        """
        if self.result_dict is None:
            self.result_dict = {}
        for k, v in flatten_arg_dict(results).items():
            if k not in self.result_dict:
                self.result_dict[k] = [v]
            else:
                self.result_dict[k].append(v)

    def plot_results(self, timestep):
        """
        Plot/log the results to Comet.
        """
        if self.exp_args.plotting:
            for k, v in self.result_dict.items():
                avg_value = np.mean(v)
                self.experiment.log_metric(k, avg_value, timestep)
        self.result_dict = None

    def plot_model_eval(self, episode, predictions, log_likelihoods, step):
        """
        Plot/log the results from model evaluation.

        Args:
            episode (dict): a collected episode
            predictions (dict): predictions from each state, containing [n_steps, horizon, n_dims]
            log_likelihoods (dict): log-likelihood evaluations of predictions, containing [n_steps, horizon, 1]
        """
        if self.exp_args.plotting:
            for variable, lls in log_likelihoods.items():
                # average the log-likelihood estimates and plot the result at the horizon length
                mean_ll = lls[:, -1].mean().item()
                self.experiment.log_metric(variable + '_pred_log_likelihood', mean_ll, step)
                # plot log-likelihoods as a function of rollout step
                plt.figure()
                mean = lls.mean(dim=0).view(-1)
                std = lls.std(dim=0).view(-1)
                plt.plot(mean.numpy())
                lower = mean - std
                upper = mean + std
                plt.fill_between(np.arange(lls.shape[1]), lower.numpy(), upper.numpy(), alpha=0.2)
                plt.xlabel('Rollout Step')
                plt.ylabel('Prediction Log-Likelihood')
                plt.xticks(np.arange(lls.shape[1]))
                self.experiment.log_figure(figure=plt, figure_name=variable + '_pred_ll_' + str(step))
                plt.close()

            # plot predictions vs. actual values for an arbitrary time step
            time_step = np.random.randint(predictions['state']['loc'].shape[0])
            for variable, preds in predictions.items():
                pred_loc = preds['loc'][time_step]
                pred_scale = preds['scale'][time_step]
                x = episode[variable][time_step+1:time_step+1+pred_loc.shape[0]]
                plt.figure()
                horizon, n_dims = pred_loc.shape
                for plot_num in range(n_dims):
                    plt.subplot(n_dims, 1, plot_num + 1)
                    plt.plot(pred_loc[:, plot_num].numpy())
                    lower = pred_loc[:, plot_num] - pred_scale[:, plot_num]
                    upper = pred_loc[:, plot_num] + pred_scale[:, plot_num]
                    plt.fill_between(np.arange(horizon), lower.numpy(), upper.numpy(), alpha=0.2)
                    plt.plot(x[:, plot_num].numpy(), '.')
                plt.xlabel('Rollout Step')
                plt.xticks(np.arange(horizon))
                self.experiment.log_figure(figure=plt, figure_name=variable + '_pred_' + str(step))
                plt.close()

    def save_checkpoint(self, step):
        """
        Checkpoint the model by getting the state dictionary for each component.
        """
        if self.exp_args.plotting:
            print('Checkpointing the agent...')
            state_dict = self.agent.state_dict()
            cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()}
            ckpt_path = os.path.join('./ckpt_step_'+ str(step) + '.ckpt')
            torch.save(cpu_state_dict, ckpt_path)
            self.experiment.log_asset(ckpt_path)
            os.remove(ckpt_path)
            print('Done.')

    def load_checkpoint(self, timestep=None):
        """
        Loads a checkpoint from Comet.

        Args:
            timestep (int, optional): the checkpoint timestep, default is latest
        """
        load_checkpoint(self.agent, self.exp_args.checkpoint_exp_key, timestep)