예제 #1
0
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """Returns the callbacks for a given stage."""
        callbacks_params = get_by_keys(self._stage_config,
                                       stage,
                                       "callbacks",
                                       default={})
        callbacks = OrderedDict(REGISTRY.get_from_params(**callbacks_params))

        is_callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values())
        if self._verbose and not is_callback_exists(TqdmCallback):
            callbacks["_verbose"] = TqdmCallback()
        if self._timeit and not is_callback_exists(TimerCallback):
            callbacks["_timer"] = TimerCallback()
        if self._check and not is_callback_exists(CheckRunCallback):
            callbacks["_check"] = CheckRunCallback()
        if self._overfit and not is_callback_exists(BatchOverfitCallback):
            callbacks["_overfit"] = BatchOverfitCallback()

        if self._logdir is not None and not is_callback_exists(
                ICheckpointCallback):
            callbacks["_checkpoint"] = CheckpointCallback(logdir=os.path.join(
                self._logdir, "checkpoints"), )

        return callbacks
예제 #2
0
    def get_callbacks(self, stage: str) -> "OrderedDict[str, Callback]":
        """Returns the callbacks for a given stage."""
        callbacks_params = self._config.stages[stage].callbacks or {}

        callbacks: Dict[str, Callback] = {
            name: self._get_callback_from_params(callback_params)
            for name, callback_params in callbacks_params.items()
        }

        is_callback_exists = lambda callback_fn: any(
            callback_isinstance(x, callback_fn) for x in callbacks.values()
        )
        if self._verbose and not is_callback_exists(TqdmCallback):
            callbacks["_verbose"] = TqdmCallback()
        if self._timeit and not is_callback_exists(TimerCallback):
            callbacks["_timer"] = TimerCallback()
        if self._check and not is_callback_exists(CheckRunCallback):
            callbacks["_check"] = CheckRunCallback()
        if self._overfit and not is_callback_exists(BatchOverfitCallback):
            callbacks["_overfit"] = BatchOverfitCallback()

        if self._logdir is not None and not is_callback_exists(ICheckpointCallback):
            callbacks["_checkpoint"] = CheckpointCallback(
                logdir=os.path.join(self._logdir, "checkpoints")
            )

        return callbacks
예제 #3
0
파일: test_amp.py 프로젝트: zkid18/catalyst
 def get_callbacks(self, stage: str):
     return {
         "criterion":
         CriterionCallback(metric_key="loss",
                           input_key="logits",
                           target_key="targets"),
         "optimizer":
         OptimizerCallback(metric_key="loss"),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "checkpoint":
         CheckpointCallback(self._logdir,
                            loader_key="valid",
                            metric_key="loss",
                            minimize=True,
                            save_n_best=3),
         "test_nn_module":
         ModuleTypeChecker(),
         "test_device":
         DeviceCheckCallback(self._device, logger=logger),
         "test_loss_minimization":
         LossMinimizationCallback("loss", logger=logger),
         "test_logits_type":
         TensorTypeChecker("logits"),
         # "loss_type_checker": TensorTypeChecker("loss", True),
     }
예제 #4
0
 def get_callbacks(self, stage: str):
     return {
         "criterion": CriterionCallback(
             metric_key="loss", input_key="logits", target_key="targets"
         ),
         "optimizer": OptimizerCallback(metric_key="loss"),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "checkpoint": CheckpointCallback(
             self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3
         ),
         "test_nn_parallel_data_parallel": DataParallelTypeChecker(),
         "test_loss_minimization": LossMinimizationCallback("loss", logger=logger),
         "test_logits_type": OPTTensorTypeChecker("logits", self._opt_level),
     }
예제 #5
0
 def get_callbacks(self, stage: str):
     return {
         "criterion": CriterionCallback(
             metric_key="loss", input_key="logits", target_key="targets"
         ),
         "optimizer": OptimizerCallback(metric_key="loss"),
         # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"),
         "checkpoint": CheckpointCallback(
             self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3
         ),
         "test_nn_parallel_distributed_data_parallel": DistributedDataParallelTypeChecker(),
         "test_loss_minimization": LossMinimizationCallback("loss", logger=logger),
         "test_world_size": WorldSizeCheckCallback(NUM_CUDA_DEVICES, logger=logger),
     }
