Example #1
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        lm = mnist.LightningMNISTClassifier(
            lr=context.get_hparam('learning_rate'))
        data_dir = f"/tmp/data-rank{context.distributed.get_rank()}"
        self.dm = mnist.MNISTDataModule(context.get_data_config()["url"],
                                        data_dir)

        super().__init__(context, lightning_module=lm)
        self.dm.prepare_data()
Example #2
0
    def __init__(self, context: det_torch.PyTorchTrialContext) -> None:
        self.logger = logging.getLogger(__name__)
        self.hparams = attrdict.AttrDict(context.get_hparams())
        self.data_config = attrdict.AttrDict(context.get_data_config())
        self.context = context

        # Load dataset and get metadata.
        # This needs to be done before we initialize the HF config, tokenizer, and model
        # because we need to know num_labels before doing so.
        if self.data_config.train_language is None:
            train_dataset = datasets.load_dataset("xnli",
                                                  self.data_config.language,
                                                  split="train")
        else:
            train_dataset = datasets.load_dataset(
                "xnli", self.data_config.train_language, split="train")
        eval_dataset = datasets.load_dataset("xnli",
                                             self.data_config.language,
                                             split="validation")

        self.raw_datasets = {
            "train": train_dataset,
            "validation": eval_dataset
        }
        label_list = train_dataset.features["label"].names
        self.hparams.num_labels = len(label_list)

        super(XNLITrial, self).__init__(context)
        self.logger.info(self.config)

        # We need to create the tokenized dataset after init because we need to model and
        # tokenizer to be available.
        self.tokenized_datasets = self.build_datasets()
        train_length = len(self.tokenized_datasets["train"])
        self.logger.info("training records: {}".format(train_length))
        if ("records_per_epoch" in self.exp_config
                and train_length != self.exp_config["records_per_epoch"]):
            self.logger.warning(
                "number of train records {} does not match records_per_epoch of {}"
                .format(train_length, self.exp_config["records_per_epoch"]))

        # Create metric reducer
        metric = datasets.load_metric("xnli", timeout=200)

        def compute_metrics(pred_labels) -> Dict:
            preds, labels = zip(*pred_labels)
            preds = utils.expand_like(preds)
            labels = utils.expand_like(labels)
            preds = np.argmax(preds, axis=1)
            return metric.compute(predictions=preds, references=labels)

        self.reducer = context.experimental.wrap_reducer(compute_metrics,
                                                         for_training=False)
Example #3
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        self.context = context
        self.data_config = context.get_data_config()
        self.hparams = AttrDict(context.get_hparams())

        # Create a unique download directory for each rank so they don't overwrite each
        # other when doing distributed training.
        self.download_directory = self.data_config["data_download_dir"]
        data.download_data(self.download_directory)
        corpus = data_util.Corpus(self.download_directory)
        self.corpus = corpus
        self.ntokens = len(corpus.dictionary)
        self.hidden = None

        # This is used to store eval history and will switch to ASGD
        # once validation perplexity stops improving.
        self._last_loss = None
        self._eval_history = []
        self._last_epoch = -1

        # Define the model
        genotype = self.get_genotype_from_hps()
        self.model = self.context.wrap_model(
            RNNModel(
                self.ntokens,
                self.hparams.emsize,
                self.hparams.nhid,
                self.hparams.nhidlast,
                self.hparams.dropout,
                self.hparams.dropouth,
                self.hparams.dropoutx,
                self.hparams.dropouti,
                self.hparams.dropoute,
                genotype=genotype,
            ))
        total_params = sum(x.data.nelement() for x in self.model.parameters())
        logging.info("Model total parameters: {}".format(total_params))

        # Define the optimizer
        self._optimizer = self.context.wrap_optimizer(
            HybridSGD(
                self.model.parameters(),
                self.hparams.learning_rate,
                self.hparams.weight_decay,
                lambd=0,
                t0=0,
            ))

        # Define the LR scheduler
        self.myLR = MyLR(self._optimizer, self.hparams)
        step_mode = LRScheduler.StepMode.MANUAL_STEP
        self.wrapped_LR = self.context.wrap_lr_scheduler(self.myLR,
                                                         step_mode=step_mode)
