예제 #1
0

def train(net, sample, epoch):

    hidden = net.initHidden()

    net.zero_grad()

    for i in range(sample[0].shape[0]):
        output, hidden = net(sample[0][i, :], hidden)

    optimizer.zero_grad()
    loss = criterion(output, sample[1][-1, None, None].to(torch.float))
    loss.backward()
    optimizer.step()

    tb.add_scalar('loss', loss, epoch)
    tb.add_histogram('i2h.weight', net.i2h.weight, epoch)

    i += 1

    return output, loss.item()


i = 0
for s in tqdm(samples):
    train(net, s, i)
    i = i + 1

tb.flush()
tb.close()
예제 #2
0
            summ = SummaryWriter(log_val_str)

            # Compute metrics at different thresholds
            val_metric, metrics, _ = eval_proced(preds, true, vd_high_idxs,
                                                 vd_low_idxs, 'val')

            if val_metric > best_value:
                print('New best model found')
                best_value = val_metric
                best_config = config

            # Logging hyperparams and metrics
            hparams = {**config, 'trait': trait, 'seed': seed}
            summ.add_hparams(hparams, metrics)
            summ.flush()

        # Testing the best configuration
        A_test = sp.csr_matrix(sp.vstack((sp_tr_data, sp_te_tr_data)))
        B_test = EASE(A_test, best_config['lam'])

        Atild_test = sp.csr_matrix(A_test.dot(B_test))
        Atild_test = Atild_test[sp_tr_data.shape[0]:, :]
        # Removing entries from training data
        Atild_test[sp_te_tr_data.nonzero()] = .0

        preds = Atild_test.toarray()
        true = sp_te_te_data.toarray()

        summ = SummaryWriter(log_te_str)
예제 #3
0
class TensorBoardCallback(TrainerCallback):
    """
    A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
    <https://www.tensorflow.org/tensorboard>`__.

    Args:
        tb_writer (:obj:`SummaryWriter`, `optional`):
            The writer to use. Will instantiate one if not set.
    """

    def __init__(self, tb_writer=None):
        assert (
            _has_tensorboard
        ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
        self.tb_writer = tb_writer

    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
        self.tb_writer = SummaryWriter(log_dir=log_dir)

    def on_train_begin(self, args, state, control, **kwargs):
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

        self._init_summary_writer(args, log_dir)

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})

    def on_log(self, args, state, control, logs=None, **kwargs):
        if state.is_world_process_zero:
            if self.tb_writer is None:
                self._init_summary_writer(args)

        if self.tb_writer:
            logs = rewrite_logs(logs)
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
예제 #4
0
    def fit(self,
            series: TimeSeries,
            val_series: Optional[TimeSeries] = None,
            verbose: bool = False) -> None:
        """
        :param series: The training time series
        :param val_series: Optionally, a validation time series that will
                           be used to compute validation loss throughout training
        """

        super().fit(series)

        if self.from_scratch:
            shutil.rmtree(_get_checkpoint_folder(self.work_dir,
                                                 self.model_name),
                          ignore_errors=True)

        if self.batch_size is None:
            self.batch_size = len(series) // 10
            print('No batch size set. Using: {}'.format(self.batch_size))

        # Prepare training data:
        dataset = TimeSeriesDataset1D(series, self.seq_len, self.output_length)
        train_loader = DataLoader(dataset,
                                  batch_size=self.batch_size,
                                  shuffle=True,
                                  num_workers=0,
                                  pin_memory=True,
                                  drop_last=True)
        raise_if_not(
            len(train_loader) > 0,
            'The provided training time series is too short for obtaining even one training point.',
            logger)

        # Prepare validation data:
        if val_series is not None:
            val_dataset = TimeSeriesDataset1D(val_series, self.seq_len,
                                              self.output_length)
            val_loader = DataLoader(val_dataset,
                                    batch_size=self.batch_size,
                                    shuffle=False,
                                    num_workers=0,
                                    pin_memory=True,
                                    drop_last=False)
            raise_if_not(
                len(val_dataset) > 0 and len(val_loader) > 0,
                'The provided validation time series is too short for this model output length.',
                logger)
        else:
            val_loader = None

        # Tensorboard
        runs_folder = _get_runs_folder(self.work_dir, self.model_name)
        if self.log_tensorboard:
            if self.from_scratch:
                shutil.rmtree(runs_folder, ignore_errors=True)
                tb_writer = SummaryWriter(runs_folder)
                dummy_input = torch.empty(self.batch_size, self.seq_len,
                                          self.input_size).to(self.device)
                tb_writer.add_graph(self.model, dummy_input)
            else:
                tb_writer = SummaryWriter(runs_folder,
                                          purge_step=self.start_epoch)
        else:
            tb_writer = None

        self._train(train_loader, val_loader, tb_writer, verbose)

        if tb_writer is not None:
            tb_writer.flush()
            tb_writer.close()
예제 #5
0
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for 🤗 Transformers.

    Args:
        model (:class:`~transformers.PreTrainedModel`):
            The model to train, evaluate or use for predictions.
        args (:class:`~transformers.TrainingArguments`):
            The arguments to tweak training.
        data_collator (:obj:`DataCollator`, `optional`, defaults to :func:`~transformers.default_data_collator`):
            The function to use to from a batch from a list of elements of :obj:`train_dataset` or
            :obj:`eval_dataset`.
        train_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
            The dataset to use for training.
        eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
            The dataset to use for evaluation.
        compute_metrics (:obj:`Callable[[EvalPrediction], Dict]`, `optional`):
            The function that will be used to compute metrics at evaluation. Must take a
            :class:`~transformers.EvalPrediction` and return a dictionary string to metric values.
        prediction_loss_only (:obj:`bool`, `optional`, defaults to `False`):
            When performing evaluation and predictions, only returns the loss.
        tb_writer (:obj:`SummaryWriter`, `optional`):
            Object to write to TensorBoard.
        optimizers (:obj:`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR`, `optional`):
            A tuple containing the optimizer and the scheduler to use. Will default to an instance of
            :class:`~transformers.AdamW` on your model and a scheduler given by
            :func:`~transformers.get_linear_schedule_with_warmup` controlled by :obj:`args`.
    """

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        tokenizer = None
    ):
        self.model = model.to(args.device)
        self.args = args
        self.data_collator = data_collator if data_collator is not None else default_data_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizer, self.lr_scheduler = optimizers
        self.tb_writer = tb_writer
        self.tokenizer = tokenizer
        if tb_writer is None and is_tensorboard_available() and self.is_world_process_zero():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self.setup_wandb()
        elif os.environ.get("WANDB_DISABLED") != "true":
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        if is_comet_available():
            self.setup_comet()
        elif os.environ.get("COMET_MODE") != "DISABLED":
            logger.info(
                "To use comet_ml logging, run `pip/conda install comet_ml` "
                "see https://www.comet.ml/docs/python-sdk/huggingface/"
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_process_zero():
            os.makedirs(self.args.output_dir, exist_ok=True)
        if is_torch_tpu_available():
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )
        self.global_step = None
        self.epoch = None
        if self.args.fp16 and _use_native_amp:
            self.scaler = torch.cuda.amp.GradScaler()

    def _get_train_sampler(self) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(self.train_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return get_tpu_sampler(self.train_dataset)
        else:
            return (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.train_dataset` is a :obj:`torch.utils.data.IterableDataset`, a random sampler
        (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

    def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.sampler.Sampler]:
        if isinstance(eval_dataset, torch.utils.data.IterableDataset):
            return None
        elif is_torch_tpu_available():
            return SequentialDistributedSampler(eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
        elif self.args.local_rank != -1:
            return SequentialDistributedSampler(eval_dataset)
        else:
            return SequentialSampler(eval_dataset)

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.eval_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
                If provided, will override :obj:`self.eval_dataset`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        eval_sampler = self._get_eval_sampler(eval_dataset)

        return DataLoader(
            eval_dataset,
            sampler=eval_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`test_dataset` is a :obj:`torch.utils.data.IterableDataset`, a sequential
        sampler (adapted to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
                The test dataset to use.
        """
        test_sampler = self._get_eval_sampler(test_dataset)

        # We use the same batch_size as for eval.
        return DataLoader(
            test_dataset,
            sampler=test_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through :obj:`optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            no_decay = ["bias", "LayerNorm.weight"]
            optimizer_grouped_parameters = [
                {
                    "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                    "weight_decay": self.args.weight_decay,
                },
                {
                    "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                    "weight_decay": 0.0,
                },
            ]
            self.optimizer = AdamW(
                optimizer_grouped_parameters,
                lr=self.args.learning_rate,
                betas=(self.args.adam_beta1, self.args.adam_beta2),
                eps=self.args.adam_epsilon,
            )
        if self.lr_scheduler is None:
            self.lr_scheduler = get_linear_schedule_with_warmup(
                self.optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
            )

    def setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can subclass and override this method to customize the setup if needed. Find more information
        `here <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
        """
        if hasattr(self, "_setup_wandb"):
            warnings.warn(
                "The `_setup_wandb` method is deprecated and won't be called in a future version, define `setup_wandb` in your subclass.",
                FutureWarning,
            )
            return self._setup_wandb()

        if self.is_world_process_zero():
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
            wandb.init(
                project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
            )
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )

    def setup_comet(self):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE:
                (Optional): str - "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME:
                (Optional): str - Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY:
                (Optional): str - folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"

        For a number of configurable items in the environment,
        see `here <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__
        """
        if self.is_world_master():
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(self.model, framework="transformers")
                experiment._log_parameters(self.args, prefix="args/", framework="transformers")
                experiment._log_parameters(self.model.config, prefix="config/", framework="transformers")

    def num_examples(self, dataloader: DataLoader) -> int:
        """
        Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
        """
        return len(dataloader.dataset)

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        self.create_optimizer_and_scheduler(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            self.optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16 and _use_apex:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_process_zero()
        )
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_process_zero())
            else:
                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_process_zero())

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self.training_step(model, inputs)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        self.optimizer.step()

                    self.lr_scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            self.lr_scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else self.lr_scheduler.get_lr()[0]
                        )
                        logging_loss = tr_loss

                        self.log(logs)

                    if self.args.evaluate_during_training and self.global_step % self.args.eval_steps == 0:
                        self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
                        else:
                            assert model is self.model, f"Model {model} should be a reference to self.model"
                        # Save model checkpoint
                        output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)
                        self.tokenizer.save_vocabulary(os.path.join(output_dir, 'vocab.txt'))

                        if self.is_world_process_zero():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_process_zero():
                            torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug or self.args.debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.global_step, tr_loss / self.global_step)

    def log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        """
        Log :obj:`logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (:obj:`Dict[str, float]`):
                The values to log.
            iterator (:obj:`tqdm`, `optional`):
                A potential tqdm progress bar to write the logs on.
        """
        if hasattr(self, "_log"):
            warnings.warn(
                "The `_log` method is deprecated and won't be called in a future version, define `log` in your subclass.",
                FutureWarning,
            )
            return self._log(logs, iterator=iterator)

        # if self.epoch is not None:
        #     logs["epoch"] = self.epoch
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_process_zero():
                wandb.log(logs, step=self.global_step)
        if is_comet_available():
            if self.is_world_process_zero():
                experiment = comet_ml.config.get_global_experiment()
                if experiment is not None:
                    experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
        output = {**logs, **{"step": self.global_step}}
        if iterator is not None:
            iterator.write(output)
        else:
            print(output)

    def _prepare_inputs(
        self, inputs: Dict[str, Union[torch.Tensor, Any]], model: nn.Module
    ) -> Dict[str, Union[torch.Tensor, Any]]:
        """
        Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
        handling potential state.
        """
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)

        if self.args.past_index >= 0 and self._past is not None:
            inputs["mems"] = self._past

        return inputs

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> float:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
            :obj:`float`: The training loss on this batch.
        """
        if hasattr(self, "_training_step"):
            warnings.warn(
                "The `_training_step` method is deprecated and won't be called in a future version, define `training_step` in your subclass.",
                FutureWarning,
            )
            return self._training_step(model, inputs, self.optimizer)

        model.train()
        inputs = self._prepare_inputs(inputs, model)

        if self.args.fp16 and _use_native_amp:
            with autocast():
                outputs = model(**inputs)
                loss = outputs[0]
        else:
            outputs = model(**inputs)
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs[0]

        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16 and _use_native_amp:
            self.scaler.scale(loss).backward()
        elif self.args.fp16 and _use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

    def is_local_master(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_local_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_local_process_zero()` instead.", FutureWarning)
        return self.is_local_process_zero()

    def is_local_process_zero(self) -> bool:
        """
        Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on
        several machines) main process.
        """
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

    def is_world_master(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).

        .. warning::

            This method is deprecated, use :meth:`~transformers.Trainer.is_world_process_zero` instead.
        """
        warnings.warn("This method is deprecated, use `Trainer.is_world_process_zero()` instead.", FutureWarning)
        return self.is_world_process_zero()

    def is_world_process_zero(self) -> bool:
        """
        Whether or not this process is the global main process (when training in a distributed fashion on
        several machines, this is only going to be :obj:`True` for one process).
        """
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Will save the model, so you can reload it using :obj:`from_pretrained()`.

        Will only save from the world_master process (unless in TPUs).
        """

        if is_torch_tpu_available():
            self._save_tpu(output_dir)
        elif self.is_world_process_zero():
            self._save(output_dir)

    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`.

        Returns:
            A dictionary containing the evaluation loss and the potential metrics computed from the predictions.
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self.prediction_loop(eval_dataloader, description="Evaluation")

        self.log(output.metrics)

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def evaluate_return_all(self, eval_dataset: Optional[Dataset] = None):
        """
        Run evaluation and returns metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent (pass it to the init :obj:`compute_metrics` argument).

        You can also subclass and override this method to inject custom behavior.

        Args:
            eval_dataset (:obj:`Dataset`, `optional`):
                Pass a dataset if you wish to override :obj:`self.eval_dataset`.

        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self.prediction_loop(eval_dataloader, description="Evaluation")

        self.log(output.metrics)

        if self.args.tpu_metrics_debug or self.args.debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.predictions

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and returns predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in :obj:`evaluate()`.

        Args:
            test_dataset (:obj:`Dataset`):
                Dataset to run the predictions on.

        Returns:
            `NamedTuple`:
            predictions (:obj:`np.ndarray`):
                The predictions on :obj:`test_dataset`.
            label_ids (:obj:`np.ndarray`, `optional`):
                The labels (if the dataset contained some).
            metrics (:obj:`Dict[str, float]`, `optional`):
                The potential dictionary of metrics (if the dataset contained labels).
        """
        test_dataloader = self.get_test_dataloader(test_dataset)

        return self.prediction_loop(test_dataloader, description="Prediction")

    def prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if hasattr(self, "_prediction_loop"):
            warnings.warn(
                "The `_prediction_loop` method is deprecated and won't be called in a future version, define `prediction_loop` in your subclass.",
                FutureWarning,
            )
            return self._prediction_loop(dataloader, description, prediction_loss_only=prediction_loss_only)

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        label_ids: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        for inputs in tqdm(dataloader, desc=description):
            loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
            if loss is not None:
                eval_losses.append(loss)
            if logits is not None:
                preds = logits if preds is None else torch.cat((preds, logits), dim=0)
            if labels is not None:
                label_ids = labels if label_ids is None else torch.cat((label_ids, labels), dim=0)

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()

        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output

    def prediction_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
            A tuple with the loss, logits and labels (each being optional).
        """
        has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

        inputs = self._prepare_inputs(inputs, model)

        with torch.no_grad():
            outputs = model(**inputs)
            if has_labels:
                loss, logits = outputs[:2]
                loss = loss.mean().item()
            else:
                loss = None
                logits = outputs[0]
            if self.args.past_index >= 0:
                self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]

        if prediction_loss_only:
            return (loss, None, None)

        labels = inputs.get("labels")
        if labels is not None:
            labels = labels.detach()
        return (loss, logits.detach(), labels)
