示例#1
0
class TrainManager:
    """ Manages training loop, validations, learning rate scheduling
    and early stopping."""
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger("{}/train.log".format(self.model_dir))
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.loss = XentLoss(pad_index=self.pad_index,
                             smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens"]:
            raise ConfigurationError("Invalid normalization. "
                                     "Valid options: 'batch', 'tokens'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in ['bleu', 'chrf']:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # if we schedule after BLEU/chrf, we want to maximize it, else minimize
        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in ["bleu", "chrf"]:
                self.minimize_metric = False
            else:  # eval metric that has to get minimized (not yet implemented)
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        # comparison function for scores
        self.is_best = lambda score: score < self.best_ckpt_score \
            if self.minimize_metric else score > self.best_ckpt_score

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            reset_best_ckpt = train_config.get("reset_best_ckpt", False)
            reset_scheduler = train_config.get("reset_scheduler", False)
            reset_optimizer = train_config.get("reset_optimizer", False)
            self.init_from_checkpoint(model_load_path,
                                      reset_best_ckpt=reset_best_ckpt,
                                      reset_scheduler=reset_scheduler,
                                      reset_optimizer=reset_optimizer)

    def _save_checkpoint(self) -> None:
        """
        Save the model's current parameters and the training state to a
        checkpoint.

        The training state contains the total number of training steps,
        the total number of training tokens,
        the best checkpoint score and iteration so far,
        and optimizer and scheduler states.

        """
        model_path = "{}/{}.ckpt".format(self.model_dir, self.steps)
        state = {
            "steps": self.steps,
            "total_tokens": self.total_tokens,
            "best_ckpt_score": self.best_ckpt_score,
            "best_ckpt_iteration": self.best_ckpt_iteration,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict() if \
            self.scheduler is not None else None,
        }
        torch.save(state, model_path)
        if self.ckpt_queue.full():
            to_delete = self.ckpt_queue.get()  # delete oldest ckpt
            try:
                os.remove(to_delete)
            except FileNotFoundError:
                self.logger.warning(
                    "Wanted to delete old checkpoint %s but "
                    "file does not exist.", to_delete)

        self.ckpt_queue.put(model_path)

        best_path = "{}/best.ckpt".format(self.model_dir)
        try:
            # create/modify symbolic link for best checkpoint
            symlink_update("{}.ckpt".format(self.steps), best_path)
        except OSError:
            # overwrite best.ckpt
            torch.save(state, best_path)

    def init_from_checkpoint(self,
                             path: str,
                             reset_best_ckpt: bool = False,
                             reset_scheduler: bool = False,
                             reset_optimizer: bool = False) -> None:
        """
        Initialize the trainer from a given checkpoint file.

        This checkpoint file contains not only model parameters, but also
        scheduler and optimizer states, see `self._save_checkpoint`.

        :param path: path to checkpoint
        :param reset_best_ckpt: reset tracking of the best checkpoint,
                                use for domain adaptation with a new dev
                                set or when using a new metric for fine-tuning.
        :param reset_scheduler: reset the learning rate scheduler, and do not
                                use the one stored in the checkpoint.
        :param reset_optimizer: reset the optimizer, and do not use the one
                                stored in the checkpoint.
        """
        model_checkpoint = load_checkpoint(path=path, use_cuda=self.use_cuda)

        # restore model and optimizer parameters
        self.model.load_state_dict(model_checkpoint["model_state"])

        if not reset_optimizer:
            self.optimizer.load_state_dict(model_checkpoint["optimizer_state"])
        else:
            self.logger.info("Reset optimizer.")

        if not reset_scheduler:
            if model_checkpoint["scheduler_state"] is not None and \
                    self.scheduler is not None:
                self.scheduler.load_state_dict(
                    model_checkpoint["scheduler_state"])
        else:
            self.logger.info("Reset scheduler.")

        # restore counts
        self.steps = model_checkpoint["steps"]
        self.total_tokens = model_checkpoint["total_tokens"]

        if not reset_best_ckpt:
            self.best_ckpt_score = model_checkpoint["best_ckpt_score"]
            self.best_ckpt_iteration = model_checkpoint["best_ckpt_iteration"]
        else:
            self.logger.info("Reset tracking of the best checkpoint.")

        # move parameters to cuda
        if self.use_cuda:
            self.model.cuda()

    # pylint: disable=unnecessary-comprehension
    def train_and_validate(self, train_data: Dataset, valid_data: Dataset) \
            -> None:
        """
        Train the model and validate it from time to time on the validation set.

        :param train_data: training data
        :param valid_data: validation data
        """
        train_iter = make_data_iter(train_data,
                                    batch_size=self.batch_size,
                                    batch_type=self.batch_type,
                                    train=True,
                                    shuffle=self.shuffle)
        for epoch_no in range(self.epochs):
            self.logger.info("EPOCH %d", epoch_no + 1)

            if self.scheduler is not None and self.scheduler_step_at == "epoch":
                self.scheduler.step(epoch=epoch_no)

            self.model.train()

            # Reset statistics for each epoch.
            start = time.time()
            total_valid_duration = 0
            start_tokens = self.total_tokens
            count = self.batch_multiplier - 1
            epoch_loss = 0

            for batch in iter(train_iter):
                # reactivate training
                self.model.train()
                # create a Batch object from torchtext batch
                batch = Batch(batch, self.pad_index, use_cuda=self.use_cuda)

                # only update every batch_multiplier batches
                # see https://medium.com/@davidlmorton/
                # increasing-mini-batch-size-without-increasing-
                # memory-6794e10db672
                update = count == 0
                # print(count, update, self.steps)
                batch_loss = self._train_batch(batch, update=update)
                self.tb_writer.add_scalar("train/train_batch_loss", batch_loss,
                                          self.steps)
                count = self.batch_multiplier if update else count
                count -= 1
                epoch_loss += batch_loss.detach().cpu().numpy()

                if self.scheduler is not None and \
                        self.scheduler_step_at == "step" and update:
                    self.scheduler.step()

                # log learning progress
                if self.steps % self.logging_freq == 0 and update:
                    elapsed = time.time() - start - total_valid_duration
                    elapsed_tokens = self.total_tokens - start_tokens
                    self.logger.info(
                        "Epoch %3d Step: %8d Batch Loss: %12.6f "
                        "Tokens per Sec: %8.0f, Lr: %.6f", epoch_no + 1,
                        self.steps, batch_loss, elapsed_tokens / elapsed,
                        self.optimizer.param_groups[0]["lr"])
                    start = time.time()
                    total_valid_duration = 0
                    start_tokens = self.total_tokens

                # validate on the entire dev set
                if self.steps % self.validation_freq == 0 and update:
                    valid_start_time = time.time()

                    valid_score, valid_loss, valid_ppl, valid_sources, \
                        valid_sources_raw, valid_references, valid_hypotheses, \
                        valid_hypotheses_raw, valid_attention_scores = \
                        validate_on_data(
                            logger=self.logger,
                            batch_size=self.eval_batch_size,
                            data=valid_data,
                            eval_metric=self.eval_metric,
                            level=self.level, model=self.model,
                            use_cuda=self.use_cuda,
                            max_output_length=self.max_output_length,
                            loss_function=self.loss,
                            beam_size=1,  # greedy validations
                            batch_type=self.eval_batch_type
                        )

                    self.tb_writer.add_scalar("valid/valid_loss", valid_loss,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_score", valid_score,
                                              self.steps)
                    self.tb_writer.add_scalar("valid/valid_ppl", valid_ppl,
                                              self.steps)

                    if self.early_stopping_metric == "loss":
                        ckpt_score = valid_loss
                    elif self.early_stopping_metric in ["ppl", "perplexity"]:
                        ckpt_score = valid_ppl
                    else:
                        ckpt_score = valid_score

                    new_best = False
                    if self.is_best(ckpt_score):
                        self.best_ckpt_score = ckpt_score
                        self.best_ckpt_iteration = self.steps
                        self.logger.info(
                            'Hooray! New best validation result [%s]!',
                            self.early_stopping_metric)
                        if self.ckpt_queue.maxsize > 0:
                            self.logger.info("Saving new checkpoint.")
                            new_best = True
                            self._save_checkpoint()

                    if self.scheduler is not None \
                            and self.scheduler_step_at == "validation":
                        self.scheduler.step(ckpt_score)

                    # append to validation report
                    self._add_report(valid_score=valid_score,
                                     valid_loss=valid_loss,
                                     valid_ppl=valid_ppl,
                                     eval_metric=self.eval_metric,
                                     new_best=new_best)

                    self._log_examples(
                        sources_raw=[v for v in valid_sources_raw],
                        sources=valid_sources,
                        hypotheses_raw=valid_hypotheses_raw,
                        hypotheses=valid_hypotheses,
                        references=valid_references)

                    valid_duration = time.time() - valid_start_time
                    total_valid_duration += valid_duration
                    self.logger.info(
                        'Validation result (greedy) at epoch %3d, '
                        'step %8d: %s: %6.2f, loss: %8.4f, ppl: %8.4f, '
                        'duration: %.4fs', epoch_no + 1, self.steps,
                        self.eval_metric, valid_score, valid_loss, valid_ppl,
                        valid_duration)

                    # store validation set outputs
                    self._store_outputs(valid_hypotheses)

                    # store attention plots for selected valid sentences
                    if valid_attention_scores:
                        store_attention_plots(
                            attentions=valid_attention_scores,
                            targets=valid_hypotheses_raw,
                            sources=[s for s in valid_data.src],
                            indices=self.log_valid_sents,
                            output_prefix="{}/att.{}".format(
                                self.model_dir, self.steps),
                            tb_writer=self.tb_writer,
                            steps=self.steps)

                if self.stop:
                    break
            if self.stop:
                self.logger.info(
                    'Training ended since minimum lr %f was reached.',
                    self.learning_rate_min)
                break

            self.logger.info('Epoch %3d: total training loss %.2f',
                             epoch_no + 1, epoch_loss)
        else:
            self.logger.info('Training ended after %3d epochs.', epoch_no + 1)
        self.logger.info(
            'Best validation result (greedy) at step '
            '%8d: %6.2f %s.', self.best_ckpt_iteration, self.best_ckpt_score,
            self.early_stopping_metric)

        self.tb_writer.close()  # close Tensorboard writer

    def _train_batch(self, batch: Batch, update: bool = True) -> Tensor:
        """
        Train the model on one batch: Compute the loss, make a gradient step.

        :param batch: training batch
        :param update: if False, only store gradient. if True also make update
        :return: loss for batch (sum)
        """
        batch_loss = self.model.get_loss_for_batch(batch=batch,
                                                   loss_function=self.loss)

        # normalize batch loss
        if self.normalization == "batch":
            normalizer = batch.nseqs
        elif self.normalization == "tokens":
            normalizer = batch.ntokens
        else:
            raise NotImplementedError("Only normalize by 'batch' or 'tokens'")

        norm_batch_loss = batch_loss / normalizer
        # division needed since loss.backward sums the gradients until updated
        norm_batch_multiply = norm_batch_loss / self.batch_multiplier

        # compute gradients
        norm_batch_multiply.backward()

        if self.clip_grad_fun is not None:
            # clip gradients (in-place)
            self.clip_grad_fun(params=self.model.parameters())

        if update:
            # make gradient step
            self.optimizer.step()
            self.optimizer.zero_grad()

            # increment step counter
            self.steps += 1

        # increment token counter
        self.total_tokens += batch.ntokens

        return norm_batch_loss

    def _add_report(self,
                    valid_score: float,
                    valid_ppl: float,
                    valid_loss: float,
                    eval_metric: str,
                    new_best: bool = False) -> None:
        """
        Append a one-line report to validation logging file.

        :param valid_score: validation evaluation score [eval_metric]
        :param valid_ppl: validation perplexity
        :param valid_loss: validation loss (sum over whole validation set)
        :param eval_metric: evaluation metric, e.g. "bleu"
        :param new_best: whether this is a new best model
        """
        current_lr = -1
        # ignores other param groups for now
        for param_group in self.optimizer.param_groups:
            current_lr = param_group['lr']

        if current_lr < self.learning_rate_min:
            self.stop = True

        with open(self.valid_report_file, 'a') as opened_file:
            opened_file.write(
                "Steps: {}\tLoss: {:.5f}\tPPL: {:.5f}\t{}: {:.5f}\t"
                "LR: {:.8f}\t{}\n".format(self.steps, valid_loss, valid_ppl,
                                          eval_metric, valid_score, current_lr,
                                          "*" if new_best else ""))

    def _log_parameters_list(self) -> None:
        """
        Write all model parameters (name, shape) to the log.
        """
        model_parameters = filter(lambda p: p.requires_grad,
                                  self.model.parameters())
        n_params = sum([np.prod(p.size()) for p in model_parameters])
        self.logger.info("Total params: %d", n_params)
        trainable_params = [
            n for (n, p) in self.model.named_parameters() if p.requires_grad
        ]
        self.logger.info("Trainable parameters: %s", sorted(trainable_params))
        assert trainable_params

    def _log_examples(self,
                      sources: List[str],
                      hypotheses: List[str],
                      references: List[str],
                      sources_raw: List[List[str]] = None,
                      hypotheses_raw: List[List[str]] = None,
                      references_raw: List[List[str]] = None) -> None:
        """
        Log a the first `self.log_valid_sents` sentences from given examples.

        :param sources: decoded sources (list of strings)
        :param hypotheses: decoded hypotheses (list of strings)
        :param references: decoded references (list of strings)
        :param sources_raw: raw sources (list of list of tokens)
        :param hypotheses_raw: raw hypotheses (list of list of tokens)
        :param references_raw: raw references (list of list of tokens)
        """
        for p in self.log_valid_sents:

            if p >= len(sources):
                continue

            self.logger.info("Example #%d", p)

            if sources_raw is not None:
                self.logger.debug("\tRaw source:     %s", sources_raw[p])
            if references_raw is not None:
                self.logger.debug("\tRaw reference:  %s", references_raw[p])
            if hypotheses_raw is not None:
                self.logger.debug("\tRaw hypothesis: %s", hypotheses_raw[p])

            self.logger.info("\tSource:     %s", sources[p])
            self.logger.info("\tReference:  %s", references[p])
            self.logger.info("\tHypothesis: %s", hypotheses[p])

    def _store_outputs(self, hypotheses: List[str]) -> None:
        """
        Write current validation outputs to file in `self.model_dir.`

        :param hypotheses: list of strings
        """
        current_valid_output_file = "{}/{}.hyps".format(
            self.model_dir, self.steps)
        with open(current_valid_output_file, 'w') as opened_file:
            for hyp in hypotheses:
                opened_file.write("{}\n".format(hyp))
示例#2
0
class Logger:
    def __init__(self, logpath):
        self.logpath = logpath
        self.writer = None

    def create_summarywriter(self):
        if self.writer is None:
            self.writer = SummaryWriter(self.logpath)

    def write_vls(self, data_blob, outputs, flowselector, step):
        img1 = data_blob['img1'][0].permute([1, 2, 0]).numpy().astype(np.uint8)
        img2 = data_blob['img2'][0].permute([1, 2, 0]).numpy().astype(np.uint8)

        figmask_flow = tensor2disp(flowselector, vmax=1, viewind=0)

        depthpredvls = tensor2disp(1 / outputs[('depth', 2)],
                                   vmax=1,
                                   viewind=0)
        flowvls = flow_to_image(
            outputs[('flowpred', 2)][0].detach().cpu().permute([1, 2,
                                                                0]).numpy(),
            rad_max=10)
        imgrecon = tensor2rgb(outputs[('reconImg', 2)], viewind=0)

        img_val_up = np.concatenate([np.array(img1), np.array(img2)], axis=1)
        img_val_mid2 = np.concatenate(
            [np.array(depthpredvls),
             np.array(figmask_flow)], axis=1)
        img_val_mid3 = np.concatenate(
            [np.array(imgrecon), np.array(flowvls)], axis=1)
        img_val = np.concatenate([
            np.array(img_val_up),
            np.array(img_val_mid2),
            np.array(img_val_mid3)
        ],
                                 axis=0)
        self.writer.add_image('predvls', (torch.from_numpy(img_val).float() /
                                          255).permute([2, 0, 1]), step)

        X = self.vls_sampling(img1, img2, data_blob['depthmap'],
                              data_blob['flowpred'], outputs)
        self.writer.add_image('X', (torch.from_numpy(X).float() / 255).permute(
            [2, 0, 1]), step)

    def vls_sampling(self, img1, img2, depthgt, flowpred, outputs):
        depthgtnp = depthgt[0].squeeze().cpu().numpy()

        h, w, _ = img1.shape
        xx, yy = np.meshgrid(range(w), range(h), indexing='xy')
        selector = (depthgtnp > 0)

        slRange_sel = (np.mod(xx, 4) == 0) * (np.mod(yy, 4) == 0) * selector
        dsratio = 4

        xxfsl = xx[slRange_sel]
        yyfsl = yy[slRange_sel]
        rndidx = np.random.randint(0, xxfsl.shape[0], 1).item()

        xxfsl_sel = xxfsl[rndidx]
        yyfsl_sel = yyfsl[rndidx]

        slvlsxx_fg = (outputs['sample_pts'][
            0, :, int(yyfsl_sel / dsratio),
            int(xxfsl_sel / dsratio), 0].detach().cpu().numpy() + 1) / 2 * w
        slvlsyy_fg = (outputs['sample_pts'][
            0, :, int(yyfsl_sel / dsratio),
            int(xxfsl_sel / dsratio), 1].detach().cpu().numpy() + 1) / 2 * h

        flow_predx = flowpred[0, 0, yyfsl_sel, xxfsl_sel].cpu().numpy()
        flow_predy = flowpred[0, 1, yyfsl_sel, xxfsl_sel].cpu().numpy()

        fig = plt.figure(figsize=(16, 9))
        canvas = FigureCanvasAgg(fig)
        fig.add_subplot(2, 1, 1)
        plt.imshow(img1)
        plt.scatter(xxfsl_sel, yyfsl_sel, 3, 'r')
        plt.title("Input")

        fig.add_subplot(2, 1, 2)
        plt.scatter(slvlsxx_fg, slvlsyy_fg, 3, 'b')
        plt.scatter(xxfsl_sel + flow_predx, yyfsl_sel + flow_predy, 3, 'r')
        plt.imshow(img2)
        plt.title("Sampling Arae")

        fig.tight_layout()  # Or equivalently,  "plt.tight_layout()"
        canvas.draw()
        buf = canvas.buffer_rgba()
        plt.close()
        X = np.asarray(buf)
        return X

    def write_vls_eval(self, data_blob, outputs, tagname, step):
        img1 = data_blob['img1'][0].permute([1, 2, 0]).numpy().astype(np.uint8)
        img2 = data_blob['img2'][0].permute([1, 2, 0]).numpy().astype(np.uint8)

        inputdepth = tensor2disp(1 / data_blob['depthmap'], vmax=1, viewind=0)
        depthpredvls = tensor2disp(1 / outputs[('depth', 2)],
                                   vmax=1,
                                   viewind=0)
        flowvls = flow_to_image(
            outputs[('flowpred', 2)][0].detach().cpu().permute([1, 2,
                                                                0]).numpy(),
            rad_max=10)
        imgrecon = tensor2rgb(outputs[('reconImg', 2)], viewind=0)

        img_val_up = np.concatenate([np.array(img1), np.array(img2)], axis=1)
        img_val_mid2 = np.concatenate(
            [np.array(inputdepth),
             np.array(depthpredvls)], axis=1)
        img_val_mid3 = np.concatenate(
            [np.array(imgrecon), np.array(flowvls)], axis=1)
        img_val = np.concatenate([
            np.array(img_val_up),
            np.array(img_val_mid2),
            np.array(img_val_mid3)
        ],
                                 axis=0)
        self.writer.add_image('{}_predvls'.format(tagname),
                              (torch.from_numpy(img_val).float() /
                               255).permute([2, 0, 1]), step)

        X = self.vls_sampling(img1, img2, data_blob['depthmap'],
                              data_blob['flowpred'], outputs)
        self.writer.add_image('{}_X'.format(tagname),
                              (torch.from_numpy(X).float() / 255).permute(
                                  [2, 0, 1]), step)

    def write_dict(self, results, step):
        for key in results:
            self.writer.add_scalar(key, results[key], step)

    def close(self):
        self.writer.close()
示例#3
0
def train():
    from torch.utils.tensorboard import SummaryWriter
    # load configs
    parser = argparse.ArgumentParser(description="Phrase Classification Training")
    parser.add_argument('-c', '--config_file', default=None, help="path to config file")
    parser.add_argument('-o', '--opts', default=None, nargs=argparse.REMAINDER,
                        help="Modify config options using the command-line. E.g. TRAIN.INIT_LR 0.01",)
    args = parser.parse_args()

    if args.config_file is not None:
        cfg.merge_from_file(args.config_file)
    if args.opts is not None:
        cfg.merge_from_list(args.opts)

    cfg.freeze()
    print(cfg.dump())

    if not os.path.exists(cfg.OUTPUT_PATH):
        os.makedirs(cfg.OUTPUT_PATH)
    with open(os.path.join(cfg.OUTPUT_PATH, 'train.cfg'), 'w') as f:
        f.write(cfg.dump())

    # set random seed
    torch.manual_seed(cfg.RAND_SEED)
    np.random.seed(cfg.RAND_SEED)
    random.seed(cfg.RAND_SEED)

    # make data_loader, model, criterion, optimizer
    dataset = PhraseClassifyDataset(split=cfg.TRAIN_SPLIT, is_train=True, cached_resnet_feats=None)
    train_data_loader = DataLoader(dataset, batch_size=cfg.TRAIN.BATCH_SIZE, shuffle=True, drop_last=True)

    eval_data_loader = None
    if cfg.TRAIN.EVAL_EVERY_EPOCH > 0:
        eval_dataset = PhraseClassifyDataset(split=cfg.EVAL_SPLIT, is_train=False, cached_resnet_feats=None)
        eval_data_loader = DataLoader(eval_dataset, batch_size=64, shuffle=False)

    model: PhraseClassifier = PhraseClassifier(class_num=len(dataset.phrases), pretrained_backbone=True,
                                               fc_dims=cfg.MODEL.FC_DIMS, use_feats=cfg.MODEL.BACKBONE_FEATS)
    if not cfg.TRAIN.TUNE_BACKBONE:
        model.img_encoder.requires_grad = False
        model.img_encoder.eval()

    if len(cfg.MODEL.LOAD_WEIGHTS) > 0:
        model.load_state_dict(torch.load(cfg.MODEL.LOAD_WEIGHTS))

    model.train()
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    # re-weight loss based on phrase frequency, more weights on positive samples
    if cfg.TRAIN.LOSS_REWEIGHT:
        class_weights = get_class_weights(dataset.phrase_freq)
        class_weights = torch.from_numpy(class_weights).to(device)
        criterion = nn.BCEWithLogitsLoss(reduction='mean', pos_weight=class_weights)
    else:
        criterion = nn.BCEWithLogitsLoss(reduction='mean')
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                                 lr=cfg.TRAIN.INIT_LR, weight_decay=cfg.TRAIN.WEIGHT_DECAY,
                                 betas=(cfg.TRAIN.ADAM.ALPHA, cfg.TRAIN.ADAM.BETA),
                                 eps=cfg.TRAIN.ADAM.EPSILON)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.TRAIN.LR_DECAY_EPOCH,
                                                   gamma=cfg.TRAIN.LR_DECAY_GAMMA)

    # make tensorboard writer and dirs
    checkpoint_dir = os.path.join(cfg.OUTPUT_PATH, 'checkpoints')
    tb_dir = os.path.join(cfg.OUTPUT_PATH, 'tensorboard')
    tb_writer = SummaryWriter(tb_dir)
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)

    # training loop
    step = 1
    epoch = 1
    loss = None
    pred_labels = None
    while epoch <= cfg.TRAIN.MAX_EPOCH:
        lr = optimizer.param_groups[0]['lr']

        for _, imgs, labels in train_data_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            pred_labels = model(imgs)
            loss = criterion(pred_labels, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step <= 20:
                print('[%s] epoch-%d step-%d: loss %.4f; lr %.4f'
                      % (time.strftime('%m/%d %H:%M:%S'), epoch, step, loss, lr))
            # if epoch == 1 and step == 2:  # debug
            #     do_eval(model, eval_data_loader, device, visualize_path=visualize_path, add_to_summary_name='debug')
            step += 1

        lr_scheduler.step(epoch=epoch)
        print('[%s] epoch-%d step-%d: loss %.4f; lr %.4f'
              % (time.strftime('%m/%d %H:%M:%S'), epoch, step, loss, lr))

        tb_writer.add_scalar('train/loss', loss, epoch)
        tb_writer.add_scalar('train/lr', lr, epoch)
        tb_writer.add_scalar('step', step, epoch)
        tb_writer.add_histogram('pred_labels', pred_labels, epoch)

        vis = None
        if epoch % cfg.TRAIN.CHECKPOINT_EVERY_EPOCH == 0:
            vis = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize_epoch%03d' % epoch)
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, 'epoch%03d.pth' % epoch))

        if epoch % cfg.TRAIN.EVAL_EVERY_EPOCH == 0 and cfg.TRAIN.EVAL_EVERY_EPOCH > 0:
            p2i_result, i2p_result = do_eval(model, eval_data_loader, device, visualize_path=vis)
            for m, v in p2i_result.items():
                tb_writer.add_scalar('eval_p2i/%s' % m, v, epoch)
            for m, v in i2p_result.items():
                tb_writer.add_scalar('eval_i2p/%s' % m, v, epoch)
            model.train()
            if not cfg.TRAIN.TUNE_BACKBONE:
                model.img_encoder.eval()

        epoch += 1

    tb_writer.close()
    visualize_path = os.path.join(cfg.OUTPUT_PATH, 'eval_visualize')
    p2i_result, i2p_result = do_eval(model, eval_data_loader, device, visualize_path=visualize_path,
                                     add_to_summary_name='%s:epoch-%d' % (cfg.OUTPUT_PATH, epoch))
    return p2i_result, i2p_result