Example #4
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        self.context = context
        self.data_config = context.get_data_config()
        self.num_classes = {
            "train": context.get_hparam("num_classes_train"),
            "val": context.get_hparam("num_classes_val"),
        }
        self.num_support = {
            "train": context.get_hparam("num_support_train"),
            "val": context.get_hparam("num_support_val"),
        }
        self.num_query = {
            "train": context.get_hparam("num_query_train"),
            "val":
            None,  # Use all available examples for val at meta-test time
        }
        self.get_train_valid_splits()

        x_dim = 1  # Omniglot is black and white
        hid_dim = self.context.get_hparam("hidden_dim")
        z_dim = self.context.get_hparam("embedding_dim")

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )

        self.model = self.context.wrap_model(
            nn.Sequential(
                conv_block(x_dim, hid_dim),
                conv_block(hid_dim, hid_dim),
                conv_block(hid_dim, hid_dim),
                conv_block(hid_dim, z_dim),
                Flatten(),
            ))

        self.optimizer = self.context.wrap_optimizer(
            torch.optim.Adam(
                self.model.parameters(),
                lr=self.context.get_hparam("learning_rate"),
                weight_decay=self.context.get_hparam("weight_decay"),
            ))

        self.lr_scheduler = self.context.wrap_lr_scheduler(
            torch.optim.lr_scheduler.StepLR(
                self.optimizer,
                self.context.get_hparam("reduce_every"),
                gamma=self.context.get_hparam("lr_gamma"),
            ), LRScheduler.StepMode.STEP_EVERY_EPOCH)
Example #5
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        self.context = context
        self.data_config = context.get_data_config()
        self.hparams = context.get_hparams()
        self.criterion = torch.nn.functional.cross_entropy
        # The last epoch is only used for logging.
        self._last_epoch = -1
        self.results = {
            "loss": float("inf"),
            "top1_accuracy": 0,
            "top5_accuracy": 0
        }

        # Define the model
        genotype = self.get_genotype_from_hps()
        self.model = self.context.wrap_model(
            Network(
                self.hparams["init_channels"],
                10,  # num_classes
                self.hparams["layers"],
                self.hparams["auxiliary"],
                genotype,
            ))
        print("param size = {} MB".format(
            utils.count_parameters_in_MB(self.model)))
        size = 0
        for p in self.model.parameters():
            size += p.nelement()
        print("param count: {}".format(size))

        # Apply constraints if desired
        if "use_constraints" in self.hparams and self.hparams[
                "use_constraints"]:
            apply_constraints(self.hparams, size)

        # Define the optimizer
        self.optimizer = self.context.wrap_optimizer(
            torch.optim.SGD(
                self.model.parameters(),
                lr=self.context.get_hparam("learning_rate"),
                momentum=self.context.get_hparam("momentum"),
                weight_decay=self.context.get_hparam("weight_decay"),
            ))

        # Define the LR scheduler
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer,
            self.context.get_hparam("train_epochs"),
        )
        step_mode = LRScheduler.StepMode.STEP_EVERY_EPOCH
        self.wrapped_scheduler = self.context.wrap_lr_scheduler(
            self.scheduler, step_mode=step_mode)
Example #6
0
    def __init__(self, context: det_torch.PyTorchTrialContext) -> None:
        self.context = context
        # A subclass of BaseTransformerTrial may have already set hparams and data_config
        # attributes so we only reset them if they do not exist.
        if not hasattr(self, "hparams"):
            self.hparams = attrdict.AttrDict(context.get_hparams())
        if not hasattr(self, "data_config"):
            self.data_config = attrdict.AttrDict(context.get_data_config())
        if not hasattr(self, "exp_config"):
            self.exp_config = attrdict.AttrDict(
                context.get_experiment_config())
        # Check to make sure all expected hyperparameters are set.
        self.check_hparams()

        # Parse hparams and data_config.
        (
            self.config_kwargs,
            self.tokenizer_kwargs,
            self.model_kwargs,
        ) = hf_parse.default_parse_config_tokenizer_model_kwargs(self.hparams)
        optimizer_kwargs, scheduler_kwargs = hf_parse.default_parse_optimizer_lr_scheduler_kwargs(
            self.hparams)

        self.config, self.tokenizer, self.model = build_using_auto(
            self.config_kwargs,
            self.tokenizer_kwargs,
            self.hparams.model_mode,
            self.model_kwargs,
            use_pretrained_weights=self.hparams.use_pretrained_weights,
        )
        self.model = self.context.wrap_model(self.model)

        self.optimizer = self.context.wrap_optimizer(
            build_default_optimizer(self.model, optimizer_kwargs))

        if self.hparams.use_apex_amp:
            self.model, self.optimizer = self.context.configure_apex_amp(
                models=self.model,
                optimizers=self.optimizer,
            )

        self.lr_scheduler = self.context.wrap_lr_scheduler(
            build_default_lr_scheduler(self.optimizer, scheduler_kwargs),
            det_torch.LRScheduler.StepMode.STEP_EVERY_BATCH,
        )

        self.grad_clip_fn = None

        if optimizer_kwargs.max_grad_norm > 0:  # type: ignore
            self.grad_clip_fn = lambda x: torch.nn.utils.clip_grad_norm_(
                x, optimizer_kwargs.max_grad_norm)