예제 #6
0
def train(dataset,
          max_iter,
          ckpt_path,
          save_iter=5000,
          lr=0.0002,
          batch_size=64,
          manual_seed=None,
          cuda=True,
          resume=True):
    manual_seed = None
    if manual_seed is None:
        manual_seed = random.randint(1, 10000)
    print("Random Seed: ", manual_seed)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)

    cudnn.benchmark = True

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=2)

    nz = 100
    netG = Generator(nz=nz)
    netD = Discriminator()
    criterion = nn.BCELoss()

    if cuda:
        cudnn.benchmark = True
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    netG.to(device)
    netD.to(device)

    # setup optimizer
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

    start_iter = 0

    writer = SummaryWriter()

    if resume and os.path.exists(ckpt_path):
        # ckpt = torch.load(ckpt_path, map_location='cpu')
        ckpt = torch.load(ckpt_path)
        start_iter = ckpt['iteration']
        netG.load_state_dict(ckpt['netG'])
        netD.load_state_dict(ckpt['netD'])
        optimizerG.load_state_dict(ckpt['optimizerG'])
        optimizerD.load_state_dict(ckpt['optimizerD'])
    else:
        netG.apply(weights_init)
        netD.apply(weights_init)

    fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
    real_label = 1
    fake_label = 0

    dataloader_iter = iter(dataloader)
    for iteration in range(start_iter, max_iter):
        try:
            data = dataloader_iter.next()
        except StopIteration:
            dataloader_iter = iter(dataloader)
            data = dataloader_iter.next()

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(device)
        label = torch.full((real_cpu.size(0), ), real_label, device=device)

        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(real_cpu.size(0), nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if iteration % 200 == 0:
            print('%d/%d errD_real:%.2e errD_fake:%.2e errG:%.2e' %
                  (iteration, max_iter, errD_real.item(), errD_fake.item(),
                   errG.item()))

        if iteration % 1000 == 0:
            writer.add_scalar('anime-dcgan/errD_real', errD_real.item(),
                              iteration)
            writer.add_scalar('anime-dcgan/errD_fake', errD_fake.item(),
                              iteration)
            writer.add_scalar('anime-dcgan/errG', errG.item(), iteration)

            netG.eval()
            with torch.no_grad():
                grid = vutils.make_grid(float2byte(real_cpu),
                                        range=(0, 255),
                                        scale_each=True)
                writer.add_image('real', grid, iteration)
                fake = netG(fixed_noise)
                grid = vutils.make_grid(float2byte(fake),
                                        range=(0, 255),
                                        scale_each=True)
                writer.add_image('fixed_fake', grid, iteration)
            netG.train()

            writer.flush()

        if iteration > 0 and iteration % save_iter == 0:
            save(netD, netG, optimizerD, optimizerG, iteration, ckpt_path)
class Logger(object):
    def __init__(self, args):
        self.args = args
        if args.plot:
            self.plot_path = os.path.join(args.plot_dir, args.plot_name)
            self.writer = SummaryWriter(self.plot_path)
        self.logs = dict()

    def print(self, msg):
        if self.args.local_rank == 0:
            print(msg)
        if self.args.plot:
            self.writer.add_text("stdout", msg)

    def set_state(self, state):
        self.logs = state

    def get_state(self):
        return self.logs

    def log(self, title, value):
        if title not in self.logs:
            self.logs[title] = {"data": [], "type": "line"}
        self.logs[title]["data"].append(value)

    def plot_step(self, step):
        for title, v in self.logs.items():
            if v["type"] == "line":
                if title == "X":
                    pass
                else:
                    self.writer.add_scalar(
                        title, v["data"][step], self.logs["X"]["data"][step]
                    )

    def plot_line(self, title, vals, X=None):
        if torch.is_tensor(vals):
            vals = vals.detach().cpu().numpy()
        if X is None:
            X = range(len(vals))
        if torch.is_tensor(X):
            X = X.detach().cpu().numpy()
        fig = plt.figure()
        plt.plot(vals)
        self.writer.add_figure(title, fig)

    def plot_bar(self, title, vals, X=None):
        if torch.is_tensor(vals):
            vals = vals.detach().cpu().numpy()
        if X is None:
            X = range(len(vals))
        if torch.is_tensor(X):
            X = X.detach().cpu().numpy()
        fig = plt.figure()
        plt.bar(X, vals)
        self.writer.add_figure(title, fig)

    def plot_heatmap(self, title, vals):
        if torch.is_tensor(vals):
            vals = vals.detach().cpu().numpy()
        fig = plt.figure()
        plt.imshow(vals, cmap="hot", interpolation="nearest")
        self.writer.add_figure(title, fig)

    def step(self, args, stat_train, stat_val, elapsed, gpu_mem):
        if "err" in stat_train:
            print(
                "{}\ttrain: {:.2f}%\tval: {:.2f}%\tms/batch: {:.1f}\tgpu_mem: {:.1f}gb".format(
                    (args.ep + 1) * args.nbatches // args.update_freq,
                    stat_train["err"] * 100,
                    stat_val["err"] * 100,
                    elapsed,
                    gpu_mem,
                )
            )
            self.log("loss/train", stat_train["loss"])
            self.log("loss/val", stat_val["loss"])
            self.log("err/train", stat_train["err"])
            self.log("err/val", stat_val["err"])
        elif args.data_type == "char":
            print(
                "{}\ttrain: {:.2f}bpc\tval: {:.2f}bpc\tms/batch: {:.1f}\tgpu_mem: {:.1f}gb".format(
                    (args.ep + 1) * args.nbatches // args.update_freq,
                    stat_train["loss"] / math.log(2),
                    stat_val["loss"] / math.log(2),
                    elapsed,
                    gpu_mem,
                )
            )
            self.log("loss/train", stat_train["loss"] / math.log(2))
            self.log("loss/val", stat_val["loss"] / math.log(2))
        else:
            train_ppl = math.exp(min(stat_train["loss"], 30))  # avoid overflow
            val_ppl = math.exp(min(stat_val["loss"], 30))  # avoid overflow
            print(
                "{}\ttrain_ppl: {:.1f}\tval_ppl: {:.1f}\tms/batch: {:.1f}\tgpu_mem: {:.1f}gb".format(
                    (args.ep + 1) * args.nbatches // args.update_freq,
                    train_ppl,
                    val_ppl,
                    elapsed,
                    gpu_mem,
                )
            )
            self.log("loss/train", stat_train["loss"])
            self.log("loss/val", stat_val["loss"])
            self.log("loss/ppl_train", train_ppl)
            self.log("loss/ppl_val", val_ppl)
        self.log("X", (args.ep + 1) * args.nbatches // args.update_freq)

        if args.plot:
            self.log("compute/gpu_mem_gb", gpu_mem)
            self.log("compute/batch_time_ms", elapsed)
            self.plot_step(-1)
            self.writer.flush()
예제 #8
0
파일: train.py 프로젝트: JCly-rikiu/dvector
def ddp_train(
    rank,
    world_size,
    tmp_file_path,
    model_dir,
    n_speakers,
    n_utterances,
    seg_len,
    save_every,
    valid_every,
    decay_every,
    batch_per_valid,
    n_workers,
    start_time,
    checkpoints_path,
    metadata,
    trainset,
    validset,
):
    print(f"[DDP] Running on rank {rank}.")
    dist.init_process_group(
        "nccl",
        rank=rank,
        world_size=world_size,
        store=dist.FileStore(tmp_file_path, world_size),
    )

    train_sampler = DistributedSampler(trainset)
    train_loader = MultiEpochsDataLoader(
        trainset,
        batch_size=n_speakers,
        sampler=train_sampler,
        num_workers=n_workers,
        collate_fn=pad_batch,
        drop_last=True,
    )
    train_iter = infinite_iterator(train_loader, train_sampler)

    # build network and training tools
    dvector = DVector(dim_input=metadata["n_mels"], seg_len=seg_len).to(rank)
    ddp_dvector = DDP(dvector, device_ids=[rank])

    criterion = GE2ELoss().to(rank)
    ddp_criterion = DDP(criterion, device_ids=[rank])
    optimizer = SGD(
        list(ddp_dvector.parameters()) + list(ddp_criterion.parameters()), lr=0.01
    )
    scheduler = StepLR(optimizer, step_size=decay_every, gamma=0.5)

    train_losses, valid_losses = [], []
    batch_ms, model_ms, loss_ms, backward_ms = [], [], [], []

    if rank == 0:
        writer = SummaryWriter(Path(model_dir) / "logs" / start_time)
        valid_loader = MultiEpochsDataLoader(
            validset,
            batch_size=n_speakers,
            num_workers=n_workers,
            collate_fn=pad_batch,
            drop_last=True,
        )
        valid_iter = infinite_iterator(valid_loader)
        pbar = tqdm(total=valid_every, ncols=0, desc="Train")
        cuda_timer = CUDATimer()
    else:
        valid_loader, valid_iter, writer = None, None, None
        pbar, cuda_timer = None, None

    # start training
    for step in count(start=1):

        if rank == 0:
            cuda_timer.record("batch")
        batch = next(train_iter).to(rank)

        if rank == 0:
            cuda_timer.record("model")
        embds = ddp_dvector(batch).view(n_speakers, n_utterances, -1)

        if rank == 0:
            cuda_timer.record("loss")
        loss = ddp_criterion(embds)

        if rank == 0:
            cuda_timer.record("backward")
        optimizer.zero_grad()
        loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(
            list(ddp_dvector.parameters()) + list(ddp_criterion.parameters()),
            max_norm=3,
            norm_type=2.0,
        )
        dvector.embedding.weight.grad *= 0.5
        dvector.embedding.bias.grad *= 0.5
        criterion.w.grad *= 0.01
        criterion.b.grad *= 0.01

        optimizer.step()
        scheduler.step()

        if rank == 0:
            cuda_timer.record()
            elapsed_times = cuda_timer.stop()

            train_losses.append(loss.item())
            batch_ms.append(elapsed_times["batch"])
            model_ms.append(elapsed_times["model"])
            loss_ms.append(elapsed_times["loss"])
            backward_ms.append(elapsed_times["backward"])

            pbar.update(1)
            pbar.set_postfix(loss=loss.item(), grad_norm=grad_norm.item())

            if step % valid_every == 0:
                pbar.close()

                for _ in range(batch_per_valid):
                    batch = next(valid_iter).to(rank)

                    with torch.no_grad():
                        embd = ddp_dvector(batch).view(n_speakers, n_utterances, -1)
                        loss = ddp_criterion(embd)
                        valid_losses.append(loss.item())

                avg_train_loss = sum(train_losses) / len(train_losses)
                avg_valid_loss = sum(valid_losses) / len(valid_losses)
                avg_batch_ms = sum(batch_ms) / len(batch_ms)
                avg_model_ms = sum(model_ms) / len(model_ms)
                avg_loss_ms = sum(loss_ms) / len(loss_ms)
                avg_backward_ms = sum(backward_ms) / len(backward_ms)
                print(f"[DDP] Valid: loss={avg_valid_loss:.1f}, ")
                print(
                    f"[DDP] Average elapsed time: "
                    f"batch {avg_batch_ms:.1f} ms, "
                    f"model {avg_model_ms:.1f} ms, "
                    f"loss {avg_loss_ms:.1f} ms, "
                    f"backward {avg_backward_ms:.1f} ms"
                )

                writer.add_scalar("Loss/train", avg_train_loss, step)
                writer.add_scalar("Loss/valid", avg_valid_loss, step)

                writer.add_scalar("Elapsed time/batch (ms)", avg_batch_ms, step)
                writer.add_scalar("Elapsed time/model (ms)", avg_model_ms, step)
                writer.add_scalar("Elapsed time/loss (ms)", avg_loss_ms, step)
                writer.add_scalar("Elapsed time/backward (ms)", avg_backward_ms, step)
                writer.flush()

                pbar = tqdm(
                    total=step + valid_every, ncols=0, initial=step, desc="Train",
                )
                train_losses, valid_losses = [], []

            if step % save_every == 0:
                ckpt_path = checkpoints_path / f"dvector-step{step}.pt"
                torch.save(ddp_dvector, str(ckpt_path))

    dist.destroy_process_group()
예제 #9
0
class ClassificationModelTrainer:
    """
    Model trainer for classification task.
    """
    SchedulerType = Union[optim.lr_scheduler._LRScheduler,
                          optim.lr_scheduler.ReduceLROnPlateau]

    def __init__(self, model: nn.Module, optimizer: optim.Optimizer,
                 scheduler: SchedulerType, dataset: str, batch_size: int,
                 num_workers: int, data_path: Path, log_path: Path,
                 checkpoint_path: Path):
        self.logger = get_logger(name=__name__,
                                 save_dir=str(log_path / 'logs'))
        self.logger.info('Initializing Classification Model Trainer.')

        if dataset.upper() == 'CIFAR10':
            train_loader, eval_loader = get_cifar10_loaders(
                data_root=str(data_path),
                batch_size=batch_size,
                num_workers=num_workers,
                augment=True)
        else:
            raise NotImplementedError('Only CIFAR10 implemented.')

        self.model = model  # Assumes model has already been sent to device.
        self.optimizer = optimizer  # Assumes optimizer is associated with model.
        self.device = get_single_model_device(
            model)  # Finds device of model assuming it is on a single device.
        self.loss_func = nn.CrossEntropyLoss()
        self.writer = SummaryWriter(str(log_path))
        self.manager = CheckpointManager(model,
                                         optimizer,
                                         checkpoint_path,
                                         mode='max',
                                         save_best_only=True,
                                         max_to_keep=1)
        self.scheduler = scheduler  # No learning rate scheduling if scheduler = None.
        self.train_loader = train_loader
        self.eval_loader = eval_loader
        self.epoch = 0
        self.tic = 0
        self.tic_tic = time()

    def _train_epoch(self) -> float:
        self.tic = time()  # Starts counting the time.
        self.model.train()  # Change settings for batchnorm, dropout, etc.
        torch.autograd.enable_grad = True  # Allow gradient calculations.
        losses = list()
        correct = torch.tensor(
            0,
            device=self.device)  # Counter for number of correct predictions.

        for inputs, targets in self.train_loader:
            targets = targets.to(self.device)
            inputs = inputs.to(
                self.device, non_blocking=True
            )  # Asynchronous transfer to minimize data starvation.
            self.optimizer.zero_grad()
            outputs: Tensor = self.model(
                inputs
            )  # Type hinting 'outputs'. This does not affect the value in any way.
            loss = self.loss_func(outputs, targets)
            loss.backward()
            self.optimizer.step()
            losses.append(loss.detach())

            with torch.no_grad(
            ):  # Number of correct values. Maximum of outputs is the same as softmax maximum.
                correct += (targets == outputs.argmax(dim=1)).sum().detach()

        accuracy = correct.item() / len(self.train_loader.dataset) * 100
        self._write_epoch_metrics(accuracy=accuracy,
                                  losses=losses,
                                  is_train=True)
        return accuracy

    def _eval_epoch(self) -> float:
        self.tic = time()
        self.model.eval()
        torch.autograd.enable_grad = False  # Disable gradient calculations for faster calculations.
        losses = list()
        correct = torch.tensor(
            0,
            device=self.device)  # Counter for number of correct predictions.

        for inputs, targets in self.eval_loader:
            targets = targets.to(self.device)
            outputs: Tensor = self.model(
                inputs.to(self.device)
            )  # Asynchronous transfer is impossible for evaluation.
            loss = self.loss_func(outputs, targets)
            losses.append(loss.detach())
            correct += (targets == outputs.argmax(dim=1)).sum().detach()

        accuracy = correct.item() / len(self.eval_loader.dataset) * 100
        self._write_epoch_metrics(accuracy=accuracy,
                                  losses=losses,
                                  is_train=False)
        return accuracy

    def _write_epoch_metrics(self, accuracy: float, losses: list,
                             is_train: bool):
        phase = 'Train' if is_train else 'Eval'
        # epoch_loss is not a true mean because of the possibly smaller size of the last mini-batch, but this will do.
        with torch.no_grad(
        ):  # Small speed-up by removing unnecessary gradient calculations.
            epoch_loss = torch.stack(losses).mean().item(
            )  # Minimizing device to host data transfer this way.
        self.writer.add_scalar(tag=f'{phase}/epoch_loss',
                               scalar_value=epoch_loss,
                               global_step=self.epoch)
        self.writer.add_scalar(tag=f'{phase}/epoch_accuracy',
                               scalar_value=accuracy,
                               global_step=self.epoch)
        toc = int(time() - self.tic)
        self.logger.info(
            f'Epoch {self.epoch:02d} {phase} loss: {epoch_loss:.3f}, accuracy {accuracy:.1f}%. Time: {toc}s'
        )

    def _write_learning_rates(self):
        for idx, group in enumerate(self.optimizer.param_groups,
                                    start=1):  # Recording learning rate.
            self.writer.add_scalar(tag=f'Learning Rate {idx}',
                                   scalar_value=group['lr'],
                                   global_step=self.epoch)

    def _scheduler_step(self, metrics):
        if self.scheduler is not None:  # No learning rate scheduling if scheduler is None.
            if isinstance(self.scheduler,
                          optim.lr_scheduler.ReduceLROnPlateau):
                self.scheduler.step(metrics=metrics)
            else:
                self.scheduler.step()

    def _train_model(self, num_epochs: int) -> float:
        best_acc = 0.
        for epoch in range(1, num_epochs + 1):  # 1 based indexing.
            self.epoch = epoch  # Update epoch.
            train_epoch_acc = self._train_epoch()
            eval_epoch_acc = self._eval_epoch()
            best_acc = max(best_acc,
                           eval_epoch_acc)  # Update best performance if
            self._write_learning_rates()
            self.manager.save(metric=eval_epoch_acc,
                              epoch=self.epoch)  # Save checkpoint.
            over_fit = train_epoch_acc - eval_epoch_acc  # Positive values indicate over-fitting.
            self.writer.add_scalar(tag='Over-fitting',
                                   scalar_value=over_fit,
                                   global_step=self.epoch)
            self.logger.info(
                f'Epoch {self.epoch:02d} Over-fitting: {over_fit:.3f}.')
            self._scheduler_step(metrics=eval_epoch_acc
                                 )  # Scheduler step for all scheduler types.
        return best_acc

    def train_model(self, num_epochs: int) -> float:
        try:  # Including safeguards against keyboard interruption.
            best_acc = self._train_model(num_epochs=num_epochs)
            self.writer.flush()
            toc_toc = int(time() - self.tic_tic)
            self.logger.info(
                f'Finished Training. Best performance: {best_acc:.2f}%. '
                f'Time: {toc_toc // 60}min {toc_toc % 60}s.')
            return best_acc
        except KeyboardInterrupt:
            self.writer.flush()  # Write to tensorboard before terminating.
            self.logger.info('Training interrupted before completion.')
            return -1
def train(args):  # noqa: C901
    train_start_time = time.perf_counter()
    assert num_gpus > 0, "Found 0 cuda devices, CPU training is not supported."
    total_batch_size = args.batch_size * num_gpus
    assert total_batch_size % args.num_workers == 0, (
        f"batch_size * num_gpus ({total_batch_size}) must be divisible by num_workers "
        f"({args.num_workers}).")

    with open(os.path.join(args.model_dir, "hyperparameters.yml"), "w") as f:
        yaml.dump(vars(args), f)

    # initialization of tensorboard summary writers
    date_time = datetime.datetime.now().strftime(_STRFTIME_FORMAT)
    writer = SummaryWriter(
        os.path.join(args.tensorboard_dir, f"logs/{date_time}"))
    train_writer = SummaryWriter(
        os.path.join(args.tensorboard_dir, f"logs/{date_time}/train"))
    val_writer = SummaryWriter(
        os.path.join(args.tensorboard_dir, f"logs/{date_time}/val"))

    # get weights path, selecting the best weights if weights == "best"
    weights_path = _get_weights_path(args.weights_dir, args.weights)
    # create the correct data structure splitting input data in train and val sets
    prepare_annotations(args.data_dir, args.classes, ["train", "val"])

    torch.cuda.manual_seed(args.seed)

    train_loader = _get_train_data_loader(args)
    val_loader = _get_val_data_loader(args)

    model = EfficientDetBackbone(
        num_classes=len(args.classes),
        compound_coef=args.compound_coef,
        ratios=args.anchors_ratios,
        scales=args.anchors_scales,
    )
    _init_weights(model, weights_path)

    if args.freeze_backbone:
        logger.info("Freezing backbone")
        model.apply(_freeze_submodule_if_backbone)

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # use synchronized batch normalization when the batch size per gpu is too small
    if args.batch_size < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
        logger.info("Using Synchronized Batch Normalization")
    else:
        use_sync_bn = False

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model)
    model = model.cuda()

    if num_gpus > 1:
        # TODO: see if there are better way to parallelize
        model = CustomDataParallel(model, num_gpus)
        if use_sync_bn:
            patch_replication_callback(model)

    steps_per_epoch = len(train_loader)
    last_step, es_baseline = _get_last_step_and_es_baseline(
        weights_path, args.resume_training)
    es = EarlyStopping(args,
                       baseline=es_baseline,
                       best_epoch=last_step // steps_per_epoch - 1)
    optimizer = _get_optimizer(model, args)
    scheduler = _get_scheduler(optimizer, steps_per_epoch, args)
    model.train()
    logger.info(f"Starting training from step {last_step}")

    for epoch in range(args.epochs):
        if epoch in args.milestones:
            for group in optimizer.param_groups:
                if args.scheduler == "onecyclelr":
                    group["max_lr"] *= args.multisteplr_gamma
                    group["min_lr"] *= args.multisteplr_gamma
                else:
                    group["lr"] *= args.multisteplr_gamma

        last_epoch = last_step // steps_per_epoch
        if epoch < last_epoch:
            if scheduler is not None:
                for _ in range(steps_per_epoch):
                    scheduler.step()

            continue

        train_loader_iter = iter(train_loader)
        for batch_idx in range(steps_per_epoch):
            iter_start_time = time.perf_counter()
            data_start_time = time.perf_counter()
            data = next(train_loader_iter)
            data_time = time.perf_counter() - data_start_time
            if batch_idx < (last_step - last_epoch * steps_per_epoch):
                if scheduler is not None:
                    scheduler.step()

                continue

            imgs = data["img"]
            annotations = data["annot"]
            # if only one gpu, just send it to cuda:0 elif multiple gpus,
            # send it to multiple gpus in CustomDataParallel
            if num_gpus == 1:
                imgs = imgs.cuda()
                annotations = annotations.cuda()

            optimizer.zero_grad()
            loss_cls, loss_box_reg = model(imgs, annotations)
            loss_cls = loss_cls.mean()
            loss_box_reg = loss_box_reg.mean()
            total_loss = loss_cls + loss_box_reg
            if total_loss == 0 or not torch.isfinite(total_loss):
                continue

            total_loss.backward()
            if args.clip_gradients_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.clip_gradients_norm)

            lr = optimizer.param_groups[0]["lr"]
            optimizer.step()
            if scheduler is not None:
                scheduler.step()

            date_time = datetime.datetime.now().strftime("%m/%d %H:%M:%S")
            eta = datetime.timedelta(seconds=round(time.perf_counter() -
                                                   train_start_time))
            max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
            iter_time = time.perf_counter() - iter_start_time
            logger.info(f"[{date_time} train]:  "
                        f"eta: {eta}  "
                        f"epoch: {epoch + 1}/{args.epochs}  "
                        f"batch: {batch_idx + 1}/{steps_per_epoch}  "
                        f"loss_cls: {loss_cls.item():.4f}  "
                        f"loss_box_reg: {loss_box_reg.item():.4f}  "
                        f"total_loss: {total_loss.item():.4f}  "
                        f"time: {iter_time:.4f}  "
                        f"data_time: {data_time:.4f}  "
                        f"lr: {lr:.6f}  "
                        f"max_mem: {max_mem_mb:.0f}M")
            writer.add_scalar("hp/lr", lr, last_step)
            if args.cycle_momentum:
                momentum = optimizer.param_groups[0]["momentum"]
                writer.add_scalar("hp/momentum", momentum, last_step)

            writer.add_scalar("usage/max_mem", max_mem_mb, last_step)
            writer.flush()

            train_writer.add_scalar("loss/total_loss", total_loss.item(),
                                    last_step)
            train_writer.add_scalar("loss/loss_cls", loss_cls.item(),
                                    last_step)
            train_writer.add_scalar("loss/loss_box_reg", loss_box_reg.item(),
                                    last_step)
            train_writer.add_scalar("time/time", iter_time, last_step)
            train_writer.add_scalar("time/data_time", data_time, last_step)
            train_writer.flush()

            last_step += 1

        # See https://github.com/pytorch/pytorch/issues/1355#issuecomment-658660582.
        del train_loader_iter

        if epoch % args.val_interval == 0 or epoch + 1 == args.epochs:
            total_val_loss = validate(model, val_loader, last_step - 1, epoch,
                                      args.epochs, val_writer)
            _save_model(
                model,
                args.checkpoints_dir,
                args.compound_coef,
                epoch,
                last_step,
                total_val_loss,
            )
            if es.step(epoch, total_val_loss):
                break

            model.train()

    model_params = {
        "classes": args.classes,
        "compound_coef": args.compound_coef,
        "anchors_scales": args.anchors_scales,
        "anchors_ratios": args.anchors_ratios,
    }
    with open(os.path.join(args.model_dir, "model_params.yml"), "w") as f:
        yaml.dump(model_params, f)

    writer.close()
    train_writer.close()
    val_writer.close()

    best_weights_path = _get_best_weights_path(args.checkpoints_dir)
    shutil.copyfile(best_weights_path, os.path.join(args.model_dir,
                                                    "model.pth"))

    evaluate(
        args.model_dir,
        args.data_dir,
        eval_set="val",
        threshold=args.eval_threshold,
        nms_threshold=args.eval_nms_threshold,
        max_imgs=args.eval_max_imgs,
        use_float16=args.use_float16,
        device=args.eval_device,
    )
예제 #11
0
class Learner(ABC):
    """Abstract training and prediction routines for a model.

    This can be subclassed to handle different computer vision tasks. If a model_path
    is passed to the constructor, the Learner can only be used for prediction (ie. only
    predict and numpy_predict should be called). Otherwise, the Learner can be used for
    training using the main() method.

    Note that the validation set is used to validate at the end of each epoch, and the
    test set is only used at the end of training. It's possible to set these to the same
    dataset if desired.
    """
    def __init__(self,
                 cfg: LearnerConfig,
                 tmp_dir: str,
                 model_path: Optional[str] = None,
                 model_def_path: Optional[str] = None,
                 loss_def_path: Optional[str] = None,
                 training: bool = True):
        """Constructor.

        Args:
            cfg (LearnerConfig): Configuration.
            tmp_dir (str): Root of temp dirs.
            model_path (str, optional): A local path to model weights.
                Defaults to None.
            model_def_path (str, optional): A local path to a directory with a
                hubconf.py. If provided, the model definition is imported from
                here. Defaults to None.
            loss_def_path (str, optional): A local path to a directory with a
                hubconf.py. If provided, the loss function definition is
                imported from here. Defaults to None.
            training (bool, optional): Whether the model is to be used for
                training or prediction. If False, the model is put in eval mode
                and the loss function, optimizer, etc. are not initialized.
                Defaults to True.
        """
        log_system_details()
        self.cfg = cfg
        self.tmp_dir = tmp_dir

        self.preview_batch_limit = self.cfg.data.preview_batch_limit

        # TODO make cache dirs configurable
        torch_cache_dir = '/opt/data/torch-cache'
        os.environ['TORCH_HOME'] = torch_cache_dir
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.data_cache_dir = '/opt/data/data-cache'
        make_dir(self.data_cache_dir)

        if FileSystem.get_file_system(cfg.output_uri) == LocalFileSystem:
            self.output_dir = cfg.output_uri
            make_dir(self.output_dir)
        else:
            self.output_dir = get_local_path(cfg.output_uri, tmp_dir)
            make_dir(self.output_dir, force_empty=True)

            if training and not cfg.overfit_mode:
                self.sync_from_cloud()

        self.modules_dir = join(self.output_dir, MODULES_DIRNAME)

        self.setup_model(model_def_path=model_def_path)

        if model_path is not None:
            if isfile(model_path):
                log.info(f'Loading model weights from: {model_path}')
                self.model.load_state_dict(
                    torch.load(model_path, map_location=self.device))
            else:
                raise Exception(
                    'Model could not be found at {}'.format(model_path))
        if training:
            self.setup_training(loss_def_path=loss_def_path)
        else:
            self.model.eval()

    def main(self):
        """Main training sequence.

        This plots the dataset, runs a training and validation loop (which will resume if
        interrupted), logs stats, plots predictions, and syncs results to the cloud.
        """
        self.run_tensorboard()
        cfg = self.cfg
        self.log_data_stats()
        if not cfg.predict_mode:
            self.plot_dataloaders(self.preview_batch_limit)
            if cfg.overfit_mode:
                self.overfit()
            else:
                self.train()
                if cfg.save_model_bundle:
                    self.save_model_bundle()

        self.load_checkpoint()
        if cfg.eval_train:
            self.eval_model('train')
        self.eval_model('test')
        self.sync_to_cloud()
        self.stop_tensorboard()

    def setup_training(self, loss_def_path=None):
        log.info(self.cfg)
        log.info(f'Using device: {self.device}')

        # ds = dataset, dl = dataloader
        self.train_ds = None
        self.train_dl = None
        self.valid_ds = None
        self.valid_dl = None
        self.test_ds = None
        self.test_dl = None

        self.config_path = join(self.output_dir, 'learner-config.json')
        str_to_file(self.cfg.json(), self.config_path)

        self.log_path = join(self.output_dir, 'log.csv')
        self.train_state_path = join(self.output_dir, 'train-state.json')
        model_bundle_fname = basename(self.cfg.get_model_bundle_uri())
        self.model_bundle_path = join(self.output_dir, model_bundle_fname)
        self.metric_names = self.build_metric_names()

        self.last_model_path = join(self.output_dir, 'last-model.pth')
        self.load_checkpoint()

        self.setup_loss(loss_def_path=loss_def_path)
        self.opt = self.build_optimizer()
        self.setup_data()
        self.start_epoch = self.get_start_epoch()
        self.steps_per_epoch = len(self.train_ds) // self.cfg.solver.batch_sz
        self.step_scheduler = self.build_step_scheduler()
        self.epoch_scheduler = self.build_epoch_scheduler()
        self.setup_tensorboard()

    def sync_to_cloud(self):
        """Sync any output to the cloud at output_uri."""
        sync_to_dir(self.output_dir, self.cfg.output_uri)

    def sync_from_cloud(self):
        """Sync any previous output in the cloud to output_dir."""
        sync_from_dir(self.cfg.output_uri, self.output_dir)

    def setup_tensorboard(self):
        """Setup for logging stats to TB."""
        self.tb_writer = None
        if self.cfg.log_tensorboard:
            self.tb_log_dir = join(self.output_dir, 'tb-logs')
            make_dir(self.tb_log_dir)
            self.tb_writer = SummaryWriter(log_dir=self.tb_log_dir)

    def run_tensorboard(self):
        """Run TB server serving logged stats."""
        if self.cfg.run_tensorboard:
            log.info('Starting tensorboard process')
            self.tb_process = Popen([
                'tensorboard', '--bind_all',
                '--logdir={}'.format(self.tb_log_dir)
            ])
            terminate_at_exit(self.tb_process)

    def stop_tensorboard(self):
        """Stop TB logging and server if it's running."""
        if self.cfg.log_tensorboard:
            self.tb_writer.close()
            if self.cfg.run_tensorboard:
                self.tb_process.terminate()

    def setup_model(self, model_def_path: Optional[str] = None) -> None:
        """Setup self.model.

        Args:
            model_def_path (str, optional): Model definition path. Will be
            available when loading from a bundle. Defaults to None.
        """
        ext_cfg = self.cfg.model.external_def
        if ext_cfg is not None:
            hubconf_dir = self._get_external_module_dir(
                ext_cfg, model_def_path)
            self.model = self.load_external_module(ext_cfg=ext_cfg,
                                                   hubconf_dir=hubconf_dir)
        else:
            self.model = self.build_model()
        self.model.to(self.device)
        self.load_init_weights()

    @abstractmethod
    def build_model(self) -> nn.Module:
        """Build a PyTorch model."""
        pass

    def setup_loss(self, loss_def_path: Optional[str] = None) -> None:
        """Setup self.loss.

        Args:
            loss_def_path (str, optional): Loss definition path. Will be
            available when loading from a bundle. Defaults to None.
        """
        ext_cfg = self.cfg.solver.external_loss_def
        if ext_cfg is not None:
            hubconf_dir = self._get_external_module_dir(ext_cfg, loss_def_path)
            self.loss = self.load_external_module(ext_cfg=ext_cfg,
                                                  hubconf_dir=hubconf_dir)
        else:
            self.loss = self.build_loss()

        if self.loss is not None and isinstance(self.loss, nn.Module):
            self.loss.to(self.device)

    def build_loss(self) -> nn.Module:
        """Build a loss Callable."""
        pass

    def _get_external_module_dir(
            self,
            ext_cfg: ExternalModuleConfig,
            existing_def_path: Optional[str] = None) -> Optional[str]:
        """Determine correct dir, taking cfg options and existing_def_path into
        account.

        Args:
            ext_cfg (ExternalModuleConfig): Config describing the module.
            existing_def_path (str, optional): Loss definition path.
            Will be available when loading from a bundle. Defaults to None.

        Returns:
            Optional[str]: [description]
        """
        dir_from_cfg = get_hubconf_dir_from_cfg(ext_cfg,
                                                parent=self.modules_dir)
        if isdir(dir_from_cfg) and not ext_cfg.force_reload:
            return dir_from_cfg
        return existing_def_path

    def load_external_module(self,
                             ext_cfg: ExternalModuleConfig,
                             save_dir: Optional[str] = None,
                             hubconf_dir: Optional[str] = None,
                             tmp_dir: Optional[str] = None) -> Any:
        """Load an external module via torch.hub.

        Note: Loading a PyTorch module is the typical use case, but there are
        no type restrictions on the object loaded through torch.hub.

        Args:
            ext_cfg (ExternalModuleConfig): Config describing the module.
            save_dir (str, optional): The module def will be saved here.
                Defaults to self.modules_dir.
            hubconf_dir (str, optional): Path to existing definition.
                If provided, the definition will not be fetched from the source
                specified by ext_cfg. Defaults to None.
            tmp_dir (str, optional): Temporary directory to use for downloads
                etc. Defaults to self.tmp_dir.

        Returns:
            nn.Module: The module loaded via torch.hub.
        """
        if hubconf_dir is not None:
            log.info(f'Using existing module definition at: {hubconf_dir}')
            module = torch_hub_load_local(hubconf_dir=hubconf_dir,
                                          entrypoint=ext_cfg.entrypoint,
                                          *ext_cfg.entrypoint_args,
                                          **ext_cfg.entrypoint_kwargs)
            return module

        save_dir = self.modules_dir if save_dir is None else save_dir
        tmp_dir = self.tmp_dir if tmp_dir is None else tmp_dir

        hubconf_dir = get_hubconf_dir_from_cfg(ext_cfg, parent=save_dir)
        if ext_cfg.github_repo is not None:
            log.info(f'Fetching module definition from: {ext_cfg.github_repo}')
            module = torch_hub_load_github(repo=ext_cfg.github_repo,
                                           hubconf_dir=hubconf_dir,
                                           tmp_dir=save_dir,
                                           entrypoint=ext_cfg.entrypoint,
                                           *ext_cfg.entrypoint_args,
                                           **ext_cfg.entrypoint_kwargs)
        else:
            log.info(f'Fetching module definition from: {ext_cfg.uri}')
            module = torch_hub_load_uri(uri=ext_cfg.uri,
                                        hubconf_dir=hubconf_dir,
                                        tmp_dir=tmp_dir,
                                        entrypoint=ext_cfg.entrypoint,
                                        *ext_cfg.entrypoint_args,
                                        **ext_cfg.entrypoint_kwargs)
        return module

    def unzip_data(self, uri: Union[str, List[str]]) -> List[str]:
        """Unzip dataset zip files.

        Args:
            uri: a list of URIs of zip files or the URI of a directory containing
                zip files

        Returns:
            paths to directories that each contain contents of one zip file
        """
        data_dirs = []

        if isinstance(uri, list):
            zip_uris = uri
        else:
            zip_uris = ([uri] if uri.endswith('.zip') else list_paths(
                uri, 'zip'))

        for zip_ind, zip_uri in enumerate(zip_uris):
            zip_path = get_local_path(zip_uri, self.data_cache_dir)
            if not isfile(zip_path):
                zip_path = download_if_needed(zip_uri, self.data_cache_dir)
            with zipfile.ZipFile(zip_path, 'r') as zipf:
                data_dir = join(self.tmp_dir, 'data', str(uuid.uuid4()),
                                str(zip_ind))
                data_dirs.append(data_dir)
                zipf.extractall(data_dir)

        return data_dirs

    def get_bbox_params(self) -> Optional[A.BboxParams]:
        """Returns BboxParams used by albumentations for data augmentation."""
        return None

    def get_data_transforms(self) -> Tuple[A.BasicTransform, A.BasicTransform]:
        """Get albumentations transform objects for data augmentation.

        Returns:
           1st tuple arg: a transform that doesn't do any data augmentation
           2nd tuple arg: a transform with data augmentation
        """
        cfg = self.cfg
        bbox_params = self.get_bbox_params()
        base_tfs = [A.Resize(cfg.data.img_sz, cfg.data.img_sz)]
        if cfg.data.base_transform is not None:
            base_tfs.append(A.from_dict(cfg.data.base_transform))
        base_transform = A.Compose(base_tfs, bbox_params=bbox_params)

        if cfg.data.aug_transform is not None:
            aug_transform = A.from_dict(cfg.data.aug_transform)
            aug_transform = A.Compose([aug_transform, base_transform],
                                      bbox_params=bbox_params)
            return base_transform, aug_transform

        augmentors_dict = {
            'Blur': A.Blur(),
            'RandomRotate90': A.RandomRotate90(),
            'HorizontalFlip': A.HorizontalFlip(),
            'VerticalFlip': A.VerticalFlip(),
            'GaussianBlur': A.GaussianBlur(),
            'GaussNoise': A.GaussNoise(),
            'RGBShift': A.RGBShift(),
            'ToGray': A.ToGray()
        }
        aug_transforms = []
        for augmentor in cfg.data.augmentors:
            try:
                aug_transforms.append(augmentors_dict[augmentor])
            except KeyError as e:
                log.warning(
                    '{0} is an unknown augmentor. Continuing without {0}. \
                    Known augmentors are: {1}'.format(
                        e, list(augmentors_dict.keys())))
        aug_transforms.append(base_transform)
        aug_transform = A.Compose(aug_transforms, bbox_params=bbox_params)

        return base_transform, aug_transform

    def get_collate_fn(self) -> Optional[callable]:
        """Returns a custom collate_fn to use in DataLoader.

        None is returned if default collate_fn should be used.

        See https://pytorch.org/docs/stable/data.html#working-with-collate-fn
        """
        return None

    def _get_datasets(
        self,
        uri: Optional[Union[str, List[str]]] = None
    ) -> Tuple[Dataset, Dataset, Dataset]:
        """Gets Datasets for a single group of chips.

        Returns:
            train, validation, and test DataSets."""
        if isinstance(self.cfg.data, ImageDataConfig):
            return self._get_image_datasets(uri)

        if isinstance(self.cfg.data, GeoDataConfig):
            return self._get_geo_datasets()

        raise TypeError('Learner.cfg.data')

    def _get_image_datasets(
            self, uri: Union[str,
                             List[str]]) -> Tuple[Dataset, Dataset, Dataset]:
        """Gets image training, validation, and test datasets from a single
        zip file.

        Args:
            uri (Union[str, List[str]]): Uri of a zip file containing the
                images.

        Returns:
            Tuple[Dataset, Dataset, Dataset]: Training, validation, and test
                dataSets.
        """
        cfg = self.cfg
        data_dirs = self.unzip_data(uri)

        train_dirs = [join(d, 'train') for d in data_dirs if isdir(d)]
        val_dirs = [join(d, 'valid') for d in data_dirs if isdir(d)]

        train_dirs = [d for d in train_dirs if isdir(d)]
        val_dirs = [d for d in val_dirs if isdir(d)]

        base_transform, aug_transform = self.get_data_transforms()
        train_tf = aug_transform if not cfg.overfit_mode else base_transform
        val_tf, test_tf = base_transform, base_transform

        train_ds, val_ds, test_ds = cfg.data.make_datasets(
            train_dirs=train_dirs,
            val_dirs=val_dirs,
            test_dirs=val_dirs,
            train_tf=train_tf,
            val_tf=val_tf,
            test_tf=test_tf)
        return train_ds, val_ds, test_ds

    def _get_geo_datasets(self) -> Tuple[Dataset, Dataset, Dataset]:
        """Gets geo datasets.

        Returns:
            train, validation, and test DataSets."""
        cfg = self.cfg
        base_transform, aug_transform = self.get_data_transforms()
        train_tf = aug_transform if not cfg.overfit_mode else base_transform
        val_tf, test_tf = base_transform, base_transform

        train_ds, val_ds, test_ds = cfg.data.make_datasets(
            tmp_dir=self.tmp_dir,
            train_tf=train_tf,
            val_tf=val_tf,
            test_tf=test_tf)
        return train_ds, val_ds, test_ds

    def get_datasets(self) -> Tuple[Dataset, Dataset, Dataset]:
        """Returns train, validation, and test DataSets."""
        cfg = self.cfg
        if isinstance(cfg.data, GeoDataConfig):
            return self._get_datasets()
        if cfg.data.group_uris is None:
            return self._get_datasets(cfg.data.uri)

        if cfg.data.uri is not None:
            log.warn('Both DataConfig.uri and DataConfig.group_uris '
                     'specified. Only DataConfig.group_uris will be used.')
        train_ds_lst, valid_ds_lst, test_ds_lst = [], [], []

        group_sizes = None
        if cfg.data.group_train_sz is not None:
            group_sizes = cfg.data.group_train_sz
        elif cfg.data.group_train_sz_rel is not None:
            group_sizes = cfg.data.group_train_sz_rel
        if not sequence_like(group_sizes):
            group_sizes = [group_sizes] * len(cfg.data.group_uris)

        for uri, sz in zip(cfg.data.group_uris, group_sizes):
            train_ds, valid_ds, test_ds = self._get_datasets(uri)
            if sz is not None:
                if isinstance(sz, float):
                    sz = int(len(train_ds) * sz)
                train_inds = list(range(len(train_ds)))
                random.seed(1234)
                random.shuffle(train_inds)
                train_inds = train_inds[:sz]
                train_ds = Subset(train_ds, train_inds)
            train_ds_lst.append(train_ds)
            valid_ds_lst.append(valid_ds)
            test_ds_lst.append(test_ds)

        train_ds, valid_ds, test_ds = (ConcatDataset(train_ds_lst),
                                       ConcatDataset(valid_ds_lst),
                                       ConcatDataset(test_ds_lst))
        return train_ds, valid_ds, test_ds

    def get_train_sampler(self, train_ds: Dataset) -> Optional[Sampler]:
        """Return a sampler to use for the training dataloader or None to not use any."""
        return None

    def setup_data(self):
        """Set the the DataSet and DataLoaders for train, validation, and test sets."""
        cfg = self.cfg
        batch_sz = self.cfg.solver.batch_sz
        num_workers = self.cfg.data.num_workers

        train_ds, valid_ds, test_ds = self.get_datasets()
        if len(train_ds) < batch_sz:
            raise ConfigError(
                'Training dataset has fewer elements than batch size.')
        if len(valid_ds) < batch_sz:
            raise ConfigError(
                'Validation dataset has fewer elements than batch size.')
        if len(test_ds) < batch_sz:
            raise ConfigError(
                'Test dataset has fewer elements than batch size.')

        if cfg.overfit_mode:
            train_ds = Subset(train_ds, range(batch_sz))
            valid_ds = train_ds
            test_ds = train_ds
        elif cfg.test_mode:
            train_ds = Subset(train_ds, range(batch_sz))
            valid_ds = Subset(valid_ds, range(batch_sz))
            test_ds = Subset(test_ds, range(batch_sz))

        if cfg.data.train_sz is not None or cfg.data.train_sz_rel is not None:
            train_inds = list(range(len(train_ds)))
            random.seed(1234)
            random.shuffle(train_inds)
            train_sz = (cfg.data.train_sz if cfg.data.train_sz is not None else
                        int(round(len(train_ds) * cfg.data.train_sz_rel)))
            train_inds = train_inds[0:train_sz]
            train_ds = Subset(train_ds, train_inds)

        train_sampler = self.get_train_sampler(train_ds)
        train_shuffle = train_sampler is None

        collate_fn = self.get_collate_fn()
        train_dl = DataLoader(train_ds,
                              shuffle=train_shuffle,
                              batch_size=batch_sz,
                              drop_last=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              collate_fn=collate_fn,
                              sampler=train_sampler)
        valid_dl = DataLoader(valid_ds,
                              shuffle=True,
                              batch_size=batch_sz,
                              num_workers=num_workers,
                              pin_memory=True,
                              collate_fn=collate_fn)
        test_dl = DataLoader(test_ds,
                             shuffle=True,
                             batch_size=batch_sz,
                             num_workers=num_workers,
                             pin_memory=True,
                             collate_fn=collate_fn)

        self.train_ds, self.valid_ds, self.test_ds = (train_ds, valid_ds,
                                                      test_ds)
        self.train_dl, self.valid_dl, self.test_dl = (train_dl, valid_dl,
                                                      test_dl)

    def log_data_stats(self):
        """Log stats about each DataSet."""
        if self.train_ds:
            log.info('train_ds: {} items'.format(len(self.train_ds)))
        if self.valid_ds:
            log.info('valid_ds: {} items'.format(len(self.valid_ds)))
        if self.test_ds:
            log.info('test_ds: {} items'.format(len(self.test_ds)))

    def build_optimizer(self) -> optim.Optimizer:
        """Returns optimizer."""
        return optim.Adam(self.model.parameters(), lr=self.cfg.solver.lr)

    def build_step_scheduler(self) -> _LRScheduler:
        """Returns an LR scheduler that changes the LR each step.

        This is used to implement the "one cycle" schedule popularized by
        fastai.
        """
        scheduler = None
        cfg = self.cfg
        if cfg.solver.one_cycle and cfg.solver.num_epochs > 1:
            total_steps = cfg.solver.num_epochs * self.steps_per_epoch
            step_size_up = (cfg.solver.num_epochs // 2) * self.steps_per_epoch
            step_size_down = total_steps - step_size_up
            scheduler = CyclicLR(self.opt,
                                 base_lr=cfg.solver.lr / 10,
                                 max_lr=cfg.solver.lr,
                                 step_size_up=step_size_up,
                                 step_size_down=step_size_down,
                                 cycle_momentum=False)
            for _ in range(self.start_epoch * self.steps_per_epoch):
                scheduler.step()
        return scheduler

    def build_epoch_scheduler(self) -> _LRScheduler:
        """Returns an LR scheduler tha changes the LR each epoch.

        This is used to divide the LR by 10 at certain epochs.
        """
        scheduler = None
        if self.cfg.solver.multi_stage:
            scheduler = MultiStepLR(self.opt,
                                    milestones=self.cfg.solver.multi_stage,
                                    gamma=0.1)
            for _ in range(self.start_epoch):
                scheduler.step()
        return scheduler

    def build_metric_names(self) -> List[str]:
        """Returns names of metrics used to validate model at each epoch."""
        metric_names = [
            'epoch', 'train_time', 'valid_time', 'train_loss', 'val_loss',
            'avg_f1', 'avg_precision', 'avg_recall'
        ]

        for label in self.cfg.data.class_names:
            metric_names.extend([
                '{}_f1'.format(label), '{}_precision'.format(label),
                '{}_recall'.format(label)
            ])
        return metric_names

    @abstractmethod
    def train_step(self, batch: Any, batch_ind: int) -> MetricDict:
        """Compute loss for a single training batch.

        Args:
            batch: batch data needed to compute loss
            batch_ind: index of batch within epoch

        Returns:
            dict with 'train_loss' as key and possibly other losses
        """
        pass

    @abstractmethod
    def validate_step(self, batch: Any, batch_ind: int) -> MetricDict:
        """Compute metrics on validation batch.

        Args:
            batch: batch data needed to compute validation metrics
            batch_ind: index of batch within epoch

        Returns:
            dict with metric names mapped to metric values
        """
        pass

    def train_end(self, outputs: List[MetricDict],
                  num_samples: int) -> MetricDict:
        """Aggregate the ouput of train_step at the end of the epoch.

        Args:
            outputs: a list of outputs of train_step
            num_samples: total number of training samples processed in epoch
        """
        metrics = {}
        for k in outputs[0].keys():
            metrics[k] = torch.stack([o[k] for o in outputs
                                      ]).sum().item() / num_samples
        return metrics

    def validate_end(self, outputs: List[MetricDict],
                     num_samples: int) -> MetricDict:
        """Aggregate the ouput of validate_step at the end of the epoch.

        Args:
            outputs: a list of outputs of validate_step
            num_samples: total number of validation samples processed in epoch
        """
        metrics = {}
        for k in outputs[0].keys():
            metrics[k] = torch.stack([o[k] for o in outputs
                                      ]).sum().item() / num_samples
        return metrics

    def post_forward(self, x: Any) -> Any:
        """Post process output of call to model().

        Useful for when predictions are inside a structure returned by model().
        """
        return x

    def prob_to_pred(self, x: Tensor) -> Tensor:
        """Convert a Tensor with prediction probabilities to class ids.

        The class ids should be the classes with the maximum probability.
        """
        raise NotImplementedError()

    def to_batch(self, x: Tensor) -> Tensor:
        """Ensure that image array has batch dimension.

        Args:
            x: assumed to be either image or batch of images

        Returns:
            x with extra batch dimension of length 1 if needed
        """
        if x.ndim == 3:
            x = x[None, ...]
        return x

    def normalize_input(self, x: np.ndarray) -> np.ndarray:
        """If x.dtype is a subtype of np.unsignedinteger, normalize it to
        [0, 1] using the max possible value of that dtype. Otherwise, assume
        it is in [0, 1] already and do nothing.

        Args:
            x (np.ndarray): an image or batch of images
        Returns:
            the same array scaled to [0, 1].
        """
        if np.issubdtype(x.dtype, np.unsignedinteger):
            max_val = np.iinfo(x.dtype).max
            x = x.astype(float) / max_val
        return x

    def predict(self, x: Tensor, raw_out: bool = False) -> Any:
        """Make prediction for an image or batch of images.

        Args:
            x (Tensor): Image or batch of images as a float Tensor with pixel
                values normalized to [0, 1].
            raw_out (bool): if True, return prediction probabilities

        Returns:
            the predictions, in probability form if raw_out is True, in class_id form
                otherwise
        """
        x = self.to_batch(x).float()
        x = self.to_device(x, self.device)
        with torch.no_grad():
            out = self.model(x)
            if not raw_out:
                out = self.prob_to_pred(self.post_forward(out))
        out = self.to_device(out, 'cpu')
        return out

    def output_to_numpy(self, out: Tensor) -> np.ndarray:
        """Convert output of model to numpy format.

        Args:
            out: the output of the model in PyTorch format

        Returns: the output of the model in numpy format
        """
        return out.numpy()

    def numpy_predict(self,
                      x: np.ndarray,
                      raw_out: bool = False) -> np.ndarray:
        """Make a prediction using an image or batch of images in numpy format.
        If x.dtype is a subtype of np.unsignedinteger, it will be normalized
        to [0, 1] using the max possible value of that dtype. Otherwise, x will
        be assumed to be in [0, 1] already and will be cast to torch.float32
        directly.

        Args:
            x: (ndarray) of shape [height, width, channels] or
                [batch_sz, height, width, channels]
            raw_out: if True, return prediction probabilities

        Returns:
            predictions using numpy arrays
        """
        transform, _ = self.get_data_transforms()
        x = self.normalize_input(x)
        x = self.to_batch(x)
        x = np.stack([transform(image=img)['image'] for img in x])
        x = torch.from_numpy(x)
        x = x.permute((0, 3, 1, 2))
        out = self.predict(x, raw_out=raw_out)
        return self.output_to_numpy(out)

    def predict_dataloader(self,
                           dl: DataLoader,
                           one_batch: bool = False,
                           return_x: bool = True):
        """Make predictions over all batches in a DataLoader.

        Args:
            dl: the DataLoader
            one_batch: if True, just makes predictions over the first batch
            return_x: if True, returns all the inputs in addition to the predictions and
                targets

        Returns:
            if return_x: (x, y, z) ie. all images, labels, predictions for dl
            else: (y, z) ie. all labels, predictions for dl
        """
        self.model.eval()

        xs, ys, zs = [], [], []
        with torch.no_grad():
            for x, y in dl:
                x = self.to_device(x, self.device)
                z = self.prob_to_pred(self.post_forward(self.model(x)))
                x = self.to_device(x, 'cpu')
                z = self.to_device(z, 'cpu')
                if one_batch:
                    return x, y, z
                if return_x:
                    xs.append(x)
                ys.append(y)
                zs.append(z)

        if return_x:
            return torch.cat(xs), torch.cat(ys), torch.cat(zs)
        return torch.cat(ys), torch.cat(zs)

    def get_dataloader(self, split: str) -> DataLoader:
        """Get the DataLoader for a split.

        Args:
            split: a split name which can be train, valid, or test
        """
        if split == 'train':
            return self.train_dl
        elif split == 'valid':
            return self.valid_dl
        elif split == 'test':
            return self.test_dl
        else:
            raise ValueError('{} is not a valid split'.format(split))

    @abstractmethod
    def plot_xyz(self, ax, x: Tensor, y, z=None):
        """Plot image, ground truth labels, and predicted labels.

        Args:
            ax: matplotlib axis on which to plot
            x: image
            y: ground truth labels
            z: optional predicted labels
        """
        pass

    def plot_batch(self,
                   x: Tensor,
                   y,
                   output_path: str,
                   z=None,
                   batch_limit: Optional[int] = None):
        """Plot a whole batch in a grid using plot_xyz.

        Args:
            x: batch of images
            y: ground truth labels
            output_path: local path where to save plot image
            z: optional predicted labels
            batch_limit: optional limit on (rendered) batch size
        """
        batch_sz = x.shape[0]
        batch_sz = min(batch_sz,
                       batch_limit) if batch_limit is not None else batch_sz
        if batch_sz == 0:
            return
        ncols = nrows = math.ceil(math.sqrt(batch_sz))
        fig = plt.figure(constrained_layout=True,
                         figsize=(3 * ncols, 3 * nrows))
        grid = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig)

        # (N, c, h, w) --> (N, h, w, c)
        x = x.permute(0, 2, 3, 1)

        # apply transform, if given
        if self.cfg.data.plot_options.transform is not None:
            tf = A.from_dict(self.cfg.data.plot_options.transform)
            imgs = [tf(image=img)['image'] for img in x.numpy()]
            x = torch.from_numpy(np.stack(imgs))

        for i in range(batch_sz):
            ax = fig.add_subplot(grid[i])
            if z is None:
                self.plot_xyz(ax, x[i], y[i])
            else:
                self.plot_xyz(ax, x[i], y[i], z=z[i])

        make_dir(output_path, use_dirname=True)
        plt.savefig(output_path)
        plt.close()

    def plot_predictions(self, split: str, batch_limit: Optional[int] = None):
        """Plot predictions for a split.

        Uses the first batch for the corresponding DataLoader.

        Args:
            split: dataset split. Can be train, valid, or test.
            batch_limit: optional limit on (rendered) batch size
        """
        log.info('Plotting predictions...')
        dl = self.get_dataloader(split)
        output_path = join(self.output_dir, '{}_preds.png'.format(split))
        x, y, z = self.predict_dataloader(dl, one_batch=True)
        self.plot_batch(x, y, output_path, z=z, batch_limit=batch_limit)

    def plot_dataloader(self,
                        dl: DataLoader,
                        output_path: str,
                        batch_limit: Optional[int] = None):
        """Plot images and ground truth labels for a DataLoader."""
        x, y = next(iter(dl))
        self.plot_batch(x, y, output_path, batch_limit=batch_limit)

    def plot_dataloaders(self, batch_limit: Optional[int] = None):
        """Plot images and ground truth labels for all DataLoaders."""
        if self.train_dl:
            self.plot_dataloader(
                self.train_dl, join(self.output_dir, 'dataloaders/train.png'),
                batch_limit)
        if self.valid_dl:
            self.plot_dataloader(
                self.valid_dl, join(self.output_dir, 'dataloaders/valid.png'),
                batch_limit)
        if self.test_dl:
            self.plot_dataloader(self.test_dl,
                                 join(self.output_dir, 'dataloaders/test.png'),
                                 batch_limit)

    @staticmethod
    def from_model_bundle(model_bundle_uri: str,
                          tmp_dir: str,
                          cfg: Optional[LearnerConfig] = None,
                          training: bool = False):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        model_path = join(model_bundle_dir, 'model.pth')

        if cfg is None:
            config_path = join(model_bundle_dir, 'pipeline-config.json')

            config_dict = file_to_json(config_path)
            config_dict = upgrade_config(config_dict)

            cfg = build_config(config_dict)
            cfg = cfg.learner

        hub_dir = join(model_bundle_dir, MODULES_DIRNAME)
        model_def_path = None
        loss_def_path = None

        # retrieve existing model definition, if available
        ext_cfg = cfg.model.external_def
        if ext_cfg is not None:
            model_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(
                f'Using model definition found in bundle: {model_def_path}')

        # retrieve existing loss function definition, if available
        ext_cfg = cfg.solver.external_loss_def
        if ext_cfg is not None and training:
            loss_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(f'Using loss definition found in bundle: {loss_def_path}')

        return cfg.build(tmp_dir=tmp_dir,
                         model_path=model_path,
                         model_def_path=model_def_path,
                         loss_def_path=loss_def_path,
                         training=training)

    def save_model_bundle(self):
        """Save a model bundle.

        This is a zip file with the model weights in .pth format and a serialized
        copy of the LearningConfig, which allows for making predictions in the future.
        """
        from rastervision.pytorch_learner.learner_pipeline_config import (
            LearnerPipelineConfig)

        log.info('Creating bundle.')
        model_bundle_dir = join(self.tmp_dir, 'model-bundle')
        make_dir(model_bundle_dir, force_empty=True)

        shutil.copyfile(self.last_model_path,
                        join(model_bundle_dir, 'model.pth'))

        # copy modules into bundle
        if isdir(self.modules_dir):
            log.info('Copying modules into bundle.')
            bundle_modules_dir = join(model_bundle_dir, MODULES_DIRNAME)
            if isdir(bundle_modules_dir):
                shutil.rmtree(bundle_modules_dir)
            shutil.copytree(self.modules_dir, bundle_modules_dir)

        pipeline_cfg = LearnerPipelineConfig(learner=self.cfg)
        save_pipeline_config(pipeline_cfg,
                             join(model_bundle_dir, 'pipeline-config.json'))
        zipdir(model_bundle_dir, self.model_bundle_path)

    def get_start_epoch(self) -> int:
        """Get start epoch.

        If training was interrupted, this returns the last complete epoch + 1.
        """
        start_epoch = 0
        if isfile(self.log_path):
            with open(self.log_path) as log_file:
                last_line = log_file.readlines()[-1]
            last_epoch = int(last_line.split(',')[0].strip())
            start_epoch = last_epoch + 1
        return start_epoch

    def load_init_weights(self):
        """Load the weights to initialize model."""
        if self.cfg.model.init_weights:
            weights_path = download_if_needed(self.cfg.model.init_weights,
                                              self.tmp_dir)
            self.model.load_state_dict(torch.load(weights_path,
                                                  map_location=self.device),
                                       strict=self.cfg.model.load_strict)

    def load_checkpoint(self):
        """Load last weights from previous run if available."""
        if isfile(self.last_model_path):
            log.info('Loading checkpoint from {}'.format(self.last_model_path))
            self.model.load_state_dict(
                torch.load(self.last_model_path, map_location=self.device))

    def to_device(self, x: Any, device: str) -> Any:
        """Load Tensors onto a device.

        Args:
            x: some object with Tensors in it
            device: 'cpu' or 'cuda'

        Returns:
            x but with any Tensors in it on the device
        """
        if isinstance(x, list):
            return [_x.to(device) for _x in x]
        else:
            return x.to(device)

    def train_epoch(self) -> MetricDict:
        """Train for a single epoch."""
        start = time.time()
        self.model.train()
        num_samples = 0
        outputs = []
        with click.progressbar(self.train_dl, label='Training') as bar:
            for batch_ind, (x, y) in enumerate(bar):
                x = self.to_device(x, self.device)
                y = self.to_device(y, self.device)
                batch = (x, y)
                self.opt.zero_grad()
                output = self.train_step(batch, batch_ind)
                output['train_loss'].backward()
                self.opt.step()
                # detach tensors in the output, if any, to avoid memory leaks
                for k, v in output.items():
                    output[k] = v.detach() if isinstance(v, Tensor) else v
                outputs.append(output)
                if self.step_scheduler:
                    self.step_scheduler.step()
                num_samples += x.shape[0]
        metrics = self.train_end(outputs, num_samples)
        end = time.time()
        train_time = datetime.timedelta(seconds=end - start)
        metrics['train_time'] = str(train_time)
        return metrics

    def validate_epoch(self, dl: DataLoader) -> MetricDict:
        """Validate for a single epoch."""
        start = time.time()
        self.model.eval()
        num_samples = 0
        outputs = []
        with torch.no_grad():
            with click.progressbar(dl, label='Validating') as bar:
                for batch_ind, (x, y) in enumerate(bar):
                    x = self.to_device(x, self.device)
                    y = self.to_device(y, self.device)
                    batch = (x, y)
                    output = self.validate_step(batch, batch_ind)
                    outputs.append(output)
                    num_samples += x.shape[0]
        end = time.time()
        validate_time = datetime.timedelta(seconds=end - start)

        metrics = self.validate_end(outputs, num_samples)
        metrics['valid_time'] = str(validate_time)
        return metrics

    def overfit(self):
        """Optimize model using the same batch repeatedly."""
        self.on_overfit_start()

        x, y = next(iter(self.train_dl))
        x = self.to_device(x, self.device)
        y = self.to_device(y, self.device)
        batch = (x, y)

        with click.progressbar(range(self.cfg.solver.overfit_num_steps),
                               label='Overfitting') as bar:
            for step in bar:
                loss = self.train_step(batch, step)['train_loss']
                loss.backward()
                self.opt.step()

                if (step + 1) % 25 == 0:
                    log.info('\nstep: {}'.format(step))
                    log.info('train_loss: {}'.format(loss))

        torch.save(self.model.state_dict(), self.last_model_path)

    def train(self):
        """Training loop that will attempt to resume training if appropriate."""
        self.on_train_start()

        if self.start_epoch > 0 and self.start_epoch <= self.cfg.solver.num_epochs:
            log.info('Resuming training from epoch {}'.format(
                self.start_epoch))

        for epoch in range(self.start_epoch, self.cfg.solver.num_epochs):
            log.info('epoch: {}'.format(epoch))
            train_metrics = self.train_epoch()
            if self.epoch_scheduler:
                self.epoch_scheduler.step()
            valid_metrics = self.validate_epoch(self.valid_dl)
            metrics = dict(epoch=epoch, **train_metrics, **valid_metrics)
            log.info('metrics: {}'.format(metrics))

            self.on_epoch_end(epoch, metrics)

    def on_overfit_start(self):
        """Hook that is called at start of overfit routine."""
        pass

    def on_train_start(self):
        """Hook that is called at start of train routine."""
        pass

    def on_epoch_end(self, curr_epoch, metrics):
        """Hook that is called at end of epoch.

        Writes metrics to CSV and TB, and saves model.
        """
        if not isfile(self.log_path):
            with open(self.log_path, 'w') as log_file:
                log_writer = csv.writer(log_file)
                row = self.metric_names
                log_writer.writerow(row)

        with open(self.log_path, 'a') as log_file:
            log_writer = csv.writer(log_file)
            row = [metrics[k] for k in self.metric_names]
            log_writer.writerow(row)

        if self.cfg.log_tensorboard:
            for key, val in metrics.items():
                if isinstance(val, numbers.Number):
                    self.tb_writer.add_scalar(key, val, curr_epoch)
            for name, param in self.model.named_parameters():
                self.tb_writer.add_histogram(name, param, curr_epoch)
            self.tb_writer.flush()

        torch.save(self.model.state_dict(), self.last_model_path)

        if (curr_epoch + 1) % self.cfg.solver.sync_interval == 0:
            self.sync_to_cloud()

    def eval_model(self, split: str):
        """Evaluate model using a particular dataset split.

        Gets validation metrics and saves them along with prediction plots.

        Args:
            split: the dataset split to use: train, valid, or test.
        """
        log.info('Evaluating on {} set...'.format(split))
        dl = self.get_dataloader(split)
        metrics = self.validate_epoch(dl)
        log.info('metrics: {}'.format(metrics))
        json_to_file(metrics,
                     join(self.output_dir, '{}_metrics.json'.format(split)))
        self.plot_predictions(split, self.preview_batch_limit)
예제 #12
0
def main():
    dir_weight = os.path.join(dir_save, 'weight')
    dir_log = os.path.join(dir_save, 'log')
    os.makedirs(dir_weight, exist_ok=True)
    writer = SummaryWriter(dir_log)

    indexes = [
        int(os.path.splitext(path)[0]) for path in os.listdir(dir_weight)
    ]
    current_step = max(indexes) if indexes else 0

    image_size = 768
    lr = 1e-3
    batch_size = 12
    num_workers = 4

    max_step = 12000
    lr_cfg = [[7500, lr], [max_step, lr / 10]]
    warm_up = [500, lr / 50, lr]
    save_interval = 1000

    aug = Compose([
        ops.ToFloat(),
        ops.PhotometricDistort(),
        ops.RandomHFlip(),
        ops.RandomVFlip(),
        ops.RandomRotate90(),
        ops.ResizeJitter([0.8, 1.2]),
        ops.PadSquare(),
        ops.Resize(image_size),
    ])
    dataset = HRSC2016(dir_dataset, ['trainval'], aug)
    loader = DataLoader(dataset,
                        batch_size,
                        shuffle=True,
                        num_workers=num_workers,
                        pin_memory=True,
                        drop_last=True,
                        collate_fn=dataset.collate)
    num_classes = len(dataset.names)

    prior_box = {
        'strides': [8, 16, 32, 64, 128],
        'sizes': [3] * 5,
        'aspects': [[1.5, 3, 5, 8]] * 5,
        'scales': [[2**0, 2**(1 / 3), 2**(2 / 3)]] * 5,
    }

    cfg = {
        'prior_box': prior_box,
        'num_classes': num_classes,
        'extra': 2,
    }

    model = RDD(backbone(fetch_feature=True), cfg)
    model.build_pipe(shape=[2, 3, image_size, image_size])
    if current_step:
        model.restore(os.path.join(dir_weight, '%d.pth' % current_step))
    else:
        model.init()
    if len(device_ids) > 1:
        model = convert_model(model)
        model = CustomDetDataParallel(model, device_ids)
    model = model.cuda()
    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=0.9,
                          weight_decay=5e-4)
    training = True

    while training and current_step < max_step:
        tqdm_loader = tqdm.tqdm(loader)
        for images, targets, infos in tqdm_loader:
            current_step += 1
            adjust_lr_multi_step(optimizer, current_step, lr_cfg, warm_up)

            images = images.cuda() / 255
            losses = model(images, targets)
            loss = sum(losses.values())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            for key, val in list(losses.items()):
                losses[key] = val.item()
                writer.add_scalar(key, val, global_step=current_step)
            writer.flush()
            tqdm_loader.set_postfix(losses)
            tqdm_loader.set_description(f'<{current_step}/{max_step}>')

            if current_step % save_interval == 0:
                save_path = os.path.join(dir_weight, '%d.pth' % current_step)
                state_dict = model.state_dict() if len(
                    device_ids) == 1 else model.module.state_dict()
                torch.save(state_dict, save_path)
                cache_file = os.path.join(
                    dir_weight, '%d.pth' % (current_step - save_interval))
                if os.path.exists(cache_file):
                    os.remove(cache_file)

            if current_step >= max_step:
                training = False
                writer.close()
                break