示例#4
0
def train(args, training_features, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0] and args.log_dir:
        tb_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        tb_writer = None

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
    else:
        amp = None

    # model recover
    recover_step = utils.get_max_epoch_model(args.output_dir)

    if recover_step:
        checkpoint_state_dict = utils.get_checkpoint_state_dict(
            args.output_dir, recover_step)
    else:
        checkpoint_state_dict = None

    model.to(args.device)
    model, optimizer = prepare_for_training(args,
                                            model,
                                            checkpoint_state_dict,
                                            amp=amp)

    per_node_train_batch_size = args.per_gpu_train_batch_size * args.n_gpu * args.gradient_accumulation_steps
    train_batch_size = per_node_train_batch_size * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)
    global_step = recover_step if recover_step else 0

    if args.num_training_steps == -1:
        args.num_training_steps = args.num_training_epochs * len(
            training_features) / train_batch_size

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.num_training_steps,
        last_epoch=-1)

    if checkpoint_state_dict:
        scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"])

    train_dataset = utils.Seq2seqDatasetForBert(
        features=training_features,
        max_source_len=args.max_source_seq_length,
        max_target_len=args.max_target_seq_length,
        vocab_size=tokenizer.vocab_size,
        cls_id=tokenizer.cls_token_id,
        sep_id=tokenizer.sep_token_id,
        pad_id=tokenizer.pad_token_id,
        mask_id=tokenizer.mask_token_id,
        random_prob=args.random_prob,
        keep_prob=args.keep_prob,
        offset=train_batch_size * global_step,
        num_training_instances=train_batch_size * args.num_training_steps,
        source_mask_prob=args.source_mask_prob,
        target_mask_prob=args.target_mask_prob,
        mask_way=args.mask_way,
        num_max_mask_token=args.num_max_mask_token,
    )

    logger.info("Check dataset:")
    for i in range(5):
        source_ids, target_ids = train_dataset.__getitem__(i)[:2]
        logger.info("Instance-%d" % i)
        logger.info("Source tokens = %s" %
                    " ".join(tokenizer.convert_ids_to_tokens(source_ids)))
        logger.info("Target tokens = %s" %
                    " ".join(tokenizer.convert_ids_to_tokens(target_ids)))

    logger.info("Mode = %s" % str(model))

    # Train!
    logger.info("  ***** Running training *****  *")
    logger.info("  Num examples = %d", len(training_features))
    logger.info("  Num Epochs = %.2f",
                len(train_dataset) / len(training_features))
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info("  Batch size per node = %d", per_node_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        train_batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", args.num_training_steps)

    if args.num_training_steps <= global_step:
        logger.info(
            "Training is done. Please use a new dir or clean this dir!")
    else:
        # The training features are shuffled
        train_sampler = SequentialSampler(train_dataset) \
            if args.local_rank == -1 else DistributedSampler(train_dataset, shuffle=False)
        train_dataloader = DataLoader(
            train_dataset,
            sampler=train_sampler,
            batch_size=per_node_train_batch_size //
            args.gradient_accumulation_steps,
            collate_fn=utils.batch_list_to_batch_tensors)

        train_iterator = tqdm.tqdm(train_dataloader,
                                   initial=global_step *
                                   args.gradient_accumulation_steps,
                                   desc="Iter (loss=X.XXX, lr=X.XXXXXXX)",
                                   disable=args.local_rank not in [-1, 0])

        model.train()
        model.zero_grad()

        tr_loss, logging_loss = 0.0, 0.0

        for step, batch in enumerate(train_iterator):
            if global_step > args.num_training_steps:
                break
            batch = tuple(t.to(args.device) for t in batch)
            if args.mask_way == 'v2':
                inputs = {
                    'source_ids': batch[0],
                    'target_ids': batch[1],
                    'label_ids': batch[2],
                    'pseudo_ids': batch[3],
                    'num_source_tokens': batch[4],
                    'num_target_tokens': batch[5]
                }
            elif args.mask_way == 'v1' or args.mask_way == 'v0':
                inputs = {
                    'source_ids': batch[0],
                    'target_ids': batch[1],
                    'masked_ids': batch[2],
                    'masked_pos': batch[3],
                    'masked_weight': batch[4],
                    'num_source_tokens': batch[5],
                    'num_target_tokens': batch[6]
                }
            loss = model(**inputs)
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training

            train_iterator.set_description(
                'Iter (loss=%5.3f) lr=%9.7f' %
                (loss.item(), scheduler.get_lr()[0]))

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            logging_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logger.info("")
                    logger.info(" Step [%d ~ %d]: %.2f",
                                global_step - args.logging_steps, global_step,
                                logging_loss)
                    logging_loss = 0.0

                if args.local_rank in [-1, 0] and args.save_steps > 0 and \
                        (global_step % args.save_steps == 0 or global_step == args.num_training_steps):

                    save_path = os.path.join(args.output_dir,
                                             "ckpt-%d" % global_step)
                    os.makedirs(save_path, exist_ok=True)
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(save_path)

                    optim_to_save = {
                        "optimizer": optimizer.state_dict(),
                        "lr_scheduler": scheduler.state_dict(),
                    }
                    if args.fp16:
                        optim_to_save["amp"] = amp.state_dict()
                    torch.save(optim_to_save,
                               os.path.join(save_path, utils.OPTIM_NAME))

                    logger.info("Saving model checkpoint %d into %s",
                                global_step, save_path)

    if args.local_rank in [-1, 0] and tb_writer:
        tb_writer.close()
示例#5
0
class TransUnetTrainer:
    def __init__(self, net, optimizer, loss, scheduler, save_dir, save_from,
                 logger):
        self.net = net
        self.optimizer = optimizer
        self.loss = loss
        self.scheduler = scheduler
        self.save_dir = save_dir
        self.save_from = save_from
        self.writer = SummaryWriter()
        self.logger = logger

    def val(self, test_loader, epoch):
        len_test = len(test_loader)

        for i, pack in enumerate(test_loader, start=1):
            image, gt = pack
            self.net.eval()

            # gt = gt[0][0]
            # gt = np.asarray(gt, np.float32)
            res2 = 0
            image = image.cuda()
            gt = gt.cuda()

            (
                loss_recordx2,
                loss_recordx3,
                loss_recordx4,
                loss_record2,
                loss_record3,
                loss_record4,
                loss_record5,
            ) = (
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
            )

            res5, res4, res3, res2 = self.net(image)

            loss5 = self.loss(res5, gt)
            loss4 = self.loss(res4, gt)
            loss3 = self.loss(res3, gt)
            loss2 = self.loss(res2, gt)
            loss = loss2 + loss3 + loss4 + loss5

            loss_record2.update(loss2.data, 1)
            loss_record3.update(loss3.data, 1)
            loss_record4.update(loss4.data, 1)
            loss_record5.update(loss5.data, 1)

            self.writer.add_scalar("Loss1_test", loss_record2.show(),
                                   (epoch - 1) * len(test_loader) + i)
            # writer.add_scalar("Loss2", loss_record3.show(), (epoch-1)*len(train_loader) + i)
            # writer.add_scalar("Loss3", loss_record4.show(), (epoch-1)*len(train_loader) + i)
            # writer.add_scalar("Loss4", loss_record5.show(), (epoch-1)*len(train_loader) + i)

            if i == len_test - 1:
                self.logger.info(
                    "TEST:{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}],\
                    [loss_record2: {:.4f},loss_record3: {:.4f},loss_record4: {:.4f},loss_record5: {:.4f}]"
                    .format(
                        datetime.now(),
                        epoch,
                        epoch,
                        self.optimizer.param_groups[0]["lr"],
                        i,
                        loss_record2.show(),
                        loss_record3.show(),
                        loss_record4.show(),
                        loss_record5.show(),
                    ))

    def fit(
        self,
        train_loader,
        is_val=False,
        test_loader=None,
        img_size=352,
        start_from=0,
        num_epochs=200,
        batchsize=16,
        clip=0.5,
        fold=4,
    ):

        size_rates = [0.75, 1, 1.25]
        rate = 1

        test_fold = f"fold{fold}"
        start = timeit.default_timer()
        for epoch in range(start_from, num_epochs):

            self.net.train()
            loss_all, loss_record2, loss_record3, loss_record4, loss_record5 = (
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
                AvgMeter(),
            )
            for i, pack in enumerate(train_loader, start=1):

                self.optimizer.zero_grad()

                # ---- data prepare ----
                images, gts = pack
                # images, gts, paths, oriimgs = pack

                images = Variable(images).cuda()
                gts = Variable(gts).cuda()

                lateral_map_5 = self.net(images)
                loss5 = self.loss(lateral_map_5, gts)

                loss5.backward()
                clip_gradient(self.optimizer, clip)
                self.optimizer.step()

                if rate == 1:
                    loss_record5.update(loss5.data, batchsize)
                    self.writer.add_scalar(
                        "Loss5",
                        loss_record5.show(),
                        (epoch - 1) * len(train_loader) + i,
                    )

                total_step = len(train_loader)
                if i % 25 == 0 or i == total_step:
                    self.logger.info(
                        "{} Epoch [{:03d}/{:03d}], with lr = {}, Step [{:04d}/{:04d}],\
                        [loss_record5: {:.4f}]".format(
                            datetime.now(),
                            epoch,
                            epoch,
                            self.optimizer.param_groups[0]["lr"],
                            i,
                            total_step,
                            loss_record5.show(),
                        ))

            if is_val:
                self.val(test_loader, epoch)

            os.makedirs(self.save_dir, exist_ok=True)
            if (epoch + 1) % 3 == 0 and epoch > self.save_from or epoch == 23:
                torch.save(
                    {
                        "model_state_dict": self.net.state_dict(),
                        "lr": self.optimizer.param_groups[0]["lr"],
                    },
                    os.path.join(self.save_dir,
                                 "PraNetDG-" + test_fold + "-%d.pth" % epoch),
                )
                self.logger.info(
                    "[Saving Snapshot:]" +
                    os.path.join(self.save_dir, "PraNetDG-" + test_fold +
                                 "-%d.pth" % epoch))

            self.scheduler.step()

        self.writer.flush()
        self.writer.close()
        end = timeit.default_timer()

        self.logger.info("Training cost: " + str(end - start) + "seconds")
示例#6
0
class Network(object):
    def __init__(self,
                 model_fn,
                 flags,
                 train_loader,
                 test_loader,
                 ckpt_dir=os.path.join(os.path.abspath(''), 'models'),
                 inference_mode=False,
                 saved_model=None):
        self.model_fn = model_fn  # The model maker function
        self.flags = flags  # The Flags containing the specs
        if inference_mode:  # If inference mode, use saved model
            if saved_model.startswith('models/'):
                saved_model = saved_model.replace('models/', '')
            self.ckpt_dir = os.path.join(ckpt_dir, saved_model)
            self.saved_model = saved_model
            print("This is inference mode, the ckpt is", self.ckpt_dir)
        else:  # training mode, create a new ckpt folder
            if flags.model_name is None:  # leave custume name if possible
                self.ckpt_dir = os.path.join(
                    ckpt_dir, time.strftime('%Y%m%d_%H%M%S', time.localtime()))
            else:
                self.ckpt_dir = os.path.join(ckpt_dir, flags.model_name)
        self.model = self.create_model()  # The model itself
        self.loss = self.make_loss()  # The loss function
        self.optm = None  # The optimizer: Initialized at train() due to GPU
        self.optm_eval = None  # The eval_optimizer: Initialized at eva() due to GPU
        self.lr_scheduler = None  # The lr scheduler: Initialized at train() due to GPU
        self.train_loader = train_loader  # The train data loader
        self.test_loader = test_loader  # The test data loader
        self.log = SummaryWriter(
            self.ckpt_dir
        )  # Create a summary writer for keeping the summary to the tensor board
        if not os.path.isdir(self.ckpt_dir) and not inference_mode:
            os.mkdir(self.ckpt_dir)
        self.best_validation_loss = float('inf')  # Set the BVL to large number

    def make_optimizer_eval(self, geometry_eval, optimizer_type=None):
        """
        The function to make the optimizer during evaluation time.
        The difference between optm is that it does not have regularization and it only optmize the self.geometr_eval tensor
        :return: the optimizer_eval
        """
        if optimizer_type is None:
            optimizer_type = self.flags.optim
        if optimizer_type == 'Adam':
            op = torch.optim.Adam([geometry_eval], lr=self.flags.lr)
        elif optimizer_type == 'RMSprop':
            op = torch.optim.RMSprop([geometry_eval], lr=self.flags.lr)
        elif optimizer_type == 'SGD':
            op = torch.optim.SGD([geometry_eval], lr=self.flags.lr)
        else:
            raise Exception(
                "Your Optimizer is neither Adam, RMSprop or SGD, please change in param or contact Ben"
            )
        return op

    def create_model(self):
        """
        Function to create the network module from provided model fn and flags
        :return: the created nn module
        """
        model = self.model_fn(self.flags)
        # summary(model, input_size=(128, 8))
        print(model)
        return model

    def make_loss(self,
                  logit=None,
                  labels=None,
                  G=None,
                  return_long=False,
                  epoch=None):
        """
        Create a tensor that represents the loss. This is consistant both at training time \
        and inference time for Backward model
        :param logit: The output of the network
        :param labels: The ground truth labels
        :param larger_BDY_penalty: For only filtering experiments, a larger BDY penalty is added
        :param return_long: The flag to return a long list of loss in stead of a single loss value,
                            This is for the forward filtering part to consider the loss
        :param pairwise: The addition of a pairwise loss in the loss term for the MD
        :return: the total loss
        """
        if logit is None:
            return None
        MSE_loss = nn.functional.mse_loss(logit, labels)  # The MSE Loss
        BDY_loss = 0
        MD_loss = 0
        if G is not None:  # This is using the boundary loss
            X_range, X_lower_bound, X_upper_bound = self.get_boundary_lower_bound_uper_bound(
            )
            X_mean = (X_lower_bound + X_upper_bound) / 2  # Get the mean
            relu = torch.nn.ReLU()
            BDY_loss_all = 1 * relu(
                torch.abs(G - self.build_tensor(X_mean)) -
                0.5 * self.build_tensor(X_range))
            BDY_loss = 0.1 * torch.sum(BDY_loss_all)
            #BDY_loss = self.flags.BDY_strength*torch.sum(BDY_loss_all)

        # Adding a pairwise MD loss for back propagation, it needs to be open as well as in the signified start and end epoch
        if self.flags.md_coeff > 0 and G is not None and epoch > self.flags.md_start and epoch < self.flags.md_end:
            pairwise_dist_mat = torch.cdist(
                G, G, p=2)  # Calculate the pairwise distance
            MD_loss = torch.mean(
                relu(-pairwise_dist_mat + self.flags.md_radius))
            MD_loss *= self.flags.md_coeff
            #print('MD_loss = ', MD_loss)
            #print('MSE loss = ', MSE_loss)

        self.MSE_loss = MSE_loss
        self.Boundary_loss = BDY_loss
        return torch.add(torch.add(MSE_loss, BDY_loss), MD_loss)

    def build_tensor(self, nparray, requires_grad=False):
        return torch.tensor(nparray,
                            requires_grad=requires_grad,
                            device='cuda',
                            dtype=torch.float)

    def make_optimizer(self, optimizer_type=None):
        """
        Make the corresponding optimizer from the flags. Only below optimizers are allowed. Welcome to add more
        :return:
        """
        # For eval mode to change to other optimizers
        if optimizer_type is None:
            optimizer_type = self.flags.optim
        if optimizer_type == 'Adam':
            op = torch.optim.Adam(self.model.parameters(),
                                  lr=self.flags.lr,
                                  weight_decay=self.flags.reg_scale)
        elif optimizer_type == 'RMSprop':
            op = torch.optim.RMSprop(self.model.parameters(),
                                     lr=self.flags.lr,
                                     weight_decay=self.flags.reg_scale)
        elif optimizer_type == 'SGD':
            op = torch.optim.SGD(self.model.parameters(),
                                 lr=self.flags.lr,
                                 weight_decay=self.flags.reg_scale)
        else:
            raise Exception(
                "Your Optimizer is neither Adam, RMSprop or SGD, please change in param or contact Ben"
            )
        return op

    def make_lr_scheduler(self, optm):
        """
        Make the learning rate scheduler as instructed. More modes can be added to this, current supported ones:
        1. ReduceLROnPlateau (decrease lr when validation error stops improving
        :return:
        """
        return lr_scheduler.ReduceLROnPlateau(optimizer=optm,
                                              mode='min',
                                              factor=self.flags.lr_decay_rate,
                                              patience=10,
                                              verbose=True,
                                              threshold=1e-4)

    def save(self):
        """
        Saving the model to the current check point folder with name best_model_forward.pt
        :return: None
        """
        #torch.save(self.model, os.path.join(self.ckpt_dir, 'best_model_forward.pt'))
        torch.save(self.model.state_dict(),
                   os.path.join(self.ckpt_dir, 'best_model.pt'))

    def load(self):
        """
        Loading the model from the check point folder with name best_model_forward.pt
        :return:
        """
        if torch.cuda.is_available():
            #self.model = torch.load(os.path.join(self.ckpt_dir, 'best_model_forward.pt'))
            self.model.load_state_dict(
                torch.load(os.path.join(self.ckpt_dir, 'best_model.pt')))
        else:
            #self.model = torch.load(os.path.join(self.ckpt_dir, 'best_model_forward.pt'), map_location=torch.device('cpu'))
            self.model.load_state_dict(
                torch.load(os.path.join(self.ckpt_dir, 'best_model.pt'),
                           map_location=torch.device('cpu')))

    def train(self):
        """
        The major training function. This would start the training using information given in the flags
        :return: None
        """

        pytorch_total_params = sum(p.numel() for p in self.model.parameters()
                                   if p.requires_grad)
        print("Total Number of Parameters: {}".format(pytorch_total_params))

        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.model.cuda()

        # Construct optimizer after the model moved to GPU
        self.optm = self.make_optimizer()
        self.lr_scheduler = self.make_lr_scheduler(self.optm)

        # Time keeping
        tk = time_keeper(
            time_keeping_file=os.path.join(self.ckpt_dir, 'training time.txt'))

        a = self.flags.train_step
        for epoch in range(self.flags.train_step):
            # Set to Training Mode
            train_loss = 0
            # boundary_loss = 0                 # Unnecessary during training since we provide geometries
            self.model.train()
            for j, (geometry, spectra) in enumerate(self.train_loader):
                if cuda:
                    geometry = geometry.cuda()  # Put data onto GPU
                    spectra = spectra.cuda()  # Put data onto GPU
                self.optm.zero_grad()  # Zero the gradient first
                logit = self.model(geometry)  # Get the output
                loss = self.make_loss(logit, spectra)  # Get the loss tensor
                loss.backward()  # Calculate the backward gradients
                self.optm.step()  # Move one step the optimizer
                train_loss += loss  # Aggregate the loss

            # Calculate the avg loss of training
            train_avg_loss = train_loss.cpu().data.numpy() / (j + 1)

            if epoch % self.flags.eval_step == 0:  # For eval steps, do the evaluations and tensor board
                # Record the training loss to the tensorboard
                self.log.add_scalar('Loss/train', train_avg_loss, epoch)
                # self.log.add_scalar('Loss/BDY_train', boundary_avg_loss, epoch)

                # Set to Evaluation Mode
                self.model.eval()
                #print("Doing Evaluation on the model now")
                test_loss = 0
                for j, (geometry, spectra) in enumerate(
                        self.test_loader):  # Loop through the eval set
                    if cuda:
                        geometry = geometry.cuda()
                        spectra = spectra.cuda()
                    logit = self.model(geometry)
                    loss = self.make_loss(logit, spectra)  # compute the loss
                    test_loss += loss  # Aggregate the loss

                # Record the testing loss to the tensorboard
                test_avg_loss = test_loss.cpu().data.numpy() / (j + 1)
                self.log.add_scalar('Loss/test', test_avg_loss, epoch)

                print("This is Epoch %d, training loss %.5f, validation loss %.5f" \
                      % (epoch, train_avg_loss, test_avg_loss ))

                # Model improving, save the model down
                if test_avg_loss < self.best_validation_loss:
                    self.best_validation_loss = test_avg_loss
                    self.save()
                    print("Saving the model down...")

                    if self.best_validation_loss < self.flags.stop_threshold:
                        print("Training finished EARLIER at epoch %d, reaching loss of %.5f" %\
                              (epoch, self.best_validation_loss))
                        break

            # Learning rate decay upon plateau
            self.lr_scheduler.step(train_avg_loss)
        self.log.close()
        tk.record(1)  # Record at the end of the training

    def validate_model(self,
                       save_dir='data/',
                       save_all=False,
                       MSE_Simulator=False,
                       save_misc=False,
                       save_Simulator_Ypred=True):
        """
        The function to evaluate how good the models is (outputs validation loss)
        Note that Ypred and Ytruth still refer to spectra, while Xpred and Xtruth still refer to geometries.
        #Assumes testloader was modified to be one big tensor
        :return:
        """

        self.load()  # load the model as constructed

        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.model.cuda()
        self.model.eval()
        saved_model_str = self.saved_model.replace('/', '_')
        # Get the file names
        Ypred_file = os.path.join(save_dir, 'test_Ypred_{}.csv'.format(
            saved_model_str))  #Input associated? No real value
        Xtruth_file = os.path.join(save_dir, 'test_Xtruth_{}.csv'.format(
            saved_model_str))  #Output to compare against
        Ytruth_file = os.path.join(
            save_dir,
            'test_Ytruth_{}.csv'.format(saved_model_str))  #Input of Neural Net
        Xpred_file = os.path.join(
            save_dir,
            'test_Xpred_{}.csv'.format(saved_model_str))  #Output of Neural Net
        print("evalution output pattern:", Ypred_file)

        # Time keeping
        tk = time_keeper(
            time_keeping_file=os.path.join(save_dir, 'evaluation_time.txt'))

        # Open those files to append
        with open(Xtruth_file,
                  'w') as fxt, open(Ytruth_file,
                                    'w') as fyt, open(Ypred_file, 'w') as fyp:

            # Loop through the eval data and evaluate
            geometry, spectra = next(iter(self.test_loader))

            if cuda:
                geometry = geometry.cuda()
                spectra = spectra.cuda()

            # Initialize the geometry first
            Ypred = self.model(geometry).cpu().data.numpy()
            Ytruth = spectra.cpu().data.numpy()

            MSE_List = np.mean(np.power(Ypred - Ytruth, 2), axis=1)
            mse = np.mean(MSE_List)
            print(mse)

            np.savetxt(fxt, geometry.cpu().data.numpy())
            np.savetxt(fyt, Ytruth)
            if self.flags.data_set != 'Yang':
                np.savetxt(fyp, Ypred)

        return Ypred_file, Ytruth_file

    def modulized_bp_ff(self,
                        X_init_mat,
                        Ytruth,
                        FF,
                        save_dir='data/',
                        save_all=True):
        """
        The "evaluation" function for the modulized backprop and forward filtering. It takes the X_init_mat as the different initializations of the X values and do evaluate function on that instead of taking evaluation data from the data loader
        :param X_init_mat: The input initialization of X positions, numpy array of shape (#init, #point, #xdim) usually (2048, 1000, xdim)
        :param Yturth: The Ytruth numpy array of shape (#point, #ydim)
        :param save_dir: The directory to save the results
        :param FF(forward_filtering): The flag to control whether use forward filtering or not
        """
        self.load()  # load the model as constructed
        try:
            bs = self.flags.backprop_step  # for previous code that did not incorporate this
        except AttributeError:
            print(
                "There is no attribute backprop_step, catched error and adding this now"
            )
            self.flags.backprop_step = 300
        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.model.cuda()
        self.model.eval()
        saved_model_str = self.saved_model.replace('/', '_')

        # Prepare Ytruth into tensor
        Yt = self.build_tensor(Ytruth, requires_grad=False)
        print("shape of Yt in modulized bp ff is:", Yt.size())
        print("shape of the X_init_mat is:", np.shape(X_init_mat))
        # Loop through #points
        for ind in range(np.shape(X_init_mat)[1]):
            Xpred, Ypred, loss = self.evaluate_one(
                Yt[ind, :],
                save_dir=save_dir,
                save_all=save_all,
                ind=ind,
                init_from_Xpred=X_init_mat[:, ind, :],
                FF=FF)
        return None

    def evaluate(self,
                 save_dir='data/',
                 save_all=False,
                 MSE_Simulator=False,
                 save_misc=False,
                 save_Simulator_Ypred=True,
                 noise_level=0):
        """
        The function to evaluate how good the Neural Adjoint is and output results
        :param save_dir: The directory to save the results
        :param save_all: Save all the results instead of the best one (T_200 is the top 200 ones)
        :param MSE_Simulator: Use simulator loss to sort (DO NOT ENABLE THIS, THIS IS OK ONLY IF YOUR APPLICATION IS FAST VERIFYING)
        :param save_misc: save all the details that are probably useless
        :param save_Simulator_Ypred: Save the Ypred that the Simulator gives
        (This is useful as it gives us the true Ypred instead of the Ypred that the network "thinks" it gets, which is
        usually inaccurate due to forward model error)
        :return:
        """
        self.load()  # load the model as constructed
        try:
            bs = self.flags.backprop_step  # for previous code that did not incorporate this
        except AttributeError:
            print(
                "There is no attribute backprop_step, catched error and adding this now"
            )
            self.flags.backprop_step = 300
        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.model.cuda()
        self.model.eval()
        saved_model_str = self.saved_model.replace('/', '_')
        # Get the file names
        Ypred_file = os.path.join(save_dir,
                                  'test_Ypred_{}.csv'.format(saved_model_str))
        Xtruth_file = os.path.join(
            save_dir, 'test_Xtruth_{}.csv'.format(saved_model_str))
        Ytruth_file = os.path.join(
            save_dir, 'test_Ytruth_{}.csv'.format(saved_model_str))
        Xpred_file = os.path.join(save_dir,
                                  'test_Xpred_{}.csv'.format(saved_model_str))
        print("evalution output pattern:", Ypred_file)

        # Time keeping
        tk = time_keeper(
            time_keeping_file=os.path.join(save_dir, 'evaluation_time.txt'))

        # Open those files to append
        with open(Xtruth_file, 'a') as fxt,open(Ytruth_file, 'a') as fyt,\
                open(Ypred_file, 'a') as fyp, open(Xpred_file, 'a') as fxp:
            # Loop through the eval data and evaluate
            for ind, (geometry, spectra) in enumerate(self.test_loader):

                if cuda:
                    geometry = geometry.cuda()
                    spectra = spectra.cuda()
                # Initialize the geometry first
                Xpred, Ypred, loss = self.evaluate_one(
                    spectra,
                    save_dir=save_dir,
                    save_all=save_all,
                    ind=ind,
                    MSE_Simulator=MSE_Simulator,
                    save_misc=save_misc,
                    save_Simulator_Ypred=save_Simulator_Ypred,
                    noise_level=noise_level)
                tk.record(
                    ind)  # Keep the time after each evaluation for backprop
                # self.plot_histogram(loss, ind)                                # Debugging purposes
                np.savetxt(fxt, geometry.cpu().data.numpy())
                np.savetxt(fyt, spectra.cpu().data.numpy())
                if 'Yang' not in self.flags.data_set:
                    np.savetxt(fyp, Ypred)
                np.savetxt(fxp, Xpred)

        return Ypred_file, Ytruth_file

    def evaluate_one(self,
                     target_spectra,
                     save_dir='data/',
                     MSE_Simulator=False,
                     save_all=False,
                     ind=None,
                     save_misc=False,
                     save_Simulator_Ypred=True,
                     init_from_Xpred=None,
                     FF=True,
                     save_MSE_each_epoch=False,
                     noise_level=0):
        """
        The function which being called during evaluation and evaluates one target y using # different trails
        :param target_spectra: The target spectra/y to backprop to 
        :param save_dir: The directory to save to when save_all flag is true
        :param MSE_Simulator: Use Simulator Loss to get the best instead of the default NN output logit
        :param save_all: The multi_evaluation where each trail is monitored (instad of the best) during backpropagation
        :param ind: The index of this target_spectra in the batch
        :param save_misc: The flag to print misc information for degbugging purposes, usually printed to best_mse
        :param noise_level: For datasets that need extra level of exploration, we add some gaussian noise to the resulting geometry
        :return: Xpred_best: The 1 single best Xpred corresponds to the best Ypred that is being backproped 
        :return: Ypred_best: The 1 singe best Ypred that is reached by backprop
        :return: MSE_list: The list of MSE at the last stage
        :param FF(forward_filtering): [default to be true for historical reason] The flag to control whether use forward filtering or not
        """

        # Initialize the geometry_eval or the initial guess xs
        geometry_eval = self.initialize_geometry_eval(init_from_Xpred)
        # Set up the learning schedule and optimizer
        self.optm_eval = self.make_optimizer_eval(geometry_eval)
        self.lr_scheduler = self.make_lr_scheduler(self.optm_eval)

        # expand the target spectra to eval batch size
        target_spectra_expand = target_spectra.expand(
            [self.flags.eval_batch_size, -1])

        # # Extra for early stopping
        loss_list = []
        # end_lr = self.flags.lr / 8
        # print(self.optm_eval)
        # param_group_1 = self.optm_eval.param_groups[0]
        # if self.flags.data_set == 'Chen':
        #     stop_threshold = 1e-4
        # elif self.flags.data_set == 'Peurifoy':
        #     stop_threshold = 1e-3
        # else:
        #     stop_threshold = 1e-3

        # Begin NA
        begin = time.time()
        for i in range(self.flags.backprop_step):
            # Make the initialization from [-1, 1], can only be in loop due to gradient calculator constraint
            if init_from_Xpred is None:
                geometry_eval_input = self.initialize_from_uniform_to_dataset_distrib(
                    geometry_eval)
            else:
                geometry_eval_input = geometry_eval
            #if save_misc and ind == 0 and i == 0:                       # save the modified initial guess to verify distribution
            #    np.savetxt('geometry_initialization.csv',geometry_eval_input.cpu().data.numpy())
            self.optm_eval.zero_grad()  # Zero the gradient first
            logit = self.model(geometry_eval_input)  # Get the output
            ###################################################
            # Boundar loss controled here: with Boundary Loss #
            ###################################################
            loss = self.make_loss(logit,
                                  target_spectra_expand,
                                  G=geometry_eval_input,
                                  epoch=i)  # Get the loss
            loss.backward()  # Calculate the Gradient
            # update weights and learning rate scheduler
            self.optm_eval.step()  # Move one step the optimizer
            loss_np = loss.data
            self.lr_scheduler.step(loss_np)
            # Extra step of recording the MSE loss of each epoch
            #loss_list.append(np.copy(loss_np.cpu()))
            # Comment the below 2 for maximum performance
            #if loss_np < stop_threshold or param_group_1['lr'] < end_lr:
            #    break;
        if save_MSE_each_epoch:
            with open(
                    'data/{}_MSE_progress_point_{}.txt'.format(
                        self.flags.data_set, ind), 'a') as epoch_file:
                np.savetxt(epoch_file, loss_list)

        if save_all:  # If saving all the results together instead of the first one
            mse_loss = np.reshape(
                np.sum(np.square(logit.cpu().data.numpy() -
                                 target_spectra_expand.cpu().data.numpy()),
                       axis=1), [-1, 1])
            BDY_loss = self.get_boundary_loss_list_np(
                geometry_eval_input.cpu().data.numpy())
            BDY_strength = 0.5
            mse_loss += BDY_strength * np.reshape(BDY_loss, [-1, 1])
            # The strategy of re-using the BPed result. Save two versions of file: one with FF and one without
            mse_loss = np.concatenate(
                (mse_loss,
                 np.reshape(np.arange(self.flags.eval_batch_size), [-1, 1])),
                axis=1)
            loss_sort = mse_loss[mse_loss[:, 0].argsort(
                kind='mergesort')]  # Sort the loss list
            loss_sort_FF_off = mse_loss
            exclude_top = 0
            trail_nums = 200
            good_index = loss_sort[exclude_top:trail_nums + exclude_top,
                                   1].astype('int')  # Get the indexs
            good_index_FF_off = loss_sort_FF_off[exclude_top:trail_nums +
                                                 exclude_top, 1].astype(
                                                     'int')  # Get the indexs
            #print("In save all funciton, the top 10 index is:", good_index[:10])
            if init_from_Xpred is None:
                saved_model_str = self.saved_model.replace(
                    '/', '_') + 'inference' + str(ind)
            else:
                saved_model_str = self.saved_model.replace(
                    '/', '_') + 'modulized_inference' + str(ind)
            # Adding some random noise to the result
            #print("Adding random noise to the output for increasing the diversity!!")
            geometry_eval_input += torch.randn_like(
                geometry_eval_input) * noise_level

            Ypred_file = os.path.join(
                save_dir, 'test_Ypred_point{}.csv'.format(saved_model_str))
            Yfake_file = os.path.join(
                save_dir, 'test_Yfake_point{}.csv'.format(saved_model_str))
            Xpred_file = os.path.join(
                save_dir, 'test_Xpred_point{}.csv'.format(saved_model_str))
            if 'Yang' not in self.flags.data_set:  # This is for meta-meterial dataset, since it does not have a simple simulator
                # 2 options: simulator/logit
                Ypred = simulator(
                    self.flags.data_set,
                    geometry_eval_input.cpu().data.numpy()[good_index, :])
                with open(Xpred_file, 'a') as fxp, open(Ypred_file,
                                                        'a') as fyp:
                    np.savetxt(fyp, Ypred)
                    np.savetxt(
                        fxp,
                        geometry_eval_input.cpu().data.numpy()[good_index, :])
            else:
                with open(Xpred_file, 'a') as fxp:
                    np.savetxt(
                        fxp,
                        geometry_eval_input.cpu().data.numpy()[good_index, :])

        ###################################
        # From candidates choose the best #
        ###################################
        Ypred = logit.cpu().data.numpy()
        # calculate the MSE list and get the best one
        MSE_list = np.mean(np.square(Ypred -
                                     target_spectra_expand.cpu().data.numpy()),
                           axis=1)
        BDY_list = self.get_boundary_loss_list_np(
            geometry_eval_input.cpu().data.numpy())
        MSE_list += BDY_list
        best_estimate_index = np.argmin(MSE_list)
        #print("The best performing one is:", best_estimate_index)
        Xpred_best = np.reshape(
            np.copy(geometry_eval_input.cpu().data.numpy()[
                best_estimate_index, :]), [1, -1])
        if save_Simulator_Ypred and self.flags.data_set != 'Yang':
            begin = time.time()
            Ypred = simulator(self.flags.data_set,
                              geometry_eval_input.cpu().data.numpy())
            #print("SIMULATOR: ",time.time()-begin)
            if len(
                    np.shape(Ypred)
            ) == 1:  # If this is the ballistics dataset where it only has 1d y'
                Ypred = np.reshape(Ypred, [-1, 1])
        Ypred_best = np.reshape(np.copy(Ypred[best_estimate_index, :]),
                                [1, -1])

        return Xpred_best, Ypred_best, MSE_list

    def get_boundary_loss_list_np(self, Xpred):
        """
        Return the boundary loss in the form of numpy array
        :param Xpred: input numpy array of prediction
        """
        X_range, X_lower_bound, X_upper_bound = self.get_boundary_lower_bound_uper_bound(
        )
        X_mean = (X_lower_bound + X_upper_bound) / 2  # Get the mean
        BDY_loss = np.mean(np.maximum(0,
                                      np.abs(Xpred - X_mean) - 0.5 * X_range),
                           axis=1)
        return BDY_loss

    def initialize_geometry_eval(self, init_from_Xpred):
        """
        Initialize the geometry eval according to different dataset. These 2 need different handling
        :param init_from_Xpred: Initiallize from Xpred file, this is for modulized trails
        :return: The initialized geometry eval

        """
        if init_from_Xpred is not None:
            geometry_eval = self.build_tensor(init_from_Xpred,
                                              requires_grad=True)
        else:
            geometry_eval = torch.rand(
                [self.flags.eval_batch_size, self.flags.linear[0]],
                requires_grad=True,
                device='cuda')
        #geomtry_eval = torch.randn([self.flags.eval_batch_size, self.flags.linear[0]], requires_grad=True, device='cuda')
        return geometry_eval

    def initialize_from_uniform_to_dataset_distrib(self, geometry_eval):
        """
        since the initialization of the backprop is uniform from [0,1], this function transforms that distribution
        to suitable prior distribution for each dataset. The numbers are accquired from statistics of min and max
        of the X prior given in the training set and data generation process
        :param geometry_eval: The input uniform distribution from [0,1]
        :return: The transformed initial guess from prior distribution
        """
        X_range, X_lower_bound, X_upper_bound = self.get_boundary_lower_bound_uper_bound(
        )
        geometry_eval_input = geometry_eval * self.build_tensor(
            X_range) + self.build_tensor(X_lower_bound)
        return geometry_eval_input
        #return geometry_eval

    def get_boundary_lower_bound_uper_bound(self):
        """
        Due to the fact that the batched dataset is a random subset of the training set, mean and range would fluctuate.
        Therefore we pre-calculate the mean, lower boundary and upper boundary to avoid that fluctuation. Replace the
        mean and bound of your dataset here
        :return:
        """
        if self.flags.data_set == 'Chen':
            dim = 5
        elif self.flags.data_set == 'Peurifoy':
            dim = 8
        elif self.flags.data_set == 'Yang_sim':
            dim = 14
        else:
            sys.exit(
                "In Tandem, during getting the boundary loss boundaries, Your data_set entry is not correct, check again!"
            )

        return np.array([2 for i in range(dim)
                         ]), np.array([-1 for i in range(dim)
                                       ]), np.array([1 in range(dim)])

    def predict(self, Xpred_file, no_save=False, load_state_dict=None):
        """
        The prediction function, takes Xpred file and write Ypred file using trained model
        :param Xpred_file: Xpred file by (usually VAE) for meta-material
        :param no_save: do not save the txt file but return the np array
        :param load_state_dict: If None, load model using self.load() (default way), If a dir, load state_dict from that dir
        :return: pred_file, truth_file to compare
        """
        print("entering predict function")
        if load_state_dict is None:
            self.load()  # load the model in the usual way
        else:
            self.model.load_state_dict(torch.load(load_state_dict))

        Ypred_file = Xpred_file.replace('Xpred', 'Ypred')
        Ytruth_file = Ypred_file.replace('Ypred', 'Ytruth')
        Xpred = pd.read_csv(Xpred_file, header=None,
                            delimiter=',')  # Read the input
        if len(Xpred.columns
               ) == 1:  # The file is not delimitered by ',' but ' '
            Xpred = pd.read_csv(Xpred_file, header=None, delimiter=' ')
        Xpred.info()
        print(Xpred.head())
        print("Xpred shape", np.shape(Xpred.values))
        Xpred_tensor = torch.from_numpy(Xpred.values).to(torch.float)
        cuda = True if torch.cuda.is_available() else False
        # Put into evaluation mode
        self.model.eval()
        if cuda:
            self.model.cuda()
        # Get small chunks for the evaluation
        chunk_size = 1000
        Ypred_mat = np.zeros([len(Xpred_tensor), 2000])
        for i in range(int(np.floor(len(Xpred_tensor) / chunk_size))):
            Xpred = Xpred_tensor[i * chunk_size:(i + 1) * chunk_size, :]
            if cuda:
                Xpred = Xpred.cuda()
            Ypred = self.model(Xpred).cpu().data.numpy()
            Ypred_mat[i * chunk_size:(i + 1) * chunk_size, :] = Ypred
        if load_state_dict is not None:
            Ypred_file = Ypred_file.replace('Ypred',
                                            'Ypred' + load_state_dict[-7:-4])
        elif self.flags.model_name is not None:
            Ypred_file = Ypred_file.replace('Ypred',
                                            'Ypred' + self.flags.model_name)
        if no_save:  # If instructed dont save the file and return the array
            return Ypred_mat, Ytruth_file
        np.savetxt(Ypred_file, Ypred_mat)

        return Ypred_file, Ytruth_file

    def plot_histogram(self, loss, ind):
        """
        Plot the loss histogram to see the loss distribution
        """
        f = plt.figure()
        plt.hist(loss, bins=100)
        plt.xlabel('MSE loss')
        plt.ylabel('cnt')
        plt.suptitle('(Avg MSE={:4e})'.format(np.mean(loss)))
        plt.savefig(os.path.join('data', 'loss{}.png'.format(ind)))
        return None

    def predict_inverse(self,
                        Ytruth_file,
                        multi_flag,
                        save_dir='data/',
                        prefix=''):
        self.load()  # load the model as constructed
        cuda = True if torch.cuda.is_available() else False
        if cuda:
            self.model.cuda()
        self.model.eval()
        saved_model_str = self.saved_model.replace('/', '_') + prefix

        Ytruth = pd.read_csv(Ytruth_file, header=None,
                             delimiter=',')  # Read the input
        if len(Ytruth.columns
               ) == 1:  # The file is not delimitered by ',' but ' '
            Ytruth = pd.read_csv(Ytruth_file, header=None, delimiter=' ')
        Ytruth_tensor = torch.from_numpy(Ytruth.values).to(torch.float)
        print('shape of Ytruth tensor :', Ytruth_tensor.shape)

        # Get the file names
        Ypred_file = os.path.join(save_dir,
                                  'test_Ypred_{}.csv'.format(saved_model_str))
        Ytruth_file = os.path.join(
            save_dir, 'test_Ytruth_{}.csv'.format(saved_model_str))
        Xpred_file = os.path.join(save_dir,
                                  'test_Xpred_{}.csv'.format(saved_model_str))
        # keep time
        tk = time_keeper(os.path.join(save_dir, 'evaluation_time.txt'))

        # Set the save_simulator_ytruth
        save_Simulator_Ypred = True
        if 'Yang' in self.flags.data_set:
            save_Simulator_Ypred = False

        if cuda:
            Ytruth_tensor = Ytruth_tensor.cuda()
        print('model in eval:', self.model)

        # Open those files to append
        with open(Ytruth_file,
                  'a') as fyt, open(Ypred_file,
                                    'a') as fyp, open(Xpred_file, 'a') as fxp:
            np.savetxt(fyt, Ytruth_tensor.cpu().data.numpy())
            for ind in range(len(Ytruth_tensor)):
                spectra = Ytruth_tensor[ind, :]
                Xpred, Ypred, loss = self.evaluate_one(
                    spectra,
                    save_dir=save_dir,
                    save_all=multi_flag,
                    ind=ind,
                    MSE_Simulator=False,
                    save_misc=False,
                    save_Simulator_Ypred=save_Simulator_Ypred)

                np.savetxt(fxp, Xpred)
                if self.flags.data_set != 'Yang_sim':
                    Ypred = simulator(self.flags.data_set, Xpred)
                    np.savetxt(fyp, Ypred)
                tk.record(1)
        return Ypred_file, Ytruth_file
示例#7
0
def train():
    """
    Main SNLI training loop.
    """
    cfg = get_args()

    # overwrite save_path or warn to specify another path
    if os.path.exists(cfg.save_path):
        if cfg.overwrite:
            shutil.rmtree(cfg.save_path)
        else:
            raise RuntimeError(
                "save_path already exists; specify a different path")

    makedirs(cfg.save_path)

    device = get_device()
    print("device:", device)

    writer = SummaryWriter(log_dir=cfg.save_path)  # TensorBoard

    print("Loading data... ", end="")
    glove_words = load_glove_words(cfg.word_vectors)
    input_field, label_field, not_in_glove = get_data_fields(glove_words)
    train_data, dev_data, test_data = SNLI.splits(input_field, label_field)
    print("Done")

    print("First train sentence:",
          "[prem]: " + " ".join(train_data[0].premise),
          "[hypo]: " + " ".join(train_data[0].hypothesis),
          "[lab]:  " + train_data[0].label,
          sep="\n",
          end="\n\n")

    # build vocabularies
    std = 1.
    input_field.build_vocab(train_data,
                            dev_data,
                            test_data,
                            unk_init=lambda x: x.normal_(mean=0, std=std),
                            vectors=cfg.word_vectors,
                            vectors_cache=None)
    label_field.build_vocab(train_data)

    print("Words not in glove:", len(not_in_glove))

    cfg.n_embed = len(input_field.vocab)
    cfg.output_size = len(label_field.vocab)
    cfg.n_cells = cfg.n_layers
    cfg.pad_idx = input_field.vocab.stoi[PAD_TOKEN]
    cfg.unk_idx = input_field.vocab.stoi[UNK_TOKEN]
    cfg.init_idx = input_field.vocab.stoi[INIT_TOKEN]

    # normalize word embeddings (each word embedding has L2 norm of 1.)
    if cfg.normalize_embeddings:
        with torch.no_grad():
            input_field.vocab.vectors /= input_field.vocab.vectors.norm(
                2, dim=-1, keepdim=True)

    # zero out padding
    with torch.no_grad():
        input_field.vocab.vectors[cfg.pad_idx].zero_()

    # save vocabulary (not really needed but could be useful)
    with open(os.path.join(cfg.save_path, "vocab.txt"),
              mode="w",
              encoding="utf-8") as f:
        for t in input_field.vocab.itos:
            f.write(t + "\n")

    train_iter, dev_iter, test_iter = data.BucketIterator.splits(
        (train_data, dev_data, test_data),
        batch_size=cfg.batch_size,
        device=device)

    print_config(cfg)

    # double the number of cells for bidirectional networks
    if cfg.birnn:
        cfg.n_cells *= 2

    if cfg.resume_snapshot:
        ckpt = torch.load(cfg.resume_snapshot, map_location=device)
        cfg = ckpt["cfg"]
        model_state = ckpt["model"]

    # build model
    model = build_model(cfg, input_field.vocab)

    if cfg.resume_snapshot:
        model.load_state_dict(model_state)

    # load Glove word vectors
    if cfg.word_vectors:
        with torch.no_grad():
            model.embed.weight.data.copy_(input_field.vocab.vectors)

    model.to(device)

    print_parameters(model)
    print(model)

    trainable_parameters = list(
        filter(lambda p: p.requires_grad, model.parameters()))
    opt = Adam(trainable_parameters, lr=cfg.lr, weight_decay=cfg.weight_decay)

    scheduler = ReduceLROnPlateau(opt,
                                  "max",
                                  patience=cfg.patience,
                                  factor=cfg.lr_decay,
                                  min_lr=cfg.min_lr,
                                  verbose=True)

    if cfg.eval_every == -1:
        cfg.eval_every = int(np.ceil(len(train_data) / cfg.batch_size))
        print("Eval every: %d" % cfg.eval_every)

    iterations = 0
    start = time.time()
    best_dev_acc = -1
    train_iter.repeat = False

    for epoch in range(cfg.epochs):
        train_iter.init_epoch()
        n_correct, n_total = 0, 0
        for batch_idx, batch in enumerate(train_iter):

            # switch model to training mode, clear gradient accumulators
            model.train()
            opt.zero_grad()

            iterations += 1

            # forward pass
            output = model(batch)

            # calculate accuracy of predictions in the current batch
            n_correct += get_n_correct(batch, output)
            n_total += batch.batch_size
            train_acc = 100. * n_correct / n_total

            # calculate loss of the network output with respect to train labels
            loss, optional = model.get_loss(output, batch.label)

            # backpropagate and update optimizer learning rate
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           cfg.max_grad_norm)
            opt.step()

            # checkpoint model periodically
            if iterations % cfg.save_every == 0:
                ckpt = {
                    "model": model.state_dict(),
                    "cfg": cfg,
                    "iterations": iterations,
                    "epoch": epoch,
                    "best_dev_acc": best_dev_acc,
                    "optimizer": opt.state_dict()
                }
                save_checkpoint(ckpt,
                                cfg.save_path,
                                iterations,
                                delete_old=True)

            # print progress message
            if iterations % cfg.print_every == 0:
                writer.add_scalar('train/loss', loss.item(), iterations)
                writer.add_scalar('train/acc', train_acc, iterations)
                for k, v in optional.items():
                    writer.add_scalar('train/' + k, v, iterations)

                opt_s = make_kv_string(optional)
                elapsed = int(time.time() - start)
                print("{:02d}:{:02d}:{:02d} epoch {:03d} "
                      "iter {:08d} loss {:.4f} {}".format(
                          elapsed // 3600, elapsed % 3600 // 60, elapsed % 60,
                          epoch, iterations, loss.item(), opt_s))

            # evaluate performance on validation set periodically
            if iterations % cfg.eval_every == 0:

                # switch model to evaluation mode
                model.eval()
                dev_iter.init_epoch()
                test_iter.init_epoch()

                # calculate accuracy on validation set
                dev_eval = evaluate(model, model.criterion, dev_iter)
                for k, v in dev_eval.items():
                    writer.add_scalar('dev/%s' % k, v, iterations)

                dev_eval_str = make_kv_string(dev_eval)
                print("# Evaluation dev : epoch {:2d} iter {:08d} {}".format(
                    epoch, iterations, dev_eval_str))

                # calculate accuracy on test set
                test_eval = evaluate(model, model.criterion, test_iter)
                for k, v in test_eval.items():
                    writer.add_scalar('test/%s' % k, v, iterations)

                test_eval_str = make_kv_string(test_eval)
                print("# Evaluation test: epoch {:2d} iter {:08d} {}".format(
                    epoch, iterations, test_eval_str))

                # update learning rate scheduler
                if isinstance(scheduler, ExponentialLR):
                    scheduler.step()
                else:
                    scheduler.step(dev_eval["acc"])

                # update best validation set accuracy
                if dev_eval["acc"] > best_dev_acc:

                    for k, v in dev_eval.items():
                        writer.add_scalar('best/dev/%s' % k, v, iterations)

                    for k, v in test_eval.items():
                        writer.add_scalar('best/test/%s' % k, v, iterations)

                    print("# New highscore {} iter {}".format(
                        dev_eval["acc"], iterations))

                    # print examples for highscore
                    dev_iter.init_epoch()
                    print_examples(model,
                                   dev_iter,
                                   input_field.vocab,
                                   label_field.vocab,
                                   cfg.save_path,
                                   iterations,
                                   n=5,
                                   writer=writer)

                    # found a model with better validation set accuracy
                    best_dev_acc = dev_eval["acc"]

                    # save model, delete previous 'best_*' files
                    ckpt = {
                        "model": model.state_dict(),
                        "cfg": cfg,
                        "iterations": iterations,
                        "epoch": epoch,
                        "best_dev_acc": best_dev_acc,
                        "best_test_acc": test_eval["acc"],
                        "optimizer": opt.state_dict()
                    }
                    save_checkpoint(ckpt,
                                    cfg.save_path,
                                    iterations,
                                    prefix="best_ckpt",
                                    dev_acc=dev_eval["acc"],
                                    test_acc=test_eval["acc"],
                                    delete_old=True)

                if opt.param_groups[0]["lr"] < cfg.stop_lr_threshold:
                    print("Learning rate too low, stopping")
                    writer.close()
                    exit()

    writer.close()
