def train_experiment(device, engine=None): with TemporaryDirectory() as logdir: teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) model = {"teacher": teacher, "student": student} criterion = {"cls": nn.CrossEntropyLoss(), "kl": nn.KLDivLoss(reduction="batchmean")} optimizer = optim.Adam(student.parameters(), lr=0.02) loaders = { "train": DataLoader( MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32 ), "valid": DataLoader( MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32 ), } runner = DistilRunner() # model training runner.train( engine=engine or dl.DeviceEngine(device), model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, num_epochs=1, logdir=logdir, verbose=False, callbacks=[ dl.AccuracyCallback( input_key="t_logits", target_key="targets", num_classes=2, prefix="teacher_" ), dl.AccuracyCallback( input_key="s_logits", target_key="targets", num_classes=2, prefix="student_" ), dl.CriterionCallback( input_key="s_logits", target_key="targets", metric_key="cls_loss", criterion_key="cls", ), dl.CriterionCallback( input_key="s_logprobs", target_key="t_probs", metric_key="kl_div_loss", criterion_key="kl", ), dl.MetricAggregationCallback( metric_key="loss", metrics=["kl_div_loss", "cls_loss"], mode="mean" ), dl.OptimizerCallback(metric_key="loss", model_key="student"), dl.CheckpointCallback( logdir=logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3, ), ], )
def test_aggregation_2(): """ Aggregation with custom function """ loaders, model, criterion, optimizer = prepare_experiment() runner = dl.SupervisedRunner() def aggregation_function(metrics, runner): epoch = runner.stage_epoch_step loss = (3 / 2 - epoch / 2) * metrics["loss_focal"] + (1 / 2 * epoch - 1 / 2) * metrics[ "loss_bce" ] return loss runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, logdir="./logs/aggregation_2/", num_epochs=3, callbacks=[ dl.CriterionCallback( input_key="logits", target_key="targets", metric_key="loss_bce", criterion_key="bce", ), dl.CriterionCallback( input_key="logits", target_key="targets", metric_key="loss_focal", criterion_key="focal", ), # loss aggregation dl.MetricAggregationCallback(metric_key="loss", mode=aggregation_function), ], ) for loader in ["train", "valid"]: metrics = runner.epoch_metrics[loader] loss_1 = metrics["loss_bce"] loss_2 = metrics["loss"] assert np.abs(loss_1 - loss_2) < 1e-5
def test_aggregation_1(): """ Aggregation as weighted_sum """ loaders, model, criterion, optimizer = prepare_experiment() runner = dl.SupervisedRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, loaders=loaders, logdir="./logs/aggregation_1/", num_epochs=3, callbacks=[ dl.CriterionCallback( input_key="logits", target_key="targets", metric_key="loss_bce", criterion_key="bce", ), dl.CriterionCallback( input_key="logits", target_key="targets", metric_key="loss_focal", criterion_key="focal", ), # loss aggregation dl.MetricAggregationCallback( metric_key="loss", metrics={ "loss_focal": 0.6, "loss_bce": 0.4 }, mode="weighted_sum", ), ], ) for loader in ["train", "valid"]: metrics = runner.epoch_metrics[loader] loss_1 = metrics["loss_bce"] * 0.4 + metrics["loss_focal"] * 0.6 loss_2 = metrics["loss"] assert np.abs(loss_1 - loss_2) < 1e-5
def main(args): wandb.init(project="teacher-pruning", config=vars(args)) set_global_seed(42) # dataloader initialization transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = Wrp( datasets.CIFAR10(root=os.getcwd(), train=True, transform=transform_train, download=True)) valid_dataset = Wrp( datasets.CIFAR10(root=os.getcwd(), train=False, transform=transform_test)) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2) valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=128, num_workers=2) loaders = { "train": train_dataloader, "valid": valid_dataloader, } # model initialization model = PreActResNet18() model.fc = nn.Linear(512, 10) if args.teacher_model is not None: is_kd = True teacher_model = NAME2MODEL[args.teacher_model]() load_model_from_path(model=teacher_model, path=args.teacher_path) model = { "student": model, "teacher": teacher_model, } output_hiddens = args.beta is None is_kd_on_hiddens = output_hiddens runner = KDRunner(device=args.device, output_hiddens=output_hiddens) parameters = model["student"].parameters() else: is_kd = False runner = dl.SupervisedRunner(device=args.device) parameters = model.parameters() # optimizer optimizer_cls = NAME2OPTIM[args.optimizer] optimizer_kwargs = {"params": parameters, "lr": args.lr} if args.optimizer == "sgd": optimizer_kwargs["momentum"] = args.momentum else: optimizer_kwargs["betas"] = (args.beta1, args.beta2) optimizer = optimizer_cls(**optimizer_kwargs) scheduler = MultiStepLR(optimizer, milestones=[80, 120], gamma=args.gamma) logdir = f"logs/{wandb.run.name}" # callbacks callbacks = [dl.AccuracyCallback(num_classes=10), WandbCallback()] if is_kd: metrics = {} callbacks.append(dl.CriterionCallback(output_key="cls_loss")) callbacks.append(DiffOutputCallback()) coefs = get_loss_coefs(args.alpha, args.beta) metrics["cls_loss"] = coefs[0] metrics["diff_output_loss"] = coefs[1] if is_kd_on_hiddens: callbacks.append(DiffHiddenCallback()) metrics["diff_hidden_loss"] = coefs[2] aggregator_callback = dl.MetricAggregationCallback(prefix="loss", metrics=metrics, mode="weighted_sum") wrapped_agg_callback = dl.ControlFlowCallback(aggregator_callback, loaders=["train"]) callbacks.append(wrapped_agg_callback) runner.train( model=model, optimizer=optimizer, scheduler=scheduler, criterion=nn.CrossEntropyLoss(), loaders=loaders, callbacks=callbacks, num_epochs=args.epoch, logdir=logdir, verbose=True, )
def train_experiment(device): with TemporaryDirectory() as logdir: # sample data num_samples, num_features, num_classes1, num_classes2 = int(1e4), int( 1e1), 4, 10 X = torch.rand(num_samples, num_features) y1 = (torch.rand(num_samples, ) * num_classes1).to(torch.int64) y2 = (torch.rand(num_samples, ) * num_classes2).to(torch.int64) # pytorch loaders dataset = TensorDataset(X, y1, y2) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} class CustomModule(nn.Module): def __init__(self, in_features: int, out_features1: int, out_features2: int): super().__init__() self.shared = nn.Linear(in_features, 128) self.head1 = nn.Linear(128, out_features1) self.head2 = nn.Linear(128, out_features2) def forward(self, x): x = self.shared(x) y1 = self.head1(x) y2 = self.head2(x) return y1, y2 # model, criterion, optimizer, scheduler model = CustomModule(num_features, num_classes1, num_classes2) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2]) class CustomRunner(dl.Runner): def handle_batch(self, batch): x, y1, y2 = batch y1_hat, y2_hat = self.model(x) self.batch = { "features": x, "logits1": y1_hat, "logits2": y2_hat, "targets1": y1, "targets2": y2, } # model training runner = CustomRunner() runner.train( model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, num_epochs=1, verbose=False, callbacks=[ dl.CriterionCallback(metric_key="loss1", input_key="logits1", target_key="targets1"), dl.CriterionCallback(metric_key="loss2", input_key="logits2", target_key="targets2"), dl.MetricAggregationCallback(prefix="loss", metrics=["loss1", "loss2"], mode="mean"), dl.OptimizerCallback(metric_key="loss"), dl.SchedulerCallback(), dl.AccuracyCallback( input_key="logits1", target_key="targets1", num_classes=num_classes1, prefix="one_", ), dl.AccuracyCallback( input_key="logits2", target_key="targets2", num_classes=num_classes2, prefix="two_", ), dl.ConfusionMatrixCallback( input_key="logits1", target_key="targets1", num_classes=num_classes1, prefix="one_cm", ), # catalyst[ml] required dl.ConfusionMatrixCallback( input_key="logits2", target_key="targets2", num_classes=num_classes2, prefix="two_cm", ), # catalyst[ml] required dl.CheckpointCallback( "./logs/one", loader_key="valid", metric_key="one_accuracy", minimize=False, save_n_best=1, ), dl.CheckpointCallback( "./logs/two", loader_key="valid", metric_key="two_accuracy03", minimize=False, save_n_best=3, ), ], loggers={ "console": dl.ConsoleLogger(), "tb": dl.TensorboardLogger("./logs/tb") }, )
def train_experiment(engine=None): with TemporaryDirectory() as logdir: # sample data num_samples, num_features, num_classes1, num_classes2 = int(1e4), int( 1e1), 4, 10 X = torch.rand(num_samples, num_features) y1 = (torch.rand(num_samples) * num_classes1).to(torch.int64) y2 = (torch.rand(num_samples) * num_classes2).to(torch.int64) # pytorch loaders dataset = TensorDataset(X, y1, y2) loader = DataLoader(dataset, batch_size=32, num_workers=1) loaders = {"train": loader, "valid": loader} # model, criterion, optimizer, scheduler model = CustomModule(num_features, num_classes1, num_classes2) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters()) scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2]) callbacks = [ dl.CriterionCallback(metric_key="loss1", input_key="logits1", target_key="targets1"), dl.CriterionCallback(metric_key="loss2", input_key="logits2", target_key="targets2"), dl.MetricAggregationCallback(metric_key="loss", metrics=["loss1", "loss2"], mode="mean"), dl.BackwardCallback(metric_key="loss"), dl.OptimizerCallback(metric_key="loss"), dl.SchedulerCallback(), dl.AccuracyCallback( input_key="logits1", target_key="targets1", num_classes=num_classes1, prefix="one_", ), dl.AccuracyCallback( input_key="logits2", target_key="targets2", num_classes=num_classes2, prefix="two_", ), dl.CheckpointCallback( "./logs/one", loader_key="valid", metric_key="one_accuracy01", minimize=False, topk=1, ), dl.CheckpointCallback( "./logs/two", loader_key="valid", metric_key="two_accuracy03", minimize=False, topk=3, ), ] if SETTINGS.ml_required: # catalyst[ml] required callbacks.append( dl.ConfusionMatrixCallback( input_key="logits1", target_key="targets1", num_classes=num_classes1, prefix="one_cm", )) # catalyst[ml] required callbacks.append( dl.ConfusionMatrixCallback( input_key="logits2", target_key="targets2", num_classes=num_classes2, prefix="two_cm", )) # model training runner = CustomRunner() runner.train( engine=engine, model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, num_epochs=1, verbose=False, callbacks=callbacks, loggers={ "console": dl.ConsoleLogger(), "tb": dl.TensorboardLogger("./logs/tb"), }, )
def train_experiment(device): with TemporaryDirectory() as logdir: teacher = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) student = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) criterion = { "cls": nn.CrossEntropyLoss(), "kl": nn.KLDivLoss(reduction="batchmean") } optimizer = optim.Adam(student.parameters(), lr=0.02) loaders = { "train": DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=ToTensor()), batch_size=32), "valid": DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=ToTensor()), batch_size=32), } class DistilRunner(dl.Runner): def handle_batch(self, batch): x, y = batch teacher.eval() # let's manually set teacher model to eval mode with torch.no_grad(): t_logits = self.model["teacher"](x) s_logits = self.model["student"](x) self.batch = { "t_logits": t_logits, "s_logits": s_logits, "targets": y, "s_logprobs": F.log_softmax(s_logits, dim=-1), "t_probs": F.softmax(t_logits, dim=-1), } runner = DistilRunner() # model training runner.train( engine=dl.DeviceEngine(device), model={ "teacher": teacher, "student": student }, criterion=criterion, optimizer=optimizer, loaders=loaders, num_epochs=1, logdir=logdir, verbose=True, callbacks=[ dl.AccuracyCallback(input_key="t_logits", target_key="targets", num_classes=2, prefix="teacher_"), dl.AccuracyCallback(input_key="s_logits", target_key="targets", num_classes=2, prefix="student_"), dl.CriterionCallback( input_key="s_logits", target_key="targets", metric_key="cls_loss", criterion_key="cls", ), dl.CriterionCallback( input_key="s_logprobs", target_key="t_probs", metric_key="kl_div_loss", criterion_key="kl", ), dl.MetricAggregationCallback( prefix="loss", metrics=["kl_div_loss", "cls_loss"], mode="mean"), dl.OptimizerCallback(metric_key="loss", model_key="student"), dl.CheckpointCallback( logdir=logdir, loader_key="valid", metric_key="loss", minimize=True, save_n_best=3, ), ], )