def get_callbacks(self, stage: str): return { "criterion": CriterionCallback(metric_key="loss", input_key="logits", target_key="targets"), "optimizer": OptimizerCallback(metric_key="loss"), # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"), "checkpoint": CheckpointCallback(self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3), "test_nn_module": ModuleTypeChecker(), "test_device": DeviceCheckCallback(self._device, logger=logger), "test_loss_minimization": LossMinimizationCallback("loss", logger=logger), "test_logits_type": TensorTypeChecker("logits"), # "loss_type_checker": TensorTypeChecker("loss", True), }
def get_callbacks(self, stage: str): return { "criterion": CriterionCallback(metric_key="loss", input_key="logits", target_key="targets"), "accuracy": AccuracyCallback(input_key="logits", target_key="targets", topk_args=(1, )), "auc": AUCCallback(input_key="scores", target_key="targets_onehot"), "classification": PrecisionRecallF1SupportCallback( input_key="logits", target_key="targets", num_classes=4, ), # "optimizer": OptimizerCallback(metric_key="loss"), # "checkpoint": CheckpointCallback( # self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3 # ), # "verbose": TqdmCallback(), "counter": CounterCallback(), }
def test_pruning(): from catalyst.callbacks import ( AccuracyCallback, ControlFlowCallback, CriterionCallback, OptimizerCallback, PruningCallback, ) from catalyst.contrib.datasets import MNIST import torch from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from compressors.distillation.callbacks import KLDivCallback, MetricAggregationCallback from compressors.models import MLP from compressors.pruning.runners import PruneRunner from compressors.utils.data import TorchvisionDatasetWrapper as Wrp model = MLP(num_layers=3) datasets = { "train": Wrp(MNIST("./data", train=True, download=True, transform=ToTensor())), "valid": Wrp(MNIST("./data", train=False, transform=ToTensor())), } loaders = { dl_key: DataLoader(dataset, shuffle=dl_key == "train", batch_size=32) for dl_key, dataset in datasets.items() } optimizer = torch.optim.Adam(model.parameters()) runner = PruneRunner(num_sessions=10) runner.train(model=model, loaders=loaders, optimizer=optimizer, criterion=torch.nn.CrossEntropyLoss(), callbacks=[ PruningCallback( pruning_fn="l1_unstructured", amount=0.2, remove_reparametrization_on_stage_end=False, ), OptimizerCallback(metric_key="loss"), CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"), AccuracyCallback(input_key="logits", target_key="targets"), ], logdir="./pruned_model", valid_loader="valid", valid_metric="accuracy", minimize_valid_metric=False, check=True)
def get_callbacks(self, stage: str): return { "criterion": CriterionCallback( metric_key="loss", input_key="logits", target_key="targets" ), "optimizer": OptimizerCallback(metric_key="loss"), # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"), "checkpoint": CheckpointCallback( self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3 ), "test_nn_parallel_data_parallel": DataParallelTypeChecker(), "test_loss_minimization": LossMinimizationCallback("loss", logger=logger), "test_logits_type": OPTTensorTypeChecker("logits", self._opt_level), }
def get_callbacks(self, stage: str): return { "criterion": CriterionCallback( metric_key="loss", input_key="logits", target_key="targets" ), "optimizer": OptimizerCallback(metric_key="loss"), # "scheduler": dl.SchedulerCallback(loader_key="valid", metric_key="loss"), "checkpoint": CheckpointCallback( self._logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3 ), "test_nn_parallel_distributed_data_parallel": DistributedDataParallelTypeChecker(), "test_loss_minimization": LossMinimizationCallback("loss", logger=logger), "test_world_size": WorldSizeCheckCallback(NUM_CUDA_DEVICES, logger=logger), }
def test_tracer_callback(): """ Tests a feature of `TracingCallback` for model tracing during training """ logdir = "./logs" dataset_root = "./data" loaders = _get_loaders(root=dataset_root, batch_size=4, num_workers=1) images, targets = next(iter(loaders["train"])) _, c, h, w = images.shape input_shape = (c, h, w) model = _TracedNet(input_shape) criterion = nn.CrossEntropyLoss() optimizer = Adam(model.parameters()) method_name = "forward" mode = "eval" requires_grad = False checkpoint_name = "best" opt_level = None trace_name = get_trace_name( method_name=method_name, mode=mode, requires_grad=requires_grad, additional_string=checkpoint_name, ) tracing_path = Path(logdir) / "trace" / trace_name criterion_callback = CriterionCallback() optimizer_callback = OptimizerCallback() tracer_callback = TracingCallback( metric="loss", minimize=False, trace_mode=mode, mode=checkpoint_name, do_once=True, method_name=method_name, requires_grad=requires_grad, opt_level=opt_level, ) test_callback = _OnStageEndCheckModelTracedCallback( path=tracing_path, inputs=images, ) callbacks = collections.OrderedDict( loss=criterion_callback, optimizer=optimizer_callback, tracer_callback=tracer_callback, test_callback=test_callback, ) runner = SupervisedRunner(input_key="x") runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, logdir=logdir, callbacks=callbacks, check=True, verbose=True, ) shutil.rmtree(logdir)
def main(args): set_global_seed(42) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) datasets = { "train": Wrp( CIFAR100(root=".", train=True, download=True, transform=transform_train)), "valid": Wrp(CIFAR100(root=".", train=False, transform=transform_test)), } loaders = { k: DataLoader(v, batch_size=args.batch_size, shuffle=k == "train", num_workers=2) for k, v in datasets.items() } teacher_model = NAME2MODEL[args.teacher](num_classes=100) if args.teacher_path is None: teacher_sd = load_state_dict_from_url(NAME2URL[args.teacher]) teacher_model.load_state_dict(teacher_sd) else: unpack_checkpoint(torch.load(args.teacher_path), model=teacher_model) student_model = NAME2MODEL[args.student](num_classes=100) optimizer = torch.optim.SGD(student_model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [150, 180, 210], gamma=0.1) runner = DistilRunner(apply_probability_shift=args.probability_shift) runner.train(model={ "teacher": teacher_model, "student": student_model }, loaders=loaders, optimizer=optimizer, scheduler=scheduler, valid_metric="accuracy", minimize_valid_metric=False, logdir=args.logdir, callbacks=[ ControlFlowCallback(AttentionHiddenStatesCallback(), loaders="train"), ControlFlowCallback(KLDivCallback(temperature=4), loaders="train"), CriterionCallback(input_key="s_logits", target_key="targets", metric_key="cls_loss"), ControlFlowCallback( MetricAggregationCallback( prefix="loss", metrics={ "attention_loss": args.beta, "kl_div_loss": args.alpha, "cls_loss": 1 - args.alpha, }, mode="weighted_sum", ), loaders="train", ), AccuracyCallback(input_key="s_logits", target_key="targets"), OptimizerCallback(metric_key="loss", model_key="student"), SchedulerCallback(), ], valid_loader="valid", num_epochs=args.num_epochs, criterion=torch.nn.CrossEntropyLoss(), seed=args.seed)
select_last_hidden_states = HiddenStatesSelectCallback(layers=[2, 3]) kl_callback = KLDivCallback() runner = DistilRunner(output_hidden_states=True) optimizer = torch.optim.Adam(student.parameters(), lr=1e-2) runner.train( model={ "teacher": teacher, "student": student }, loaders=loaders, optimier=optimizer, criterion=nn.CrossEntropyLoss(), callbacks=[ AccuracyCallback(input_key="s_logits", target_key="targets"), CriterionCallback(input_key="s_logits"), mse_callback, select_last_hidden_states, kl_callback, MetricAggregationCallback({ "kl_loss": 0.2, "mse_loss": 0.2, "loss": 0.6 }), ], valid_metric="accuracy01", minimize_valid_metric=False, num_epochs=5, )
def main(args): set_global_seed(args.seed) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) datasets = { "train": Wrp( CIFAR100(root=".", train=True, download=True, transform=transform_train)), "valid": Wrp(CIFAR100(root=".", train=False, transform=transform_test)), } loaders = { k: DataLoader(v, batch_size=args.batch_size, shuffle=k == "train", num_workers=2) for k, v in datasets.items() } model = NAME2MODEL[args.model](num_classes=100) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [150, 180, 210], gamma=0.1) runner = SupervisedRunner() runner.train(model=model, loaders=loaders, optimizer=optimizer, scheduler=scheduler, valid_metric="accuracy", minimize_valid_metric=False, logdir=args.logdir, callbacks=[ CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"), AccuracyCallback(input_key="logits", target_key="targets"), OptimizerCallback(metric_key="loss"), SchedulerCallback(), ], valid_loader="valid", num_epochs=args.num_epochs, criterion=torch.nn.CrossEntropyLoss(), seed=args.seed)