示例#8
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 1
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    # Added here for reproductibility
    set_seed(args)

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if args.model_type in [
                    "xlm", "roberta", "distilbert", "camembert"
            ]:
                del inputs["token_type_ids"]

            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if args.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
                if hasattr(model, "config") and hasattr(
                        model.config, "lang2id"):
                    inputs.update({
                        "langs":
                        (torch.ones(batch[0].shape, dtype=torch.int64) *
                         args.lang_id).to(args.device)
                    })

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel (not distributed) training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                # Save model checkpoint
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    # Take care of distributed/parallel training
                    model_to_save = model.module if hasattr(
                        model, "module") else model
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
示例#9
0
def main():
    # Command line options
    args = parser.parse_args()
    print("Command line options:")
    for arg in vars(args):
        print(arg, getattr(args, arg))

    # import the correct loss and training functions depending which model to optimize
    # TODO: these could easily be refactored into one function, but we kept it this way for modularity
    if args.train_var:
        if args.joint:
            from lib.Training.train import train_var_joint as train
            from lib.Training.validate import validate_var_joint as validate
            from lib.Training.loss_functions import var_loss_function_joint as criterion
        else:
            from lib.Training.train import train_var as train
            from lib.Training.validate import validate_var as validate
            from lib.Training.loss_functions import var_loss_function as criterion
    else:
        if args.joint:
            from lib.Training.train import train_joint as train
            from lib.Training.validate import validate_joint as validate
            from lib.Training.loss_functions import loss_function_joint as criterion
        else:
            from lib.Training.train import train as train
            from lib.Training.validate import validate as validate
            from lib.Training.loss_functions import loss_function as criterion

    # Check whether GPU is available and can be used
    # if CUDA is found then device is set accordingly
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Launch a writer for the tensorboard summary writer instance
    save_path = 'runs/' + strftime("%Y-%m-%d_%H-%M-%S", gmtime()) + '_' + args.dataset + '_' + args.architecture +\
                '_dropout_' + str(args.dropout)

    if args.train_var:
        save_path += '_variational_samples_' + str(
            args.var_samples) + '_latent_dim_' + str(args.var_latent_dim)

    if args.joint:
        save_path += '_joint'

    # if we are resuming a previous training, note it in the name
    if args.resume:
        save_path = save_path + '_resumed'
    writer = SummaryWriter(save_path)

    # saving the parsed args to file
    log_file = os.path.join(save_path, "stdout")
    log = open(log_file, "a")
    for arg in vars(args):
        log.write(arg + ':' + str(getattr(args, arg)) + '\n')

    # Dataset loading
    data_init_method = getattr(datasets, args.dataset)
    dataset = data_init_method(torch.cuda.is_available(), args)
    # get the number of classes from the class dictionary
    num_classes = dataset.num_classes

    # add command line options to TensorBoard
    args_to_tensorboard(writer, args)

    log.close()

    # Get a sample input from the data loader to infer color channels/size
    net_input, _ = next(iter(dataset.train_loader))
    # get the amount of color channels in the input images
    num_colors = net_input.size(1)

    # import model from architectures class
    net_init_method = getattr(architectures, args.architecture)

    # build the model
    model = net_init_method(device, num_classes, num_colors, args)

    # Parallel container for multi GPU use and cast to available device
    model = torch.nn.DataParallel(model).to(device)
    print(model)

    # Initialize the weights of the model, by default according to He et al.
    print("Initializing network with: " + args.weight_init)
    WeightInitializer = WeightInit(args.weight_init)
    WeightInitializer.init_model(model)

    # Define optimizer and loss function (criterion)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    epoch = 0
    best_prec = 0
    best_loss = random.getrandbits(128)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            epoch = checkpoint['epoch']
            best_prec = checkpoint['best_prec']
            best_loss = checkpoint['best_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # optimize until final amount of epochs is reached.
    while epoch < args.epochs:
        # train
        train(dataset, model, criterion, epoch, optimizer, writer, device,
              args)

        # evaluate on validation set
        prec, loss = validate(dataset, model, criterion, epoch, writer, device,
                              args)

        # remember best prec@1 and save checkpoint
        is_best = loss < best_loss
        best_loss = min(loss, best_loss)
        best_prec = max(prec, best_prec)
        save_checkpoint(
            {
                'epoch': epoch,
                'arch': args.architecture,
                'state_dict': model.state_dict(),
                'best_prec': best_prec,
                'best_loss': best_loss,
                'optimizer': optimizer.state_dict()
            }, is_best, save_path)

        # increment epoch counters
        epoch += 1

    writer.close()
示例#10
0
class RunManager():
    """
    Description:
        Class for logging the results of one run to a tensorboard.
    Arguments:
        -path[string]:      The path to save the tensorboard data to.
    """
    
    def __init__(self, path, run, model, example_data, run_count):
        #save the variables
        self.path = path
        self.run = run

        """
        Epoch Variables
        """
        self.epoch_count = -1
        self.epoch_train_loss = 0
        self.epoch_test_loss = 0
        self.epoch_train_num_correct = 0
        self.epoch_test_num_correct = 0
        self.epoch_test_choices = torch.Tensor()

        """
        Run Variables
        """
        self.run_count = run_count
        self.run_best_test_accuracy = 0
        self.run_best_specific_profit_stability = 0

        #create the tb for the run and save the graph of the network
        run_dict = dataclasses.asdict(run)
        del run_dict["features"]
        del run_dict["epochs"]
        del run_dict["test_percentage"]
        run_string = json.dumps(run_dict).replace('"', "").replace(":", "=").replace(" ", "")
        directory = f"{self.path}/Run{self.run_count}{run_string}"
        self.tb = SummaryWriter(log_dir=directory)
        self.tb.add_graph(model, input_to_model=example_data)

        #save the directory
        self.log_directory = f"{self.path}/Run{self.run_count}{run_string}"

    def end_run(self):
        #save the hyperparameters
        metrics = {
            "ZMax Test Accuracy": self.run_best_test_accuracy,
            "ZMax Specific Profit Stability": self.run_best_specific_profit_stability
        }
        HPs = dataclasses.asdict(self.run)
        HPs["features"] = str(HPs["features"])
        self.tb.add_hparams(hparam_dict=HPs, metric_dict=metrics)

        #close the tensorboard
        self.tb.close()

    def begin_epoch(self):
        #update epoch count
        self.epoch_count += 1

        #reset the variables
        self.epoch_train_loss = 0
        self.epoch_test_loss = 0
        self.epoch_train_num_correct = 0
        self.epoch_test_num_correct = 0
        self.epoch_test_choices = torch.Tensor()

    def log_training(self, num_train_samples):
        #calculate the metrics
        loss = (self.epoch_train_loss/num_train_samples)*100
        accuracy = (self.epoch_train_num_correct/num_train_samples)*100

        #add the metrics to the tensorboard
        self.tb.add_scalar('Train Loss', loss, self.epoch_count)
        self.tb.add_scalar('Train Accuracy', accuracy, self.epoch_count)

    def track_train_metrics(self, loss, preds, labels):
        #track train loss
        self.epoch_train_loss += loss

        #track train num correct
        self.epoch_train_num_correct += self._get_num_correct(preds, labels)

    def log_testing(self, num_test_samples, performance_data=None, trading_activity_interval=(500,560)): 
        #calculate the metrics
        loss = (self.epoch_test_loss/num_test_samples)*100
        accuracy = (self.epoch_test_num_correct/num_test_samples) * 100
        specific_profit = performance_data["specific_profit"]
        specific_profit_rate = performance_data["specific_profit_rate"]
        specific_profit_stability = performance_data["specific_profit_stability"]

        #update best variables
        self.run_best_test_accuracy = accuracy if (self.run_best_test_accuracy < accuracy) else self.run_best_test_accuracy
        self.run_best_specific_profit_stability = specific_profit_stability if (self.run_best_specific_profit_stability < specific_profit_stability) else self.run_best_specific_profit_stability

        #add the metrics to the tensorboard
        self.tb.add_scalar('Test Loss', loss, self.epoch_count)
        self.tb.add_scalar('Test Accuracy', accuracy, self.epoch_count)
        self.tb.add_histogram("Choices", self.epoch_test_choices, self.epoch_count)

        #add the performance data
        self.tb.add_scalar('Specific Profit', specific_profit, self.epoch_count)
        self.tb.add_scalar('Specific Profit Rate', specific_profit_rate, self.epoch_count)
        self.tb.add_scalar('Specific Profit Stability', specific_profit_stability, self.epoch_count)

        """
        Plots
        """
        #get the interval infos
        interval_info = performance_data["interval_info"]
        #get the trading frame
        trading_frame = performance_data["trading_frame"]
        
        #Specific Profit Figure
        fig, (ax1, ax2) = plt.subplots(nrows=2)
        

        y = trading_frame.loc[:,"specific_profit"].to_numpy().astype(np.double)
        mask = np.isfinite(y)
        x = np.arange(0,len(y),1)
        ax1.plot(x[mask], y[mask])
        
        y = trading_frame.loc[:,"specific_profit_accumulated"].to_numpy().astype(np.double)
        mask = np.isfinite(y)
        x = np.arange(0,len(y),1)
        ax2.plot(x[mask], y[mask], drawstyle="steps")
        ax2.plot(trading_frame["specific_profit_accumulated"], drawstyle="steps", linestyle="--")

        ax1.set_title("Specific Profit", fontsize="7")
        ax2.set_title("Accumulated Specific Profit", fontsize="7")
        ax1.tick_params(labelsize=7)
        ax2.tick_params(labelsize=7)
        ax1.set_xlim(left=0)
        ax2.set_xlim(left=0)
        fig.tight_layout()
        
        self.tb.add_figure("Specific Profit Graph", fig, self.epoch_count)

        #clear the figure
        fig.clear()
        del(fig)

        #trading Activity Figure
        fig, ax = plt.subplots()

        tas = trading_activity_interval[0]
        tae = trading_activity_interval[1]
        
        ax.plot(trading_frame.loc[tas:tae, "close"], color="black")
        ax.plot(trading_frame.loc[tas:tae, "hold"], marker='o', linestyle="", color="gray", markersize=4)
        ax.plot(trading_frame.loc[tas:tae, "buy"], marker='o', linestyle="", color="green", markersize=4)
        ax.plot(trading_frame.loc[tas:tae, "sell"], marker='o', linestyle="", color="red", markersize=4)

        title = f"Date: {interval_info['date_interval']}, Movement: {interval_info['movement']}"
        ax.set_title(title, fontsize="7")
        ax.tick_params(labelsize=7)
        fig.tight_layout()
        
        self.tb.add_figure("Trading Activity", fig, self.epoch_count)

        #clear the figure
        fig.clear()
        del(fig)

        #close all the figures
        plt.close("all")

    def track_test_metrics(self, loss, preds, labels):
        #track test loss
        self.epoch_test_loss += loss

        #track test num correct
        self.epoch_test_num_correct += self._get_num_correct(preds, labels)

        #track choice distribution
        self.epoch_test_choices = torch.cat((self.epoch_test_choices, preds.argmax(dim=1).to('cpu')), dim=0)

    @torch.no_grad()
    def _get_num_correct(self, preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()
    
    def checkpoint(self):
        return {"epoch": self.epoch_count,
                "best_test_accuracy": self.run_best_test_accuracy,
                "best_specific_profit_stability": self.run_best_specific_profit_stability}

    def load_checkpoint(self, checkpoint):
        self.epoch_count = checkpoint["epoch"]
        self.run_best_test_accuracy = checkpoint["best_test_accuracy"]
        self.run_best_specific_profit_stability = checkpoint["best_specific_profit_stability"]
示例#11
0
def train_gail(args):
    # Tensorboard
    writer = SummaryWriter('runs/{}_SAC_{}_{}_seed[{}]-[]'.format(
        datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), args.env_name,
        args.policy, args.seed, args.scrip_num))

    # Environment
    env = gym.make(args.env_name)
    device = torch.device(
        "cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # Seed
    env.seed(args.seed)
    env.action_space.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)

    # Define actor, using SAC generator
    generator = SAC(env.observation_space.shape[0], env.action_space, args)

    # Load expert data
    expert_dataset = expert_data.MuJoCoExpertData(
        expert_path=args.expert_path, traj_limitation=args.traj_limitation)
    state_mean, state_std, action_mean, action_std = expert_dataset.state_mean.numpy(), expert_dataset.state_std.numpy(), \
                                                     expert_dataset.action_mean.numpy(), expert_dataset.action_std.numpy()

    # Define replay buffer
    generator_memory = ReplayMemory(args.gen_replay_size, args.seed)
    # discriminator_memory = Discriminator_ReplayMemory(args.dis_replay_size, args.seed)

    # pretrained by BC algorithm
    # if args.pretrain_bc:
    #     import behavior_clone
    #     generator.behavior_clone()

    # Define discriminator
    discriminator = DiscriminatorModel(env.observation_space.shape[0],
                                       env.action_space.shape[0], args)

    # Define the class what generates a segment
    normal_list = [state_mean, state_std, action_mean, action_std]

    segment_generator = generate_segment(env,
                                         generator,
                                         args.segment_len,
                                         device,
                                         normal=normal_list,
                                         is_normal=args.is_normal)

    # pretrain the discriminator
    # if args.pretrain:
    #     # sampling expert data set from expert demonstrations, type: torch tensor, device:cpu
    #     expert_state, expert_action, expert_next_state = expert_dataset.get_next_expert_batch(args.expert_batch)
    #     total_data = np.concatenate((expert_state, expert_action))
    #     discriminator_reward = discriminator.discriminator(
    #         torch.from_numpy(total_data).to(torch.float32).to(device))
    #     for i in range(args.d_steps):
    #         discriminator.pretrain_discriminator(dsa=discriminator_reward, expert_batch=args.expert_batch)
    #
    #     if len(memory) < args.replay_size:
    #         for i in range(args.expert_batch):
    #             memory.push(expert_state[i], expert_action["action"][i], 1,
    #                         expert_next_state["next_state"][i], 1.0)
    #     for _ in range(args.g_steps):
    #         critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = generator.update_parameters(memory,
    #                                                                                                  args.batch_size,
    #                                                                                                  0)

    color_text = Colored()
    avg_reward1 = 0
    iters, sample_interval, print_interval, writer_interval = 0, 1, 2, 20
    critic_1_loss, critic_2_loss, policy_loss = 0, 0, 0
    generator_accuracy, expert_accuracy, dis_loss_gen, dis_loss_exp = 0, 0, 0, 0
    csv_path = "data/{}/{}_seed[{}].csv".format(args.env_name, args.env_name,
                                                args.seed)
    while True:
        # sampling generator data set from environment, type= numpy array, device:
        if iters % sample_interval == 0:
            with timed("sampling from environment segment len:{}".format(
                    args.segment_len)):
                gene_data = segment_generator.__next__()
                # seg["state"]:arrary, seg["action"]:array, seg["entropy"]:tensor,shape:segment_len*action_space
        else:
            gene_data = segment_generator.__next__()
        # print("args.is_normal", args.is_normal)
        # print("generator_replay_batch_size", args.generator_replay_batch_size)
        # print("gen_replay_size", args.gen_replay_size)
        # sampling expert data set from expert demonstrations, type: torch tensor, device:cpu
        expert_state, expert_action, expert_next_state = expert_dataset.get_next_expert_batch(
            args.expert_batch)

        # a = np.concatenate((gene_data["state"], gene_data["action"]), axis=1)
        # b = np.concatenate((expert_state, expert_action), axis=1)
        # In total_data: 0 - (args.segment_len-1) row is generated data,
        # args.segment_len - (args.segment_len+args.expert_batch-1) row is expert data
        if args.is_normal:
            gene_state_memory = (gene_data["state"] - state_mean) / state_std
            gene_action_memory = (gene_data["action"] -
                                  action_mean) / action_std
            gene_next_state_memory = (gene_data["next_state"] -
                                      state_mean) / state_std
            expert_state_memory = (expert_state - state_mean) / state_std
            expert_action_memory = (expert_action - action_mean) / action_std
            expert_next_state_memory = (expert_next_state -
                                        state_mean) / state_std
        else:
            gene_state_memory = gene_data["state"]
            gene_action_memory = gene_data["action"]
            gene_next_state_memory = gene_data["next_state"]
            expert_state_memory = expert_state
            expert_action_memory = expert_action
            expert_next_state_memory = expert_next_state

        # gene_total_data_normal = np.concatenate((gen_state_normal, gen_action_normal), axis=1)

        # exp_state_normal = (expert_state - state_mean) / state_std
        # exp_action_normal = (expert_action - action_mean) / action_std
        # exp_next_state_normal = (expert_next_state - state_mean) / state_std
        # expert_total_data_normal = np.concatenate((exp_state_normal, exp_action_normal), axis=1)

        # total_data_normal = np.vstack((gene_total_data_normal, expert_total_data_normal))
        total_reward = discriminator.discriminator(
            torch.from_numpy(
                np.vstack(
                    (np.concatenate((gene_data["state"], gene_data["action"]),
                                    axis=1),
                     np.concatenate((expert_state, expert_action),
                                    axis=1)))).to(torch.float32).to(device))
        # expert_accuracy = np.array(total_reward[args.segment_len:].cpu() > 0.5).mean()
        # generator_accuracy = np.array(total_reward[:args.segment_len].cpu() < 0.5).mean()
        # gene_reward = total_reward[:args.segment_len]
        # # # Creating a mix_index to random mix expert data and generate data
        # # # In the original total_data, 0-99 row is generated data, 100-199 row is expert data
        # # mix_index = np.arange(2*args.segment_len)
        # # np.random.shuffle(mix_index)
        # # # the data what is mixed, shape:(2*args.segment_len, state_space+action_space)
        # # total_data = total_data[mix_index, :]
        # # Compute reward from discriminator
        # # discriminator_reward = discriminator.discriminator(torch.from_numpy(total_data).to(torch.float32).to(device))
        # # Computer accuracy of discriminator
        # # generator_accuracy = np.array(discriminator_reward[:args.segment_len - 1].cpu() < 0.5).mean()
        # # expert_accuracy = np.array(discriminator_reward[args.segment_len:].cpu() > 0.5).mean()
        #
        # # if expert_accuracy < 0.9:
        if iters % args.dis_interval == 0:
            for i in range(args.d_steps):
                # Compute reward from discriminator
                dis_loss_exp = discriminator.update_discriminator(
                    dsa=total_reward,
                    segment_len=args.segment_len,
                    exp_current_len=args.dis_upd_exp_len,
                    gen_current_len=args.dis_upd_gen_len)
                generator_accuracy = np.array(
                    total_reward[:args.segment_len].cpu() < 0.5).mean()
                expert_accuracy = np.array(
                    total_reward[args.segment_len:].cpu() > 0.5).mean()

        # Push generated data and expert data to replay buffer
        for i in range(args.segment_len):
            if generator_accuracy > 0.6 or generator_accuracy < 0.4:
                generator_memory.push(
                    gene_state_memory[i], gene_action_memory[i],
                    np.array([0])
                    if total_reward[:args.segment_len][i] < 0.2 else np.array([
                        -torch.log(1 - total_reward[:args.segment_len][i] +
                                   args.reward_baseline).item()
                    ]), gene_next_state_memory[i], gene_data["mask"][i])
            else:
                generator_memory.push(
                    gene_state_memory[i], gene_action_memory[i],
                    np.array([
                        -torch.log(torch.tensor(args.reward_baseline)).item()
                    ]), gene_next_state_memory[i], gene_data["mask"][i])
            # -torch.log(1 - gene_reward[i] + 1e-1).item()
        for i in range(args.expert_batch):
            generator_memory.push(
                expert_state_memory[i], expert_action_memory[i],
                np.array(
                    [-torch.log(torch.tensor(args.reward_baseline)).item()]),
                expert_next_state_memory[i], np.array([1]))
        if len(generator_memory) < args.generator_replay_batch_size:
            continue

        for _ in range(args.g_steps):
            critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha = \
                generator.update_parameters(generator_memory, args.generator_replay_batch_size, 0,
                                            gene_entropy=torch.mean(gene_data["entropy"]).item())

        if iters % print_interval == 0:
            # Test generator
            avg_reward = 0.
            episodes = 4
            for _ in range(episodes):
                state = env.reset()
                episode_reward = 0
                done = False
                while not done:
                    # state_normal = (state - state_mean) / state_std
                    action, _ = generator.select_action(state, evaluate=True)

                    # action_normal = (action - action_mean) / action_std
                    next_state, reward, done, _ = env.step(action)
                    episode_reward += reward

                    state = next_state
                avg_reward += episode_reward
            avg_reward /= episodes
            avg_reward1 = avg_reward
            # print
            print(
                "iters:{}, total_steps:{}, episode_return:{}, verified_episode_return:{}, "
                "generator_accuracy:{}, expert_accuracy:{}".format(
                    iters, gene_data["total_steps"], round(avg_reward, 4),
                    round(gene_data["episode_return"], 3),
                    round(generator_accuracy, 3), round(expert_accuracy, 3)))
            print(
                "scrip_num:{}, disc_loss_gen:{}, dis_loss_exp:{}, critic_1_loss:{}, critic_2_loss:{}, policy_loss:{}"
                .format(args.scrip_num, round(dis_loss_gen, 4),
                        round(dis_loss_exp, 4), round(critic_1_loss, 4),
                        round(critic_2_loss, 4), round(policy_loss, 4)))

            # writer data to csv
            # print(color_text.cyan("--" * 10, ))
            # data = [round(avg_reward, 4, 3), iters]
            # with open(csv_path, "a+", newline='') as f:
            #     print_text = "-----------{} added a new line!------------".format(csv_path)
            #     print(color_text.cyan(print_text))
            #     csv_writer = csv.writer(f)
            #     csv_writer.writerow(data)
            # verifying the generator performence
        if iters % writer_interval == 0:
            # Writing data to tensorboard
            writer.add_scalar("verified_episode_return", round(avg_reward1, 4),
                              gene_data["total_steps"])
            writer.add_scalar("policy_loss", round(policy_loss, 4), iters)
            writer.add_scalar("critic_1_loss", round(critic_1_loss, 4), iters)
            writer.add_scalar("critic_2_loss", round(critic_2_loss, 4), iters)
            writer.add_scalar("expert_accuracy", round(expert_accuracy, 4),
                              iters)
            writer.add_scalar("generator_accuracy",
                              round(generator_accuracy, 4), iters)
            writer.add_scalar("scrip_num", args.scrip_num, iters)

        iters += 1
        if gene_data["total_steps"] > args.num_steps:
            break
    print("done!")
    # 先更新 discriminator,然后异步更新 generator,
    env.close()
    writer.close()
示例#12
0
def main(argv):
    writer = SummaryWriter()

    torch.manual_seed(FLAGS.random_seed)

    np.random.seed(FLAGS.random_seed)
    if hasattr(torch, "cuda_is_available"):
        if torch.cuda_is_available():
            torch.cuda.manual_seed(FLAGS.random_seed)
            torch.backends.cudnn.enabled = True
            torch.backends.cudnn.benchmark = True

    device = torch.device(FLAGS.device)

    kwargs = {
        "num_workers": 1,
        "pin_memory": True
    } if FLAGS.device is "cuda" else {}
    train_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            root=".",
            train=True,
            download=True,
            transform=torchvision.transforms.Compose([
                #                    torchvision.transforms.RandomCrop(size=[28,28], padding=4),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307, ), (0.3081, )),
            ]),
        ),
        batch_size=FLAGS.batch_size,
        shuffle=True,
        **kwargs,
    )
    test_loader = torch.utils.data.DataLoader(
        torchvision.datasets.MNIST(
            root=".",
            train=False,
            transform=torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307, ), (0.3081, )),
            ]),
        ),
        batch_size=FLAGS.batch_size,
        **kwargs,
    )

    label = os.environ.get("SLURM_JOB_ID", str(uuid.uuid4()))
    if FLAGS.prefix:
        path = f"runs/mnist/{FLAGS.prefix}/{label}"
    else:
        path = f"runs/mnist/{label}"

    os.makedirs(path, exist_ok=True)
    os.chdir(path)
    FLAGS.append_flags_into_file(f"flags.txt")

    input_features = 28 * 28
    output_features = 10

    model = LIFConvNet(
        input_features,
        FLAGS.seq_length,
        model=FLAGS.model,
        device=device,
        refrac=FLAGS.refrac,
        only_first_spike=FLAGS.only_first_spike,
    ).to(device)

    if FLAGS.optimizer == "sgd":
        optimizer = torch.optim.SGD(model.parameters(), lr=FLAGS.learning_rate)
    elif FLAGS.optimizer == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=FLAGS.learning_rate)

    if FLAGS.only_output:
        optimizer = torch.optim.Adam(model.out.parameters(),
                                     lr=FLAGS.learning_rate)

    training_losses = []
    mean_losses = []
    test_losses = []
    accuracies = []

    for epoch in range(FLAGS.epochs):
        training_loss, mean_loss = train(model,
                                         device,
                                         train_loader,
                                         optimizer,
                                         epoch,
                                         writer=writer)
        test_loss, accuracy = test(model,
                                   device,
                                   test_loader,
                                   epoch,
                                   writer=writer)

        training_losses += training_loss
        mean_losses.append(mean_loss)
        test_losses.append(test_loss)
        accuracies.append(accuracy)

        max_accuracy = np.max(np.array(accuracies))

        if (epoch % FLAGS.model_save_interval == 0) and FLAGS.save_model:
            model_path = f"mnist-{epoch}.pt"
            save(
                model_path,
                model=model,
                optimizer=optimizer,
                epoch=epoch,
                is_best=accuracy > max_accuracy,
            )

    np.save("training_losses.npy", np.array(training_losses))
    np.save("mean_losses.npy", np.array(mean_losses))
    np.save("test_losses.npy", np.array(test_losses))
    np.save("accuracies.npy", np.array(accuracies))
    model_path = f"mnist-final.pt"
    save(
        model_path,
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        is_best=accuracy > max_accuracy,
    )
    writer.close()