예제 #13
0
class OwnRnn(nn.Module):
    def __init__(self,
                 log_dir="out/rnn",
                 num_actions=2,
                 num_states=2,
                 optimizer_str=None,
                 learning_rate=7e-4):
        super(OwnRnn, self).__init__()

        # episode buffer, will hold info for later gradient calulation of a single episode
        # usually contains 200 trials, one trial per row
        self.epbuffer = []

        # 2 is one value for the reward, and one fot the timestamp
        self.input_size = num_actions + num_states + 2
        self.num_actions = num_actions
        self.num_states = num_actions

        self.num_rnn_units = 48

        # input of the rnn has shape [sequence length, batch, inputs]
        # output of rnn has shape: [sequence length / num trials per episode, batch (number of episodes), num_hidden_unit_actiontions]
        # hence output[1,2,30] represents the value of the activation of the 30th unit in the RNN in response to the first input of the 2nd sample (there is multiple samples in one batch; and each sample presents the network with a number of consecutive inputs; here only 1)

        # we dont use RNN but LSTM now
        #self.reccurrent_layer = torch.nn.RNN(input_size=self.input_size, hidden_size= self.num_rnn_units, num_layers =1, nonlinearity = 'relu')

        # do here new https://github.com/ikostrikov/pytorch-a3c/blob/master/model.py
        # self.lstm = nn.LSTMCell(self.input_size, self.num_rnn_units)  -> is wrong, we dont use cell but the LSTM
        self.lstm = nn.LSTM(self.input_size, self.num_rnn_units)

        self.lstm.bias_ih_l0.data.fill_(0)
        self.lstm.bias_hh_l0.data.fill_(0)

        self.action_outp_layer = nn.Linear(in_features=self.num_rnn_units,
                                           out_features=num_actions)
        self.value_outp_layer = nn.Linear(in_features=self.num_rnn_units,
                                          out_features=1)

        self.action_outp_layer.weight.data = normalized_columns_initializer(
            self.action_outp_layer.weight.data, 0.01)
        self.value_outp_layer.weight.data = normalized_columns_initializer(
            self.value_outp_layer.weight.data, 0.01)

        # softmax function for getting the action distributions later on
        # using the second dimension because this one includes the hidden unit activations
        self.act_smx = nn.Softmax(dim=2)

        # hyperparameter
        self.gamma = .9

        # set the optimzer
        if optimizer_str == 'RMS':
            self.optimizer = torch.optim.RMSprop(self.parameters(),
                                                 lr=learning_rate)
        elif optimizer_str == 'Adam':
            self.optimizer_str = torch.optim.Adam(self.parameters(),
                                                  lr=learning_rate)
        else:
            self.optimizer = torch.optim.RMSprop(self.parameters(),
                                                 lr=learning_rate)

        # define the logfolders
        # used for the matplotlib images
        self.output_pref = log_dir + "/ownrnn_ep-"
        self.wr = SummaryWriter(log_dir=log_dir + '/tb')

        return

    def __del__(self):
        self.wr.close()

    def do_training(self,
                    taskenv,
                    num_episodes=20000,
                    single_episode_length=200,
                    stats_every_x_episodes=500):

        start_time = time.time()

        min_acc_episode_reward = 0

        print("Starting...")

        for x in range(num_episodes + 1):

            taskenv.reset()

            # have the network run twice: first accumuluating the episode rollout (run single action by action)
            # so kinda detach the episode accumulation from the gradient calculation and run as a batch on the update weights function

            # first run
            # each episdoe has 200 game steps / trials
            episode_buffer = self.run_x_times(single_episode_length, taskenv)

            # second run
            # for estimation of loss; without interacting anew with the environment...;
            info = self.calc_loss_and_update_weights(episode_buffer, t=x)

            ### from here on, we only have logging functions, can be edited as pleased

            new_highscore = info['acc_ep_reward'] > min_acc_episode_reward

            # this is the progress report info line that updates every 20 trials to the console
            if x % 20 == 0 or new_highscore:

                if new_highscore:
                    min_acc_episode_reward = info['acc_ep_reward']
                minstr = "\t(New Highscore)" if new_highscore else ""

                print(int(x * 100 / num_episodes), "% - ", "Episode: ", x,
                      " \tLoss = ", info['loss'], " \tAccRew = ",
                      info['acc_ep_reward'], minstr)

            # starting from here on, we only keep track of the stats, i.e. value distrubutions of parameters/weights and their gradients

            if x % stats_every_x_episodes == 0:

                # output to console at every x episodes the stats for the most recent episode / gradient update
                print("###### Beg Netw-Starts ######")
                print("Name\t\t\t\t\tavgp\tmedp\tstdp\tminp\tmaxp\tsump")
                print_tensor_param_stats(self.value_outp_layer)
                print_tensor_param_stats(self.action_outp_layer)
                print_tensor_param_stats(self.lstm)
                print("...")
                print_tensor_param_stats(self.value_outp_layer, grad=True)
                print_tensor_param_stats(self.action_outp_layer, grad=True)
                print_tensor_param_stats(self.lstm, grad=True)
                print("###### End Netw-Starts ######")

                # console output of passed and estimated remaining time for all the episodes to complete
                elapsed_time = time.time() - start_time
                exp_total_duration = (elapsed_time / (x + 1e-5)) * num_episodes
                remain_time = exp_total_duration - elapsed_time

                print("-------------------------------")
                print("Total Runtime so far    : ",
                      time.strftime("%H:%M:%S", time.gmtime(elapsed_time)))
                print("Expected reminaing time : ",
                      time.strftime("%H:%M:%S", time.gmtime(remain_time)))
                print(
                    "Expected total duration : ",
                    time.strftime("%H:%M:%S", time.gmtime(exp_total_duration)))
                print("-------------------------------")

                self.plot(x, taskenv.stats)
                taskenv.stats = np.zeros((2, 2, 2))

            # log more detailed states every 2000 episodes to tensorboard (do this more rarely as it may be somewhat resource intensive)
            if x % 2000 == 0:

                log_step = int(x / stats_every_x_episodes) + 1
                add_tb_param_histograms(self.wr, self.action_outp_layer,
                                        log_step)
                add_tb_param_histograms(self.wr, self.value_outp_layer,
                                        log_step)
                self.wr.flush()
                add_tb_param_histograms(self.wr, self.lstm, log_step)
                self.wr.flush()
                add_tb_param_histograms(self.wr,
                                        self.action_outp_layer,
                                        log_step,
                                        grad=True)
                add_tb_param_histograms(self.wr,
                                        self.value_outp_layer,
                                        log_step,
                                        grad=True)
                self.wr.flush()
                add_tb_param_histograms(self.wr,
                                        self.lstm,
                                        log_step,
                                        grad=True)
                self.wr.flush()

                print("addied histograms: ", 1)
                print("-------------------------------")

            ### end if episode % 100 = 0
        ### End for;

    ### end do training

    # this is basically our episode
    def run_x_times(self, number_of_feedforward_steps, taskenv):

        # we never are really done, because our automaton
        # equals one episode with one feedforward thing, so instead
        # we define an episode as enough draws:

        # number_of_feedforward_steps ~ our batch size / episode length

        oh_prev_action = F.one_hot(ts(0), self.num_actions)
        oh_prev_reached_state = F.one_hot(ts(0), self.num_states)
        prev_receivd_rewrd = ts([0])

        # initialize the hidden states / recursive inout to zeros
        cx = torch.zeros(1, self.num_rnn_units).view(1, 1, -1)
        hx = torch.zeros(1, self.num_rnn_units).view(1, 1, -1)

        #one batch ~ one episode of 200 trials, will be saved internally
        self.epbuffer = []

        taskenv.reset()

        for i in range(number_of_feedforward_steps):
            # i is also actually also our timestep variable

            cinput = torch.cat(
                (oh_prev_action, oh_prev_reached_state, prev_receivd_rewrd,
                 ts([i])), 0).float().view(1, 1, self.input_size)

            # run the lstm on a single trial
            out, (hx, cx) = self.lstm(cinput, (hx, cx))

            # LSTM output is used as input for two different layers
            # estimated value of the current state
            # ~ cummulative estimate for the future, kind of
            value = self.value_outp_layer(out)
            policy_out = self.action_outp_layer(out)

            # draw action from the last and only action distribution
            policy_distrib = self.act_smx(policy_out).contiguous()
            act_distr = torch.distributions.Categorical(
                policy_distrib.view(-1, self.num_actions)[-1])
            act = act_distr.sample()

            # mean entopy of our action distribution (not needed)
            # acc_entropy = acc_entropy+act_distr.entropy().mean()

            # execute action in the task_environment
            [reached_state, reward] = taskenv.conduct_action(act.item())
            # out: reached state is either 0 or 1; reward also either 0 or 1

            # do not keep any gradient releveant info, aka dont save any tensors
            # save as: action done -> stated reached upon that action -> reward received for that state, and actually predicted value of that action, all @ timepoint i
            self.epbuffer.append(
                [act.item(), reached_state, reward, i,
                 value.item()])

            #self.a2c_ts_out_buffer.append(SavedAction(act_distr.log_prob(act), value))

            # prepare vars for the next trial
            oh_prev_reached_state = F.one_hot(ts(reached_state),
                                              self.num_states)
            oh_prev_action = F.one_hot(act, self.num_actions)
            prev_receivd_rewrd = ts([reward])

        # end of for number of feedforward steps

        return self.epbuffer

    def calc_loss_and_update_weights(self, epbuffer=None, t=0):

        # run the entire set of 200 trials again, not one by one, but instead as batch
        # use exactly the same data & actions as before
        # just so we have a better way for backpropagating the single
        # as it works by having the entire input as matrix

        # standard procedure is to take the internal buffer, it can also be given explicitly to make it nicer
        # to make the code more readily readable
        if epbuffer != None: epbuffer = self.epbuffer
        epbuffer = np.array(epbuffer)  # just make sure to convert to numpy

        ## prepare the input

        actions = epbuffer[:, 0]  # based on the policy head output of the A2C
        reached_states = epbuffer[:, 1]
        rewards = epbuffer[:, 2].astype(
            np.long
        )  # may be nessesary, as nparray may happen to be of type object np array, if we
        timesteps = epbuffer[:, 3].astype(
            np.long
        )  # i.e. use it to be tensors, otherwise no problem (so could also leave it away)
        pred_values = epbuffer[:,
                               4]  # based on the value head output of the A2C

        prev_actions = [0] + actions[:-1].tolist()  # prev conducted_actions
        prev_reached_states = [0] + reached_states[:-1].tolist(
        )  # previously reaced states through that action
        prev_rewards = [
            0
        ] + rewards[:-1].tolist()  # the result of the previous state

        # network needs tensors as input
        ohprev_actions = F.one_hot(ts(prev_actions).long(),
                                   self.num_actions).long()
        ohprev_reached_states = F.one_hot(
            ts(prev_reached_states).long(), self.num_states).long()
        prev_rewards = ts(prev_rewards).view(len(epbuffer), 1)
        timesteps_ts = ts(timesteps.tolist()).view(len(epbuffer), 1)

        #prev_reached_states = ts(prev_reached_states.tolist()).view(len(epbuffer),2)

        # merge them all horizontally (i.e. one array row contains one trial)
        cinput = torch.cat((ohprev_actions, ohprev_reached_states,
                            prev_rewards, timesteps_ts), 1)

        # transform it all into the right input array of shape
        # i.e. add another dimension for episode id; but we only have one episode of 200 trials to process
        # [trials per episode ~200, numer of episodes ~ 1, input size ~ action+state+rew+ts]
        cinput = cinput.float().view(len(epbuffer), 1, self.input_size)

        ## run the network

        # initialize the recurrence nodes of the LSTM; start with state zero, as should be the beginning of each episode (~200 trials)
        cx = torch.zeros(1, self.num_rnn_units).view(1, 1, -1)
        hx = torch.zeros(1, self.num_rnn_units).view(1, 1, -1)

        # feed the input into the LSTM nodes and get the output
        out, (hx, cx) = self.lstm(cinput, (hx, cx))

        # two heads for the A2C algorithm (actor critic);
        # feed in the output gathered from the hidden nodes of the LSTM
        values = self.value_outp_layer(out)
        policy_out1 = self.action_outp_layer(out)
        policy_out = self.act_smx(policy_out1)

        ## do the loss calculation

        # calculate the policy loss (has biggest influence)

        ohactions = F.one_hot(ts(actions.tolist()).long(), self.num_actions)
        resp_outps = torch.sum(policy_out.squeeze() * ohactions, dim=1)

        value_plus = np.asarray(pred_values.tolist() + [0.0])
        #value_plus = np.asarray(values.squeeze().tolist() + [0.0])
        und_adv = rewards + self.gamma * value_plus[1:] - value_plus[:-1]
        advantages = discount(und_adv, self.gamma)

        policy_loss = -torch.sum(
            torch.log(resp_outps + 1e-7) * ts(advantages.copy()))

        # calculate the value loss
        # compute the targets for the value head
        rewards_plus = np.asarray(rewards.tolist() + [0.0])
        # equals target_v, these are our targets, because they are the only real value we have
        disc_cumm_future_rewards = discount(rewards_plus, self.gamma)[:-1]

        # have to create a copy of the numpy array, to have the conscutive items of the array also in
        # conscutive positions on the memory, which is required for the transformation into a tensor
        diff = ts(disc_cumm_future_rewards.copy()) - values.squeeze()
        value_loss = 0.5 * torch.sum(diff * diff)

        # calculate the entropy loss
        # how certain is the network of its own decision
        entropy_loss = -torch.sum(policy_out * torch.log(policy_out + 1e-7))

        # conbine it all into one loss
        loss = 0.05 * value_loss + policy_loss - 0.05 * entropy_loss

        # reset the gradient
        self.optimizer.zero_grad()

        # calculate the gradient
        #loss.backward(retain_graph=True);
        loss.backward()

        # make sure the gradient is not too big
        torch.nn.utils.clip_grad_norm_(self.parameters(), 999.0)

        ### Here do all the bookkeeping
        # gradient will be applied afterwards
        self.wr.add_scalars(
            'losses', {
                'loss': loss,
                'val_loss': value_loss,
                'pol_loss': policy_loss,
                'ent_loss': entropy_loss
            }, t)

        self.wr.add_scalar('sum_rewards', rewards.sum(), t)

        # plot the parameters before the gradients have been applied
        self.wr.add_scalars(
            'ValueLayerParams',
            get_tb_dir_for_tensor_param_stats(self.value_outp_layer), t)
        self.wr.add_scalars(
            'PolcyLayerParams',
            get_tb_dir_for_tensor_param_stats(self.action_outp_layer), t)
        self.wr.add_scalars('LSTMNLayerParams',
                            get_tb_dir_for_tensor_param_stats(self.lstm), t)

        self.wr.add_scalars(
            'ValueLayerCGrads',
            get_tb_dir_for_tensor_param_stats(self.value_outp_layer,
                                              grad=True), t)
        self.wr.add_scalars(
            'PolcyLayerCGrads',
            get_tb_dir_for_tensor_param_stats(self.action_outp_layer,
                                              grad=True), t)
        self.wr.add_scalars(
            'LSTMNLayerCGrads',
            get_tb_dir_for_tensor_param_stats(self.lstm, grad=True), t)

        # apply the gradient
        self.optimizer.step()

        return {'loss': loss.item(), 'acc_ep_reward': rewards.sum().item()}

    ## from here on, we only have plotting functions
    def plot(self, episode_count, transition_count):

        fig, ax = plt.subplots()
        x = np.arange(2)
        ax.set_ylim([0.0, 1.0])
        ax.set_ylabel('Stay Probability')

        row_sums = transition_count.sum(axis=-1)
        stay_probs = transition_count / row_sums[:, :, np.newaxis]

        # own rsc
        uncommon = [stay_probs[1, 0, 1], stay_probs[0, 0, 1]]
        common = [stay_probs[1, 1, 1], stay_probs[0, 1, 1]]

        ax.set_xticks([1.3, 3.3])
        ax.set_xticklabels(['Last trial rewarded', 'Last trial not rewarded'])

        c = plt.bar([1, 3], common, color='b', width=0.5)
        uc = plt.bar([1.8, 3.8], uncommon, color='r', width=0.5)
        ax.legend((c[0], uc[0]), ('common', 'uncommon'))

        path = self.output_pref + str(episode_count) + ".png"
        plt.savefig(path)
        print("Saved plot as: ", path)
    train(model,
          dataloader,
          epoch_num,
          criterion,
          optimizer,
          scheduler,
          device,
          tensorboard_writer,
          os.path.join(save_model_path, model_name, 'Figures'),
          plot_steps=plot_steps,
          stop_condition=stop_condition,
          sample_weights=sample_weights)
    torch.save(model.state_dict(), save_model_path_filename)

    tensorboard_writer.flush()
    tensorboard_writer.close()