Example #7
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        data_dir = f'/tmp/data-rank{context.distributed.get_rank()}'
        self.dm = gan.MNISTDataModule(context.get_data_config()['url'], data_dir,
                                      batch_size=context.get_per_slot_batch_size())
        channels, width, height = self.dm.size()
        lm = gan.GAN(channels, width, height,
                    batch_size=context.get_per_slot_batch_size(),
                    lr=context.get_hparam('lr'),
                    b1=context.get_hparam('b1'),
                    b2=context.get_hparam('b2'),
        )

        super().__init__(context, lightning_module=lm)
        self.dm.prepare_data()
Example #8
0
    def __init__(self, context: PyTorchTrialContext, *args, **kwargs) -> None:
        lm = mnist.LitMNIST(
            hidden_size=context.get_hparam('hidden_size'),
            learning_rate=context.get_hparam('learning_rate'),
        )
        data_dir = f"/tmp/data-rank{context.distributed.get_rank()}"
        self.dm = data.MNISTDataModule(
            data_url=context.get_data_config()["url"],
            data_dir=data_dir,
            batch_size=context.get_per_slot_batch_size(),
        )

        super().__init__(context, lightning_module=lm, *args, **kwargs)
        self.dm.prepare_data()
    def __init__(self, context: det_torch.PyTorchTrialContext) -> None:
        self.context = context
        self.hparams = attrdict.AttrDict(context.get_hparams())
        self.data_config = attrdict.AttrDict(context.get_data_config())
        self.cfg = self.build_mmdet_config()
        # We will control how data is moved to GPU.
        self.context.experimental.disable_auto_to_device()

        # Build model and make sure it's compatible with horovod.
        self.model = mmdet.models.build_detector(self.cfg.model)

        # Initialize model
        self.model.init_weights()

        # If use_pretrained, try loading pretrained weights for the mmcv config if available.
        if self.hparams.use_pretrained:
            ckpt_path, ckpt = utils.get_pretrained_ckpt_path("/tmp", self.hparams.config_file)
            if ckpt_path is not None:
                logging.info("Loading from pretrained weights.")
                if "state_dict" in ckpt:
                    self.model.load_state_dict(ckpt["state_dict"])
                else:
                    self.model.load_state_dict(ckpt)

        # If fp16 is specified in the mmdet config, we will use torch native amp.
        fp16_cfg = self.cfg.get("fp16", None)
        if fp16_cfg is not None:
            self.setup_torch_amp(fp16_cfg)

        self.model = self.context.wrap_model(self.model)

        self.optimizer = self.context.wrap_optimizer(
            mmcv.runner.build_optimizer(self.model, self.cfg.optimizer)
        )
        self.model.zero_grad()

        self.clip_grads_fn = None
        if self.cfg.optimizer_config.grad_clip is not None:
            self.clip_grads_fn = lambda x: torch.nn.utils.clip_grad_norm_(
                x,
                self.cfg.optimizer_config.grad_clip.max_norm,
                self.cfg.optimizer_config.grad_clip.norm_type,
            )

        # mmdet sets loggers in the package that interrupt with Determined logging.
        # We reset the root logger after mmdet models are initialized.
        set_logger(bool(self.context.env.experiment_config.get("debug", False)))