示例#13
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Num seq size = %d", len(train_dataset[0]))  # 512
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            inputs, labels = mask_tokens(batch, tokenizer,
                                         args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            # import pdb
            # pdb.interact
            outputs = model(inputs,
                            masked_lm_labels=labels) if args.mlm else model(
                                inputs, labels=labels)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = 'checkpoint'
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        '{}-{}'.format(checkpoint_prefix, global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train_engine(__C):
    # define network
    net = get_network(__C)
    net = net.cuda()

    __C.batch_size = __C.batch_size // __C.gradient_accumulation_steps

    # define dataloader
    train_loader = get_train_loader(__C)
    test_loader = get_test_loader(__C)

    # define optimizer and loss function
    if __C.label_smoothing:
        loss_function = LabelSmoothingCrossEntropy(__C.smoothing)
    else:
        loss_function = nn.CrossEntropyLoss()

    # define optimizer and training parameters
    if __C.no_bias_decay:
        params = split_weights(net)
    else:
        params = net.parameters()
    optimizer = optim.SGD(params, lr=__C.lr, momentum=0.9, weight_decay=5e-4)

    # define optimizer scheduler
    # len(train_loader) 就是一个epoch的steps数量
    warmup_steps = __C.warmup_steps
    total_steps = __C.num_steps
    # change epoch into steps
    for i in __C.milestones:
        i *= len(train_loader)
    if __C.decay_type == 'multi_step':
        train_scheduler = WarmupMultiStepSchedule(__C,
                                                  optimizer,
                                                  warmup_steps=warmup_steps,
                                                  t_total=total_steps)
    elif __C.decay_type == 'cosine':
        train_scheduler = WarmupCosineSchedule(optimizer,
                                               warmup_steps=warmup_steps,
                                               t_total=total_steps)
    elif __C.decay_type == 'linear':
        train_scheduler = WarmupLinearSchedule(optimizer,
                                               warmup_steps=warmup_steps,
                                               t_total=total_steps)

    # define tensorboard writer
    writer = SummaryWriter(
        log_dir=os.path.join(__C.tensorboard_log_dir, __C.model, __C.version))

    # define model save dir
    checkpoint_path = os.path.join(__C.ckpts_dir, __C.model, __C.version)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    checkpoint_path = os.path.join(checkpoint_path,
                                   '{net}-{global_step}-{type}.pth')

    # define log save dir
    log_path = os.path.join(__C.result_log_dir, __C.model)
    if not os.path.exists(log_path):
        os.makedirs(log_path)
    log_path = os.path.join(log_path, __C.version + '.txt')

    # write the hyper parameters to log
    logfile = open(log_path, 'a+')
    logfile.write(str(__C))
    logfile.close()

    # Train!
    logger.info("  ***** Running training *****")
    logger.info("  Total optimization steps = %d", __C.num_steps)
    logger.info("  Instantaneous batch size per GPU = %d", __C.batch_size)
    logger.info("  Gradient Accumulation steps = %d",
                __C.gradient_accumulation_steps)

    net.zero_grad()
    losses = AverageMeter()
    global_step, best_acc = 0, 0
    while True:
        net.train()
        epoch_iterator = tqdm(train_loader,
                              desc="Training (X / X Steps) (loss=X.X)",
                              bar_format="{l_bar}{r_bar}",
                              dynamic_ncols=True)
        for step, (images, labels) in enumerate(train_loader):
            images = images.cuda()
            labels = labels.cuda()
            train_outputs = net(images)
            loss = loss_function(train_outputs, labels)

            if __C.gradient_accumulation_steps > 1:
                loss = loss / __C.gradient_accumulation_steps
            else:
                loss.backward()

            if (step + 1) % __C.gradient_accumulation_steps == 0:
                losses.update(loss.item() * __C.gradient_accumulation_steps)
                torch.nn.utils.clip_grad_norm_(net.parameters(),
                                               __C.max_grad_norm)
                train_scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

                epoch_iterator.set_description(
                    "Training (%d / %d Steps) (loss=%2.5f)" %
                    (global_step, total_steps, losses.val))

                writer.add_scalar("[Step] Train/loss",
                                  scalar_value=losses.val,
                                  global_step=global_step)
                writer.add_scalar("[Step] Train/lr",
                                  scalar_value=train_scheduler.get_lr()[0],
                                  global_step=global_step)

                if global_step % __C.eval_every == 0:
                    accuracy = valid(__C,
                                     model=net,
                                     writer=writer,
                                     test_loader=test_loader,
                                     global_step=global_step,
                                     loss_function=loss_function)
                    if best_acc < accuracy:
                        torch.save(
                            net.state_dict(),
                            checkpoint_path.format(net=__C.model,
                                                   global_step=global_step,
                                                   type='best'))
                        best_acc = accuracy
                    net.train()

                if global_step % total_steps == 0:
                    break
        losses.reset()
        if global_step % total_steps == 0:
            break

    writer.close()
    logger.info("Best Accuracy: \t%f" % best_acc)
    logger.info("End Training!")
示例#15
0
class metaLogger(object):
    def __init__(self, args, flush_sec=5):
        self.log_path = args.j_dir + "/log/"
        self.tb_path = args.j_dir + "/tb/"
        # self.ckpt_status = "curr"
        self.ckpt_status = self.get_ckpt_status(args.j_dir, args.j_id)
        self.log_dict = self.load_log(self.log_path)
        self.writer = SummaryWriter(log_dir=self.tb_path, flush_secs=flush_sec)
        self.enable_wandb = args.enable_wandb

        if self.enable_wandb:
            self.wandb_log = wandb.init(name=args.j_dir.split("/")[-1],
                                        project=args.wandb_project,
                                        dir=args.j_dir,
                                        id=str(args.j_id),
                                        resume=True)
            self.wandb_log.config.update(args)

    def get_ckpt_status(self, j_dir, j_id):

        status = "curr"
        ckpt_dir = j_dir + "/" + str(j_id) + "/"
        ckpt_location_prev = os.path.join(ckpt_dir, "ckpt_prev.pth")
        ckpt_location_curr = os.path.join(ckpt_dir, "ckpt_curr.pth")
        if os.path.exists(ckpt_location_curr) and os.path.exists(
                ckpt_location_prev):
            try:
                torch.load(ckpt_location_curr)
            except TypeError:
                status = "prev"
        else:
            status = "curr"
        return status

    def load_log(self, log_path):
        if self.ckpt_status == "curr":
            try:
                log_dict = torch.load(log_path + "/log_curr.pth")
            except FileNotFoundError:
                log_dict = defaultdict(lambda: list())
            except TypeError:
                log_dict = torch.load(log_path + "/log_prev.pth")
                self.ckpt_status = "prev"
        else:
            log_dict = torch.load(log_path + "/log_prev.pth")
        return log_dict

    def add_scalar(self, name, val, step):
        self.writer.add_scalar(name, val, step)
        self.log_dict[name] += [(time.time(), int(step), float(val))]
        try:
            self.log_dict[name] += [(time.time(), int(step), float(val))]
        except KeyError:
            self.log_dict[name] = [(time.time(), int(step), float(val))]

        if self.enable_wandb:
            if "_itr" in name:
                self.wandb_log.log({"iteration": step, name: float(val)})
            else:
                self.wandb_log.log({"epoch": step, name: float(val)})

    def add_scalars(self, name, val_dict, step):
        self.writer.add_scalars(name, val_dict, step)
        for key, val in val_dict.items():
            self.log_dict[name + key] += [(time.time(), int(step), float(val))]

    def add_figure(self, name, val, step):
        self.writer.add_figure(name, val, step)
        val.savefig(self.log_path + "/" + name + ".png")

    def save_log(self):
        try:
            os.makedirs(self.log_path)
        except os.error:
            pass

        log_curr = os.path.join(self.log_path, "log_curr.pth")
        log_prev = os.path.join(self.log_path, "log_prev.pth")

        # no existing logs
        if not (os.path.exists(log_curr) or os.path.exists(log_prev)):
            torch.save(dict(self.log_dict), log_curr)

        elif os.path.exists(log_curr):
            # overwrite log_prev with log_curr
            cmd = "cp -r {} {}".format(log_curr, log_prev)
            os.system(cmd)
            torch.save(dict(self.log_dict), log_curr)

    # def log_obj(self, name, val):
    # self.logobj[name] = val

    # def log_objs(self, name, val, step=None):
    # self.logobj[name] += [(time.time(), step, val)]

    # def log_vector(self, name, val, step=None):
    # name += '_v'
    # if step is None:
    # step = len(self.logobj[name])
    # self.logobj[name] += [(time.time(), step, list(val.flatten()))]

    def close(self):
        self.writer.close()
def run_training_test(root_dir, device=torch.device("cuda:0"), cachedataset=False):
    monai.config.print_config()
    images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(root_dir, "seg*.nii.gz")))
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:20], segs[:20])]
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-20:], segs[-20:])]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"]),
            ScaleIntensityd(keys=["img", "seg"]),
            RandCropByPosNegLabeld(
                keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]),
            ToTensord(keys=["img", "seg"]),
        ]
    )
    train_transforms.set_random_state(1234)
    val_transforms = Compose(
        [
            LoadNiftid(keys=["img", "seg"]),
            AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
            Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"]),
            ScaleIntensityd(keys=["img", "seg"]),
            ToTensord(keys=["img", "seg"]),
        ]
    )

    # create a training data loader
    if cachedataset:
        train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.8)
    else:
        train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4, collate_fn=list_data_collate)
    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True, to_onehot_y=False, sigmoid=True, reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 5e-4)

    # start a typical PyTorch training
    val_interval = 2
    best_metric, best_metric_epoch = -1, -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter(log_dir=os.path.join(root_dir, "runs"))
    model_filename = os.path.join(root_dir, "best_metric_model.pth")
    for epoch in range(6):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{6}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss:{loss.item():0.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch +1} average loss:{epoch_loss:0.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                metric_sum = 0.0
                metric_count = 0
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    sw_batch_size, roi_size = 4, (96, 96, 96)
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    value = dice_metric(y_pred=val_outputs, y=val_labels)
                    not_nans = dice_metric.not_nans.item()
                    metric_count += not_nans
                    metric_sum += value.item() * not_nans
                metric = metric_sum / metric_count
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), model_filename)
                    print("saved new best metric model")
                print(
                    f"current epoch {epoch +1} current mean dice: {metric:0.4f} "
                    f"best mean dice: {best_metric:0.4f} at epoch {best_metric_epoch}"
                )
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
    print(f"train completed, best_metric: {best_metric:0.4f}  at epoch: {best_metric_epoch}")
    writer.close()
    return epoch_loss_values, best_metric, best_metric_epoch
