示例#1
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger("{}/train.log".format(self.model_dir))
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.loss = XentLoss(pad_index=self.pad_index,
                             smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens", "none"]:
            raise ConfigurationError("Invalid normalization option."
                                     "Valid options: "
                                     "'batch', 'tokens', 'none'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in [
                'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy'
        ]:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf', "
                                     "'token_accuracy', 'sequence_accuracy'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric.
        # If we schedule after BLEU/chrf/accuracy, we want to maximize the
        # score, else we want to minimize it.
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in [
                    "bleu", "chrf", "token_accuracy", "sequence_accuracy"
            ]:
                self.minimize_metric = False
            # eval metric that has to get minimized (not yet implemented)
            else:
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)
        self.current_batch_multiplier = self.batch_multiplier

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize accumalted batch loss (needed for batch_multiplier)
        self.norm_batch_loss_accumulated = 0
        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf
        # comparison function for scores
        self.is_best = lambda score: score < self.best_ckpt_score \
            if self.minimize_metric else score > self.best_ckpt_score

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            reset_best_ckpt = train_config.get("reset_best_ckpt", False)
            reset_scheduler = train_config.get("reset_scheduler", False)
            reset_optimizer = train_config.get("reset_optimizer", False)
            self.init_from_checkpoint(model_load_path,
                                      reset_best_ckpt=reset_best_ckpt,
                                      reset_scheduler=reset_scheduler,
                                      reset_optimizer=reset_optimizer)
示例#2
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = make_model_dir(train_config["model_dir"],
                                        overwrite=train_config.get(
                                            "overwrite", False))
        self.logger = make_logger(model_dir=self.model_dir)
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = join(self.model_dir, "validations.txt")
        self.tb_writer = SummaryWriter(
            log_dir=join(self.model_dir, "tensorboard/")
        )
        self.log_sparsity = train_config.get("log_sparsity", False)

        self.apply_mask = train_config.get("apply_mask", False)
        self.valid_apply_mask = train_config.get("valid_apply_mask", True)

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self.bos_index = self.model.bos_index
        self._log_parameters_list()

        # objective
        objective = train_config.get("loss", "cross_entropy")
        loss_alpha = train_config.get("loss_alpha", 1.5)
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        if self.label_smoothing > 0 and objective == "cross_entropy":
            xent_loss = partial(
                LabelSmoothingLoss, smoothing=self.label_smoothing)
        else:
            xent_loss = nn.CrossEntropyLoss

        assert loss_alpha >= 1
        entmax_loss = partial(
            EntmaxBisectLoss, alpha=loss_alpha, n_iter=30
        )

        loss_funcs = {"cross_entropy": xent_loss,
                      "entmax15": partial(Entmax15Loss, k=512),
                      "sparsemax": partial(SparsemaxLoss, k=512),
                      "entmax": entmax_loss}
        if objective not in loss_funcs:
            raise ConfigurationError("Unknown loss function")
        loss_func = loss_funcs[objective]
        self.loss = loss_func(ignore_index=self.pad_index, reduction='sum')

        if "language_loss" in train_config:
            assert "language_weight" in train_config
            self.language_loss = loss_func(
                ignore_index=self.pad_index, reduction='sum'
            )
            self.language_weight = train_config["language_weight"]
        else:
            self.language_loss = None
            self.language_weight = 0.0

        self.norm_type = train_config.get("normalization", "batch")
        if self.norm_type not in ["batch", "tokens"]:
            raise ConfigurationError("Invalid normalization. "
                                     "Valid options: 'batch', 'tokens'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(
            config=train_config, parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.plot_attention = train_config.get("plot_attention", False)
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))

        allowed = {'bleu', 'chrf', 'token_accuracy',
                   'sequence_accuracy', 'cer', 'wer'}
        eval_metrics = train_config.get("eval_metric", "bleu")
        if isinstance(eval_metrics, str):
            eval_metrics = [eval_metrics]
        if any(metric not in allowed for metric in eval_metrics):
            ok_metrics = " ".join(allowed)
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: {}".format(ok_metrics))
        self.eval_metrics = eval_metrics

        early_stop_metric = train_config.get("early_stopping_metric", "loss")
        allowed_early_stop = {"ppl", "loss"} | set(self.eval_metrics)
        if early_stop_metric not in allowed_early_stop:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', and eval_metrics.")
        self.early_stopping_metric = early_stop_metric
        self.minimize_metric = early_stop_metric in {"ppl", "loss",
                                                     "cer", "wer"}

        attn_metrics = train_config.get("attn_metric", [])
        if isinstance(attn_metrics, str):
            attn_metrics = [attn_metrics]
        ok_attn_metrics = {"support"}
        assert all(met in ok_attn_metrics for met in attn_metrics)
        self.attn_metrics = attn_metrics

        # learning rate scheduling
        if "encoder" in config["model"]:
            hidden_size = config["model"]["encoder"]["hidden_size"]
        else:
            hidden_size = config["model"]["encoders"]["src"]["hidden_size"]
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=hidden_size)

        # data & batch handling
        data_cfg = config["data"]
        self.src_level = data_cfg.get(
            "src_level", data_cfg.get("level", "word")
        )
        self.trg_level = data_cfg.get(
            "trg_level", data_cfg.get("level", "word")
        )
        levels = ["word", "bpe", "char"]
        if self.src_level not in levels or self.trg_level not in levels:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")

        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            self.logger.info("Loading model from %s", model_load_path)
            restart_training = train_config.get("restart_training", False)
            self.init_from_checkpoint(model_load_path, restart_training)