else:  # Load old model with the given name
    model.load_state_dict(torch.load(save_model_path_filename), strict=False)

# TESTING
validation_dataset = UnlabeledDataset(valid_data_path,
                                      transformations=transformation_pipeline)
valid_dataloader = DataLoader(validation_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=num_workers)

valid_outputs, valid_losses, valid_filenames = test(model, valid_dataloader,
                                                    criterion, device)
print("Validation outputs")
예제 #15
0
def main():
    # dataset
    trainset, valset, testset = build_datasets(args.dataset, args.base_size,
                                               args.crop_size)

    # 定义 student/teacher 模型
    student = BiSeNet(trainset.num_classes,
                      context_path='resnet18',
                      in_planes=32)
    teacher = BiSeNet(trainset.num_classes,
                      context_path='resnet101',
                      in_planes=64)
    print_model_parm_nums(student,
                          'student')  # student: Number of params: 5.66 M
    print_model_parm_nums(teacher,
                          'teacher')  # teacher: Number of params: 132.92 M

    # 加载 student/teacher 已经训练好的模型
    device = f'cuda:{args.gpu_ids}'
    load_state_dict(
        student,
        'runs/SUNRGBD/res18_inp32_deconv_Jul27_100319/checkpoint.pth.tar',
        device)
    load_state_dict(
        teacher,
        'runs/SUNRGBD/res101_inp64_deconv_Jul26_205859/checkpoint.pth.tar',
        device)

    class_weights = None
    if args.use_balanced_weights:  # default false
        class_weights = np.array([  # med_freq
            0.382900,
            0.452448,
            0.637584,
            0.377464,
            0.585595,
            0.479574,
            0.781544,
            0.982534,
            1.017466,
            0.624581,
            2.589096,
            0.980794,
            0.920340,
            0.667984,
            1.172291,  # 15
            0.862240,
            0.921714,
            2.154782,
            1.187832,
            1.178115,  # 20
            1.848545,
            1.428922,
            2.849658,
            0.771605,
            1.656668,  # 25
            4.483506,
            2.209922,
            1.120280,
            2.790182,
            0.706519,  # 30
            3.994768,
            2.220004,
            0.972934,
            1.481525,
            5.342475,  # 35
            0.750738,
            4.040773  # 37
        ])

    saver = Saver(args, timestamp=get_curtime())
    writer = SummaryWriter(saver.experiment_dir)

    trainer = Trainer(args, student, teacher, trainset, valset, testset,
                      class_weights, saver, writer)

    start_epoch = 0

    miou_caches = AccCaches(patience=5)  # miou
    for epoch in range(start_epoch, args.epochs):
        trainer.training(epoch)
        if epoch % args.eval_interval == (args.eval_interval - 1):
            miou, pixelAcc = trainer.validation(epoch)
            miou_caches.add(epoch, miou)
            if miou_caches.full():
                print('acc caches:', miou_caches.accs)
                print('best epoch:', trainer.best_epoch, 'best miou:',
                      trainer.best_mIoU)
                _, max_miou = miou_caches.max_cache_acc()
                if max_miou < trainer.best_mIoU:
                    print('end training')
                    break

    print('valid')
    print('best mIoU:', trainer.best_mIoU, 'pixelAcc:', trainer.best_pixelAcc)

    # test
    epoch = trainer.load_best_checkpoint()
    test_mIoU, test_pixelAcc = trainer.validation(epoch, test=True)
    print('test')
    print('best mIoU:', test_mIoU, 'pixelAcc:', test_pixelAcc)

    writer.flush()
    writer.close()