def train(opt):

    params = Params(
        f'Yet-Another-EfficientDet-Pytorch/projects/{opt.project}.yml')

    if params.num_gpus == 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    if torch.cuda.is_available():
        logging.debug("yes! CUDA is availabale")
        torch.cuda.manual_seed(42)
    else:
        torch.manual_seed(42)

    opt.saved_path = opt.saved_path + f'/checkpoints/'
    opt.log_path = opt.log_path + f'/{params.project_name}/tensorboard/'
    os.makedirs(opt.log_path, exist_ok=True)
    os.makedirs(opt.saved_path, exist_ok=True)

    name_tensorboard_logs = input("NAME TENSORBOARD LOG [optional]:")

    if not name_tensorboard_logs:
        writer = SummaryWriter(opt.log_path + f"/{name_tensorboard_logs}")
    else:
        writer = SummaryWriter(
            opt.log_path +
            f'/{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}/')

    compound_coefficient, _ = coefficient_from_weights_filepath(
        opt.load_weights)

    training_params = {
        'batch_size': opt.batch_size,
        'shuffle': True,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    val_params = {
        'batch_size': opt.batch_size,
        'shuffle': False,
        'drop_last': True,
        'collate_fn': collater,
        'num_workers': opt.num_workers
    }

    input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
    training_set = ThermalDataset(
        root_dir=os.path.join(opt.data_path, params.project_name),
        set=params.train_set,
        transform=transforms.Compose([
            Normalizer(mean=params.mean, std=params.std),
            Augmenter(),
            Resizer(input_sizes[compound_coefficient])
        ]))

    train_size = int(0.7 * len(training_set))
    test_size = len(training_set) - train_size
    print(f"Train size split: {train_size}\n, Test size split: {test_size}")

    train_dataset, val_dataset = torch.utils.data.random_split(
        training_set, [train_size, test_size])

    training_generator = DataLoader(train_dataset, **training_params)

    val_generator = DataLoader(val_dataset, **val_params)

    model = EfficientDetBackbone(num_classes=len(params.obj_list),
                                 compound_coef=compound_coefficient,
                                 ratios=eval(params.anchors_ratios),
                                 scales=eval(params.anchors_scales))

    # load last weights
    if opt.load_weights is not None:
        if opt.load_weights.endswith('.pth'):
            weights_path = opt.load_weights
        else:
            weights_path = get_last_weights(opt.saved_path)
        try:
            last_step = int(
                os.path.basename(weights_path).split('_')[-1].split('.')[0])
        except:
            last_step = 0

        try:
            ret = model.load_state_dict(torch.load(weights_path), strict=False)
        except RuntimeError as e:
            print(f'[Warning] Ignoring {e}')
            print(
                '[Warning] Don\'t panic if you see this, this might be because you load a pretrained weights with different number of classes. The rest of the weights should be loaded already.'
            )

        print(
            f'[Info] loaded weights: {os.path.basename(weights_path)}, resuming checkpoint from step: {last_step}'
        )
    else:
        last_step = 0
        print('[Info] initializing weights...')
        init_weights(model)

    # freeze backbone if train head_only
    if opt.head_only:

        def freeze_backbone(m):
            classname = m.__class__.__name__
            for ntl in ['EfficientNet', 'BiFPN']:
                if ntl in classname:
                    for param in m.parameters():
                        param.requires_grad = False

        model.apply(freeze_backbone)
        print('[Info] freezed backbone')

    # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
    # apply sync_bn when using multiple gpu and batch_size per gpu is lower than 4
    #  useful when gpu memory is limited.
    # because when bn is disable, the training will be very unstable or slow to converge,
    # apply sync_bn can solve it,
    # by packing all mini-batch across all gpus as one batch and normalize, then send it back to all gpus.
    # but it would also slow down the training by a little bit.
    if params.num_gpus > 1 and opt.batch_size // params.num_gpus < 4:
        model.apply(replace_w_sync_bn)
        use_sync_bn = True
    else:
        use_sync_bn = False

    # warp the model with loss function, to reduce the memory usage on gpu0 and speedup
    model = ModelWithLoss(model, debug=opt.debug)

    if params.num_gpus > 0:
        model = model.cuda()
        if params.num_gpus > 1:
            model = CustomDataParallel(model, params.num_gpus)
            if use_sync_bn:
                patch_replication_callback(model)

    if opt.optim == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),
                                    opt.lr,
                                    momentum=0.9,
                                    nesterov=True)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           patience=3,
                                                           verbose=True)

    epoch = 0
    best_loss = 1e5
    best_epoch = 0
    step = max(0, last_step)
    model.train()

    num_iter_per_epoch = len(training_generator)

    try:
        for epoch in range(opt.num_epochs):
            last_epoch = step // num_iter_per_epoch
            if epoch < last_epoch:
                continue

            epoch_loss = []
            progress_bar = tqdm(training_generator)
            for iter, data in enumerate(progress_bar):
                if iter < step - last_epoch * num_iter_per_epoch:
                    progress_bar.update()
                    continue
                try:
                    imgs = data['img']
                    annot = data['annot']

                    if params.num_gpus == 1:
                        # if only one gpu, just send it to cuda:0
                        # elif multiple gpus, send it to multiple gpus in CustomDataParallel, not here
                        imgs = imgs.cuda()
                        annot = annot.cuda()

                    optimizer.zero_grad()
                    cls_loss, reg_loss = model(imgs,
                                               annot,
                                               obj_list=params.obj_list)
                    cls_loss = cls_loss.mean()
                    reg_loss = reg_loss.mean()

                    loss = cls_loss + reg_loss
                    if loss == 0 or not torch.isfinite(loss):
                        continue

                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
                    optimizer.step()

                    epoch_loss.append(float(loss))

                    progress_bar.set_description(
                        'Step: {}. Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. Total loss: {:.5f}'
                        .format(step, epoch, opt.num_epochs, iter + 1,
                                num_iter_per_epoch, cls_loss.item(),
                                reg_loss.item(), loss.item()))

                    writer.add_scalar('Loss', loss, step)
                    writer.add_scalar('Regression_loss', reg_loss, step)
                    writer.add_scalar('Classfication_loss', cls_loss, step)

                    # log learning_rate
                    current_lr = optimizer.param_groups[0]['lr']
                    writer.add_scalar('learning_rate', current_lr, step)

                    step += 1

                    if step % opt.save_interval == 0 and step > 0:
                        save_checkpoint(
                            model,
                            f'efficientdet-d{compound_coefficient}_{epoch}_{step}.pth'
                        )
                        print('checkpoint...')

                except Exception as e:
                    print('[Error]', traceback.format_exc())
                    print(e)
                    continue
            scheduler.step(np.mean(epoch_loss))

            if epoch % opt.val_interval == 0:
                model.eval()
                loss_regression_ls = []
                loss_classification_ls = []
                for iter, data in enumerate(val_generator):
                    with torch.no_grad():
                        imgs = data['img']
                        annot = data['annot']

                        if params.num_gpus == 1:
                            imgs = imgs.cuda()
                            annot = annot.cuda()

                        cls_loss, reg_loss = model(imgs,
                                                   annot,
                                                   obj_list=params.obj_list)
                        cls_loss = cls_loss.mean()
                        reg_loss = reg_loss.mean()

                        loss = cls_loss + reg_loss
                        if loss == 0 or not torch.isfinite(loss):
                            continue

                        loss_classification_ls.append(cls_loss.item())
                        loss_regression_ls.append(reg_loss.item())

                cls_loss = np.mean(loss_classification_ls)
                reg_loss = np.mean(loss_regression_ls)
                loss = cls_loss + reg_loss

                print(
                    'Val. Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                    .format(epoch, opt.num_epochs, cls_loss, reg_loss, loss))
                writer.add_scalars('Loss', {'val': loss}, step)

                if loss + opt.es_min_delta < best_loss:
                    best_loss = loss
                    best_epoch = epoch

                    save_checkpoint(
                        model,
                        f'efficientdet-d{compound_coefficient}_{epoch}_{step}.pth'
                    )

                model.train()

                # Early stopping
                if epoch - best_epoch > opt.es_patience > 0:
                    print(
                        '[Info] Stop training at epoch {}. The lowest loss achieved is {}'
                        .format(epoch, best_loss))
                    break
    except KeyboardInterrupt:
        save_checkpoint(
            model, f'efficientdet-d{compound_coefficient}_{epoch}_{step}.pth')
        writer.close()
    writer.close()
示例#18
0
def train(args, train_dataset, model: BertForMlmWithClassification,
          tokenizer: BertTokenizer) -> Tuple[int, float]:
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(data: List[torch.Tensor]):
        sentences, labels = list(zip(*data))
        if tokenizer._pad_token is None:
            return pad_sequence(sentences, batch_first=True)
        return (
            pad_sequence(sentences,
                         batch_first=True,
                         padding_value=tokenizer.pad_token_id),
            torch.tensor(labels),
        )

    train_sampler = (RandomSampler(train_dataset) if args.local_rank == -1 else
                     DistributedSampler(train_dataset))
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size,
                                  collate_fn=collate)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = (
            args.max_steps //
            (len(train_dataloader) // args.gradient_accumulation_steps) + 1)
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if (args.model_name_or_path and os.path.isfile(
            os.path.join(args.model_name_or_path, "optimizer.pt"))
            and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt"))):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss, logging_loss = 0.0, 0.0

    # Take care of distributed/parallel training
    model_to_resize = model.module if hasattr(model, "module") else model
    model_to_resize.resize_token_embeddings(len(tokenizer))

    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproducibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            batch, class_labels = batch

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, mask_labels = (mask_tokens(batch, tokenizer, args)
                                   if args.mlm else (batch, batch))
            inputs = inputs.to(args.device)
            mask_labels = mask_labels.to(args.device) if args.mlm else None
            class_labels = class_labels.to(args.device)
            model.train()
            outputs = model(input_ids=inputs,
                            masked_lm_labels=mask_labels,
                            class_labels=class_labels)

            loss = outputs[0]

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if (args.local_rank in [-1, 0] and args.logging_steps > 0
                        and global_step % args.logging_steps == 0):
                    # Log metrics
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                    tb_writer.add_scalar("lr",
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if (args.local_rank in [-1, 0] and args.save_steps > 0
                        and global_step % args.save_steps == 0):
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir,
                        "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
示例#19
0
def main():

    args = parse_args()
    trainroot = args.trainroot
    testroot = args.testroot
    backbone = args.backbone
    print("trainroot:", trainroot)
    print("testroot:", testroot)
    print("backbone:", backbone)

    if config.output_dir is None:
        config.output_dir = 'output'
    if config.restart_training:
        shutil.rmtree(config.output_dir, ignore_errors=True)
    if not os.path.exists(config.output_dir):
        os.makedirs(config.output_dir)

    logger = setup_logger(os.path.join(config.output_dir, 'train_log'))
    logger.info(config.print())

    torch.manual_seed(config.seed)  # 为CPU设置随机种子
    if config.gpu_id is not None and torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        logger.info('train with gpu {} and pytorch {}'.format(
            config.gpu_id, torch.__version__))
        device = torch.device("cuda:0")
        torch.cuda.manual_seed(config.seed)  # 为当前GPU设置随机种子
        torch.cuda.manual_seed_all(config.seed)  # 为所有GPU设置随机种子
    else:
        logger.info('train with cpu and pytorch {}'.format(torch.__version__))
        device = torch.device("cpu")

    train_data = MyDataset(trainroot,
                           data_shape=config.data_shape,
                           n=config.n,
                           m=config.m,
                           transform=transforms.ToTensor())

    print("len(train_data):", len(train_data))
    train_loader = Data.DataLoader(dataset=train_data,
                                   batch_size=config.train_batch_size,
                                   shuffle=True,
                                   num_workers=int(config.workers))

    writer = SummaryWriter(config.output_dir)
    model = PSENet(backbone=backbone,
                   pretrained=config.pretrained,
                   result_num=config.n,
                   scale=config.scale)
    if not config.pretrained and not config.restart_training:
        model.apply(weights_init)

    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        model = nn.DataParallel(model)
    model = model.to(device)
    # dummy_input = torch.autograd.Variable(torch.Tensor(1, 3, 600, 800).to(device))
    # writer.add_graph(models=models, input_to_model=dummy_input)
    criterion = PSELoss(Lambda=config.Lambda,
                        ratio=config.OHEM_ratio,
                        reduction='mean')
    # optimizer = torch.optim.SGD(models.parameters(), lr=config.lr, momentum=0.99)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    if config.checkpoint != '' and not config.restart_training:
        start_epoch = load_checkpoint(config.checkpoint, model, logger, device,
                                      optimizer)
        start_epoch += 1
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            config.lr_decay_step,
            gamma=config.lr_gamma,
            last_epoch=start_epoch)
    else:
        start_epoch = config.start_epoch
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         config.lr_decay_step,
                                                         gamma=config.lr_gamma)

    all_step = len(train_loader)
    logger.info('train dataset has {} samples,{} in dataloader'.format(
        train_data.__len__(), all_step))
    epoch = 0
    best_model = {'recall': 0, 'precision': 0, 'f1': 0, 'models': ''}
    try:
        for epoch in range(start_epoch, config.epochs):
            start = time.time()
            train_loss, lr = train_epoch(model, optimizer, scheduler,
                                         train_loader, device, criterion,
                                         epoch, all_step, writer, logger)
            logger.info(
                '[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'.format(
                    epoch, config.epochs, train_loss,
                    time.time() - start, lr))
            # net_save_path = '{}/PSENet_{}_loss{:.6f}.pth'.format(config.output_dir, epoch,
            #                                                                               train_loss)
            # save_checkpoint(net_save_path, models, optimizer, epoch, logger)
            if (0.3 < train_loss < 0.4 and epoch % 4 == 0) or train_loss < 0.3:
                recall, precision, f1 = eval(
                    model, os.path.join(config.output_dir, 'output'), testroot,
                    device)
                logger.info(
                    'test: recall: {:.6f}, precision: {:.6f}, f1: {:.6f}'.
                    format(recall, precision, f1))

                net_save_path = '{}/PSENet_{}_loss{:.6f}_r{:.6f}_p{:.6f}_f1{:.6f}.pth'.format(
                    config.output_dir, epoch, train_loss, recall, precision,
                    f1)
                save_checkpoint(net_save_path, model, optimizer, epoch, logger)
                if f1 > best_model['f1']:
                    best_path = glob.glob(config.output_dir + '/Best_*.pth')
                    for b_path in best_path:
                        if os.path.exists(b_path):
                            os.remove(b_path)

                    best_model['recall'] = recall
                    best_model['precision'] = precision
                    best_model['f1'] = f1
                    best_model['models'] = net_save_path

                    best_save_path = '{}/Best_{}_r{:.6f}_p{:.6f}_f1{:.6f}.pth'.format(
                        config.output_dir, epoch, recall, precision, f1)
                    if os.path.exists(net_save_path):
                        shutil.copyfile(net_save_path, best_save_path)
                    else:
                        save_checkpoint(best_save_path, model, optimizer,
                                        epoch, logger)

                    pse_path = glob.glob(config.output_dir + '/PSENet_*.pth')
                    for p_path in pse_path:
                        if os.path.exists(p_path):
                            os.remove(p_path)

                writer.add_scalar(tag='Test/recall',
                                  scalar_value=recall,
                                  global_step=epoch)
                writer.add_scalar(tag='Test/precision',
                                  scalar_value=precision,
                                  global_step=epoch)
                writer.add_scalar(tag='Test/f1',
                                  scalar_value=f1,
                                  global_step=epoch)
        writer.close()
    except KeyboardInterrupt:
        save_checkpoint('{}/final.pth'.format(config.output_dir), model,
                        optimizer, epoch, logger)
    finally:
        if best_model['models']:
            logger.info(best_model)
示例#20
0
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    writer = SummaryWriter(
        comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(),
                              lr=lr,
                              weight_decay=1e-8,
                              momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)

                    if net.n_classes > 1:
                        logging.info(
                            'Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info(
                            'Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks,
                                          global_step)
                        writer.add_images('masks/pred',
                                          torch.sigmoid(masks_pred) > 0.5,
                                          global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
示例#21
0
def train_correspondence_block(json_file, cls, gpu, synthetic, epochs=50, batch_size=64, val_ratio=0.2,
                               save_model=True, iter_print=10):
    """
    Training a UNnet for each class using real train and/or synthetic data
    Args:
        json_file: .txt file which stores the directory of the training images
        cls: the class to train on, from 1 to 6
        gpu: gpu id to use
        synthetic: whether use synthetic data or not
        epochs: number of epochs to train
        batch_size: batch size
        val_ratio: validation ratio during training
        save_model: save model or not
        iter_print: print training results per iter_print iterations

    """
    train_data = NOCSDataset(json_file, cls, synthetic=synthetic, resize=64,
                             transform=transforms.Compose([transforms.ColorJitter(brightness=(0.6, 1.4),
                                                                                  contrast=(0.8, 1.2),
                                                                                  saturation=(0.8, 1.2),
                                                                                  hue=(-0.01, 0.01)),
                                                           AddGaussianNoise(10 / 255)]))
    print('Size of trainset ', len(train_data))
    indices = list(range(len(train_data)))
    np.random.shuffle(indices)

    num_train = len(indices)
    split = int(np.floor(num_train * val_ratio))
    train_idx, valid_idx = indices[split:], indices[:split]

    # define samplers for obtaining training and validation batches
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    # prepare data loaders (combine dataset and sampler)
    num_workers = 4
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                               sampler=train_sampler, num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                             sampler=valid_sampler, num_workers=num_workers)
    device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
    print("device: ", f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
    # architecture for correspondence block - 13 objects + backgound = 14 channels for ID masks
    correspondence_block = UNet()
    correspondence_block = correspondence_block.to(device)

    # custom loss function and optimizer
    criterion_x = nn.CrossEntropyLoss()
    criterion_y = nn.CrossEntropyLoss()
    criterion_z = nn.CrossEntropyLoss()

    # specify optimizer
    optimizer = optim.Adam(correspondence_block.parameters(), lr=3e-4, weight_decay=3e-5)

    # training loop
    val_loss_min = np.Inf
    save_path = model_save_path(cls)
    writer = SummaryWriter(save_path.parent / save_path.stem / datetime.now().strftime("%d%H%M"))

    for epoch in range(epochs):
        t0 = time.time()
        train_loss = 0
        val_loss = 0
        print("------ Epoch ", epoch, " ---------")
        correspondence_block.train()
        print("training")
        for iter, (rgb, xmask, ymask, zmask, adr_rgb) in enumerate(train_loader):

            rgb = rgb.to(device)
            xmask = xmask.to(device)
            ymask = ymask.to(device)
            zmask = zmask.to(device)

            optimizer.zero_grad()
            xmask_pred, ymask_pred, zmask_pred = correspondence_block(rgb)

            loss_x = criterion_x(xmask_pred, xmask)
            loss_y = criterion_y(ymask_pred, ymask)
            loss_z = criterion_z(zmask_pred, zmask)

            loss = loss_x + loss_y + loss_z

            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            if iter % iter_print == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.
                        format(epoch, iter * len(rgb), len(train_loader.dataset),
                               100. * iter / len(train_loader), loss.item()))

        correspondence_block.eval()

        print("validating")
        for rgb, xmask, ymask, zmask, _ in val_loader:
            rgb = rgb.to(device)
            xmask = xmask.to(device)
            ymask = ymask.to(device)
            zmask = zmask.to(device)

            xmask_pred, ymask_pred, zmask_pred = correspondence_block(rgb)

            loss_x = criterion_x(xmask_pred, xmask)
            loss_y = criterion_y(ymask_pred, ymask)
            loss_z = criterion_z(zmask_pred, zmask)

            loss = loss_x + loss_y + loss_z
            val_loss += loss.item()

        # calculate average losses
        train_loss = train_loss / len(train_loader.sampler)
        val_loss = val_loss / len(val_loader.sampler)
        t_end = time.time()
        print(f'{t_end - t0} seconds')
        writer.add_scalar('train loss', train_loss, epoch)
        writer.add_scalar('val loss', val_loss, epoch)
        writer.add_scalar('epoch time', t_end - t0, epoch)

        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch, train_loss, val_loss))

        # save model if validation loss has decreased
        if val_loss <= val_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                val_loss_min,
                val_loss))
            if save_model:
                torch.save(correspondence_block.state_dict(), save_path)
            val_loss_min = val_loss
    writer.close()