Example #10
0
    def __init__(self, trial_context: PyTorchTrialContext) -> None:
        self.context = trial_context
        self.data_config = trial_context.get_data_config()
        self.hparams = AttrDict(trial_context.get_hparams())
        self.last_epoch = 0

        self.data_dir = os.path.join(
            self.data_config["download_dir"],
            f"data-rank{self.context.distributed.get_rank()}",
        )

        # Initialize the models.
        criterion = nn.CrossEntropyLoss()
        self.model = self.context.wrap_model(
            Network(
                self.hparams.init_channels,
                self.hparams.n_classes,
                self.hparams.layers,
                criterion,
                self.hparams.nodes,
                k=self.hparams.shuffle_factor,
            ))

        # Initialize the optimizers and learning rate scheduler.
        self.ws_opt = self.context.wrap_optimizer(
            torch.optim.SGD(
                self.model.ws_parameters(),
                self.hparams.learning_rate,
                momentum=self.hparams.momentum,
                weight_decay=self.hparams.weight_decay,
            ))
        self.arch_opt = self.context.wrap_optimizer(
            EG(
                self.model.arch_parameters(),
                self.hparams.arch_learning_rate,
                lambda p: p / p.sum(dim=-1, keepdim=True),
            ))

        self.lr_scheduler = self.context.wrap_lr_scheduler(
            lr_scheduler=CosineAnnealingLR(
                self.ws_opt,
                self.hparams.scheduler_epochs,
                self.hparams.min_learning_rate,
            ),
            step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH,
        )
Example #11
0
    def __init__(self, context: det_torch.PyTorchTrialContext) -> None:
        self.logger = logging.getLogger(__name__)
        self.context = context
        self.hparams = attrdict.AttrDict(context.get_hparams())
        self.data_config = attrdict.AttrDict(context.get_data_config())

        # Load dataset and get metadata.
        # This needs to be done before we initialize the HF config, tokenizer, and model
        # because we need to know num_labels before doing so.
        self.raw_datasets = hf.default_load_dataset(self.data_config)
        datasets_metadata = ner_utils.get_dataset_metadata(self.raw_datasets, self.hparams)
        self.hparams.num_labels = datasets_metadata.num_labels

        super(NERTrial, self).__init__(context)
        self.logger.info(self.config)

        # We need to create the tokenized dataset after init because we need to model and
        # tokenizer to be available.
        self.tokenized_datasets = ner_utils.build_tokenized_datasets(
            self.raw_datasets,
            self.model,
            self.data_config,
            self.tokenizer,
            datasets_metadata.text_column_name,
            datasets_metadata.label_column_name,
            datasets_metadata.label_to_id,
        )
        train_length = len(self.tokenized_datasets["train"])
        self.logger.info("training records: {}".format(train_length))
        if (
            "records_per_epoch" in self.exp_config
            and train_length != self.exp_config["records_per_epoch"]
        ):
            self.logger.warning(
                "number of train records {} does not match records_per_epoch of {}".format(
                    train_length, self.exp_config["records_per_epoch"]
                )
            )

        # Create metric reducer
        self.reducer = context.experimental.wrap_reducer(
            functools.partial(ner_utils.compute_metrics, datasets_metadata.label_list),
            for_training=False,
        )
Example #12
0
    def __init__(self, context: PyTorchTrialContext) -> None:
        self.context = context
        self.data_config = context.get_data_config()
        self.criterion = CrossEntropyLabelSmooth(
            context.get_hparam("num_classes"),  # num classes
            context.get_hparam("label_smoothing_rate"),
        )
        self.last_epoch_idx = -1

        self.model = self.context.wrap_model(self.build_model_from_config())

        self.optimizer = self.context.wrap_optimizer(
            torch.optim.SGD(
                self.model.parameters(),
                lr=self.context.get_hparam("learning_rate"),
                momentum=self.context.get_hparam("momentum"),
                weight_decay=self.context.get_hparam("weight_decay"),
            ))

        self.lr_scheduler = self.context.wrap_lr_scheduler(
            self.build_lr_scheduler_from_config(self.optimizer),
            step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH,
        )