예제 #6
0
def main(args):
    if args.wandb:
        import wandb
        wandb.init()
        logdir = args.logdir + "/" + wandb.run.name
    else:
        logdir = args.logdir
    set_global_seed(args.seed)
    datasets = load_dataset(args.dataset)

    tokenizer = AutoTokenizer.from_pretrained(args.teacher_model)
    datasets = datasets.map(
        lambda e: tokenizer(
            e["text"], truncation=True, padding="max_length", max_length=128),
        batched=True,
    )
    datasets = datasets.map(lambda e: {"labels": e["label"]}, batched=True)
    datasets.set_format(
        type="torch",
        columns=["input_ids", "token_type_ids", "attention_mask", "labels"],
    )
    loaders = {
        "train":
        DataLoader(datasets["train"], batch_size=args.batch_size,
                   shuffle=True),
        "valid":
        DataLoader(datasets["test"], batch_size=args.batch_size),
    }
    teacher_model = AutoModelForSequenceClassification.from_pretrained(
        args.teacher_model, num_labels=args.num_labels)
    unpack_checkpoint(torch.load(args.teacher_path), model=teacher_model)
    metric_callback = LoaderMetricCallback(
        metric=HFMetric(metric=load_metric("accuracy")),
        input_key="s_logits",
        target_key="labels",
    )
    layers = [int(layer) for layer in args.layers.split(",")]
    slct_callback = ControlFlowCallback(
        HiddenStatesSelectCallback(hiddens_key="t_hidden_states",
                                   layers=layers),
        loaders="train",
    )

    lambda_hiddens_callback = ControlFlowCallback(
        LambdaPreprocessCallback(lambda s_hiddens, t_hiddens: (
            [c_s[:, 0] for c_s in s_hiddens],
            [t_s[:, 0] for t_s in t_hiddens],  # tooks only CLS token
        )),
        loaders="train",
    )

    mse_hiddens = ControlFlowCallback(MSEHiddenStatesCallback(),
                                      loaders="train")

    kl_div = ControlFlowCallback(
        KLDivCallback(temperature=args.kl_temperature), loaders="train")

    runner = HFDistilRunner()

    student_model = AutoModelForSequenceClassification.from_pretrained(
        args.student_model, num_labels=args.num_labels)
    callbacks = [
        metric_callback, slct_callback, lambda_hiddens_callback, kl_div,
        OptimizerCallback(metric_key="loss"),
        CheckpointCallback(logdir=logdir,
                           loader_key="valid",
                           mode="model",
                           metric_key="accuracy",
                           minimize=False)
    ]
    if args.beta > 0:
        aggregator = ControlFlowCallback(
            MetricAggregationCallback(
                prefix="loss",
                metrics={
                    "kl_div_loss": args.alpha,
                    "mse_loss": args.beta,
                    "task_loss": 1 - args.alpha
                },
                mode="weighted_sum",
            ),
            loaders="train",
        )
        callbacks.append(mse_hiddens)
        callbacks.append(aggregator)
    else:
        aggregator = ControlFlowCallback(
            MetricAggregationCallback(
                prefix="loss",
                metrics={
                    "kl_div_loss": args.alpha,
                    "task_loss": 1 - args.alpha
                },
                mode="weighted_sum",
            ),
            loaders="train",
        )
        callbacks.append(aggregator)
    runner.train(model=torch.nn.ModuleDict({
        "teacher": teacher_model,
        "student": student_model
    }),
                 loaders=loaders,
                 optimizer=torch.optim.Adam(student_model.parameters(),
                                            lr=args.lr),
                 callbacks=callbacks,
                 num_epochs=args.num_epochs,
                 valid_metric="accuracy",
                 logdir=logdir,
                 minimize_valid_metric=False,
                 valid_loader="valid",
                 verbose=args.verbose,
                 seed=args.seed)

    if args.wandb:
        import csv
        import shutil
        with open(logdir + "/valid.csv") as fi:
            reader = csv.DictReader(fi)
            accuracy = []
            for row in reader:
                if row["accuracy"] == "accuracy":
                    continue
                accuracy.append(float(row["accuracy"]))

        wandb.log({"accuracy": max(accuracy[-args.num_epochs:])})
        shutil.rmtree(logdir)
예제 #7
0
    optimizer = torch.optim.Adam(net.parameters(), lr=0.02)
    scheduler = OneCycleLR(
        optimizer,
        max_lr=0.02,
        epochs=args.n_epochs,
        steps_per_epoch=len(train_loaders["train"]),
    )

    runner = CustomRunner(n_classes=args.n_classes)
    runner.train(
        model=net,
        optimizer=optimizer,
        loaders=train_loaders,
        num_epochs=args.n_epochs,
        scheduler=scheduler,
        callbacks=[CheckpointCallback(logdir=logdir)],
        logdir=logdir,
        verbose=True,
    )

    segmentations = {}
    for subject in range(infer_loaders["infer"].dataset.subjects):
        segmentations[subject] = torch.zeros(
            tuple(np.insert(volume_shape, 0, args.n_classes)),
            dtype=torch.uint8,
        )

    segmentations = voxel_majority_predict_from_subvolumes(
        infer_loaders["infer"], args.n_classes, segmentations)
    subject_metrics = []
    for subject, subject_data in enumerate(