def train_model(args):
    writer = SummaryWriter()

    transforms = DetectionTransform(output_size=args.resize,
                                    greyscale=True,
                                    normalize=True)

    dataset = FBSDetectionDataset(database_path=args.db,
                                  data_path=args.images,
                                  greyscale=True,
                                  transforms=transforms,
                                  categories_filter={'person': True},
                                  area_filter=[100**2, 500**2])
    dataset.print_categories()
    ''' 
    Split dataset into train and validation
    '''
    train_len = int(0.75 * len(dataset))
    dataset_lens = [train_len, len(dataset) - train_len]
    print("Splitting dataset into pieces: ", dataset_lens)
    datasets = torch.utils.data.random_split(dataset, dataset_lens)
    print(datasets)
    '''
    Setup the data loader objects (collation, batching)
    '''
    loader = torch.utils.data.DataLoader(collate_fn=collate_detection_samples,
                                         dataset=datasets[0],
                                         batch_size=args.batch_size,
                                         pin_memory=True,
                                         num_workers=args.num_data_workers)

    validation_loader = torch.utils.data.DataLoader(
        dataset=datasets[1],
        batch_size=args.batch_size,
        pin_memory=True,
        collate_fn=collate_detection_samples,
        num_workers=args.num_data_workers)
    '''
    Select device (cpu/gpu)
    '''
    device = torch.device(args.device)
    '''
    Create the model and transfer weights to device
    '''
    model = ObjectDetection(input_image_shape=args.resize,
                            pos_threshold=args.pos_anchor_iou,
                            neg_threshold=args.neg_anchor_iou,
                            num_classes=len(dataset.categories),
                            predict_conf_threshold=0.5).to(device)
    '''
    Select optimizer
    '''
    optim = torch.optim.SGD(params=model.parameters(),
                            lr=args.lr,
                            momentum=0.5)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optim,
                                                   step_size=2,
                                                   gamma=0.1,
                                                   last_epoch=-1)
    '''
    Outer training loop
    '''
    for epoch in range(1, args.epochs + 1):
        '''
        Inner training loop
        '''
        print("\n BEGINNING TRAINING STEP EPOCH {}".format(epoch))
        cummulative_loss = 0.0
        start_time = time.time()

        batch: ObjectDetectionBatch
        for idx, batch in enumerate(loader):
            '''
            Reset gradient
            '''
            optim.zero_grad()
            '''
            Push the data to the gpu (if necessary)
            '''
            batch.to(device)
            batch.debug = True if idx % args.log_interval == 0 else False
            '''
            Run the model
            '''
            losses, model_data = model(batch)
            cummulative_loss += losses["class_loss"].item()
            '''
            Calc gradient and step optimizer.
            '''
            losses['class_loss'].backward()
            optim.step()
            '''
            Log Metrics and Visualizations
            '''
            if (idx + 1) % args.log_interval == 0:
                step = (epoch - 1) * len(loader) + idx + 1

                print(
                    "Ep {} Training Step {} Batch {}/{} Loss : {:.3f}".format(
                        epoch, step, idx, len(loader), cummulative_loss))
                '''
                Save visualizations and metrics with tensorboard
                Note: For research, to reproduce graphs you will want some way to save the collected metrics (e.g. the loss values)
                to an array for recreating figures for a paper. To do so, metrics are often wrapped in a "metering" class
                that takes care of logging to tensorboard, resetting cumulative metrics, saving arrays, etc.
                '''
                '''
                training_image - the raw training images with box labels

                training_image_predicted_anchors - predictions for the same image, using basic thresholding (0.7 confidence on the logit)
                
                training_image_predicted_post_nms - predictions for the same image, filtered at 0.7 confidence followed by Non-Max-Suppression

                training_image_positive_anchors - shows anchors which received a positive label in the labeling step in the model
                '''
                sample_image = normalize_tensor(batch.images[0])
                writer.add_image_with_boxes("training_image",
                                            sample_image,
                                            box_tensor=batch.boxes[0],
                                            global_step=step)
                writer.add_image_with_boxes(
                    "training_image_predicted_anchors",
                    sample_image,
                    model_data["pos_predicted_anchors"][0],
                    global_step=step)

                keep_ind = nms(model_data["pos_predicted_anchors"][0],
                               model_data["pos_predicted_confidence"][0],
                               iou_threshold=args.nms_iou)

                writer.add_image_with_boxes(
                    "training_image_predicted_post_nms",
                    sample_image,
                    model_data["pos_predicted_anchors"][0][keep_ind],
                    global_step=step)
                writer.add_image_with_boxes(
                    "training_image_positive_anchors",
                    sample_image,
                    box_tensor=model_data["pos_labeled_anchors"][0],
                    global_step=step)
                '''
                Scalars - batch_time, training loss
                '''
                writer.add_scalar(
                    "batch_time",
                    ((time.time() - start_time) / float(args.log_interval)) *
                    1000.0,
                    global_step=step)
                writer.add_scalar("training_loss",
                                  losses['class_loss'].item(),
                                  global_step=step)

                writer.add_scalar(
                    "avg_pos_labeled_anchor_conf",
                    torch.tensor([
                        c.mean() for c in model_data["pos_labeled_confidence"]
                    ]).mean().item(),
                    global_step=step)

                start_time = time.time()

                writer.close()
            '''
            Reset metric meters as necessary
            '''
            if idx % args.metric_interval == 0:
                cummulative_loss = 0.0
        '''
        Inner validation loop
        '''
        print("\nBEGINNING VALIDATION STEP {}\n".format(epoch))
        with torch.no_grad():
            batch: ObjectDetectionBatch
            for idx, batch in enumerate(validation_loader):
                '''
                Push the data to the gpu (if necessary)
                '''
                batch.to(device)
                batch.debug = True if idx % args.log_interval == 0 else False
                '''
                Run the model
                '''
                losses, model_data = model(batch)

                if idx % args.log_interval == 0:
                    step = (epoch - 1) * len(validation_loader) + idx + 1

                    print("Ep {} Validation Step {} Batch {}/{} Loss : {:.3f}".
                          format(epoch, step, idx, len(validation_loader),
                                 losses["class_loss"].item()))
                    '''
                    Log Images
                    '''
                    sample_image = normalize_tensor(batch.images[0])
                    writer.add_image_with_boxes("validation_images",
                                                sample_image,
                                                box_tensor=batch.boxes[0],
                                                global_step=step)

                    writer.add_image_with_boxes(
                        "validation_img_predicted_anchors",
                        sample_image,
                        model_data["pos_predicted_anchors"][0],
                        global_step=step)

                    keep_ind = nms(model_data["pos_predicted_anchors"][0],
                                   model_data["pos_predicted_confidence"][0],
                                   iou_threshold=0.5)
                    print("Indicies after NMS: ", keep_ind,
                          model_data["pos_predicted_confidence"][0].shape,
                          model_data["pos_predicted_anchors"][0].shape)

                    writer.add_image_with_boxes(
                        "validation_img_predicted_post_nms",
                        sample_image,
                        model_data["pos_predicted_anchors"][0][keep_ind],
                        global_step=step)
                    '''
                    Log Scalars
                    '''
                    writer.add_scalar("validation_loss",
                                      losses['class_loss'].item(),
                                      global_step=step)
                    writer.close()

        lr_scheduler.step()
        print("Stepped learning rate. Rate is now: ", lr_scheduler.get_lr())