Example #13
0
    def __init__(self, context: det_torch.PyTorchTrialContext) -> None:
        self.logger = logging.getLogger(__name__)
        self.hparams = attrdict.AttrDict(context.get_hparams())
        self.data_config = attrdict.AttrDict(context.get_data_config())
        self.context = context

        # Check to make sure the dataset is configured correctly.
        if self.data_config.dataset_name is not None:
            dataset_name = self.data_config.dataset_name
            if dataset_name == "squad":
                assert (not self.data_config.version_2_with_negative
                        ), "version_2_with_negative should be false for squad"
            elif dataset_name == "squad_v2":
                assert (
                    self.data_config.version_2_with_negative
                ), "version_2_with_negative should be true for squad_v2"

        self.data_processors = data_beam_search

        # Get the datasets: you can either provide your own CSV or JSON training and evaluation
        # files (see below) or just provide the name of one of the public datasets available on the
        # hub at https://huggingface.co/datasets/ (the dataset will be downloaded automatically
        # from the datasets Hub).

        # For CSV/JSON files, this script will use the column called 'text' or the first column if
        # no column called 'text' is found. You can easily tweak this behavior (see below).

        # See more about loading any type of standard or custom dataset (from files, python dict,
        # pandas DataFrame, etc) at
        # https://huggingface.co/docs/datasets/loading_datasets.html.
        self.raw_datasets = hf.default_load_dataset(self.data_config)
        self.column_names = self.raw_datasets["train"].column_names

        # For beam search, we need to use a different model from the default model returned by
        # AutoModelForQuestionAnswering.  We will use a custom init in this case that is a slight
        # modification of the BaseTransformerTrial init method.
        self.exp_config = attrdict.AttrDict(context.get_experiment_config())

        # Check to make sure all expected hyperparameters are set.
        self.check_hparams()

        # Parse hparams and data_config.
        (
            self.config_kwargs,
            self.tokenizer_kwargs,
            self.model_kwargs,
        ) = hf.default_parse_config_tokenizer_model_kwargs(self.hparams)
        optimizer_kwargs, scheduler_kwargs = hf.default_parse_optimizer_lr_scheduler_kwargs(
            self.hparams)

        self.config = transformers.XLNetConfig.from_pretrained(
            **self.config_kwargs)
        self.tokenizer = transformers.XLNetTokenizerFast.from_pretrained(
            **self.tokenizer_kwargs)

        # We need to use XLNetForQuestionAnswering instead of XLNetForQuestionAnsweringSimple
        # which is the default returned by AutoModelForQuestionAnswering.
        if self.hparams.use_pretrained_weights:
            self.model_kwargs["config"] = self.config
            self.model = transformers.XLNetForQuestionAnswering.from_pretrained(
                **self.model_kwargs)
        else:
            self.model = transformers.XLNetForQuestionAnswering(self.config)
        self.model = self.context.wrap_model(self.model)

        # The rest is the same as the parent init method.
        self.optimizer = self.context.wrap_optimizer(
            hf.build_default_optimizer(self.model, optimizer_kwargs))

        if self.hparams.use_apex_amp:
            self.model, self.optimizer = self.context.configure_apex_amp(
                models=self.model,
                optimizers=self.optimizer,
            )

        self.lr_scheduler = self.context.wrap_lr_scheduler(
            hf.build_default_lr_scheduler(self.optimizer, scheduler_kwargs),
            det_torch.LRScheduler.StepMode.STEP_EVERY_BATCH,
        )
        self.grad_clip_fn = (
            lambda x: torch.nn.utils.clip_grad_norm_(
                x, optimizer_kwargs.max_grad_norm)
            if optimizer_kwargs.max_grad_norm > 0  # type: ignore
            else None)

        self.logger.info(self.config)

        if not isinstance(self.tokenizer,
                          transformers.PreTrainedTokenizerFast):
            raise ValueError(
                "This example script only works for models that have a fast tokenizer. Checkout "
                "the big table of models at "
                "https://huggingface.co/transformers/index.html#bigtable to find the model types "
                "that meet this requirement")

        # We need to create the tokenized dataset after init because we need to model and
        # tokenizer to be available.
        self.tokenized_datasets = self.build_datasets()
        train_length = len(self.tokenized_datasets["train"])
        self.logger.info("training records: {}".format(train_length))
        if ("records_per_epoch" in self.exp_config
                and train_length != self.exp_config["records_per_epoch"]):
            self.logger.warning(
                "number of train records {} does not match records_per_epoch of {}"
                .format(train_length, self.exp_config["records_per_epoch"]))

        # Create metric reducer
        metric = datasets.load_metric("squad_v2" if self.data_config.
                                      version_2_with_negative else "squad")

        self.reducer = context.experimental.wrap_reducer(
            functools.partial(
                qa_utils.compute_metrics,
                self.data_config,
                self.column_names,
                self.data_processors.post_processing_function,
                self.raw_datasets,
                self.tokenized_datasets,
                self.model,
                metric,
            ),
            for_training=False,
        )