예제 #16
0
def train(model_name, fold, run=None, resume_epoch=-1):
    model_str = build_model_str(model_name, fold, run)

    model_info = MODELS[model_name]

    checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}'
    tensorboard_dir = f'{BaseConfig.tensorboard_dir}/{model_str}'
    oof_dir = f'{BaseConfig.oof_dir}/{model_str}'
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(oof_dir, exist_ok=True)
    print('\n', model_name, '\n')

    logger = SummaryWriter(log_dir=tensorboard_dir)

    model = model_info.factory(**model_info.args)
    model = model.cuda()

    # try:
    #     torchsummary.summary(model, (4, 512, 512))
    #     print('\n', model_name, '\n')
    # except:
    #     raise
    #     pass

    model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    dataset_train = dataset.IntracranialDataset(
        csv_file='5fold-rev3.csv',
        folds=[f for f in range(BaseConfig.nb_folds) if f != fold],
        preprocess_func=albumentations.Compose([
            albumentations.ShiftScaleRotate(shift_limit=16. / 256, scale_limit=0.05, rotate_limit=30,
                                            interpolation=cv2.INTER_LINEAR,
                                            border_mode=cv2.BORDER_REPLICATE,
                                            p=0.7),
            albumentations.Flip(),
            albumentations.RandomRotate90(),
        ]),
        **{**model_info.dataset_args, "segmentation_oversample": 1}
    )

    dataset_valid = dataset.IntracranialDataset(
        csv_file='5fold.csv',
        folds=[fold],
        preprocess_func=None,
        **{**model_info.dataset_args, "segmentation_oversample": 1}
    )

    data_loaders = {
        'train': DataLoader(dataset_train,
                            num_workers=8,
                            shuffle=True,
                            batch_size=model_info.batch_size),
        'val': DataLoader(dataset_valid,
                          shuffle=False,
                          num_workers=8,
                          batch_size=model_info.batch_size)
    }

    dataset_train_1_slice = None
    if model_info.single_slice_steps > 0:
        dataset_train_1_slice = dataset.IntracranialDataset(
            csv_file='5fold-rev3.csv',
            folds=[f for f in range(BaseConfig.nb_folds) if f != fold],
            preprocess_func=albumentations.Compose([
                albumentations.ShiftScaleRotate(shift_limit=16. / 256, scale_limit=0.05, rotate_limit=30,
                                                interpolation=cv2.INTER_LINEAR,
                                                border_mode=cv2.BORDER_REPLICATE,
                                                p=0.75),
                albumentations.Flip(),
                albumentations.RandomRotate90()
            ]),
            **{**model_info.dataset_args, "num_slices": 1}
        )

        dataset_valid_1_slice = dataset.IntracranialDataset(
            csv_file='5fold.csv',
            folds=[fold],
            preprocess_func=None,
            **{**model_info.dataset_args, "num_slices": 1, "segmentation_oversample": 1}
        )

        data_loaders['train_1_slice'] = DataLoader(
            dataset_train_1_slice,
            num_workers=8,
            shuffle=True,
            batch_size=model_info.batch_size * 2)
        data_loaders['val_1_slice'] = DataLoader(
            dataset_valid_1_slice,
            shuffle=False,
            num_workers=8,
            batch_size=model_info.batch_size * 2)

    model.train()
    optimizer = radam.RAdam(model.parameters(), lr=model_info.initial_lr)

    milestones = [5, 10, 16]
    if model_info.optimiser_milestones:
        milestones = model_info.optimiser_milestones
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2)

    print(f'Num training images: {len(dataset_train)} validation images: {len(dataset_valid)}')

    if resume_epoch > -1:
        checkpoint = torch.load(f'{checkpoints_dir}/{resume_epoch:03}.pt')
        print('load', f'{checkpoints_dir}/{resume_epoch:03}.pt')
        model.module.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 2.0]).cuda()

    def criterium(y_pred, y_true):
        return F.binary_cross_entropy_with_logits(y_pred, y_true, class_weights.repeat(y_pred.shape[0], 1))

    def criterium_mask(y_pred, y_true, have_segmentation):
        if not max(have_segmentation):
            return 0
        return F.binary_cross_entropy(y_pred[have_segmentation], y_true[have_segmentation]) * 10

    # criterium = nn.BCEWithLogitsLoss()

    # fit new layers first:
    if resume_epoch == -1 and model_info.is_pretrained:
        model.train()
        model.module.freeze_encoder()
        data_loader = data_loaders.get('train_1_slice', data_loaders['train'])
        pre_fit_steps = 50000 // model_info.batch_size
        data_iter = tqdm(enumerate(data_loader), total=pre_fit_steps)
        epoch_loss = []
        epoch_loss_mask = []
        initial_optimizer = radam.RAdam(model.parameters(), lr=1e-4)
        for iter_num, data in data_iter:
            if iter_num > pre_fit_steps:
                break
            with torch.set_grad_enabled(True):
                img = data['image'].float().cuda()
                labels = data['labels'].cuda()
                segmentation_labels = data['seg'].cuda()
                have_segmentation = data['have_segmentation']
                have_any_segmentation = max(have_segmentation)

                pred, segmentation = model(img)

                loss_cls = criterium(pred, labels)
                loss_mask = criterium_mask(segmentation, F.max_pool2d(segmentation_labels, 4), have_segmentation)
                (loss_cls + loss_mask).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 100.0)
                initial_optimizer.step()
                initial_optimizer.zero_grad()
                epoch_loss.append(float(loss_cls))
                if have_any_segmentation:
                    epoch_loss_mask.append(float(loss_mask))

                data_iter.set_description(
                    f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}' +
                    f' Running mask {np.mean(epoch_loss_mask[-500:]):1.4f} Mask {np.mean(epoch_loss_mask):1.4f}')
    model.module.unfreeze_encoder()

    for epoch_num in range(resume_epoch + 1, 8):
        if epoch_num > 3 and dataset_train_1_slice is not None:
            dataset_train_1_slice.segmentation_oversample = 1

        for phase in ['train', 'val']:
            model.train(phase == 'train')
            epoch_loss = []
            epoch_loss_mask = []
            epoch_labels = []
            epoch_predictions = []
            epoch_sample_paths = []

            if 'on_epoch' in model.module.__dir__():
                model.module.on_epoch(epoch_num)

            if epoch_num < model_info.single_slice_steps:
                data_loader = data_loaders[phase + '_1_slice']
                print("use 1 slice input")
            else:
                data_loader = data_loaders[phase]
                print("use N slices input")

            # if epoch_num == model_info.single_slice_steps:
            #     print("train only conv slices/fn layers")
            #     model.module.freeze_encoder_full()
            #
            # if epoch_num == model_info.single_slice_steps+1:
            #     print("train all")
            #     model.module.unfreeze_encoder()
            #
            # if -1 < model_info.freeze_bn_step <= epoch_num:
            #     print("freeze bn")
            #     model.module.freeze_bn()

            data_iter = tqdm(enumerate(data_loader), total=len(data_loader))
            for iter_num, data in data_iter:
                img = data['image'].float().cuda()
                labels = data['labels'].float().cuda()
                segmentation_labels = data['seg'].cuda()
                have_segmentation = data['have_segmentation']
                have_any_segmentation = max(have_segmentation)

                with torch.set_grad_enabled(phase == 'train'):
                    pred, segmentation = model(img)

                    loss_cls = criterium(pred, labels)
                    loss_mask = criterium_mask(segmentation, F.max_pool2d(segmentation_labels, 4), have_segmentation)

                    if phase == 'train':
                        ((loss_cls + loss_mask) / model_info.accumulation_steps).backward()
                        if (iter_num + 1) % model_info.accumulation_steps == 0:
                            torch.nn.utils.clip_grad_norm_(model.parameters(), 16.0)
                            optimizer.step()
                            optimizer.zero_grad()

                    epoch_loss.append(float(loss_cls))
                    if have_any_segmentation:
                        epoch_loss_mask.append(float(loss_mask))

                    epoch_labels.append(labels.detach().cpu().numpy())
                    epoch_predictions.append(torch.sigmoid(pred).detach().cpu().numpy())
                    epoch_sample_paths += data['path']

                data_iter.set_description(
                    f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}' +
                    f' Running mask {np.mean(epoch_loss_mask[-500:]):1.4f} Mask {np.mean(epoch_loss_mask):1.4f}')

            epoch_labels = np.row_stack(epoch_labels)
            epoch_predictions = np.row_stack(epoch_predictions)

            logger.add_scalar(f'loss_{phase}', np.mean(epoch_loss), epoch_num)
            logger.add_scalar(f'loss_mask_{phase}', np.mean(epoch_loss_mask), epoch_num)
            logger.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch_num)  # scheduler.get_lr()[0]
            try:
                log_metrics(logger=logger, phase=phase, epoch_num=epoch_num, y=epoch_labels, y_hat=epoch_predictions)
            except Exception:
                pass
            logger.flush()

            if phase == 'val':
                scheduler.step(epoch=epoch_num)
                torch.save(
                    {
                        'epoch': epoch_num,
                        'sample_paths': epoch_sample_paths,
                        'epoch_labels': epoch_labels,
                        'epoch_predictions': epoch_predictions,
                    },
                    f'{oof_dir}/{epoch_num:03}.pt'
                )
            else:
                # print(f'{checkpoints_dir}/{epoch_num:03}.pt')
                torch.save(
                    {
                        'epoch': epoch_num,
                        'model_state_dict': model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                    },
                    f'{checkpoints_dir}/{epoch_num:03}.pt'
                )
예제 #17
0
def trainAttention(model=NaiveAttention,dimbed=50,dimout=2,maxiter=epochs,epsilon=0.01,reg=None,verbose=True):
    # Le paramètre "reg" permet de choisir de régulariser ou non avec un critère entropique.
    writer = SummaryWriter("runs")
    
    if model == NaiveAttention:
        print("\n//////////////////// Attention-based LinNet : naive baseline /////////////////\n")
        name = "attentionbase.pch"
        etiq = "/Base_SGD"
    elif model == SimpleAttention:
        print("\n///////////////// Attention-based LinNet : basic implementation //////////////\n")
        name = "attentionclassic.pch"
        etiq = "/Classic_SGD"
    elif model == FurtherAttention:
        print("\n///////////////// Attention-based LinNet : further improvements //////////////\n")
        name = "attentionfurther.pch"
        etiq = "/Further_SGD_regul"
    elif model == LSTMAttention:
        print("\n//////////////////// Attention-based LinNet : adding an LSTM /////////////////\n")
        name = "attentionlstm.pch"
        etiq = "/LSTM_SGD"
    elif model == BILSTMAttention:
        print("\n//////////////////// Attention-based LinNet : adding an BiLSTM /////////////////\n")
        name = "attentionbilstm.pch"
        etiq = "/BiLSTM_SGD"
    # Creating a checkpointed model
    savepath = Path(name)
    if savepath.is_file():
        print("Restarting from previous state.")
        with savepath.open("rb") as fp :
            state = torch.load(fp)
    else:
        lin = model(dimbed,dimout).to(device)
        optim = torch.optim.SGD(params=lin.parameters(),lr=epsilon)
        state = State(lin,optim)
    
    loss = nn.CrossEntropyLoss()
    
    # Training the model
    for epoch in tqdm(range(state.epoch,maxiter)):
        state.model = state.model.train()
        losstrain = 0
        accytrain = 0
        divtrain = 0
        for x, y in train_loader:
            state.optim.zero_grad()
            y = y.to(device)
            if model == NaiveAttention: preds = state.model(x)
            else: preds, attns = state.model(x)  
            if model != NaiveAttention:
                entropytrain = Categorical(probs = attns.squeeze(2).t()).entropy()
            penalty = reg * torch.sum(entropytrain) if reg else 0
            ltrain = loss(preds,y.long()) + penalty
            
            ltrain.backward()
            state.optim.step()
            state.iteration += 1
            acctr = sum((preds.argmax(1) == y)).item() / y.shape[0]
            losstrain += ltrain
            accytrain += acctr
            divtrain += 1
            
        #if model != NaiveAttention:
        #    entropytrain = Categorical(probs = attns.squeeze(2).t()).entropy()
            
        state.model = state.model.eval()
        losstest = 0
        accytest = 0
        divtest = 0
        for x, y in test_loader:
            with torch.no_grad():
                y = y.to(device)
                if model == NaiveAttention :
                    preds = state.model(x)
                else :
                    preds, attns = state.model(x)                
                ltest = loss(preds,y.long()) 
                accts = sum((preds.argmax(1) == y)).item() / y.shape[0]  
            losstest += ltest
            accytest += accts
            divtest += 1
            
        # Saving the loss
        writer.add_scalars('Attention/Loss'+etiq,{'train':losstrain/divtrain,'test':losstest/divtest},epoch)
        writer.add_scalars('Attention/Accuracy'+etiq,{'train':accytrain/divtrain,'test':accytest/divtest},epoch)
        
        if model != NaiveAttention :
            entropytest = Categorical(probs = attns.squeeze(2).t()).entropy()
            writer.add_histogram('Attention/EntropyTest'+etiq,entropytest,epoch)
            writer.add_histogram('Attention/EntropyTrain'+etiq,entropytrain,epoch)
        
        if verbose:
            print('\nLOSS: \t\ttrain',(losstrain/divtrain).item(),'\t\ttest',(losstest/divtest).item())
            print('ACCURACY: \ttrain',accytrain/divtrain,'\t\ttest',accytest/divtest)
        
        # Saving the current state after each epoch
        with savepath.open ("wb") as fp:
            state.epoch = epoch+1
            torch.save(state, fp)
            
    print("\n\n\033[1mDone.\033[0m\n")
    writer.flush()
    writer.close()

# ////////////////////////////////////////////////////////////////////////////////////////////////// </training loop> ////
예제 #18
0
def main(args=None):
	parser = argparse.ArgumentParser(description='Simple training script for training a RetinaNet network.')

	parser.add_argument('--dataset', help='Dataset type, must be one of csv or coco.')
	parser.add_argument('--coco_path', help='Path to COCO directory')
	parser.add_argument('--csv_train', help='Path to file containing training annotations (see readme)')
	parser.add_argument('--csv_classes', help='Path to file containing class list (see readme)')
	parser.add_argument('--csv_val', help='Path to file containing validation annotations (optional, see readme)')

	parser.add_argument('--depth', help='Resnet depth, must be one of 18, 34, 50, 101, 152', type=int, default=50)
	parser.add_argument('--config', help='Config file path that contains scale and ratio values', type=str)
	parser.add_argument('--epochs', help='Number of epochs', type=int, default=50)
	parser.add_argument('--init-lr', help='Initial learning rate for training process', type=float, default=1e-3)
	parser.add_argument('--batch-size', help='Number of input images per step', type=int, default=1)
	parser.add_argument('--num-workers', help='Number of worker used in dataloader', type=int, default=1)

	# For resuming training from saved checkpoint
	parser.add_argument('--resume', help='Whether to resume training from checkpoint', action='store_true')
	parser.add_argument('--saved-ckpt', help='Resume training from this checkpoint', type=str)

	parser.add_argument('--multi-gpus', help='Allow to use multi gpus for training task', action='store_true')
	parser.add_argument('--snapshots', help='Location to save training snapshots', type=str, default="snapshots")

	parser.add_argument('--log-dir', help='Location to save training logs', type=str, default="logs")
	parser.add_argument('--expr-augs', help='Allow to use use experiment augmentation methods', action='store_true')
	parser.add_argument('--aug-methods', help='(Experiment) Augmentation methods to use, separate by comma symbol', type=str, default="rotate,hflip,brightness,contrast")
	parser.add_argument('--aug-prob', help='Probability of applying (experiment) augmentation in range [0.,1.]', type=float, default=0.5)

	parser = parser.parse_args(args)

	train_transforms = [Normalizer(), Resizer(), Augmenter()]

	# Define transform methods
	if parser.expr_augs:
		aug_map = get_aug_map(p=parser.aug_prob)
		aug_methods = parser.aug_methods.split(",")
		for aug in aug_methods:
			if aug in aug_map.keys():
				train_transforms.append(aug_map[aug])
			else:
				print(f"{aug} is not available.")

	# Create the data loaders
	if parser.dataset == 'coco':

		if parser.coco_path is None:
			raise ValueError('Must provide --coco_path when training on COCO,')

		dataset_train = CocoDataset(parser.coco_path, set_name='train2017',
									transform=transforms.Compose(train_transforms))
		dataset_val = CocoDataset(parser.coco_path, set_name='val2017',
								  transform=transforms.Compose([Normalizer(), Resizer()]))

	elif parser.dataset == 'csv':

		if parser.csv_train is None:
			raise ValueError('Must provide --csv_train when training on COCO,')

		if parser.csv_classes is None:
			raise ValueError('Must provide --csv_classes when training on COCO,')

		dataset_train = CSVDataset(train_file=parser.csv_train, class_list=parser.csv_classes,
								   transform=transforms.Compose(train_transforms))

		if parser.csv_val is None:
			dataset_val = None
			print('No validation annotations provided.')
		else:
			dataset_val = CSVDataset(train_file=parser.csv_val, class_list=parser.csv_classes,
									 transform=transforms.Compose([Normalizer(), Resizer()]))

	else:
		raise ValueError('Dataset type not understood (must be csv or coco), exiting.')

	sampler = AspectRatioBasedSampler(dataset_train, batch_size=parser.batch_size, drop_last=False)
	dataloader_train = DataLoader(dataset_train, num_workers=parser.num_workers, collate_fn=collater, batch_sampler=sampler)

	if dataset_val is not None:
		sampler_val = AspectRatioBasedSampler(dataset_val, batch_size=parser.batch_size, drop_last=False)
		dataloader_val = DataLoader(dataset_val, num_workers=parser.num_workers, collate_fn=collater, batch_sampler=sampler_val)

	config = dict({"scales": None,
					"ratios": None})
	
	if parser.config:
		config = load_config(parser.config, config)

	if parser.depth == 18:
		retinanet = model.resnet18(num_classes=dataset_train.num_classes(), pretrained=True, ratios=config["ratios"], scales=config["scales"])
	elif parser.depth == 34:
		retinanet = model.resnet34(num_classes=dataset_train.num_classes(), pretrained=True, ratios=config["ratios"], scales=config["scales"])
	elif parser.depth == 50:
		retinanet = model.resnet50(num_classes=dataset_train.num_classes(), pretrained=True, ratios=config["ratios"], scales=config["scales"])
	elif parser.depth == 101:
		retinanet = model.resnet101(num_classes=dataset_train.num_classes(), pretrained=True, ratios=config["ratios"], scales=config["scales"])
	elif parser.depth == 152:
		retinanet = model.resnet152(num_classes=dataset_train.num_classes(), pretrained=True, ratios=config["ratios"], scales=config["scales"])
	else:
		raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')

	optimizer = optim.Adam(retinanet.parameters(), lr=parser.init_lr)

	if parser.resume:
		if not parser.saved_ckpt:
			print("No saved checkpoint provided for resuming training. Exiting now...")
			return 
		if not os.path.exists(parser.saved_ckpt):
			print("Invalid saved checkpoint path. Exiting now...")
			return

		# Restore last state
		retinanet, optimizer, start_epoch = load_ckpt(parser.saved_ckpt, retinanet, optimizer)
		if parser.epochs <= start_epoch:
			print("Number of epochs must be higher than number of trained epochs of saved checkpoint.")
			return

	use_gpu = True

	if use_gpu:
		print("Using GPU for training process")
		if torch.cuda.is_available():
			if parser.multi_gpus:
				print("Using multi-gpus for training process")
				retinanet = torch.nn.DataParallel(retinanet.cuda(), device_ids=[0,1])
			else:
				retinanet = torch.nn.DataParallel(retinanet.cuda())
	else:
		retinanet = torch.nn.DataParallel(retinanet)

	retinanet.training = True

	scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=1, verbose=True)

	loss_hist = collections.deque(maxlen=500)

	retinanet.train()
	retinanet.module.freeze_bn()

	print('Num training images: {}'.format(len(dataset_train)))

	# Tensorboard writer
	writer = SummaryWriter(parser.log_dir)

	# Save snapshots dir
	if not os.path.exists(parser.snapshots):
		os.makedirs(parser.snapshots)

	best_mAP = 0
	start_epoch = 0 if not parser.resume else start_epoch 

	for epoch_num in range(start_epoch, parser.epochs):

		retinanet.train()
		retinanet.module.freeze_bn()

		epoch_loss = []
		epoch_csf_loss = []
		epoch_reg_loss = []

		for iter_num, data in enumerate(dataloader_train):
			try:
				optimizer.zero_grad()

				if torch.cuda.is_available():
					with torch.cuda.device(0):
						classification_loss, regression_loss = retinanet([data['img'].cuda().float(), data['annot']])
				else:
					classification_loss, regression_loss = retinanet([data['img'].float(), data['annot']])
					
				classification_loss = classification_loss.mean()
				regression_loss = regression_loss.mean()

				loss = classification_loss + regression_loss
				epoch_csf_loss.append(float(classification_loss))
				epoch_reg_loss.append(float(regression_loss))

				if bool(loss == 0):
					continue

				loss.backward()

				torch.nn.utils.clip_grad_norm_(retinanet.parameters(), 0.1)

				optimizer.step()

				loss_hist.append(float(loss))

				epoch_loss.append(float(loss))

				print(
					'\rEpoch: {}/{} | Iteration: {}/{} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'.format(
						(epoch_num + 1), parser.epochs, (iter_num + 1), len(dataloader_train), float(classification_loss), float(regression_loss), np.mean(loss_hist)), end='')

				del classification_loss
				del regression_loss
			except Exception as e:
				print(e)
				continue

		# writer.add_scalar("Loss/train", loss, epoch_num)

		_epoch_loss = np.mean(epoch_loss)
		_epoch_csf_loss = np.mean(epoch_reg_loss)
		_epoch_reg_loss = np.mean(epoch_reg_loss)

		if parser.dataset == 'coco':

			print('Evaluating dataset')

			coco_eval.evaluate_coco(dataset_val, retinanet)

			scheduler.step(_epoch_loss)

		elif parser.dataset == 'csv' and parser.csv_val is not None:

			print('\nEvaluating dataset')

			APs = csv_eval.evaluate(dataset_val, retinanet)
			mAP = round(mean(APs[ap][0] for ap in APs.keys()), 5)
			print("mAP: %f" %mAP)
			writer.add_scalar("validate/mAP", mAP, epoch_num)
			
			# Handle lr_scheduler wuth mAP value
			scheduler.step(mAP)


		lr = get_lr(optimizer)
		writer.add_scalar("train/classification-loss", _epoch_csf_loss, epoch_num)
		writer.add_scalar("train/regression-loss", _epoch_reg_loss, epoch_num)
		writer.add_scalar("train/loss", _epoch_loss, epoch_num)
		writer.add_scalar("train/learning-rate", lr, epoch_num)

		# Save model file, optimizer and epoch number

		checkpoint = {
		    'epoch': epoch_num,
		    'state_dict': retinanet.state_dict(),
		    'optimizer': optimizer.state_dict(),
		}

		# torch.save(retinanet.module, os.path.join(parser.snapshots, '{}_retinanet_{}.pt'.format(parser.dataset, epoch_num)))
		
		# Check whether this epoch's model achieves highest mAP value
		is_best = False
		if best_mAP < mAP:
			best_mAP = mAP 
			is_best = True  

		save_ckpt(checkpoint, is_best, parser.snapshots, '{}_retinanet_{}.pt'.format(parser.dataset, epoch_num + 1))

		print('\n')

	retinanet.eval()

	torch.save(retinanet, 'model_final.pt')

	writer.flush()
예제 #19
0
class Visualizer():
    """
        Defines the functions needed to write statistics and the output images to a tensorboard file.
    
    """
    def __init__(self, log_dir):
        """
            log_dir = folder where the tensorboard log should be stored
        """

        self.summary_writer = SummaryWriter(log_dir=log_dir)

    # Write the learning rate to the tensorboard file
    def write_lr(self, optim, globaliter):
        for i, param_group in enumerate(optim.param_groups):
            self.summary_writer.add_scalar('learning_rate/lr_'+ str(i) , param_group['lr'], globaliter)
        self.summary_writer.flush()
    
    # Write statistics in text form to the tensorboard file. Training time or maximum GPU memory usage.
    def write_text(self, value, statType):
        if statType == 'Time':
            self.summary_writer.add_text('Time','Total Training time: ' + value + ' seconds')
        elif statType == 'Memory':
            self.summary_writer.add_text('GPU', 'Maximum GPU usage: ' + value + ' MiB')
        self.summary_writer.flush()
    
    # Write the training loss to the tensorboard file    
    def write_loss_train(self, value, globaliter):
        self.summary_writer.add_scalar('Loss/train', value, globaliter)
        self.summary_writer.flush()

    # Write the validation loss to the tensorboard file    
    def write_loss_validation(self, value, globaliter, if_testtimes=False):
        if if_testtimes:
            postfix = '_testtimes'
        else:
            postfix = ''

        self.summary_writer.add_scalar('Loss/validation'+postfix, value, globaliter)
        self.summary_writer.flush()
    
    # Write the output images to the tensorboard file    
    def write_image(self, images, epoch, if_predict=False, if_testtimes=False, includeHeading = False):
        """
            images: Prediction or ground truth in image form
            epoch: Current epoch of the training
            if_predict: Is it a prediction or a ground truth?
            if_testtimes: Is the validation happening on the testtimes only?
            includeHeading: Is the heading channel included in the training?
        """
        
        
        if if_testtimes:
            postfix = '_testtimes'
        else:
            postfix = ''
            
        # restructure the data, save volume, speed and heading separately
        if len(images.shape) == 4:
            _, _, row, col = images.shape
            vol_batch = torch.zeros((3, 1 , row, col))
            speed_batch = torch.zeros((3, 1 , row, col))
            if includeHeading:
                head_batch = torch.zeros((3, 1 , row, col))

            # volume
            vol_batch[0] = images[0,0,:,:]
            vol_batch[1] = images[0,3,:,:]
            vol_batch[2] = images[0,6,:,:]
            # speed
            speed_batch[0] = images[0,1,:,:]
            speed_batch[1] = images[0,4,:,:]
            speed_batch[2] = images[0,7,:,:]

            # heading
            if includeHeading:
                head_batch[0] = images[0,2,:,:]
                head_batch[1] = images[0,5,:,:]
                head_batch[2] = images[0,8,:,:]
            
            
            
        else:
            
            _, _, _, row, col = images.shape
            vol_batch = torch.zeros((3, 1 , row, col))
            speed_batch = torch.zeros((3, 1 , row, col))
            if includeHeading:
                head_batch = torch.zeros((3, 1 , row, col))

            # volume
            vol_batch[0] = images[0,0,0,:,:]
            vol_batch[1] = images[0,1,0,:,:]
            vol_batch[2] = images[0,2,0,:,:]
            
            
            # speed
            speed_batch[0] = images[0,0,1,:,:]
            speed_batch[1] = images[0,1,1,:,:]
            speed_batch[2] = images[0,2,1,:,:]
            # heading
            if includeHeading:
                head_batch[0] = images[0,0,2,:,:]
                head_batch[1] = images[0,1,2,:,:]
                head_batch[2] = images[0,2,2,:,:]
            
        # add images to the tensorboard file
        if if_predict:
            
            vol_batch = torchvision.utils.make_grid(vol_batch, normalize=True)
            self.summary_writer.add_image('prediction'+postfix+'/volume', vol_batch, epoch)

            speed_batch = torchvision.utils.make_grid(speed_batch, normalize=True)
            self.summary_writer.add_image('prediction'+postfix+'/speed', speed_batch, epoch)

            if includeHeading:
                head_batch = torchvision.utils.make_grid(head_batch, normalize=True)
                self.summary_writer.add_image('prediction'+postfix+'/heading', head_batch, epoch)

        else:
            vol_batch = torchvision.utils.make_grid(vol_batch, normalize=True)
            self.summary_writer.add_image('ground_truth'+postfix+'/volume', vol_batch, epoch)

            speed_batch = torchvision.utils.make_grid(speed_batch, normalize=True)
            self.summary_writer.add_image('ground_truth'+postfix+'/speed', speed_batch, epoch)

            if includeHeading:
                head_batch = torchvision.utils.make_grid(head_batch, normalize=True)
                self.summary_writer.add_image('ground_truth'+postfix+'/heading', head_batch, epoch)
        
        # Apply changes to tensorboard file
        self.summary_writer.flush()

    def close(self):
        self.summary_writer.close()