示例#23
0
    def train_net(self,
                  epochs=5,
                  batch_size=1,
                  lr=0.001,
                  val_percent=0.1,
                  save_cp=True,
                  img_scale=0.5,
                  dir_checkpoint='checkpoints/'):
        """Runs training based on paramaters on the data

        Args:
            epochs (int, optional): Number of epochs to run the model. Defaults to 5.
            batch_size (int, optional): Batchsize number to be taken from model. Defaults to 1.
            lr (float, optional): Learning rate for stepping. Defaults to 0.001.
            val_percent (float, optional): Percentage of data to be taken for validation. Defaults to 0.1.
            save_cp (bool, optional): Save the weights or not. Defaults to True.
            img_scale (float, optional): Scale percentage of the original image to use. Defaults to 0.5.
            dir_checkpoint (str, optional): path to save the trained weights. Defaults to 'checkpoints/'.

        Returns:
            int: best validation score recorded in one training
        """

        device = self.device
        net = self.net
        mode = self.mode

        # Randomly determines the training and validation dataset
        file_list = [
            os.path.splitext(file)[0] for file in os.listdir(self.dir_img)
            if not file.startswith('.')
        ]
        random.shuffle(file_list)
        n_val = int(len(file_list) * val_percent)
        n_train = len(file_list) - n_val
        train_list = file_list[:n_train]
        val_list = file_list[n_train:]
        dataset_train = BasicDataset(train_list, self.dir_img, self.dir_mask,
                                     epochs, img_scale, 'train', mode)
        dataset_val = BasicDataset(val_list, self.dir_img, self.dir_mask,
                                   epochs, img_scale, 'val', mode)
        train_loader = DataLoader(dataset_train,
                                  batch_size=batch_size,
                                  shuffle=False,
                                  num_workers=8,
                                  pin_memory=True)
        val_loader = DataLoader(dataset_val,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=8,
                                pin_memory=True,
                                drop_last=True)

        # Tensorboard initialization
        writer = SummaryWriter(
            comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
        global_step = 0
        val_score_list = []
        logging.info(f'''Starting training:
            Epochs:          {epochs}
            Batch size:      {batch_size}
            Learning rate:   {lr}
            Training size:   {n_train}
            Validation size: {n_val}
            Checkpoints:     {save_cp}
            Device:          {self.device.type}
            Images scaling:  {img_scale}
        ''')

        # Gradient descent method
        optimizer = optim.RMSprop(net.parameters(),
                                  lr=lr,
                                  weight_decay=1e-8,
                                  momentum=0.9)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
        if net.n_classes > 1:
            criterion = nn.CrossEntropyLoss()
        else:
            criterion = nn.BCEWithLogitsLoss()

        for epoch in range(epochs):
            net.train()
            epoch_loss = 0

            # Progress bar shown on the terminal
            with tqdm(total=n_train,
                      desc=f'Epoch {epoch + 1}/{epochs}',
                      unit='img') as pbar:
                for batch in train_loader:
                    imgs = batch['image']
                    true_masks = batch['mask']
                    assert imgs.shape[1] == net.n_channels, \
                        f'Network has been defined with {net.n_channels} input channels, ' \
                        f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                        'the images are loaded correctly.'

                    imgs = imgs.to(device=device, dtype=torch.float32)
                    mask_type = torch.float32 if net.n_classes == 1 else torch.long
                    true_masks = true_masks.to(device=device, dtype=mask_type)

                    masks_pred = net(imgs)
                    loss = criterion(masks_pred, true_masks)
                    epoch_loss += loss.item()
                    writer.add_scalar('Loss/train', loss.item(), global_step)

                    pbar.set_postfix(**{'loss (batch)': loss.item()})

                    optimizer.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_value_(net.parameters(), 0.1)
                    optimizer.step()

                    pbar.update(imgs.shape[0])
                    global_step += 1

                    # Validation phase
                    if global_step % (n_train // (10 * batch_size)) == 0:
                        for tag, value in net.named_parameters():
                            tag = tag.replace('.', '/')
                            writer.add_histogram('weights/' + tag,
                                                 value.data.cpu().numpy(),
                                                 global_step)
                            writer.add_histogram('grads/' + tag,
                                                 value.grad.data.cpu().numpy(),
                                                 global_step)
                        val_score = eval_net(net, val_loader, device)
                        val_score_list.append(val_score)
                        scheduler.step(val_score)
                        writer.add_scalar('learning_rate',
                                          optimizer.param_groups[0]['lr'],
                                          global_step)

                        if net.n_classes > 1:
                            logging.info('Validation cross entropy: {}'.format(
                                val_score))
                            writer.add_scalar('Loss/test', val_score,
                                              global_step)
                        else:
                            logging.info(
                                'Validation Dice Coeff: {}'.format(val_score))
                            writer.add_scalar('Dice/test', val_score,
                                              global_step)
                        # If temporal, the images can't be added to Tensorboard
                        if mode != 'temporal' and mode != 'temporal_augmentation':
                            writer.add_images('images', imgs, global_step)
                        if net.n_classes == 1:
                            writer.add_images('masks/true', true_masks,
                                              global_step)
                            writer.add_images('masks/pred',
                                              torch.sigmoid(masks_pred) > 0.5,
                                              global_step)

            if save_cp:  #saves the trained weights
                try:
                    os.mkdir(dir_checkpoint)
                    logging.info('Created checkpoint directory')
                except OSError:
                    pass
                torch.save(net.state_dict(),
                           dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
                logging.info(f'Checkpoint {epoch + 1} saved !')

        writer.close()
        return max(val_score_list)
def main():
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)

    # IXI dataset as a demo, downloadable from https://brain-development.org/ixi-dataset/
    images = [
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI314-IOP-0889-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI249-Guys-1072-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI609-HH-2600-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI173-HH-1590-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI020-Guys-0700-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI342-Guys-0909-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI134-Guys-0780-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI577-HH-2661-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI066-Guys-0731-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI130-HH-1528-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI607-Guys-1097-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI175-HH-1570-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI385-HH-2078-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI344-Guys-0905-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI409-Guys-0960-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI584-Guys-1129-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI253-HH-1694-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI092-HH-1436-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI574-IOP-1156-T1.nii.gz"
        ]),
        os.sep.join([
            "workspace", "data", "medical", "ixi", "IXI-T1",
            "IXI585-Guys-1130-T1.nii.gz"
        ]),
    ]

    # 2 binary labels for gender classification: man and woman
    labels = np.array(
        [0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        dtype=np.int64)
    train_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[:10], labels[:10])]
    val_files = [{
        "img": img,
        "label": label
    } for img, label in zip(images[-10:], labels[-10:])]

    # Define transforms for image
    train_transforms = Compose([
        LoadImaged(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        RandRotate90d(keys=["img"], prob=0.8, spatial_axes=[0, 2]),
        ToTensord(keys=["img"]),
    ])
    val_transforms = Compose([
        LoadImaged(keys=["img"]),
        AddChanneld(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=(96, 96, 96)),
        ToTensord(keys=["img"]),
    ])

    # Define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    check_loader = DataLoader(check_ds,
                              batch_size=2,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["label"])

    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    train_loader = DataLoader(train_ds,
                              batch_size=2,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=torch.cuda.is_available())

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds,
                            batch_size=2,
                            num_workers=4,
                            pin_memory=torch.cuda.is_available())

    # Create DenseNet121, CrossEntropyLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.densenet.densenet121(spatial_dims=3,
                                                     in_channels=1,
                                                     out_channels=2).to(device)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 1e-5)

    # start a typical PyTorch training
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    writer = SummaryWriter()
    for epoch in range(5):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{5}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(
                device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(),
                              epoch_len * epoch + step)
        epoch_loss /= step
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(
                        device), val_data["label"].to(device)
                    y_pred = torch.cat([y_pred, model(val_images)], dim=0)
                    y = torch.cat([y, val_labels], dim=0)

                acc_value = torch.eq(y_pred.argmax(dim=1), y)
                acc_metric = acc_value.sum().item() / len(acc_value)
                auc_metric = compute_roc_auc(y_pred,
                                             y,
                                             to_onehot_y=True,
                                             softmax=True)
                if acc_metric > best_metric:
                    best_metric = acc_metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(),
                               "best_metric_model_classification3d_dict.pth")
                    print("saved new best metric model")
                print(
                    "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}"
                    .format(epoch + 1, acc_metric, auc_metric, best_metric,
                            best_metric_epoch))
                writer.add_scalar("val_accuracy", acc_metric, epoch + 1)
    print(
        f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}"
    )
    writer.close()
示例#25
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    record_result = []

    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    inputs_list = []
    global_step = 0
    epochs_trained = 0
    collect_step = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to global_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        except ValueError:
            global_step = 0
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            if collect_step < 5:
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {
                    "input_ids": batch[0],
                    "attention_mask": batch[1],
                    "labels": batch[3]
                }

                if args.model_type != "distilbert":
                    inputs["token_type_ids"] = (
                        batch[2] if args.model_type
                        in ["bert", "xlnet", "albert"] else None
                    )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids

                inputs_list.append(inputs)
                collect_step += 1
                continue
            else:
                print('collect_step', collect_step)
                print('start pruning')

                new_mask = GraSP(model,
                                 args.tt,
                                 inputs_list,
                                 args.device,
                                 original_mask=None)

                rate = 0
                sum1 = 0
                for key in new_mask.keys():
                    rate += float(torch.sum(new_mask[key] == 0))
                    sum1 += float(new_mask[key].nelement())

                print('zero rate = ', rate / sum1)

                torch.save(new_mask, 'grasp_mask2/' + args.task_name + '.pt')

                return 0, 0

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2]
                    if args.model_type in ["bert", "xlnet", "albert"] else None
                )  # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    len(epoch_iterator) <= args.gradient_accumulation_steps and
                (step + 1) == len(epoch_iterator)):
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if (
                            args.local_rank == -1
                            and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        record_result.append(results)
                        for key, value in results.items():
                            eval_key = "eval_{}".format(key)
                            logs[eval_key] = value

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    print(json.dumps({**logs, **{"step": global_step}}))

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)
                    torch.save(model, os.path.join(output_dir, "model.pt"))

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    results = evaluate(args, model, tokenizer)
    record_result.append(results)
    torch.save(record_result, os.path.join(args.output_dir, "result.pt"))

    return global_step, tr_loss / global_step