Example #14
0
    def __init__(self, context: det_torch.PyTorchTrialContext) -> None:
        self.logger = logging.getLogger(__name__)
        self.hparams = attrdict.AttrDict(context.get_hparams())
        self.data_config = attrdict.AttrDict(context.get_data_config())
        self.context = context

        # Load dataset and get metadata.
        # This needs to be done before we initialize the HF config, tokenizer, and model
        # because we need to know num_labels before doing so.

        # For CSV/JSON files, this example will use as labels the column called `label` and as pair
        # of sentences the sentences in columns called `sentence1` and `sentence2` if such column
        # exists or the first two columns not named label if at least two columns are provided.
        #
        # If the CSVs/JSONs contain only one non-label column, the example will do single sentence
        # classification on this single column.

        # See more about loading any type of standard or custom dataset at
        # https://huggingface.co/docs/datasets/loading_datasets.html.

        self.raw_datasets = hf.default_load_dataset(self.data_config)

        if self.hparams.finetuning_task is not None:
            is_regression = self.hparams.finetuning_task == "stsb"
            if not is_regression:
                label_list = self.raw_datasets["train"].features["label"].names
                num_labels = len(label_list)
            else:
                num_labels = 1
        else:
            # Trying to have good defaults here, don't hesitate to tweak to your needs.
            is_regression = self.raw_datasets["train"].features[
                "label"].dtype in [
                    "float32",
                    "float64",
                ]
            if is_regression:
                num_labels = 1
            else:
                # A useful fast method is datasets.Dataset.unique from
                # https://huggingface.co/docs/datasets/package_reference/main_classes.html
                label_list = self.raw_datasets["train"].unique("label")
                label_list.sort()  # Let's sort it for determinism
                num_labels = len(label_list)
        self.is_regression = is_regression
        self.hparams.num_labels = num_labels
        if not self.is_regression:
            self.label_list = label_list

        super(GLUETrial, self).__init__(context)
        self.logger.info(self.config)

        # We need to create the tokenized dataset after init because we need to model and
        # tokenizer to be available.
        self.tokenized_datasets = self.build_datasets()
        train_length = len(self.tokenized_datasets["train"])
        self.logger.info("training records: {}".format(train_length))
        if ("records_per_epoch" in self.exp_config
                and train_length != self.exp_config["records_per_epoch"]):
            self.logger.warning(
                "number of train records {} does not match records_per_epoch of {}"
                .format(train_length, self.exp_config["records_per_epoch"]))

        # Create metric reducer
        metric = datasets.load_metric("glue", self.hparams.finetuning_task)

        # You can define your custom compute_metrics function. It takes an `EvalPrediction` object
        # (a namedtuple with a predictions and label_ids field) and has to return a dictionary
        # mapping string to float.
        def compute_metrics(pred_labels) -> Dict:
            preds, labels = zip(*pred_labels)
            preds = utils.expand_like(preds)
            labels = utils.expand_like(labels)
            preds = np.squeeze(preds) if is_regression else np.argmax(preds,
                                                                      axis=1)
            if self.hparams.finetuning_task is not None:
                result = metric.compute(predictions=preds, references=labels)
                if len(result) > 1:
                    result["combined_score"] = np.mean(list(
                        result.values())).item()
                return result
            elif is_regression:
                return {"mse": ((preds - labels)**2).mean().item()}
            else:
                return {
                    "accuracy":
                    (preds == labels).astype(np.float32).mean().item()
                }

        self.reducer = context.wrap_reducer(compute_metrics,
                                            for_training=False)