示例#3
0
    def __init__(self,
                 model: Model,
                 config: dict,
                 batch_class: Batch = Batch) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        :param batch_class: batch class to encapsulate the torch class
        """
        train_config = config["training"]
        self.batch_class = batch_class

        # files for logging and storing
        self.model_dir = train_config["model_dir"]
        assert os.path.exists(self.model_dir)

        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = "{}/validations.txt".format(self.model_dir)
        self.tb_writer = SummaryWriter(log_dir=self.model_dir +
                                       "/tensorboard/")

        self.save_latest_checkpoint = train_config.get("save_latest_ckpt",
                                                       True)

        # model
        self.model = model
        self._log_parameters_list()

        # objective
        self.label_smoothing = train_config.get("label_smoothing", 0.0)
        self.model.loss_function = XentLoss(pad_index=self.model.pad_index,
                                            smoothing=self.label_smoothing)
        self.normalization = train_config.get("normalization", "batch")
        if self.normalization not in ["batch", "tokens", "none"]:
            raise ConfigurationError("Invalid normalization option."
                                     "Valid options: "
                                     "'batch', 'tokens', 'none'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(config=train_config,
                                         parameters=model.parameters())

        # validation & early stopping
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.ckpt_queue = collections.deque(
            maxlen=train_config.get("keep_last_ckpts", 5))
        self.eval_metric = train_config.get("eval_metric", "bleu")
        if self.eval_metric not in [
                'bleu', 'chrf', 'token_accuracy', 'sequence_accuracy'
        ]:
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: 'bleu', 'chrf', "
                                     "'token_accuracy', 'sequence_accuracy'.")
        self.early_stopping_metric = train_config.get("early_stopping_metric",
                                                      "eval_metric")

        # early_stopping_metric decides on how to find the early stopping point:
        # ckpts are written when there's a new high/low score for this metric.
        # If we schedule after BLEU/chrf/accuracy, we want to maximize the
        # score, else we want to minimize it.
        if self.early_stopping_metric in ["ppl", "loss"]:
            self.minimize_metric = True
        elif self.early_stopping_metric == "eval_metric":
            if self.eval_metric in [
                    "bleu", "chrf", "token_accuracy", "sequence_accuracy"
            ]:
                self.minimize_metric = False
            # eval metric that has to get minimized (not yet implemented)
            else:
                self.minimize_metric = True
        else:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', 'eval_metric'.")

        # eval options
        test_config = config["testing"]
        self.bpe_type = test_config.get("bpe_type", "subword-nmt")
        self.sacrebleu = {"remove_whitespace": True, "tokenize": "13a"}
        if "sacrebleu" in config["testing"].keys():
            self.sacrebleu["remove_whitespace"] = test_config["sacrebleu"] \
                .get("remove_whitespace", True)
            self.sacrebleu["tokenize"] = test_config["sacrebleu"] \
                .get("tokenize", "13a")

        # learning rate scheduling
        self.scheduler, self.scheduler_step_at = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=config["model"]["encoder"]["hidden_size"])

        # data & batch handling
        self.level = config["data"]["level"]
        if self.level not in ["word", "bpe", "char"]:
            raise ConfigurationError("Invalid segmentation level. "
                                     "Valid options: 'word', 'bpe', 'char'.")
        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        # Placeholder so that we can use the train_iter in other functions.
        self.train_iter = None
        self.train_iter_state = None
        # per-device batch_size = self.batch_size // self.n_gpu
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        # per-device eval_batch_size = self.eval_batch_size // self.n_gpu
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"] and torch.cuda.is_available()
        self.n_gpu = torch.cuda.device_count() if self.use_cuda else 0
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        if self.use_cuda:
            self.model.to(self.device)

        # fp16
        self.fp16 = train_config.get("fp16", False)
        if self.fp16:
            if 'apex' not in sys.modules:
                raise ImportError("Please install apex from "
                                  "https://www.github.com/nvidia/apex "
                                  "to use fp16 training.") from no_apex
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level='O1')
            # opt level: one of {"O0", "O1", "O2", "O3"}
            # see https://nvidia.github.io/apex/amp.html#opt-levels

        # initialize training statistics
        self.stats = self.TrainStatistics(
            steps=0,
            stop=False,
            total_tokens=0,
            best_ckpt_iter=0,
            best_ckpt_score=np.inf if self.minimize_metric else -np.inf,
            minimize_metric=self.minimize_metric)

        # model parameters
        if "load_model" in train_config.keys():
            self.init_from_checkpoint(
                train_config["load_model"],
                reset_best_ckpt=train_config.get("reset_best_ckpt", False),
                reset_scheduler=train_config.get("reset_scheduler", False),
                reset_optimizer=train_config.get("reset_optimizer", False),
                reset_iter_state=train_config.get("reset_iter_state", False))

        # multi-gpu training (should be after apex fp16 initialization)
        if self.n_gpu > 1:
            self.model = _DataParallel(self.model)
示例#4
0
    def __init__(self, model: Model, config: dict) -> None:
        """
        Creates a new TrainManager for a model, specified as in configuration.

        :param model: torch module defining the model
        :param config: dictionary containing the training configurations
        """
        train_config = config["training"]

        # files for logging and storing
        self.model_dir = train_config["model_dir"]
        make_model_dir(
            self.model_dir, overwrite=train_config.get("overwrite", False)
        )
        self.logger = make_logger(model_dir=self.model_dir)
        self.logging_freq = train_config.get("logging_freq", 100)
        self.valid_report_file = join(self.model_dir, "validations.txt")
        self.tb_writer = SummaryWriter(
            log_dir=join(self.model_dir, "tensorboard/")
        )

        # model
        self.model = model
        self.pad_index = self.model.pad_index
        self._log_parameters_list()

        # objective
        objective = train_config.get("loss", "cross_entropy")
        loss_alpha = train_config.get("loss_alpha", 1.5)

        assert loss_alpha >= 1
        # maybe don't do the label smoothing thing here, instead have
        # nn.CrossEntropyLoss
        # then you look up the loss func, and you either use it directly or
        # wrap it in FYLabelSmoothingLoss
        if objective == "softmax":
            objective = "cross_entropy"
        loss_funcs = {
            "cross_entropy": nn.CrossEntropyLoss,
            "entmax15": partial(Entmax15Loss, k=512),
            "sparsemax": partial(SparsemaxLoss, k=512),
            "entmax": partial(EntmaxBisectLoss, alpha=loss_alpha, n_iter=30)
        }
        if objective not in loss_funcs:
            raise ConfigurationError("Unknown loss function")

        loss_module = loss_funcs[objective]
        loss_func = loss_module(ignore_index=self.pad_index, reduction='sum')

        label_smoothing = train_config.get("label_smoothing", 0.0)
        label_smoothing_type = train_config.get("label_smoothing_type", "fy")
        assert label_smoothing_type in ["fy", "szegedy"]
        smooth_dist = train_config.get("smoothing_distribution", "uniform")
        assert smooth_dist in ["uniform", "unigram"]
        if label_smoothing > 0:
            if label_smoothing_type == "fy":
                # label smoothing entmax loss
                if smooth_dist is not None:
                    smooth_p = torch.FloatTensor(model.trg_vocab.frequencies)
                    smooth_p /= smooth_p.sum()
                else:
                    smooth_p = None
                loss_func = FYLabelSmoothingLoss(
                    loss_func, smoothing=label_smoothing, smooth_p=smooth_p
                )
            else:
                assert objective == "cross_entropy"
                loss_func = LabelSmoothingLoss(
                    ignore_index=self.pad_index,
                    reduction="sum",
                    smoothing=label_smoothing
                )
        self.loss = loss_func

        self.norm_type = train_config.get("normalization", "batch")
        if self.norm_type not in ["batch", "tokens"]:
            raise ConfigurationError("Invalid normalization. "
                                     "Valid options: 'batch', 'tokens'.")

        # optimization
        self.learning_rate_min = train_config.get("learning_rate_min", 1.0e-8)

        self.clip_grad_fun = build_gradient_clipper(config=train_config)
        self.optimizer = build_optimizer(
            config=train_config, parameters=model.parameters())

        # validation & early stopping
        self.validate_by_label = train_config.get("validate_by_label", False)
        self.validation_freq = train_config.get("validation_freq", 1000)
        self.log_valid_sents = train_config.get("print_valid_sents", [0, 1, 2])
        self.plot_attention = train_config.get("plot_attention", False)
        self.ckpt_queue = queue.Queue(
            maxsize=train_config.get("keep_last_ckpts", 5))

        allowed = {'bleu', 'chrf', 'token_accuracy',
                   'sequence_accuracy', 'cer', "wer", "levenshtein_distance"}
        eval_metrics = train_config.get("eval_metric", "bleu")
        if isinstance(eval_metrics, str):
            eval_metrics = [eval_metrics]
        if any(metric not in allowed for metric in eval_metrics):
            ok_metrics = " ".join(allowed)
            raise ConfigurationError("Invalid setting for 'eval_metric', "
                                     "valid options: {}".format(ok_metrics))
        self.eval_metrics = eval_metrics
        self.forced_sparsity = train_config.get("forced_sparsity", False)

        early_stop_metric = train_config.get("early_stopping_metric", "loss")
        allowed_early_stop = {"ppl", "loss"} | set(self.eval_metrics)
        if early_stop_metric not in allowed_early_stop:
            raise ConfigurationError(
                "Invalid setting for 'early_stopping_metric', "
                "valid options: 'loss', 'ppl', and eval_metrics.")
        self.early_stopping_metric = early_stop_metric
        min_metrics = {"ppl", "loss", "cer", "wer", "levenshtein_distance"}
        self.minimize_metric = early_stop_metric in min_metrics

        # learning rate scheduling
        hidden_size = _parse_hidden_size(config["model"])
        self.scheduler, self.sched_incr = build_scheduler(
            config=train_config,
            scheduler_mode="min" if self.minimize_metric else "max",
            optimizer=self.optimizer,
            hidden_size=hidden_size)

        # data & batch handling
        # src/trg magic
        if "level" in config["data"]:
            self.src_level = self.trg_level = config["data"]["level"]
        else:
            assert "src_level" in config["data"]
            assert "trg_level" in config["data"]
            self.src_level = config["data"]["src_level"]
            self.trg_level = config["data"]["trg_level"]

        self.shuffle = train_config.get("shuffle", True)
        self.epochs = train_config["epochs"]
        self.batch_size = train_config["batch_size"]
        self.batch_type = train_config.get("batch_type", "sentence")
        self.eval_batch_size = train_config.get("eval_batch_size",
                                                self.batch_size)
        self.eval_batch_type = train_config.get("eval_batch_type",
                                                self.batch_type)

        self.batch_multiplier = train_config.get("batch_multiplier", 1)

        # generation
        self.max_output_length = train_config.get("max_output_length", None)

        # CPU / GPU
        self.use_cuda = train_config["use_cuda"]
        if self.use_cuda:
            self.model.cuda()
            self.loss.cuda()

        # initialize training statistics
        self.steps = 0
        # stop training if this flag is True by reaching learning rate minimum
        self.stop = False
        self.total_tokens = 0
        self.best_ckpt_iteration = 0
        # initial values for best scores
        self.best_ckpt_score = np.inf if self.minimize_metric else -np.inf

        mrt_schedule = train_config.get("mrt_schedule", None)
        assert mrt_schedule is None or mrt_schedule in ["warmup", "mix", "mtl"]
        self.mrt_schedule = mrt_schedule
        self.mrt_p = train_config.get("mrt_p", 0.0)
        self.mrt_lambda = train_config.get("mrt_lambda", 1.0)
        assert 0 <= self.mrt_p <= 1
        assert 0 <= self.mrt_lambda <= 1
        self.mrt_start_steps = train_config.get("mrt_start_steps", 0)
        self.mrt_samples = train_config.get("mrt_samples", 1)
        self.mrt_alpha = train_config.get("mrt_alpha", 1.0)
        self.mrt_strategy = train_config.get("mrt_strategy", "sample")
        self.mrt_cost = train_config.get("mrt_cost", "levenshtein")
        self.mrt_max_len = train_config.get("mrt_max_len", 31)  # hmm
        self.step_counter = count()

        assert self.mrt_alpha > 0
        assert self.mrt_strategy in ["sample", "topk"]
        assert self.mrt_cost in ["levenshtein", "bleu"]

        # model parameters
        if "load_model" in train_config.keys():
            model_load_path = train_config["load_model"]
            reset_training = train_config.get("reset_training", False)
            self.logger.info("Loading model from %s", model_load_path)
            self.init_from_checkpoint(model_load_path, reset=reset_training)