示例#26
0
def _main(rank, world_size, args, savepath, logger):

    if rank == 0:
        logger.info(args)
        logger.info(f"Saving to {savepath}")
        tb_writer = SummaryWriter(os.path.join(savepath, "tb_logdir"))

    device = torch.device(
        f'cuda:{rank:d}' if torch.cuda.is_available() else 'cpu')

    if rank == 0:
        if device.type == 'cuda':
            logger.info('Found {} CUDA devices.'.format(
                torch.cuda.device_count()))
            for i in range(torch.cuda.device_count()):
                props = torch.cuda.get_device_properties(i)
                logger.info('{} \t Memory: {:.2f}GB'.format(
                    props.name, props.total_memory / (1024**3)))
        else:
            logger.info('WARNING: Using device {}'.format(device))

    t0, t1 = map(lambda x: cast(x, device), get_t0_t1(args.data))

    train_set = load_data(args.data, split="train")
    val_set = load_data(args.data, split="val")
    test_set = load_data(args.data, split="test")

    train_epoch_iter = EpochBatchIterator(
        dataset=train_set,
        collate_fn=datasets.spatiotemporal_events_collate_fn,
        batch_sampler=train_set.batch_by_size(args.max_events),
        seed=args.seed + rank,
    )
    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=args.test_bsz,
        shuffle=False,
        collate_fn=datasets.spatiotemporal_events_collate_fn,
    )
    test_loader = torch.utils.data.DataLoader(
        test_set,
        batch_size=args.test_bsz,
        shuffle=False,
        collate_fn=datasets.spatiotemporal_events_collate_fn,
    )

    if rank == 0:
        logger.info(
            f"{len(train_set)} training examples, {len(val_set)} val examples, {len(test_set)} test examples"
        )

    x_dim = get_dim(args.data)

    if args.model == "jumpcnf" and args.tpp == "neural":
        model = JumpCNFSpatiotemporalModel(
            dim=x_dim,
            hidden_dims=list(map(int, args.hdims.split("-"))),
            tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
            actfn=args.actfn,
            tpp_cond=args.tpp_cond,
            tpp_style=args.tpp_style,
            tpp_actfn=args.tpp_actfn,
            share_hidden=args.share_hidden,
            solve_reverse=args.solve_reverse,
            tol=args.tol,
            otreg_strength=args.otreg_strength,
            tpp_otreg_strength=args.tpp_otreg_strength,
            layer_type=args.layer_type,
        ).to(device)
    elif args.model == "attncnf" and args.tpp == "neural":
        model = SelfAttentiveCNFSpatiotemporalModel(
            dim=x_dim,
            hidden_dims=list(map(int, args.hdims.split("-"))),
            tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
            actfn=args.actfn,
            tpp_cond=args.tpp_cond,
            tpp_style=args.tpp_style,
            tpp_actfn=args.tpp_actfn,
            share_hidden=args.share_hidden,
            solve_reverse=args.solve_reverse,
            l2_attn=args.l2_attn,
            tol=args.tol,
            otreg_strength=args.otreg_strength,
            tpp_otreg_strength=args.tpp_otreg_strength,
            layer_type=args.layer_type,
            lowvar_trace=not args.naive_hutch,
        ).to(device)
    elif args.model == "cond_gmm" and args.tpp == "neural":
        model = JumpGMMSpatiotemporalModel(
            dim=x_dim,
            hidden_dims=list(map(int, args.hdims.split("-"))),
            tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
            actfn=args.actfn,
            tpp_cond=args.tpp_cond,
            tpp_style=args.tpp_style,
            tpp_actfn=args.tpp_actfn,
            share_hidden=args.share_hidden,
            tol=args.tol,
            tpp_otreg_strength=args.tpp_otreg_strength,
        ).to(device)
    else:
        # Mix and match between spatial and temporal models.
        if args.tpp == "poisson":
            tpp_model = HomogeneousPoissonPointProcess()
        elif args.tpp == "hawkes":
            tpp_model = HawkesPointProcess()
        elif args.tpp == "correcting":
            tpp_model = SelfCorrectingPointProcess()
        elif args.tpp == "neural":
            tpp_hidden_dims = list(map(int, args.tpp_hdims.split("-")))
            tpp_model = NeuralPointProcess(
                cond_dim=x_dim,
                hidden_dims=tpp_hidden_dims,
                cond=args.tpp_cond,
                style=args.tpp_style,
                actfn=args.tpp_actfn,
                otreg_strength=args.tpp_otreg_strength,
                tol=args.tol)
        else:
            raise ValueError(f"Invalid tpp model {args.tpp}")

        if args.model == "gmm":
            model = CombinedSpatiotemporalModel(GaussianMixtureSpatialModel(),
                                                tpp_model).to(device)
        elif args.model == "cnf":
            model = CombinedSpatiotemporalModel(
                IndependentCNF(dim=x_dim,
                               hidden_dims=list(map(int,
                                                    args.hdims.split("-"))),
                               layer_type=args.layer_type,
                               actfn=args.actfn,
                               tol=args.tol,
                               otreg_strength=args.otreg_strength,
                               squash_time=True), tpp_model).to(device)
        elif args.model == "tvcnf":
            model = CombinedSpatiotemporalModel(
                IndependentCNF(dim=x_dim,
                               hidden_dims=list(map(int,
                                                    args.hdims.split("-"))),
                               layer_type=args.layer_type,
                               actfn=args.actfn,
                               tol=args.tol,
                               otreg_strength=args.otreg_strength),
                tpp_model).to(device)
        elif args.model == "jumpcnf":
            model = CombinedSpatiotemporalModel(
                JumpCNF(dim=x_dim,
                        hidden_dims=list(map(int, args.hdims.split("-"))),
                        layer_type=args.layer_type,
                        actfn=args.actfn,
                        tol=args.tol,
                        otreg_strength=args.otreg_strength),
                tpp_model).to(device)
        elif args.model == "attncnf":
            model = CombinedSpatiotemporalModel(
                SelfAttentiveCNF(dim=x_dim,
                                 hidden_dims=list(
                                     map(int, args.hdims.split("-"))),
                                 layer_type=args.layer_type,
                                 actfn=args.actfn,
                                 l2_attn=args.l2_attn,
                                 tol=args.tol,
                                 otreg_strength=args.otreg_strength),
                tpp_model).to(device)
        else:
            raise ValueError(f"Invalid model {args.model}")

    params = []
    attn_params = []
    for name, p in model.named_parameters():
        if "self_attns" in name:
            attn_params.append(p)
        else:
            params.append(p)

    optimizer = torch.optim.AdamW([{
        "params": params
    }, {
        "params": attn_params
    }],
                                  lr=args.lr,
                                  weight_decay=args.weight_decay,
                                  betas=(0.9, 0.98))

    if rank == 0:
        ema = utils.ExponentialMovingAverage(model)

    model = DDP(model, device_ids=[rank], find_unused_parameters=True)

    if rank == 0:
        logger.info(model)

    begin_itr = 0
    checkpt_path = os.path.join(savepath, "model.pth")
    if os.path.exists(checkpt_path):
        # Restart from checkpoint if run is a restart.
        if rank == 0:
            logger.info(f"Resuming checkpoint from {checkpt_path}")
        checkpt = torch.load(checkpt_path, "cpu")
        model.module.load_state_dict(checkpt["state_dict"])
        optimizer.load_state_dict(checkpt["optim_state_dict"])
        begin_itr = checkpt["itr"] + 1

    elif args.resume:
        # Check the resume flag if run is new.
        if rank == 0:
            logger.info(f"Resuming model from {args.resume}")
        checkpt = torch.load(args.resume, "cpu")
        model.module.load_state_dict(checkpt["state_dict"])
        optimizer.load_state_dict(checkpt["optim_state_dict"])
        begin_itr = checkpt["itr"] + 1

    space_loglik_meter = utils.RunningAverageMeter(0.98)
    time_loglik_meter = utils.RunningAverageMeter(0.98)
    gradnorm_meter = utils.RunningAverageMeter(0.98)

    model.train()
    start_time = time.time()
    iteration_counter = itertools.count(begin_itr)
    begin_epoch = begin_itr // len(train_epoch_iter)
    for epoch in range(begin_epoch,
                       math.ceil(args.num_iterations / len(train_epoch_iter))):
        batch_iter = train_epoch_iter.next_epoch_itr(shuffle=True)
        for batch in batch_iter:
            itr = next(iteration_counter)

            optimizer.zero_grad()

            event_times, spatial_locations, input_mask = map(
                lambda x: cast(x, device), batch)
            N, T = input_mask.shape
            num_events = input_mask.sum()

            if num_events == 0:
                raise RuntimeError("Got batch with no observations.")

            space_loglik, time_loglik = model(event_times, spatial_locations,
                                              input_mask, t0, t1)

            space_loglik = space_loglik.sum() / num_events
            time_loglik = time_loglik.sum() / num_events
            loglik = time_loglik + space_loglik

            space_loglik_meter.update(space_loglik.item())
            time_loglik_meter.update(time_loglik.item())

            loss = loglik.mul(-1.0).mean()
            loss.backward()

            # Set learning rate
            total_itrs = math.ceil(
                args.num_iterations /
                len(train_epoch_iter)) * len(train_epoch_iter)
            lr = learning_rate_schedule(itr, args.warmup_itrs, args.lr,
                                        total_itrs)
            set_learning_rate(optimizer, lr)

            grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
                model.parameters(), max_norm=args.gradclip).item()
            gradnorm_meter.update(grad_norm)

            optimizer.step()

            if rank == 0:
                if itr > 0.8 * args.num_iterations:
                    ema.apply()
                else:
                    ema.apply(decay=0.0)

            if rank == 0:
                tb_writer.add_scalar("train/lr", lr, itr)
                tb_writer.add_scalar("train/temporal_loss", time_loglik.item(),
                                     itr)
                tb_writer.add_scalar("train/spatial_loss", space_loglik.item(),
                                     itr)
                tb_writer.add_scalar("train/grad_norm", grad_norm, itr)

            if itr % args.logfreq == 0:
                elapsed_time = time.time() - start_time

                # Average NFE across devices.
                nfe = 0
                for m in model.modules():
                    if isinstance(m, TimeVariableCNF) or isinstance(
                            m, TimeVariableODE):
                        nfe += m.nfe
                nfe = torch.tensor(nfe).to(device)
                dist.all_reduce(nfe, op=dist.ReduceOp.SUM)
                nfe = nfe // world_size

                # Sum memory usage across devices.
                mem = torch.tensor(memory_usage_psutil()).float().to(device)
                dist.all_reduce(mem, op=dist.ReduceOp.SUM)

                if rank == 0:
                    logger.info(
                        f"Iter {itr} | Epoch {epoch} | LR {lr:.5f} | Time {elapsed_time:.1f}"
                        f" | Temporal {time_loglik_meter.val:.4f}({time_loglik_meter.avg:.4f})"
                        f" | Spatial {space_loglik_meter.val:.4f}({space_loglik_meter.avg:.4f})"
                        f" | GradNorm {gradnorm_meter.val:.2f}({gradnorm_meter.avg:.2f})"
                        f" | NFE {nfe.item()}"
                        f" | Mem {mem.item():.2f} MB")

                    tb_writer.add_scalar("train/nfe", nfe, itr)
                    tb_writer.add_scalar("train/time_per_itr",
                                         elapsed_time / args.logfreq, itr)

                start_time = time.time()

            if rank == 0 and itr % args.testfreq == 0:
                # ema.swap()
                val_space_loglik, val_time_loglik = validate(
                    model, val_loader, t0, t1, device)
                test_space_loglik, test_time_loglik = validate(
                    model, test_loader, t0, t1, device)
                # ema.swap()
                logger.info(
                    f"[Test] Iter {itr} | Val Temporal {val_time_loglik:.4f} | Val Spatial {val_space_loglik:.4f}"
                    f" | Test Temporal {test_time_loglik:.4f} | Test Spatial {test_space_loglik:.4f}"
                )

                tb_writer.add_scalar("val/temporal_loss", val_time_loglik, itr)
                tb_writer.add_scalar("val/spatial_loss", val_space_loglik, itr)

                tb_writer.add_scalar("test/temporal_loss", test_time_loglik,
                                     itr)
                tb_writer.add_scalar("test/spatial_loss", test_space_loglik,
                                     itr)

                torch.save(
                    {
                        "itr": itr,
                        "state_dict": model.module.state_dict(),
                        "optim_state_dict": optimizer.state_dict(),
                        "ema_parmas": ema.ema_params,
                    }, checkpt_path)

                start_time = time.time()

    if rank == 0:
        tb_writer.close()
示例#27
0
                                     epoch)

                loss.backward()
                optimizer.step()

                print("Batch: {}/{}".format(index, len(generator)))

    if epoch % 5 == 0:
        print("∞" * 20)
        print("TEST " * 20)
        for index, batch in enumerate(generator_test):
            model.eval()
            input = batch['input']
            input = input.to(device)

            output = model(input)

            diff = torch.abs(input - output)

            output = output.detach()

            writer.add_image('Test/input_images',
                             torchvision.utils.make_grid(input), epoch)
            writer.add_image('Test/reconstructed_images',
                             torchvision.utils.make_grid(output), epoch)
            writer.add_image('Test/differences',
                             torchvision.utils.make_grid(diff), epoch)
            break

writer.close()
示例#28
0
def main():
    trainset, validset, testset = [], [], []
    if args.inference:  # 测试时只载入测试集
        with open(args.testset_path, 'r', encoding='utf8') as fr:
            for line in fr:
                testset.append(json.loads(line))
        print(f'载入测试集{len(testset)}条')
    else:  # 训练时载入训练集和验证集
        with open(args.trainset_path, 'r', encoding='utf8') as fr:
            for line in fr:
                trainset.append(json.loads(line))
        print(f'载入训练集{len(trainset)}条')
        with open(args.validset_path, 'r', encoding='utf8') as fr:
            for line in fr:
                validset.append(json.loads(line))
        print(f'载入验证集{len(validset)}条')

    vocab, embeds = [], []
    with open(args.embed_path, 'r', encoding='utf8') as fr:
        for line in fr:
            line = line.strip()
            word = line[:line.find(' ')]
            vec = line[line.find(' ') + 1:].split()
            embed = [float(v) for v in vec]
            assert len(embed) == config.embedding_size  # 检测词向量维度
            vocab.append(word)
            embeds.append(embed)
    print(f'载入词汇表: {len(vocab)}个')
    print(f'词向量维度: {config.embedding_size}')

    vads = []
    with open(args.vad_path, 'r', encoding='utf8') as fr:
        for line in fr:
            line = line.strip()
            vad = line[line.find(' ') + 1:].split()
            vad = [float(item) for item in vad]
            assert len(vad) == config.affect_embedding_size
            vads.append(vad)
    print(f'载入vad字典: {len(vads)}个')
    print(f'vad维度: {config.affect_embedding_size}')

    # 通过词汇表构建一个word2index和index2word的工具
    sentence_processor = SentenceProcessor(vocab, config.pad_id,
                                           config.start_id, config.end_id,
                                           config.unk_id)

    model = Model(config)
    model.print_parameters()  # 输出模型参数个数
    epoch = 0  # 训练集迭代次数
    global_step = 0  # 参数更新次数

    # 载入模型
    if os.path.isfile(args.model_path):  # 如果载入模型的位置存在则载入模型
        epoch, global_step = model.load_model(args.model_path)
        model.affect_embedding.embedding.weight.requires_grad = False
        print('载入模型完成')
        log_dir = os.path.split(args.model_path)[0]
    elif args.inference:  # 如果载入模型的位置不存在,但是又要测试,这是没有意义的
        print('请测试一个训练过的模型!')
        return
    else:  # 如果载入模型的位置不存在,重新开始训练,则载入预训练的词向量
        model.embedding.embedding.weight = torch.nn.Parameter(
            torch.FloatTensor(embeds))
        model.affect_embedding.embedding.weight = torch.nn.Parameter(
            torch.tensor(vads).float())
        model.affect_embedding.embedding.weight.requires_grad = False
        print('初始化模型完成')
        log_dir = os.path.join(args.log_path, 'run' + str(int(time.time())))
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
    if args.gpu:
        model.to('cuda')  # 将模型参数转到gpu

    # 定义优化器参数
    optim = Optim(config.method, config.lr, config.lr_decay,
                  config.weight_decay, config.eps, config.max_grad_norm)
    optim.set_parameters(model.parameters())  # 给优化器设置参数
    optim.update_lr(epoch)  # 每个epoch更新学习率

    # 训练
    if not args.inference:
        summary_writer = SummaryWriter(os.path.join(
            log_dir, 'summary'))  # 创建tensorboard记录的文件夹
        dp_train = DataProcessor(trainset, config.batch_size,
                                 sentence_processor)  # 数据的迭代器
        dp_valid = DataProcessor(validset,
                                 config.batch_size,
                                 sentence_processor,
                                 shuffle=False)

        while epoch < args.max_epoch:  # 最大训练轮数
            model.train()  # 切换到训练模式
            for data in dp_train.get_batch_data():
                start_time = time.time()
                feed_data = prepare_feed_data(data)
                rl_loss, reward, nll_loss, ppl = train(model, feed_data)
                if args.reinforce:
                    rl_loss.mean().backward()
                else:
                    nll_loss.mean().backward()  # 反向传播
                optim.step()  # 更新参数
                optim.optimizer.zero_grad()  # 清空梯度
                use_time = time.time() - start_time

                global_step += 1  # 参数更新次数+1
                if global_step % args.print_per_step == 0:
                    print(
                        'epoch: {:d}, global_step: {:d}, lr: {:g},  reward: {:.2f}, nll_loss: {:.2f}, ppl: {:.2f},'
                        ' time: {:.2f}s/step'.format(epoch, global_step,
                                                     optim.lr,
                                                     reward.mean().item(),
                                                     nll_loss.mean().item(),
                                                     ppl.mean().exp().item(),
                                                     use_time))
                    summary_writer.add_scalar('train_reward',
                                              reward.mean().item(),
                                              global_step)
                    summary_writer.add_scalar('train_nll',
                                              nll_loss.mean().item(),
                                              global_step)
                    summary_writer.add_scalar('train_ppl',
                                              ppl.mean().exp().item(),
                                              global_step)
                    summary_writer.flush()  # 将缓冲区写入文件

                if global_step % args.log_per_step == 0:  # 保存模型
                    log_file = os.path.join(
                        log_dir,
                        '{:03d}{:012d}.model'.format(epoch, global_step))
                    model.save_model(epoch, global_step, log_file)
                    model.eval()
                    reward, nll_loss, ppl = valid(model, dp_valid)
                    model.train()
                    print(
                        '在验证集上的REWARD为: {:g}, NLL损失为: {:g}, PPL为: {:g}'.format(
                            reward, nll_loss, np.exp(ppl)))
                    summary_writer.add_scalar('valid_reward', reward,
                                              global_step)
                    summary_writer.add_scalar('valid_nll', nll_loss,
                                              global_step)
                    summary_writer.add_scalar('valid_ppl', np.exp(ppl),
                                              global_step)
                    summary_writer.flush()  # 将缓冲区写入文件

            epoch += 1  # 数据集迭代次数+1
            optim.update_lr(epoch)  # 调整学习率

            log_file = os.path.join(
                log_dir, '{:03d}{:012d}.model'.format(epoch, global_step))
            model.save_model(epoch, global_step, log_file)
            model.eval()
            reward, nll_loss, ppl = valid(model, dp_valid)
            print('在验证集上的REWARD为: {:g}, NLL损失为: {:g}, PPL为: {:g}'.format(
                reward, nll_loss, np.exp(ppl)))
            summary_writer.add_scalar('valid_reward', reward, global_step)
            summary_writer.add_scalar('valid_nll', nll_loss, global_step)
            summary_writer.add_scalar('valid_ppl', np.exp(ppl), global_step)
            summary_writer.flush()  # 将缓冲区写入文件

        summary_writer.close()
    else:  # 测试
        if not os.path.exists(args.result_path):  # 创建结果文件夹
            os.makedirs(args.result_path)

        result_file = os.path.join(args.result_path,
                                   '{:03d}{:012d}.txt'.format(
                                       epoch, global_step))  # 命名结果文件
        fw = open(result_file, 'w', encoding='utf8')
        dp_test = DataProcessor(testset,
                                config.batch_size,
                                sentence_processor,
                                shuffle=False)

        model.eval()
        reward, nll_loss, ppl = valid(model, dp_test)  # 评估困惑度
        print('在测试集上的REWARD为: {:g}, NLL损失为: {:g}, PPL为: {:g}'.format(
            reward, nll_loss, np.exp(ppl)))

        len_results = []  # 统计生成结果的总长度
        for data in dp_test.get_batch_data():
            posts = data['str_posts']
            responses = data['str_responses']
            feed_data = prepare_feed_data(data, inference=True)
            results = test(model, feed_data)  # 使用模型计算结果 [batch, len_decoder]

            for idx, result in enumerate(results):
                new_data = dict()
                new_data['post'] = posts[idx]
                new_data['response'] = responses[idx]
                new_data['result'] = sentence_processor.index2word(
                    result)  # 将输出的句子转回单词的形式
                len_results.append(len(new_data['result']))
                fw.write(json.dumps(new_data, ensure_ascii=False) + '\n')

        fw.close()
        print(f'生成句子平均长度: {1.0 * sum(len_results) / len(len_results)}')
示例#29
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()
        log_writer = open(os.path.join(args.output_dir, "evaluate_logs.txt"),
                          'w')

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  Logging steps = %d", args.logging_steps)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path) and False:
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_dataloader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_dataloader) // args.gradient_accumulation_steps)

        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from epoch %d", epochs_trained)
        logger.info("  Continuing training from global step %d", global_step)
        logger.info("  Will skip the first %d steps in the first epoch",
                    steps_trained_in_current_epoch)

    tr_loss, logging_loss, best_avg = 0.0, 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(epochs_trained,
                            int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility

    def logging():
        if args.evaluate_during_training:
            results = evaluate(args, model, tokenizer, single_gpu=True)
            for task, result in results.items():
                for key, value in result.items():
                    tb_writer.add_scalar("eval_{}_{}".format(task, key), value,
                                         global_step)
            log_writer.write("{0}\t{1}\n".format(global_step,
                                                 json.dumps(results)))
            log_writer.flush()
        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
        tb_writer.add_scalar("loss",
                             (tr_loss - logging_loss) / args.logging_steps,
                             global_step)
        return results

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "labels": batch[3]
            }
            if args.model_type != "distilbert":
                inputs["token_type_ids"] = (
                    batch[2] if args.model_type in ["bert"] else None
                )  # XLM and DistilBERT don't use segment_ids
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    cur_result = logging()
                    logging_loss = tr_loss

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.local_rank in [-1, 0] and args.logging_each_epoch:

            cur_result = logging()

            logging_loss = tr_loss
            task_metric = "acc"
            if args.task_name == "rel":
                task_metric = "ndcg"
            if best_avg < cur_result["valid_avg"][task_metric]:
                best_avg = cur_result["valid_avg"][task_metric]
                output_dir = os.path.join(args.output_dir, "checkpoint-best")
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training

                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()
        log_writer.close()

    return global_step, tr_loss / (global_step + 1)
示例#30
0
def train_model(train_data_words, test_data_words, model, epochs=30):
    log_file = os.path.join(
        LOGS_DIR, f'{model.__class__.__name__}.{str(train_data_words)}')
    checkpoint_file = f'{CHECKPOINT_PREFIX}.{model.__class__.__name__}.{str(train_data_words)}'

    model = model.cuda()

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

    if os.path.exists(checkpoint_file):
        print('Loading checkpoint')
        epoch, best_score, vocabulary = load_train_state(
            checkpoint_file, model, optimizer, scheduler)
    else:
        epoch = 0
        best_score = -1
        vocabulary = create_vocabulary(train_data_words,
                                       vocabulary_size=VOCABULARY_SIZE)

    best_model = copy.deepcopy(model)

    train_data = WordIndexDataset(train_data_words,
                                  vocabulary,
                                  max_words=MAX_MESSAGE_LENGTH_WORDS)
    test_data = WordIndexDataset(test_data_words,
                                 vocabulary,
                                 max_words=MAX_MESSAGE_LENGTH_WORDS)
    train_loader = DataLoader(train_data,
                              batch_size=TRAIN_BATCH_SIZE,
                              shuffle=True,
                              num_workers=2,
                              collate_fn=IndexVectorCollator())
    test_loader = DataLoader(test_data,
                             batch_size=TEST_BATCH_SIZE,
                             shuffle=True,
                             num_workers=2,
                             collate_fn=IndexVectorCollator())

    writer = SummaryWriter(log_file, purge_step=epoch, flush_secs=60)

    sample_input, sample_lens, _ = next(iter(train_loader))
    summary(model=model,
            input_data=sample_input.cuda(),
            lens=sample_lens,
            device=torch.device('cuda'))

    print("Learning started")

    while epoch < epochs:
        epoch += 1
        print(f"Epoch: {epoch}")
        epoch_losses = []
        epoch_accuracy = []
        model.train()

        loss_fn = FocalLoss(alpha=0.5, gamma=2)

        for step, (x, x_len, y) in enumerate(train_loader):
            x, y = x.cuda(), y.cuda()
            y_pred = model(x, x_len)
            loss_val = loss_fn(y_pred, y)
            accuracy = torch.argmax(y_pred, 1).eq(y).sum().item() / y.shape[0]

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            epoch_losses.append(loss_val.item())
            epoch_accuracy.append(accuracy)
            print('    Batch {} of {} loss: {}, accuracy: {}, lr: {}'.format(
                step + 1, len(train_loader), loss_val.item(), accuracy,
                optimizer.param_groups[0]["lr"]),
                  file=sys.stderr)
        print(
            f'Train loss: {np.mean(epoch_losses):.4f}, accuracy: {np.mean(epoch_accuracy):.4f}'
        )
        writer.add_scalar('Loss/train',
                          np.mean(epoch_losses),
                          global_step=epoch)
        writer.add_scalar('Accuracy/train',
                          np.mean(epoch_accuracy),
                          global_step=epoch)
        writer.add_scalar('LearningRate',
                          optimizer.param_groups[0]["lr"],
                          global_step=epoch)

        score = evaluate(model,
                         test_loader,
                         loss_fn,
                         writer=writer,
                         epoch=epoch)
        if score > best_score:
            best_model = copy.deepcopy(model)
            best_score = score
            print('New best score')
            save_train_state(epoch, model, optimizer, scheduler, best_score,
                             vocabulary, checkpoint_file)
        scheduler.step()
    if best_score < 0:
        best_score = evaluate(model, test_loader, writer=writer)

    writer.close()
    save_file_path = os.path.join(
        SAVED_MODELS_PATH,
        '{}.{}.{}.{:.2f}.pck'.format(model.__class__.__name__,
                                     str(train_data_words),
                                     datetime.datetime.now().isoformat(),
                                     best_score))
    log_file_path = os.path.join(
        LOGS_DIR, '{}.{}.{}.{:.2f}'.format(model.__class__.__name__,
                                           str(train_data_words),
                                           datetime.datetime.now().isoformat(),
                                           best_score))
    os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
    shutil.move(checkpoint_file, save_file_path)
    shutil.move(log_file, log_file_path)

    return best_model, best_score