예제 #20
0
def train_dynamics(config,
                   train_dir, # str: directory to save output
                   multi_episode_dict, # multi_episode_dict
                   ):

    use_precomputed_keypoints = config['dataset']['visual_observation']['enabled'] and config['dataset']['visual_observation']['descriptor_keypoints']

    # set random seed for reproduction
    set_seed(config['train']['random_seed'])

    st_epoch = config['train']['resume_epoch'] if config['train']['resume_epoch'] > 0 else 0
    tee = Tee(os.path.join(train_dir, 'train_st_epoch_%d.log' % st_epoch), 'w')

    tensorboard_dir = os.path.join(train_dir, "tensorboard")
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)

    writer = SummaryWriter(log_dir=tensorboard_dir)

    # save the config
    save_yaml(config, os.path.join(train_dir, "config.yaml"))


    action_function = ActionFunctionFactory.function_from_config(config)
    observation_function = ObservationFunctionFactory.function_from_config(config)

    datasets = {}
    dataloaders = {}
    data_n_batches = {}
    for phase in ['train', 'valid']:
        print("Loading data for %s" % phase)
        datasets[phase] = MultiEpisodeDataset(config,
                                              action_function=action_function,
                                              observation_function=observation_function,
                                              episodes=multi_episode_dict,
                                              phase=phase)

        dataloaders[phase] = DataLoader(
            datasets[phase], batch_size=config['train']['batch_size'],
            shuffle=True if phase == 'train' else False,
            num_workers=config['train']['num_workers'], drop_last=True)

        data_n_batches[phase] = len(dataloaders[phase])

    use_gpu = torch.cuda.is_available()

    # compute normalization parameters if not starting from pre-trained network . . .


    '''
    define model for dynamics prediction
    '''

    model_dy = build_visual_dynamics_model(config)
    K = config['vision_net']['num_ref_descriptors']

    print("model_dy.vision_net._reference_descriptors.shape", model_dy.vision_net._ref_descriptors.shape)
    print("model_dy.vision_net.descriptor_dim", model_dy.vision_net.descriptor_dim)
    print("model_dy #params: %d" % count_trainable_parameters(model_dy))

    camera_name = config['vision_net']['camera_name']
    W = config['env']['rgbd_sensors']['sensor_list'][camera_name]['width']
    H = config['env']['rgbd_sensors']['sensor_list'][camera_name]['height']
    diag = np.sqrt(W**2 + H**2) # use this to scale the loss

    # sample reference descriptors unless using precomputed keypoints
    if not use_precomputed_keypoints:
        # sample reference descriptors
        episode_names = list(datasets["train"].episode_dict.keys())
        episode_names.sort()
        episode_name = episode_names[0]
        episode = datasets["train"].episode_dict[episode_name]
        episode_idx = 0
        camera_name = config["vision_net"]["camera_name"]
        image_data = episode.get_image_data(camera_name, episode_idx)
        des_img = torch.Tensor(image_data['descriptor'])
        mask_img = torch.Tensor(image_data['mask'])
        ref_descriptor_dict = sample_descriptors(des_img,
                                                 mask_img,
                                                 config['vision_net']['num_ref_descriptors'])



        model_dy.vision_net._ref_descriptors.data = ref_descriptor_dict['descriptors']
        model_dy.vision_net.reference_image = image_data['rgb']
        model_dy.vision_net.reference_indices = ref_descriptor_dict['indices']
    else:
        metadata_file = os.path.join(get_data_root(), config['dataset']['descriptor_keypoints_dir'], 'metadata.p')
        descriptor_metadata = load_pickle(metadata_file)

        # [32, 2]
        ref_descriptors = torch.Tensor(descriptor_metadata['ref_descriptors'])

        # [K, 2]
        ref_descriptors = ref_descriptors[:K]
        model_dy.vision_net._ref_descriptors.data = ref_descriptors
        model_dy.vision_net._ref_descriptors_metadata = descriptor_metadata

        # this is just a sanity check
        assert model_dy.vision_net.num_ref_descriptors == K

    print("reference_descriptors", model_dy.vision_net._ref_descriptors)

    # criterion
    criterionMSE = nn.MSELoss()
    l1Loss = nn.L1Loss()

    # optimizer
    params = model_dy.parameters()
    lr = float(config['train']['lr'])
    optimizer = optim.Adam(params, lr=lr, betas=(config['train']['adam_beta1'], 0.999))

    # setup scheduler
    sc = config['train']['lr_scheduler']
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  factor=sc['factor'],
                                  patience=sc['patience'],
                                  threshold_mode=sc['threshold_mode'],
                                  cooldown= sc['cooldown'],
                                  verbose=True)

    if use_gpu:
        print("using gpu")
        model_dy = model_dy.cuda()

    print("model_dy.vision_net._ref_descriptors.device", model_dy.vision_net._ref_descriptors.device)
    print("model_dy.vision_net #params: %d" %(count_trainable_parameters(model_dy.vision_net)))


    best_valid_loss = np.inf
    global_iteration = 0
    epoch_counter_external = 0

    try:
        for epoch in range(st_epoch, config['train']['n_epoch']):
            phases = ['train', 'valid']
            epoch_counter_external = epoch

            writer.add_scalar("Training Params/epoch", epoch, global_iteration)
            for phase in phases:
                model_dy.train(phase == 'train')

                meter_loss_rmse = AverageMeter()
                step_duration_meter = AverageMeter()


                # bar = ProgressBar(max_value=data_n_batches[phase])
                loader = dataloaders[phase]

                for i, data in enumerate(loader):

                    step_start_time = time.time()

                    global_iteration += 1

                    with torch.set_grad_enabled(phase == 'train'):
                        n_his, n_roll = config['train']['n_history'], config['train']['n_rollout']
                        n_samples = n_his + n_roll

                        if DEBUG:
                            print("global iteration: %d" %(global_iteration))


                        # visual_observations = data['visual_observations']
                        visual_observations_list = data['visual_observations_list']
                        observations = data['observations']
                        actions = data['actions']

                        if use_gpu:
                            observations = observations.cuda()
                            actions = actions.cuda()

                        # states, actions = data
                        assert actions.size(1) == n_samples

                        B = actions.size(0)
                        loss_mse = 0.


                        # compute the output of the visual model for all timesteps
                        visual_model_output_list = []
                        for visual_obs in visual_observations_list:
                            # visual_obs is a dict containing observation for a single
                            # time step (of course across a batch however)
                            # visual_obs[<camera_name>]['rgb_tensor'] has shape [B, 3, H, W]

                            # probably need to cast input to cuda
                            dynamics_net_input = None
                            if use_precomputed_keypoints:
                                # note precomputed descriptors stored on disk are of size
                                # K = 32. We need to trim it down to the appropriate size
                                # [B, K_disk, 2] where K_disk is num keypoints on disk
                                keypoints = visual_obs[camera_name]['descriptor_keypoints']


                                # [B, 32, 2] where K is num keypoints
                                keypoints = keypoints[:,:K]

                                if DEBUG:
                                    print("keypoints.shape", keypoints.shape)

                                dynamics_net_input = keypoints.flatten(start_dim=1)
                            else:
                                out_dict = model_dy.vision_net.forward(visual_obs)

                                # [B, vision_model_out_dim]
                                dynamics_net_input = out_dict['dynamics_net_input']

                            visual_model_output_list.append(dynamics_net_input)

                        # concatenate this into a tensor
                        # [B, n_samples, vision_model_out_dim]
                        visual_model_output = torch.stack(visual_model_output_list, dim=1)

                        # cast this to float so it can be concatenated below
                        visual_model_output = visual_model_output.type_as(observations)

                        if DEBUG:
                            print('visual_model_output.shape', visual_model_output.shape)
                            print("observations.shape", observations.shape)
                            print("actions.shape", actions.shape)

                        # states is gotten by concatenating visual_observations and observations
                        # [B, n_samples, vision_model_out_dim + obs_dim]
                        states = torch.cat((visual_model_output, observations), dim=-1)

                        # state_cur: B x n_his x state_dim
                        state_cur = states[:, :n_his]

                        if DEBUG:
                            print("states.shape", states.shape)

                        for j in range(n_roll):

                            if DEBUG:
                                print("n_roll j: %d" %(j))

                            state_des = states[:, n_his + j]

                            # action_cur: B x n_his x action_dim
                            action_cur = actions[:, j : j + n_his] if actions is not None else None

                            # state_pred: B x state_dim
                            # state_pred: B x state_dim
                            input = {'observation': state_cur,
                                     'action': action_cur,
                                     }

                            if DEBUG:
                                print("state_cur.shape", state_cur.shape)
                                print("action_cur.shape", action_cur.shape)

                            state_pred = model_dy.dynamics_net(input)

                            # normalize by diag to ensure the loss is in [0,1] range
                            loss_mse_cur = criterionMSE(state_pred/diag, state_des/diag)
                            loss_mse += loss_mse_cur / n_roll

                            # l1Loss
                            loss_l1 = l1Loss(state_pred, state_des)

                            # update state_cur
                            # state_pred.unsqueeze(1): B x 1 x state_dim
                            # state_cur: B x n_his x state_dim
                            state_cur = torch.cat([state_cur[:, 1:], state_pred.unsqueeze(1)], 1)

                            meter_loss_rmse.update(np.sqrt(loss_mse.item()), B)

                    step_duration_meter.update(time.time() - step_start_time)
                    if phase == 'train':
                        optimizer.zero_grad()
                        loss_mse.backward()
                        optimizer.step()

                    if (i % config['train']['log_per_iter'] == 0) or (global_iteration % config['train']['log_per_iter'] == 0):
                        log = '%s [%d/%d][%d/%d] LR: %.6f' % (
                            phase, epoch, config['train']['n_epoch'], i, data_n_batches[phase],
                            get_lr(optimizer))
                        log += ', rmse: %.6f (%.6f)' % (
                            np.sqrt(loss_mse.item()), meter_loss_rmse.avg)

                        log += ', step time %.6f' %(step_duration_meter.avg)
                        step_duration_meter.reset()


                        print(log)

                        # log data to tensorboard
                        # only do it once we have reached 100 iterations
                        if global_iteration > 100:
                            writer.add_scalar("Params/learning rate", get_lr(optimizer), global_iteration)
                            writer.add_scalar("Loss_MSE/%s" %(phase), loss_mse.item(), global_iteration)
                            writer.add_scalar("L1/%s" %(phase), loss_l1.item(), global_iteration)
                            writer.add_scalar("L1_fraction/%s" %(phase), loss_l1.item()/diag, global_iteration)
                            writer.add_scalar("RMSE average loss/%s" %(phase), meter_loss_rmse.avg, global_iteration)

                    if phase == 'train' and i % config['train']['ckp_per_iter'] == 0:
                        save_model(model_dy, '%s/net_dy_epoch_%d_iter_%d' % (train_dir, epoch, i))



                log = '%s [%d/%d] Loss: %.6f, Best valid: %.6f' % (
                    phase, epoch, config['train']['n_epoch'], meter_loss_rmse.avg, best_valid_loss)
                print(log)

                if phase == 'valid':
                    if config['train']['lr_scheduler']['enabled']:
                        scheduler.step(meter_loss_rmse.avg)

                    # print("\nPhase == valid")
                    # print("meter_loss_rmse.avg", meter_loss_rmse.avg)
                    # print("best_valid_loss", best_valid_loss)
                    if meter_loss_rmse.avg < best_valid_loss:
                        best_valid_loss = meter_loss_rmse.avg
                        save_model(model_dy, '%s/net_best_dy' % (train_dir))

                writer.flush() # flush SummaryWriter events to disk

    except KeyboardInterrupt:
        # save network if we have a keyboard interrupt
        save_model(model_dy, '%s/net_dy_epoch_%d_keyboard_interrupt' % (train_dir, epoch_counter_external))
        writer.flush() # flush SummaryWriter events to disk
예제 #21
0
class Trainer:
    """
    Trainer is a simple but feature-complete training and eval loop for PyTorch,
    optimized for Transformers.
    """

    model: PreTrainedModel
    args: TrainingArguments
    data_collator: DataCollator
    train_dataset: Optional[Dataset]
    eval_dataset: Optional[Dataset]
    compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None
    prediction_loss_only: bool
    tb_writer: Optional["SummaryWriter"] = None
    optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None
    global_step: Optional[int] = None
    epoch: Optional[float] = None

    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        prediction_loss_only=False,
        tb_writer: Optional["SummaryWriter"] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = None,
    ):
        """
        Trainer is a simple but feature-complete training and eval loop for PyTorch,
        optimized for Transformers.

        Args:
            prediction_loss_only:
                (Optional) in evaluation and prediction, only return the loss
        """
        self.model = model.to(args.device)
        self.args = args
        self.data_collator = data_collator if data_collator is not None else default_data_collator
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.compute_metrics = compute_metrics
        self.prediction_loss_only = prediction_loss_only
        self.optimizers = optimizers
        if tb_writer is not None:
            self.tb_writer = tb_writer
        elif is_tensorboard_available() and self.is_world_master():
            self.tb_writer = SummaryWriter(log_dir=self.args.logging_dir)
        if not is_tensorboard_available():
            logger.warning(
                "You are instantiating a Trainer but Tensorboard is not installed. You should consider installing it."
            )
        if is_wandb_available():
            self._setup_wandb()
        else:
            logger.info(
                "You are instantiating a Trainer but W&B is not installed. To use wandb logging, "
                "run `pip install wandb; wandb login` see https://docs.wandb.com/huggingface."
            )
        set_seed(self.args.seed)
        # Create output directory if needed
        if self.is_world_master():
            os.makedirs(self.args.output_dir, exist_ok=True)
        if is_torch_tpu_available():
            # Set an xla_device flag on the model's config.
            # We'll find a more elegant and not need to do this in the future.
            self.model.config.xla_device = True
        if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
            self.data_collator = self.data_collator.collate_batch
            warnings.warn(
                (
                    "The `data_collator` should now be a simple callable (function, class with `__call__`), classes "
                    + "with a `collate_batch` are deprecated and won't be supported in a future version."
                ),
                FutureWarning,
            )

    def get_train_dataloader(self) -> DataLoader:
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        if is_torch_tpu_available():
            train_sampler = get_tpu_sampler(self.train_dataset)
        else:
            train_sampler = (
                RandomSampler(self.train_dataset)
                if self.args.local_rank == -1
                else DistributedSampler(self.train_dataset)
            )

        data_loader = DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")

        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if is_torch_tpu_available():
            sampler = SequentialDistributedSampler(
                eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
            )
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(eval_dataset)
        else:
            sampler = SequentialSampler(eval_dataset)

        data_loader = DataLoader(
            eval_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        # We use the same batch_size as for eval.
        if is_torch_tpu_available():
            sampler = SequentialDistributedSampler(
                test_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
            )
        elif self.args.local_rank != -1:
            sampler = SequentialDistributedSampler(test_dataset)
        else:
            sampler = SequentialSampler(test_dataset)

        data_loader = DataLoader(
            test_dataset,
            sampler=sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

        return data_loader

    def get_optimizers(
        self, num_training_steps: int
    ) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
        """
        Setup the optimizer and the learning rate scheduler.

        We provide a reasonable default that works well.
        If you want to use something else, you can pass a tuple in the Trainer's init,
        or override this method in a subclass.
        """
        if self.optimizers is not None:
            return self.optimizers
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if "relational_transformer" not in n and not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if "relational_transformer" in n and not any(nd in n for nd in no_decay)],
                "weight_decay": self.args.weight_decay,
                "lr": 7e-5
            }
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=num_training_steps
        )
        return optimizer, scheduler

    def _setup_wandb(self):
        """
        Setup the optional Weights & Biases (`wandb`) integration.

        One can override this method to customize the setup if needed.  Find more information at https://docs.wandb.com/huggingface
        You can also override the following environment variables:

        Environment:
            WANDB_WATCH:
                (Optional, ["gradients", "all", "false"]) "gradients" by default, set to "false" to disable gradient logging
                or "all" to log gradients and parameters
            WANDB_PROJECT:
                (Optional): str - "huggingface" by default, set this to a custom string to store results in a different project
            WANDB_DISABLED:
                (Optional): boolean - defaults to false, set to "true" to disable wandb entirely
        """
        if self.is_world_master():
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            wandb.init(project=os.getenv("WANDB_PROJECT", "huggingface"), config=vars(self.args))
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
                wandb.watch(
                    self.model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, self.args.logging_steps)
                )

    def num_examples(self, dataloader: DataLoader) -> int:
        """
        Helper to get num of examples from a DataLoader, by accessing its Dataset.
        """
        return len(dataloader.dataset)

    def train(self, model_path: Optional[str] = None):
        """
        Main training entry point.

        Args:
            model_path:
                (Optional) Local path to model if model to train has been instantiated from a local path
                If present, we will try reloading the optimizer/scheduler states from there.
        """
        train_dataloader = self.get_train_dataloader()
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = (
                self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
            )
        else:
            t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs

        optimizer, scheduler = self.get_optimizers(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(model, optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split("/")[0])
                epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
                steps_trained_in_current_epoch = self.global_step % (
                    len(train_dataloader) // self.args.gradient_accumulation_steps
                )

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss = 0.0
        logging_loss = 0.0
        model.zero_grad()
        train_iterator = trange(
            epochs_trained, int(num_train_epochs), desc="Epoch", disable=not self.is_local_master() or not self.args.logging_tqdm
        )
        for epoch in train_iterator:
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = tqdm(parallel_loader, desc="Iteration", disable=not self.is_local_master() or not self.args.logging_tqdm)
            else:
                epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=not self.is_local_master() or not self.args.logging_tqdm)

            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

                tr_loss += self._training_step(model, inputs, optimizer)

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= self.args.gradient_accumulation_steps
                    and (step + 1) == len(epoch_iterator)
                ):
                    if self.args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(optimizer)
                    else:
                        optimizer.step()

                    scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        logs: Dict[str, float] = {}
                        logs["loss"] = (tr_loss - logging_loss) / self.args.logging_steps
                        # backward compatibility for pytorch schedulers
                        logs["learning_rate"] = (
                            scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else scheduler.get_lr()[0]
                        )
                        logging_loss = tr_loss

                        self._log(logs)

                    if (self.args.eval_steps > 0 and self.global_step % self.args.eval_steps == 0):
                        if self.args.evaluate_during_training:
                            self.evaluate()

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert model.module is self.model
                        else:
                            assert model is self.model
                        # Save model checkpoint
                        output_dir = os.path.join(self.args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}")

                        self.save_model(output_dir)

                        if self.is_world_master():
                            self._rotate_checkpoints()

                        if is_torch_tpu_available():
                            xm.rendezvous("saving_optimizer_states")
                            xm.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            xm.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                        elif self.is_world_master():
                            torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                            torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

                if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                    epoch_iterator.close()
                    break
            if self.args.max_steps > 0 and self.global_step > self.args.max_steps:
                train_iterator.close()
                break
            if self.args.tpu_metrics_debug:
                # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                xm.master_print(met.metrics_report())

        if self.tb_writer:
            self.tb_writer.close()

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return TrainOutput(self.global_step, tr_loss / self.global_step)

    def _log(self, logs: Dict[str, float], iterator: Optional[tqdm] = None) -> None:
        if self.epoch is not None:
            logs["epoch"] = self.epoch
        if self.global_step is None:
            # when logging evaluation metrics without training
            self.global_step = 0
        if self.tb_writer:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, self.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()
        if is_wandb_available():
            if self.is_world_master():
                wandb.log(logs, step=self.global_step)
        output = {**logs, **{"step": self.global_step}}
        if iterator is not None:
            iterator.write(output)
        else:
            logger.info(output)

    def _training_step(
        self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], optimizer: torch.optim.Optimizer
    ) -> float:
        model.train()
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor):
                inputs[k] = v.to(self.args.device)

        outputs = model(**inputs)
        loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training
        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        return loss.item()

    def is_local_master(self) -> bool:
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=True)
        else:
            return self.args.local_rank in [-1, 0]

    def is_world_master(self) -> bool:
        """
        This will be True only in one process, even in distributed mode,
        even when training on multiple machines.
        """
        if is_torch_tpu_available():
            return xm.is_master_ordinal(local=False)
        else:
            return self.args.local_rank == -1 or torch.distributed.get_rank() == 0

    def save_model(self, output_dir: Optional[str] = None):
        """
        Saving best-practices: if you use default names for the model,
        you can reload it using from_pretrained().

        Will only save from the world_master process (unless in TPUs).
        """

        if is_torch_tpu_available():
            self._save_tpu(output_dir)
        elif self.is_world_master():
            self._save(output_dir)

    def _save_tpu(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        logger.info("Saving model checkpoint to %s", output_dir)

        if xm.is_master_ordinal():
            os.makedirs(output_dir, exist_ok=True)
            torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        if not isinstance(self.model, PreTrainedModel):
            raise ValueError("Trainer.model appears to not be a PreTrainedModel")

        xm.rendezvous("saving_checkpoint")
        self.model.save_pretrained(output_dir)

    def _save(self, output_dir: Optional[str] = None):
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info("Saving model checkpoint to %s", output_dir)
        # Save a trained model and configuration using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        # if not isinstance(self.model, PreTrainedModel):
        #     raise ValueError("Trainer.model appears to not be a PreTrainedModel")
        self.model.save_pretrained(output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(self.args, os.path.join(output_dir, "training_args.bin"))

    def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [str(x) for x in Path(self.args.output_dir).glob(f"{checkpoint_prefix}-*")]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
                regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
                if regex_match and regex_match.groups():
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        return checkpoints_sorted

    def _rotate_checkpoints(self, use_mtime=False) -> None:
        if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
            return

        # Check if we should delete older checkpoint(s)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime)
        if len(checkpoints_sorted) <= self.args.save_total_limit:
            return

        number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - self.args.save_total_limit)
        checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
        for checkpoint in checkpoints_to_be_deleted:
            logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
            shutil.rmtree(checkpoint)

    def evaluate(
        self, eval_dataset: Optional[Dataset] = None, prediction_loss_only: Optional[bool] = None,
    ) -> Dict[str, float]:
        """
        Run evaluation and return metrics.

        The calling script will be responsible for providing a method to compute metrics, as they are
        task-dependent.

        Args:
            eval_dataset: (Optional) Pass a dataset if you wish to override
            the one on the instance.
        Returns:
            A dict containing:
                - the eval loss
                - the potential metrics computed from the predictions
        """
        eval_dataloader = self.get_eval_dataloader(eval_dataset)

        output = self._prediction_loop(eval_dataloader, description="Evaluation")

        self._log(output.metrics)

        if self.args.tpu_metrics_debug:
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        return output.metrics

    def predict(self, test_dataset: Dataset) -> PredictionOutput:
        """
        Run prediction and return predictions and potential metrics.

        Depending on the dataset and your use case, your test dataset may contain labels.
        In that case, this method will also return metrics, like in evaluate().
        """
        test_dataloader = self.get_test_dataloader(test_dataset)

        return self._prediction_loop(test_dataloader, description="Prediction")

    def _prediction_loop(
        self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by `evaluate()` and `predict()`.

        Works both with or without labels.

        NOTE: One issue is on the size of prediction and labels.
        For current code, it considers all the prediction and labels in different batch have same length of sequence.
        This is not true for our application. To make this more general, I will reformat the predictions and labels.

        """

        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else self.prediction_loss_only

        model = self.model
        # multi-gpu eval
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)
        else:
            model = self.model
        # Note: in torch.distributed mode, there's no point in wrapping the model
        # inside a DistributedDataParallel as we'll be under `no_grad` anyways.

        batch_size = dataloader.batch_size
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", self.num_examples(dataloader))
        logger.info("  Batch size = %d", batch_size)
        eval_losses: List[float] = []
        preds: torch.Tensor = None
        preds_size: torch.Tensor = None
        label_ids: torch.Tensor = None
        label_size: torch.Tensor = None
        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)

        for inputs in tqdm(dataloader, desc=description):
            has_labels = any(inputs.get(k) is not None for k in ["labels", "lm_labels", "masked_lm_labels"])

            for k, v in inputs.items():
                if isinstance(v, torch.Tensor):
                    inputs[k] = v.to(self.args.device)

            with torch.no_grad():
                outputs = model(**inputs)
                if has_labels:
                    step_eval_loss, logits = outputs[:2]
                    eval_losses += [step_eval_loss.mean().item()]
                else:
                    logits = outputs[0]

            if not prediction_loss_only:
                # Change the way of concat
                # We need to make sure that the size of preds and labels is (batch_size, sequence_length)
                if preds is None:
                    preds = logits.detach()
                    preds_size = preds.new_full(size=preds.size()[:1], fill_value=preds.size(1)).detach()
                    preds = preds.view(-1)
                else:
                    preds_size = torch.cat((preds_size, logits.new_full(size=logits.size()[:1], fill_value=logits.size(1)).detach()), dim=0)
                    preds = torch.cat((preds, logits.detach().view(-1)), dim=0)

                if inputs.get("labels") is not None:
                    if label_ids is None:
                        label_ids = inputs["labels"].detach()
                        label_size = label_ids.new_full(size=label_ids.size()[:1], fill_value=label_ids.size(1)).detach()
                        label_ids = label_ids.view(-1)
                    else:
                        label_size = torch.cat((label_size, inputs["labels"].new_full(size=inputs["labels"].size()[:1], fill_value=inputs["labels"].size(1)).detach()), dim=0)
                        label_ids = torch.cat((label_ids, inputs["labels"].detach().view(-1)), dim=0)

        if self.args.local_rank != -1:
            # In distributed mode, concatenate all results from all nodes:
            if preds is not None:
                # preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
                preds, preds_size = self.distributed_concat_with_size(preds, preds_size, num_total_examples=self.num_examples(dataloader))
            if label_ids is not None:
                # label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
                label_ids, label_size = self.distributed_concat_with_size(label_ids, label_size, num_total_examples=self.num_examples(dataloader))
        elif is_torch_tpu_available():
            # tpu-comment: Get all predictions and labels from all worker shards of eval dataset
            # NOTE: We do not modify this for now.
            if preds is not None:
                preds = xm.mesh_reduce("eval_preds", preds, torch.cat)
            if label_ids is not None:
                label_ids = xm.mesh_reduce("eval_label_ids", label_ids, torch.cat)

        # Finally, turn the aggregated tensors into numpy arrays.
        if preds is not None:
            preds = preds.cpu().numpy()
            preds_size = preds_size.cpu().numpy()
        if label_ids is not None:
            label_ids = label_ids.cpu().numpy()
            label_size = label_size.cpu().numpy()
        if self.compute_metrics is not None and preds is not None and label_ids is not None:
            # metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
            metrics = self.compute_metrics(EvalPredictionWithSize(predictions=preds, predictions_size=preds_size, label_ids=label_ids, label_size=label_size))
        else:
            metrics = {}
        if len(eval_losses) > 0:
            metrics["eval_loss"] = np.mean(eval_losses)

        # Prefix all keys with eval_
        for key in list(metrics.keys()):
            if not key.startswith("eval_"):
                metrics[f"eval_{key}"] = metrics.pop(key)

        # return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
        return PredictionOutputWithSize(predictions=preds, predictions_size=preds_size, label_ids=label_ids, label_size=label_size, metrics=metrics)

    def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)

        # truncate the dummy elements added by SequentialDistributedSampler
        output = concat[:num_total_examples]
        return output

    def distributed_concat_tensor(self, tensor: torch.Tensor):
        assert self.args.local_rank != -1

        output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(output_tensors, tensor)

        concat = torch.cat(output_tensors, dim=0)
        return concat

    def distributed_concat_varsize_tensor(self, tensor: torch.Tensor):
        assert self.args.local_rank != -1

        sizes = self.distributed_concat_tensor(tensor.new_full(size=(1,), fill_value=tensor.size(0)))
        max_size = sizes.max().item()

        padded = tensor.new_zeros(max_size)
        padded[:tensor.size(0)] = tensor

        padded_agg = self.distributed_concat_tensor(padded)
        slices = []
        for i, size in enumerate(sizes):
            start_idx = i * max_size
            end_idx = start_idx + size.item()
            slices.append(padded_agg[start_idx: end_idx])
        ret = torch.cat(slices, dim=0)
        return ret


    def distributed_concat_with_size(self, tensor: torch.Tensor, size: torch.Tensor, num_total_examples: int) -> torch.Tensor:
        assert self.args.local_rank != -1

        # output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
        # output_sizes = [size.clone() for _ in range(torch.distributed.get_world_size())]
        # torch.distributed.all_gather(output_tensors, tensor)
        # torch.distributed.all_gather(output_sizes, size)
        # concat = torch.cat(output_tensors, dim=0)
        # concat_sizes = torch.cat(output_sizes, dim=0)
        concat_sizes = self.distributed_concat_varsize_tensor(size)
        concat = self.distributed_concat_varsize_tensor(tensor)

        # output_sizes = concat_sizes[:num_total_examples]

        assert concat_sizes.sum() == concat.size(0)
        return concat, concat_sizes
예제 #22
0
class XvalMerge(object):
    def __init__(self, args, settings):
        self.epoch = args.epochs
        self.elbo = []
        self.elbo_list = []
        self.q_names = []
        self.q_values = []
        self.splits = []
        self.theta = []
        # self.normalized_iws = []
        # self.precisions = []
        # self.X_predict = []
        # self.X_states = []
        self.iw_predict_mu = []
        self.iw_predict_std = []
        self.iw_states = []
        # from data_pair.test
        self.data_ids = []
        self.devices = []
        self.treatments = []
        self.X_obs = []
        # Attributes initialized elsewhere
        self.chunk_sizes = None
        self.ids = None
        self.species_names = None
        self.times = None
        self.xval_writer = None
        self.settings = settings.data
        self.trainer = settings.trainer

    def add(self, split_idx, data_pair, val_results):
        if split_idx == 1:
            self.q_names = val_results.q_names
            self.species_names = val_results.species_names
            self.times = data_pair.train.dataset.times
        self.elbo.append(val_results.elbo)
        self.elbo_list.append(val_results.elbo_list)
        self.q_values.append(val_results.q_values)
        self.splits.append(split_idx)
        self.theta.append(val_results.theta)
        # self.normalized_iws.append(val_results.normalized_iws)
        # self.precisions.append(val_results.precisions)
        # self.X_predict.append(val_results.x_predict)
        # self.X_states.append(val_results.x_states)
        self.iw_predict_mu.append(val_results.iw_predict_mu)
        self.iw_predict_std.append(val_results.iw_predict_std)
        self.iw_states.append(val_results.iw_states)

        self.data_ids.append(data_pair.test.indices)
        dataset = data_pair.test.dataset[data_pair.test.indices]
        self.devices.append(dataset["devices"])
        self.treatments.append(dataset["inputs"].cpu().detach().numpy())
        self.X_obs.append(dataset["observations"].cpu().detach().numpy())

    def finalize(self):
        print("Preparing cross-validation results")
        self.elbo = np.array(self.elbo)
        self.elbo_list = np.array(self.elbo_list)
        self.q_values = [
            np.concatenate([np.array(q[i], ndmin=1) for q in self.q_values])
            for i, _ in enumerate(self.q_names)
        ]
        # self.normalized_iws = np.concatenate(self.normalized_iws, 0)
        # self.precisions = np.concatenate(self.precisions, 0)
        # self.X_predict = np.concatenate(self.X_predict, 0)
        # self.X_states = np.concatenate(self.X_states, 0)
        self.iw_predict_mu = np.concatenate(self.iw_predict_mu, 0)
        self.iw_predict_std = np.concatenate(self.iw_predict_std, 0)
        self.iw_states = np.concatenate(self.iw_states, 0)

        self.devices = np.concatenate(self.devices, 0)
        self.treatments = np.concatenate(self.treatments, 0)
        self.X_obs = np.concatenate(self.X_obs, 0)

        self.chunk_sizes = np.array([len(ids) for ids in self.data_ids],
                                    dtype=object)
        self.ids = np.hstack(self.data_ids)

    def prepare(self):
        '''Importance-weighted means and stds over time'''
        importance_weights = self.normalized_iws[:, :, np.newaxis, np.newaxis]
        self.iw_predict_mu = np.sum(importance_weights * self.X_predict, 1)
        self.iw_predict_std = np.sqrt(
            np.sum(
                importance_weights *
                (self.X_predict**2 + 1.0 / self.precisions), 1) -
            self.iw_predict_mu**2)
        self.iw_states = np.sum(importance_weights * self.X_states, 1)

    def save(self):
        location = self.trainer.tb_log_dir
        print("Saving results to %s" % location)

        def save(base, data):
            np.save(os.path.join(location, base + ".npy"), data)

        def savetxt(base, data):
            np.savetxt(
                os.path.join(location, base + ".txt"),
                np.array(data, dtype=str),
                delimiter=" ",
                fmt="%s",
            )

        print("Saving to: %s" % location)
        save("xval_elbo", self.elbo)
        save("xval_elbo_list", self.elbo_list)
        savetxt("xval_q_names", self.q_names)
        save("xval_q_values", self.q_values)
        save("xval_theta", self.theta)

        save("xval_iw_predict_mu", self.iw_predict_mu)
        save("xval_iw_predict_std", self.iw_predict_std)
        save("xval_iw_states", self.iw_states)
        # save("xval_normalized_iws", self.normalized_iws)
        # save("xval_precisions", self.precisions)
        # save("xval_X_predict", self.X_predict)
        # save("xval_X_states", self.X_states)

        savetxt("xval_device_names", self.settings.devices)
        save("xval_devices", self.devices)
        save("xval_treatments", self.treatments)
        save("xval_X_obs", self.X_obs)

        save("xval_chunk_sizes", self.chunk_sizes)
        save("xval_ids", self.ids)
        savetxt("xval_names", self.species_names)
        save("xval_times", self.times)

    def load(self, location=None):
        if location is None:
            location = self.trainer.tb_log_dir
        print("Loading results from %s" % location)

        def load(base):
            return np.load(os.path.join(location, base + ".npy"),
                           allow_pickle=True)

        def loadtxt(base):
            return np.loadtxt(os.path.join(location, base + ".txt"),
                              dtype=str,
                              delimiter=" ")

        self.elbo = load("xval_elbo")
        self.elbo_list = load("xval_elbo_list")
        self.q_names = loadtxt("xval_q_names")
        self.q_values = load("xval_q_values")
        self.theta = load("xval_theta")
        # self.normalized_iws = load("xval_normalized_iws")
        # self.precisions = load("xval_precisions")
        # self.X_states = load("xval_X_states")
        # self.X_predict = load("xval_X_predict")
        self.iw_predict_mu = load("xval_iw_predict_mu")
        self.iw_predict_std = load("xval_iw_predict_std")
        self.iw_states = load("xval_iw_states")

        # self.device_names = loadtxt("xval_device_names.txt")
        self.devices = load("xval_devices")
        self.treatments = load("xval_treatments")
        self.X_obs = load("xval_X_obs")

        self.chunk_sizes = load("xval_chunk_sizes")
        self.ids = load("xval_ids")
        self.species_names = loadtxt("xval_names")
        self.times = load("xval_times")

    def make_writer(self, location=None):
        if location is None:
            location = self.trainer.tb_log_dir
        self.xval_writer = SummaryWriter(os.path.join(location, "xval"))

    def close_writer(self):
        self.xval_writer.close()

    def save_figs(self, f, tag):
        # pp.close(f)
        f.savefig(os.path.join(self.trainer.tb_log_dir, "%s.png" % tag),
                  bbox_inches="tight")
        f.savefig(os.path.join(self.trainer.tb_log_dir, "%s.pdf" % tag),
                  bbox_inches="tight")

    def mark_completed(self, node_name):
        location = self.trainer.tb_log_dir
        filepath = os.path.join(location, "completed.txt")
        with open(filepath, "w") as file:
            file.write(node_name)
            file.close()

    def make_images(self):
        device_ids = list(range(len(self.settings.devices)))

        print("Making summary figure")
        f_summary = plotting.plot_prediction_summary(
            self.settings.devices,
            self.species_names,
            self.times,
            self.X_obs,
            self.iw_predict_mu,
            self.iw_predict_std,
            self.devices,
            "-",
        )
        self.save_figs(f_summary, "xval_fit")
        self.xval_writer.add_figure("Summary", f_summary, self.epoch)
        self.xval_writer.flush()

        if self.settings.separate_conditions is True:
            print("Making treatment figure")
            f_treatments = plotting.xval_treatments(self, device_ids)
            self.save_figs(f_treatments, "xval_treatments")
            self.xval_writer.add_figure("Treatment", f_treatments, self.epoch)
            self.xval_writer.flush()

        print("Making species figure")
        f_species = plotting.species_summary(
            self.species_names,
            self.treatments,
            self.devices,
            self.times,
            self.iw_states,
            device_ids,
            self.settings,
        )
        self.save_figs(f_species, "xval_species")
        self.xval_writer.add_figure("Species", f_species, self.epoch)
        self.xval_writer.flush()

        print("Making global parameters figure")
        f_gparas = plotting.xval_global_parameters(self)
        if f_gparas is not None:
            self.save_figs(f_gparas, "xval_global_parameters")
            self.xval_writer.add_figure("Parameters/Globals", f_gparas,
                                        self.epoch)
            self.xval_writer.flush()

        print("Making variable parameters figure")
        f_vparas = plotting.xval_variable_parameters(self)
        if f_vparas is not None:
            self.save_figs(f_vparas, "xval_variable_parameters")
            self.xval_writer.add_figure("Parameters/Variable", f_vparas,
                                        self.epoch)
            self.xval_writer.flush()

        print("Making summary device figures")
        for u in device_ids:
            print("- %s" % self.settings.pretty_devices[u])
            device = self.settings.devices[u]
            f_summary_i = plotting.xval_fit_summary(
                self, u, separatedInputs=self.settings.separate_conditions)
            self.save_figs(f_summary_i, "xval_summary_%s" % device)
            self.xval_writer.add_figure("Device_Summary/" + device,
                                        f_summary_i, self.epoch)
        self.xval_writer.flush()

        print("Making individual device figures")
        for u in device_ids:
            print("- %s" % self.settings.pretty_devices[u])
            device = self.settings.devices[u]
            if self.settings.separate_conditions is True:
                # TODO: Check just 1 treatment? 2treatments function fails when there is just 1 treatment
                f_indiv_i = plotting.xval_individual_2treatments(self, u)
            else:
                f_indiv_i = plotting.xval_individual(self, u)
            self.save_figs(f_indiv_i, "xval_individual_%s" % device)
            self.xval_writer.add_figure("Device_Individual/" + device,
                                        f_indiv_i, self.epoch)
        self.xval_writer.flush()
예제 #23
0
def main():
    # Settings:
    batch_size = 16
    max_steps = 0
    max_epochs = 10
    # Parse argument:
    args = parseArgument()
    if 'train_log_dir' not in args:
        args['train_log_dir'] = './'
    if 'parameter_dir' not in args:
        args['parameter_dir'] = './'
    if 'max_steps' in args:
        max_steps = args['max_steps']
    if 'max_epochs' in args:
        max_epochs = args['max_epochs']
    os.makedirs(args['train_log_dir'], exist_ok=True)
    os.makedirs(args['parameter_dir'], exist_ok=True)
    # validation: if 'test_images' and 'test_anns' is set, run validation every epochs.
    validation = ('test_images' in args and 'test_anns' in args)
    # Define model: use torchvision's resnet18 and pretrained weights
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    model = model.to(device)
    # Define optimizer:
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    # Define loss function:
    criterion = nn.CrossEntropyLoss().to(device)
    # Define DataLoader: training data loader and validation data loader
    train_loader = DataLoader(CustomDataset(args['images'], args['anns']),
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=worker)
    valid_loader = DataLoader(CustomDataset(args['test_images'],
                                            args['test_anns']),
                              batch_size=1,
                              shuffle=False,
                              num_workers=worker) if validation else None
    # Setup Tensorboard:
    writer = SummaryWriter(
        log_dir=args['train_log_dir']
    ) if TENSORBOARD_AVAILABLE and 'train_log_dir' in args else None
    # Main-loop:
    cur_step = 0
    for e in range(max_epochs):
        # Training:
        model.train()
        train_loss = 0
        for i, feed in enumerate(train_loader):
            images, labels = feed
            images = torch.autograd.Variable(images).to(device)
            labels = torch.autograd.Variable(labels).to(device)
            output = model(images)
            loss = criterion(output, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            # increment step
            cur_step = cur_step + 1
            # Check step-over: if current steps are over max steps, exit training loop.
            if max_steps > 0 and cur_step >= max_steps:
                break
        train_loss /= len(train_loader) * batch_size
        # Write to Tensorboard
        if writer is not None:
            writer.add_scalar('loss_train', train_loss, e)
        # Validation:
        if valid_loader is not None:
            valid_accuracy = 0
            model.eval()
            with torch.no_grad():
                valid_loss = 0
                count = 0
                match = 0
                for i, feed in enumerate(valid_loader):
                    images, labels = feed
                    images = torch.autograd.Variable(images).to(device)
                    labels = torch.autograd.Variable(labels).to(device)
                    output = model(images)
                    loss = criterion(output, labels)
                    valid_loss += loss.item()
                    # calculate accuracy:
                    t = labels.cpu().numpy()
                    p = torch.argmax(output).cpu().numpy()
                    count = count + 1
                    if t == p:
                        match = match + 1
                valid_loss /= len(valid_loader)
                valid_accuracy = match / count if count > 0 else 0.0
            # Write to Tensorboard:
            if writer is not None:
                writer.add_scalar('loss_valid', valid_loss, e)
                writer.add_scalar('accuracy', valid_accuracy, e)
        # Print log:
        if valid_loader is not None:
            print('Epoch {:4}/{:4}, train-loss {:e}, valid-loss {:e}'.format(
                e, max_epochs, train_loss, valid_loss))
        else:
            print('Epoch {:4}/{:4}, train-loss {:e}'.format(
                e, max_epochs, train_loss))
        # Check step-over: if current steps are over max steps, exit training loop.
        if max_steps > 0 and cur_step >= max_steps:
            break
    # Save model parameter:
    torch.save(model.state_dict(),
               os.path.join(args['parameter_dir'], 'trained_paramter.pth'))
    # Close Tensorboard:
    if writer is not None:
        writer.flush()
        writer.close()
예제 #24
0
class TensorBoardImageHandler(object):
    """
    TensorBoardImageHandler is an Ignite Event handler that can visualise images, labels and outputs as 2D/3D images.
    2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch,
    for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images'
    last three dimensions will be shown as animated GIF along the last axis (typically Depth).

    It can be used for any Ignite Engine (trainer, validator and evaluator).
    User can easily add it to engine for any expected Event, for example: ``EPOCH_COMPLETED``,
    ``ITERATION_COMPLETED``. The expected data source is ignite's ``engine.state.batch`` and ``engine.state.output``.

    Default behavior:
        - Show y_pred as images (GIF for 3D) on TensorBoard when Event triggered,
        - Need to use ``batch_transform`` and ``output_transform`` to specify
          how many images to show and show which channel.
        - Expects ``batch_transform(engine.state.batch)`` to return data
          format: (image[N, channel, ...], label[N, channel, ...]).
        - Expects ``output_transform(engine.state.output)`` to return a torch
          tensor in format (y_pred[N, channel, ...], loss).

     """
    def __init__(
        self,
        summary_writer: Optional[SummaryWriter] = None,
        log_dir: str = "./runs",
        interval: int = 1,
        epoch_level: bool = True,
        batch_transform: Callable = lambda x: x,
        output_transform: Callable = lambda x: x,
        global_iter_transform: Callable = lambda x: x,
        index: int = 0,
        max_channels: int = 1,
        max_frames: int = 64,
    ) -> None:
        """
        Args:
            summary_writer: user can specify TensorBoard SummaryWriter,
                default to create a new writer.
            log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.
            interval: plot content from engine.state every N epochs or every N iterations, default is 1.
            epoch_level: plot content from engine.state every N epochs or N iterations. `True` is epoch level,
                `False` is iteration level.
            batch_transform: a callable that is used to transform the
                ``ignite.engine.batch`` into expected format to extract several label data.
            output_transform: a callable that is used to transform the
                ``ignite.engine.output`` into expected format to extract several output data.
            global_iter_transform: a callable that is used to customize global step number for TensorBoard.
                For example, in evaluation, the evaluator engine needs to know current epoch from trainer.
            index: plot which element in a data batch, default is the first element.
            max_channels: number of channels to plot.
            max_frames: number of frames for 2D-t plot.
        """
        self._writer = SummaryWriter(
            log_dir=log_dir) if summary_writer is None else summary_writer
        self.interval = interval
        self.epoch_level = epoch_level
        self.batch_transform = batch_transform
        self.output_transform = output_transform
        self.global_iter_transform = global_iter_transform
        self.index = index
        self.max_frames = max_frames
        self.max_channels = max_channels

    def attach(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        if self.epoch_level:
            engine.add_event_handler(
                Events.EPOCH_COMPLETED(every=self.interval), self)
        else:
            engine.add_event_handler(
                Events.ITERATION_COMPLETED(every=self.interval), self)

    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.

        Raises:
            TypeError: When ``output_transform(engine.state.output)[0]`` type is not in
                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.
            TypeError: When ``batch_transform(engine.state.batch)[1]`` type is not in
                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.
            TypeError: When ``output_transform(engine.state.output)`` type is not in
                ``Optional[Union[numpy.ndarray, torch.Tensor]]``.

        """
        step = self.global_iter_transform(
            engine.state.epoch if self.epoch_level else engine.state.iteration)
        show_images = self.batch_transform(engine.state.batch)[0]
        if torch.is_tensor(show_images):
            show_images = show_images.detach().cpu().numpy()
        if show_images is not None:
            if not isinstance(show_images, np.ndarray):
                raise TypeError(
                    "output_transform(engine.state.output)[0] must be None or one of "
                    f"(numpy.ndarray, torch.Tensor) but is {type(show_images).__name__}."
                )
            plot_2d_or_3d_image(show_images, step, self._writer, self.index,
                                self.max_channels, self.max_frames, "input_0")

        show_labels = self.batch_transform(engine.state.batch)[1]
        if torch.is_tensor(show_labels):
            show_labels = show_labels.detach().cpu().numpy()
        if show_labels is not None:
            if not isinstance(show_labels, np.ndarray):
                raise TypeError(
                    "batch_transform(engine.state.batch)[1] must be None or one of "
                    f"(numpy.ndarray, torch.Tensor) but is {type(show_labels).__name__}."
                )
            plot_2d_or_3d_image(show_labels, step, self._writer, self.index,
                                self.max_channels, self.max_frames, "input_1")

        show_outputs = self.output_transform(engine.state.output)
        if torch.is_tensor(show_outputs):
            show_outputs = show_outputs.detach().cpu().numpy()
        if show_outputs is not None:
            if not isinstance(show_outputs, np.ndarray):
                raise TypeError(
                    "output_transform(engine.state.output) must be None or one of "
                    f"(numpy.ndarray, torch.Tensor) but is {type(show_outputs).__name__}."
                )
            plot_2d_or_3d_image(show_outputs, step, self._writer, self.index,
                                self.max_channels, self.max_frames, "output")

        self._writer.flush()
예제 #25
0
torch.set_grad_enabled(True)

EPOCHS = 5000
init_epoch = 0
print("Entrenamiento")
model.train()
for epoch in range(init_epoch, init_epoch + EPOCHS):
    train(epoch)
    #torch.save(model, "{}/{}-epoch-{}{}".format(folder, model_name, epoch, ext))
    #print("Model Saved Successfully")

torch.save(model, full_path)

print("Model Saved Successfully")
writer.flush()
writer.close()

from pandas import DataFrame


def test(model):
    #model_load.eval()
    y_real = []
    y_predict = []
    y_distance = []
    n_correct = 0
    examples = 0
    df = list()
    datalen = len(training_set)
    i = 0
예제 #26
0
class DDQNTrainer(TrainerBase):
    def __init__(self, env, agent, args):
        super(DDQNTrainer, self).__init__()
        self.env = env
        self.agent = agent
        self.num_episodes = args.num_episodes
        self.buffer = Buffer(args.buffer_size, args.batch_size)
        self.max_steps = args.max_steps

        assert len(
            args.solved) == 2, 'args.solved has to have length of exactly 2!'
        self.solved_r = args.solved[0]
        self.solved_ep = args.solved[1]
        self.render = args.render

        self.writer = SummaryWriter(args.tensorboard)

    def train(self):
        reward_history = []
        self.buffer.reset()
        # For each update

        if self.render:
            self.env.render()

        # For each episode
        for episode in range(self.num_episodes):
            episode_reward = 0
            episode_losses = []
            state = self.env.reset()

            # Until max_steps is reached
            for i in range(self.max_steps):

                # Get action
                action = self.agent.get_action(episode, state)

                # Take step based on action
                state_, reward, done, _ = self.env.step(action)
                episode_reward += reward

                # Store transition within the buffer for batched learning
                transition = [state, action, reward, state_, done]

                self.buffer.insert_transition(transition)

                # If the buffer is full, start learning using a random sample of transitions from the buffer
                if self.buffer.is_full():
                    batch = self.buffer.sample_buffer()
                    loss = self.agent.learn(batch)
                    episode_losses.append(loss)

                # Update the target network every 100 episodes
                if episode % 100 == 0:
                    self.agent.update_target()

                # If the episode has finished (max_steps reached or from env)
                if done or i == self.max_steps - 1:
                    print('Episode: ', episode, 'Reward: %i' % episode_reward)
                    reward_history.append(episode_reward)
                    self.writer.add_scalar("reward", episode_reward, episode)
                    self.writer.flush()

                    if len(reward_history) > self.solved_ep:
                        reward_history.pop(0)
                        if (sum(reward_history) /
                                len(reward_history)) >= self.solved_r:
                            print('Env has been solved at episode ' +
                                  str(episode) + '!')
                            self.writer.close()
                            exit()
                    break
                else:
                    state = state_

            if len(episode_losses) > 0:
                self.writer.add_scalar(
                    "loss/loss",
                    sum(episode_losses) / len(episode_losses), episode)
                self.writer.flush()

        self.writer.close()
예제 #27
0
class Trainer:
    def __init__(
        self,
        model: nn.Module,
        logger: Logger,
        prefix: str = "",
        checkpoint_dir: Union[str, None] = None,
        summary_dir: Union[str, None] = None,
        n_summaries: int = 4,  #
        input_shape: tuple = None,
        start_scratch: bool = False,
        #model_name: str="model",
    ):
        """
        Class which implements network training, validation and testing as well as writing checkpoints, logs, summaries, and saving the final model.

        :param Union[str, None] checkpoint_dir: the type is either str or None (default: None)
        :param int n_summaries: number of images as samples at different phases to visualize on tensorboard
        """
        #self.model_name=model_name
        self.model = model
        self.logger = logger
        self.prefix = prefix

        self.logger.info("Init summary writer")

        if summary_dir is not None:
            run_name = prefix + "_" if prefix != "" else ""
            run_name += "{time}-{host}".format(
                time=time.strftime("%y-%m-%d-%H-%M", time.localtime()),
                host=os.uname()[1],
            )
            self.summary_dir = os.path.join(summary_dir, run_name)

        self.n_summaries = n_summaries
        self.writer = SummaryWriter(summary_dir)

        if input_shape is not None:
            dummy_input = torch.rand(input_shape)
            self.logger.info("Writing graph to summary")
            self.writer.add_graph(self.model, dummy_input)

        if checkpoint_dir is not None:
            self.cp = CheckpointHandler(checkpoint_dir,
                                        prefix=prefix,
                                        logger=self.logger)
        else:
            self.cp = None

        self.start_scratch = start_scratch

    def fit(
        self,
        train_dataloader,
        val_dataloader,
        train_ds,
        val_ds,
        loss_fn,
        optimizer,
        n_epochs,
        val_interval,
        patience_early_stopping,
        device,
        metrics: Union[list, dict] = [],
        val_metric: Union[int, str] = "loss",
        val_metric_mode: str = "min",
        start_epoch=0,
    ):
        """
        train and validate the networks

        :param int n_epochs: max_train_epochs (default=500)
        :param int val_interval: run validation every val_interval number of epoch (ARGS.patience_early_stopping)
        :param int patience_early_stopping: after (patience_early_stopping/val_interval) number of epochs without improvement, terminate training
        """

        self.logger.info("Init model on device '{}'".format(device))
        self.model = self.model.to(device)

        # initalize delve
        self.tracker = CheckLayerSat(self.summary_dir,
                                     save_to="plotcsv",
                                     modules=self.model,
                                     device=device)

        best_model = copy.deepcopy(self.model.state_dict())
        best_metric = 0.0 if val_metric_mode == "max" else float("inf")

        # as we don't validate after each epoch but at val_interval,
        # we update the patience_stopping accordingly to how many times of validation
        patience_stopping = math.ceil(patience_early_stopping / val_interval)
        patience_stopping = int(max(1, patience_stopping))
        early_stopping = EarlyStoppingCriterion(mode=val_metric_mode,
                                                patience=patience_stopping)

        if not self.start_scratch and self.cp is not None:
            checkpoint = self.cp.read_latest()
            if checkpoint is not None:
                try:
                    try:
                        self.model.load_state_dict(checkpoint["modelState"])
                    except RuntimeError as e:
                        self.logger.error(
                            "Failed to restore checkpoint: "
                            "Checkpoint has different parameters")
                        self.logger.error(e)
                        raise SystemExit

                    optimizer.load_state_dict(
                        checkpoint["trainState"]["optState"])
                    start_epoch = checkpoint["trainState"]["epoch"] + 1
                    best_metric = checkpoint["trainState"]["best_metric"]
                    best_model = checkpoint["trainState"]["best_model"]
                    early_stopping.load_state_dict(
                        checkpoint["trainState"]["earlyStopping"])
                    #scheduler.load_state_dict(checkpoint["trainState"]["scheduler"])
                    self.logger.info(
                        "Resuming with epoch {}".format(start_epoch))
                except KeyError:
                    self.logger.error("Failed to restore checkpoint")
                    raise

        since = time.time()

        self.logger.info("Start training model " + self.prefix)

        try:
            if val_metric_mode == "min":
                val_comp = operator.lt  # to run standard operator as function
            else:
                val_comp = operator.gt
            for epoch in range(start_epoch, n_epochs):
                self.train(epoch, train_dataloader, train_ds, loss_fn,
                           optimizer, device)

                if epoch % val_interval == 0 or epoch == n_epochs - 1:
                    # first, get val_loss for further comparison
                    val_loss = self.validate(epoch,
                                             val_dataloader,
                                             val_ds,
                                             loss_fn,
                                             device,
                                             phase="val")
                    if val_metric == "loss":
                        val_result = val_loss
                        # add metrics for delve to keep track of
                        self.tracker.add_scalar("loss", val_loss)
                        # add saturation to the mix
                        self.tracker.add_saturations()
                    else:
                        val_result = metrics[val_metric].get()

                    # compare to see if improvement occurs
                    if val_comp(val_result, best_metric):
                        best_metric = val_result  # update best_metric with the loss (smaller than previous)
                        best_model = copy.deepcopy(self.model.state_dict())
                        """previously, deadlock occurred, which seemed to be related to cp. comment self.cp.write() to see if freezing goes away."""
                        # write checkpoint
                        self.cp.write({
                            "modelState": self.model.state_dict(),
                            "trainState": {
                                "epoch": epoch,
                                "best_metric": best_metric,
                                "best_model": best_model,
                                "optState": optimizer.state_dict(),
                                "earlyStopping": early_stopping.state_dict(),
                            },
                        })

                    # test if the number of accumulated no-improvement epochs is bigger than patience
                    if early_stopping.step(val_result):
                        self.logger.info(
                            "No improvement over the last {} epochs. Training is stopped."
                            .format(patience_early_stopping))
                        break
        except Exception:
            import traceback
            self.logger.warning(traceback.format_exc())
            self.logger.warning("Aborting...")
            self.logger.close()
            raise SystemExit

        # option here: load the best model to run test on test_dataset and log the final metric (along side best metric)
        # for ae, only split: train and validate dataset, without test_dataset

        time_elapsed = time.time() - since
        self.logger.info("Training complete in {:.0f}m {:.0f}s".format(
            time_elapsed // 60, time_elapsed % 60))

        self.logger.info("Best val metric: {:4f}".format(best_metric))

        # close delve tracker
        self.tracker.close()

        return self.model

    def train(self, epoch, train_dataloader, train_ds, loss_fn, optimizer,
              device):
        """
        Training of one epoch on training data, loss function, optimizer, and respective metrics
        """
        self.logger.debug("train|{}|start".format(epoch))

        self.model.train()

        epoch_start = time.time()
        start_data_loading = epoch_start
        data_loading_time = m.Sum(torch.device("cpu"))

        train_running_loss = 0.0
        for i, (train_specs, label) in enumerate(train_dataloader):
            train_specs = train_specs.to(device)
            call_label = None

            if "call" in label:
                call_label = label["call"].to(
                    device, non_blocking=True, dtype=torch.int64
                )  #  e.g. tensor([True, True, True, True, True, True])

            if "ground_truth" in label:
                ground_truth = label["ground_truth"].to(device,
                                                        non_blocking=True)

            data_loading_time.update(
                torch.Tensor([(time.time() - start_data_loading)]))
            optimizer.zero_grad()

            # compute reconstructions
            outputs = self.model(train_specs)

            # compute training reconstruction loss, when augmentation is used
            # loss = loss_fn(outputs, ground_truth)

            # compute training reconstruction loss, when no augmentation is used
            loss = loss_fn(outputs, train_specs)

            # compute accumulated gradients
            loss.backward()

            # perform parameter update based on current gradients
            optimizer.step()

            # add the mini-batch training loss to epoch loss
            # the value of total cost averaged across all training examples of the current batch
            # loss.item()*data.size(0): total loss of the current batch (not averaged).
            train_running_loss += loss.item() * train_specs.size(0)

            prediction = None
            #print("label is ", label, "call_label is ", call_label)

            if i % 2 == 0:
                self.write_summaries(
                    features=train_specs,
                    #labels=call_label,
                    #prediction=prediction,
                    reconstructed=outputs,
                    file_names=label["file_name"],
                    epoch=epoch,
                    phase="train",
                )
            start_data_loading = time.time()

            # compute the epoch training loss
        train_epoch_loss = train_running_loss / len(train_ds)

        self.write_scalar_summaries_logs(
            loss=train_epoch_loss,
            #metrics=metrics,
            lr=optimizer.param_groups[0]["lr"],
            epoch_time=time.time() - epoch_start,
            data_loading_time=data_loading_time.get(),
            epoch=epoch,
            phase="train",
        )

        self.writer.flush()

        return train_epoch_loss

    def validate(self,
                 epoch,
                 val_dataloader,
                 val_ds,
                 loss_fn,
                 device,
                 phase="val"):
        self.logger.debug("{}|{}|start".format(phase, epoch))
        self.model.eval()

        val_running_loss = 0.0
        with torch.no_grad():
            epoch_start = time.time()
            start_data_loading = epoch_start
            data_loading_time = m.Sum(torch.device("cpu"))

            for i, (val_specs, label) in enumerate(val_dataloader):
                val_specs = val_specs.to(device)
                if "call" in label:
                    call_label = label["call"].to(device,
                                                  non_blocking=True,
                                                  dtype=torch.int64)  # bool

                data_loading_time.update(
                    torch.Tensor([(time.time() - start_data_loading)]))

                # instead of converting spec. to color img, we save the 1-chn outputs directly produced by the network
                if i % 2 == 0:
                    #grid = make_grid(val_specs)
                    self.writer.add_images("Original", val_specs,
                                           epoch)  #val_specs

                outputs = self.model(val_specs)

                if i % 2 == 0:
                    # tb = SummaryWriter()
                    #grid = make_grid(outputs)
                    self.writer.add_images("Reconstructed", outputs,
                                           epoch)  #outputs

                loss = loss_fn(outputs, val_specs)

                val_running_loss += loss.item() * val_specs.size(0)

                prediction = None

                if i % 2 == 0:
                    self.write_summaries(
                        features=val_specs,  # original
                        #labels=call_label,
                        #prediction=prediction,
                        reconstructed=outputs,
                        file_names=label["file_name"],
                        epoch=epoch,
                        phase=phase,
                    )
                start_data_loading = time.time()

            val_epoch_loss = val_running_loss / len(val_ds)

            self.write_scalar_summaries_logs(
                loss=val_epoch_loss,
                #metrics=metrics,
                epoch_time=time.time() - epoch_start,
                data_loading_time=data_loading_time.get(),
                epoch=epoch,
                phase=phase,
            )

            self.writer.flush()

            return val_epoch_loss

    def write_summaries(
        self,
        features,
        #labels=None, #  tensor([True, True, True, True, True, True])
        #prediction=None,
        reconstructed=None,
        file_names=None,
        epoch=None,
        phase="train",
    ):
        #"""Writes image summary per partition (spectrograms and the corresponding predictions)"""
        """Writes image summary per partition (spectrograms and reconstructed)"""

        with torch.no_grad():
            self.write_img_summaries(
                features,
                #labels=labels,
                #prediction=prediction,
                reconstructed=reconstructed,
                file_names=file_names,
                epoch=epoch + 1,
                phase=phase,
            )

    def write_img_summaries(
        self,
        features,
        #labels=None,
        #prediction=None,
        reconstructed=None,
        file_names=None,
        epoch=None,
        phase="train",
    ):
        """
        Writes image summary per partition with respect to the prediction output (true predictions - true positive/negative, false
        predictions - false positive/negative)
        """

        with torch.no_grad():
            if file_names is not None:
                if isinstance(file_names, torch.Tensor):
                    file_names = file_names.cpu().numpy()
                elif isinstance(file_names, list):
                    file_names = np.asarray(file_names)
            #if labels is not None and prediction is not None:
            if reconstructed is not None:
                features = features.cpu()
                #labels = labels.cpu()
                #prediction = prediction.cpu()
                reconstructed = reconstructed.cpu()

                self.writer.add_images(
                    tag=phase + "/input",
                    img_tensor=features[:self.n_summaries],
                    #img_tensor=prepare_img(
                    #    features, num_images=self.n_summaries, file_names=file_names
                    #),
                    global_step=epoch,
                )

                self.writer.add_images(
                    tag=phase + "/reconstructed",
                    img_tensor=reconstructed[:self.n_summaries],
                    # img_tensor=prepare_img(
                    #    features, num_images=self.n_summaries, file_names=file_names
                    # ),
                    global_step=epoch,
                )
                """ below are needed to visualize true positive/negative examples"""
                """for label in torch.unique(labels): #  tensor(1, device='cuda:0')
                    label = label.item() # Returns the value of this tensor as a standard Python number: 1
                    l_i = torch.eq(labels, label)

                    t_i = torch.eq(prediction, label) * l_i
                    name_t = "true_{}".format("positive" if label else "negative")
                    try:
                        self.writer.add_image(
                            tag=phase + "/" + name_t,
                            img_tensor=prepare_img(
                                features[t_i],
                                num_images=self.n_summaries,
                                file_names=file_names[t_i.numpy() == 1],
                            ),
                            global_step=epoch,
                        )
                    except ValueError:
                        pass

                    f_i = torch.ne(prediction, label) * l_i
                    name_f = "false_{}".format("negative" if label else "positive")
                    try:
                        self.writer.add_image(
                            tag=phase + "/" + name_f,
                            img_tensor=prepare_img(
                                features[f_i],
                                num_images=self.n_summaries,
                                file_names=file_names[f_i.numpy() == 1],
                            ),
                            global_step=epoch,
                        )
                    except ValueError:
                        pass
            else:
                self.writer.add_image(
                    tag=phase + "/input",
                    img_tensor=prepare_img(
                        features, num_images=self.n_summaries, file_names=file_names
                    ),
                    global_step=epoch,
                )"""

    """
    Writes scalar summary per partition including loss, confusion matrix, accuracy, recall, f1-score, true positive rate,
    false positive rate, precision, data_loading_time, epoch time
    """

    def write_scalar_summaries_logs(
        self,
        loss: float,
        metrics: Union[list, dict] = [],
        lr: float = None,
        epoch_time: float = None,
        data_loading_time: float = None,
        epoch=None,
        phase="train",
    ):
        with torch.no_grad():
            log_str = phase
            if epoch is not None:
                log_str += "|{}".format(epoch)
            self.writer.add_scalar(phase + "/epoch_loss", loss, epoch)
            log_str += "|loss:{:0.3f}".format(loss)
            if isinstance(metrics, dict):
                for name, metric in metrics.items():
                    self.writer.add_scalar(phase + "/" + name, metric.get(),
                                           epoch)
                    log_str += "|{}:{:0.3f}".format(name, metric.get())
            else:
                for i, metric in enumerate(metrics):
                    self.writer.add_scalar(phase + "/metric_" + str(i),
                                           metric.get(), epoch)
                    log_str += "|m_{}:{:0.3f}".format(i, metric.get())
            if lr is not None:
                self.writer.add_scalar("lr", lr, epoch)
                log_str += "|lr:{:0.2e}".format(lr)
            if epoch_time is not None:
                self.writer.add_scalar(phase + "/time", epoch_time, epoch)
                log_str += "|t:{:0.1f}".format(epoch_time)
            if data_loading_time is not None:
                self.writer.add_scalar(phase + "/data_loading_time",
                                       data_loading_time, epoch)
            self.logger.info(log_str)
예제 #28
0
class A2CTrainer(TrainerBase):
    def __init__(self, env, agent, args):
        super(A2CTrainer, self).__init__()
        self.env = env
        self.agent = agent
        self.num_episodes = args.num_episodes
        self.buffer = Buffer(args.buffer_size, args.batch_size)

        assert len(
            args.solved) == 2, 'args.solved has to have length of exactly 2!'
        self.solved_r = args.solved[0]
        self.solved_ep = args.solved[1]
        self.render = args.render

        self.writer = SummaryWriter(args.tensorboard)

    def train(self):
        reward_history = []
        # For each update

        if self.render:
            self.env.render()

        # For each episode
        for episode in range(self.num_episodes):
            episode_reward = 0
            self.buffer.reset()
            state = self.env.reset()

            # While there is room in the buffer
            for i in range(self.buffer.buffer_size):
                # Get action
                action, log_probs, entropy = self.agent.get_action(state)

                # Take step
                state_, reward, done, _ = self.env.step(action)
                episode_reward += reward

                # Store transition in buffer for TD learning
                transition = [state, log_probs, entropy, reward, done]
                self.buffer.insert_transition(transition)

                # If the episode is finished (max_steps reached or from env)
                if done or i == self.buffer.buffer_size - 1:
                    print('Episode: ', episode, 'Reward: %i' % episode_reward)
                    reward_history.append(episode_reward)
                    self.writer.add_scalar("reward", episode_reward, episode)
                    self.writer.flush()

                    if len(reward_history) > self.solved_ep:
                        reward_history.pop(0)
                        if (sum(reward_history) /
                                len(reward_history)) >= self.solved_r:
                            print('Env has been solved at episode ' +
                                  str(episode) + '!')
                            self.writer.close()
                            exit()
                    break
                else:
                    state = state_

            # Get the estimated next step. If episode ended, then next value is 0, otherwise get from critic
            if self.buffer.buffer[-1][-1]:
                R = T.as_tensor([0])
            else:
                _, R = self.agent.model(state)
                R = R.detach()

            # Get transitions from buffer to train with
            transitions = self.buffer.get_buffer()
            states, log_probs, entropys, rewards, dones = self.agent.convert_to_tensors(
                transitions)

            # Calculate the discounted rewards
            returns = self.agent.calculate_returns(rewards, dones, R)
            actor_loss, critic_loss = self.agent.learn(states, log_probs,
                                                       returns, entropys)

            self.writer.add_scalar("loss/actor", actor_loss, episode)
            self.writer.add_scalar("loss/critic", critic_loss, episode)
            self.writer.flush()

        self.writer.close()
예제 #29
0
def train(model_name, fold, run=None, resume_epoch=-1, use_apex=False):
    model_str = build_model_str(model_name, fold, run)

    model_info = MODELS[model_name]

    checkpoints_dir = f'{BaseConfig.checkpoints_dir}/{model_str}'
    tensorboard_dir = f'{BaseConfig.tensorboard_dir}/{model_str}'
    oof_dir = f'{BaseConfig.oof_dir}/{model_str}'
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(oof_dir, exist_ok=True)
    print('\n', model_name, '\n')

    logger = SummaryWriter(log_dir=tensorboard_dir)

    model = model_info.factory(**model_info.args)
    model = model.cuda()

    # try:
    #     torchsummary.summary(model, (4, 512, 512))
    #     print('\n', model_name, '\n')
    # except:
    #     raise
    #     pass

    # model = torch.nn.DataParallel(model).cuda()
    model = model.cuda()

    augmentations = [
        albumentations.ShiftScaleRotate(shift_limit=16. / 256,
                                        scale_limit=0.05,
                                        rotate_limit=30,
                                        interpolation=cv2.INTER_LINEAR,
                                        border_mode=cv2.BORDER_REPLICATE,
                                        p=0.80),
    ]
    if model_info.use_vflip:
        augmentations += [
            albumentations.Flip(),
            albumentations.RandomRotate90()
        ]
    else:
        augmentations += [albumentations.HorizontalFlip()]

    dataset_train = dataset.IntracranialDataset(
        csv_file='5fold-test-rev3.csv',
        folds=[f for f in range(BaseConfig.nb_folds) if f != fold],
        preprocess_func=albumentations.Compose(augmentations),
        **model_info.dataset_args)

    dataset_valid = dataset.IntracranialDataset(csv_file='5fold-test-rev3.csv',
                                                folds=[fold],
                                                preprocess_func=None,
                                                **model_info.dataset_args)

    data_loaders = {
        'train':
        DataLoader(dataset_train,
                   num_workers=8,
                   shuffle=True,
                   batch_size=model_info.batch_size),
        'val':
        DataLoader(dataset_valid,
                   shuffle=False,
                   num_workers=8,
                   batch_size=model_info.batch_size)
    }

    if model_info.single_slice_steps > 0:
        augmentations = [
            albumentations.ShiftScaleRotate(shift_limit=16. / 256,
                                            scale_limit=0.05,
                                            rotate_limit=30,
                                            interpolation=cv2.INTER_LINEAR,
                                            border_mode=cv2.BORDER_REPLICATE,
                                            p=0.80),
        ]
        if model_info.use_vflip:
            augmentations += [
                albumentations.Flip(),
                albumentations.RandomRotate90()
            ]
        else:
            augmentations += [albumentations.HorizontalFlip()]

        dataset_train_1_slice = dataset.IntracranialDataset(
            csv_file='5fold-test-rev3.csv',
            folds=[f for f in range(BaseConfig.nb_folds) if f != fold],
            preprocess_func=albumentations.Compose(augmentations),
            **{
                **model_info.dataset_args, "num_slices": 1
            })

        dataset_valid_1_slice = dataset.IntracranialDataset(
            csv_file='5fold-test-rev3.csv',
            folds=[fold],
            preprocess_func=None,
            **{
                **model_info.dataset_args, "num_slices": 1
            })

        data_loaders['train_1_slice'] = DataLoader(
            dataset_train_1_slice,
            num_workers=8,
            shuffle=True,
            batch_size=model_info.batch_size * 2)
        data_loaders['val_1_slice'] = DataLoader(
            dataset_valid_1_slice,
            shuffle=False,
            num_workers=8,
            batch_size=model_info.batch_size * 2)

    model.train()

    class_weights = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 2.0]).cuda()

    def criterium(y_pred, y_true):
        return F.binary_cross_entropy_with_logits(
            y_pred, y_true, class_weights.repeat(y_pred.shape[0], 1))

    # fit the new layers first:
    if resume_epoch == -1 and model_info.is_pretrained:
        model.train()
        model.freeze_encoder()
        data_loader = data_loaders.get('train_1_slice', data_loaders['train'])
        pre_fit_steps = 40000 // model_info.batch_size
        data_iter = tqdm(enumerate(data_loader), total=pre_fit_steps)
        epoch_loss = []
        initial_optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        for iter_num, data in data_iter:
            if iter_num > pre_fit_steps:
                break
            with torch.set_grad_enabled(True):
                img = data['image'].float().cuda()
                labels = data['labels'].cuda()
                pred = model(img)
                loss = criterium(pred, labels)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 100.0)
                initial_optimizer.step()
                initial_optimizer.zero_grad()
                epoch_loss.append(float(loss))

                data_iter.set_description(
                    f'Loss: Running {np.mean(epoch_loss[-500:]):1.4f} Avg {np.mean(epoch_loss):1.4f}'
                )
        model.unfreeze_encoder()

    optimizer = radam.RAdam(model.parameters(), lr=model_info.initial_lr)
    if use_apex:
        model, optimizer = amp.initialize(model, optimizer, opt_level='O2')

    milestones = [5, 10, 16]
    if model_info.optimiser_milestones:
        milestones = model_info.optimiser_milestones
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=milestones,
                                               gamma=0.2)

    print(
        f'Num training images: {len(dataset_train)} validation images: {len(dataset_valid)}'
    )

    if resume_epoch > -1:
        checkpoint = torch.load(f'{checkpoints_dir}/{resume_epoch:03}.pt')
        print('load', f'{checkpoints_dir}/{resume_epoch:03}.pt')
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if 'amp' in checkpoint:
            amp.load_state_dict(checkpoint['amp'])

    for epoch_num in range(resume_epoch + 1, 7):
        for phase in ['train', 'val']:
            model.train(phase == 'train')
            epoch_loss = []
            epoch_labels = []
            epoch_predictions = []
            epoch_sample_paths = []

            if 'on_epoch' in model.__dir__():
                model.on_epoch(epoch_num)

            if epoch_num < model_info.single_slice_steps:
                data_loader = data_loaders[phase + '_1_slice']
                print("use 1 slice input")
            else:
                data_loader = data_loaders[phase]
                print("use N slices input")

            # if epoch_num == model_info.single_slice_steps:
            #     print("train only conv slices/fn layers")
            #     model.module.freeze_encoder_full()
            #
            # if epoch_num == model_info.single_slice_steps+1:
            #     print("train all")
            #     model.module.unfreeze_encoder()
            #
            # if -1 < model_info.freeze_bn_step <= epoch_num:
            #     print("freeze bn")
            #     model.module.freeze_bn()

            data_iter = tqdm(enumerate(data_loader),
                             total=len(data_loader),
                             ncols=200)
            for iter_num, data in data_iter:
                img = data['image'].float().cuda()
                labels = data['labels'].float().cuda()

                with torch.set_grad_enabled(phase == 'train'):
                    # if epoch_num == model_info.single_slice_steps and phase == 'train':
                    #     with torch.set_grad_enabled(False):
                    #         model_x = model(img, output_before_combine_slices=True)
                    #     with torch.set_grad_enabled(True):
                    #         pred = model(model_x.detach(), train_last_layers_only=True)
                    # else:
                    pred = model(img)
                    loss = criterium(pred, labels)

                    if phase == 'train':
                        if use_apex:
                            with amp.scale_loss(
                                    loss / model_info.accumulation_steps,
                                    optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            (loss / model_info.accumulation_steps).backward()

                        if (iter_num + 1) % model_info.accumulation_steps == 0:
                            # if not use_apex:
                            #     torch.nn.utils.clip_grad_norm_(model.parameters(), 32.0)
                            optimizer.step()
                            optimizer.zero_grad()

                    epoch_loss.append(float(loss))

                    epoch_labels.append(labels.detach().cpu().numpy())
                    epoch_predictions.append(
                        torch.sigmoid(pred).detach().cpu().numpy())
                    epoch_sample_paths += data['path']

                data_iter.set_description(
                    f'{epoch_num} Loss: Running {np.mean(epoch_loss[-1000:]):1.4f} Avg {np.mean(epoch_loss):1.4f}'
                )

            logger.add_scalar(f'loss_{phase}', np.mean(epoch_loss), epoch_num)
            logger.add_scalar('lr', optimizer.param_groups[0]['lr'],
                              epoch_num)  # scheduler.get_lr()[0]
            try:
                epoch_labels = np.row_stack(epoch_labels)
                epoch_predictions = np.row_stack(epoch_predictions)
                print(epoch_labels.shape, epoch_predictions.shape)
                log_metrics(logger=logger,
                            phase=phase,
                            epoch_num=epoch_num,
                            y=epoch_labels,
                            y_hat=epoch_predictions)
            except Exception:
                pass
            logger.flush()

            if phase == 'val':
                scheduler.step(epoch=epoch_num)
                torch.save(
                    {
                        'epoch': epoch_num,
                        'sample_paths': epoch_sample_paths,
                        'epoch_labels': epoch_labels,
                        'epoch_predictions': epoch_predictions,
                    }, f'{oof_dir}/{epoch_num:03}.pt')
            else:
                # print(f'{checkpoints_dir}/{epoch_num:03}.pt')
                torch.save(
                    {
                        'epoch': epoch_num,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'amp': amp.state_dict()
                    }, f'{checkpoints_dir}/{epoch_num:03}.pt')
예제 #30
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    # Added here for reproductibility
    set_seed(args)

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if args.model_type in [
                    "xlm", "roberta", "distilbert", "camembert"
            ]:
                del inputs["token_type_ids"]

            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if args.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
                if hasattr(model, "config") and hasattr(
                        model.config, "lang2id"):
                    inputs.update({
                        "langs":
                        (torch.ones(batch[0].shape, dtype=torch.int64) *
                         args.lang_id).to(args.device)
                    })

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                            print("eval_{} : {}".format(key, value))
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss
                    tb_writer.flush()

                # Save model checkpoint
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step