예제 #1
0
def train_experiment(device, engine=None):
    with TemporaryDirectory() as logdir:
        model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=0.02)

        loaders = {
            "train":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
            "valid":
            DataLoader(MNIST(os.getcwd(),
                             train=False,
                             download=True,
                             transform=ToTensor()),
                       batch_size=32),
        }

        runner = dl.SupervisedRunner(input_key="features",
                                     output_key="logits",
                                     target_key="targets",
                                     loss_key="loss")
        callbacks = [
            dl.AccuracyCallback(input_key="logits",
                                target_key="targets",
                                topk_args=(1, 3, 5)),
            dl.PrecisionRecallF1SupportCallback(input_key="logits",
                                                target_key="targets",
                                                num_classes=10),
        ]
        if SETTINGS.ml_required:
            callbacks.append(
                dl.ConfusionMatrixCallback(input_key="logits",
                                           target_key="targets",
                                           num_classes=10))
        if SETTINGS.amp_required and (engine is None or not isinstance(
                engine,
            (dl.AMPEngine, dl.DataParallelAMPEngine,
             dl.DistributedDataParallelAMPEngine),
        )):
            callbacks.append(
                dl.AUCCallback(input_key="logits", target_key="targets"))
        if SETTINGS.onnx_required:
            callbacks.append(
                dl.OnnxCallback(logdir=logdir, input_key="features"))
        if SETTINGS.pruning_required:
            callbacks.append(
                dl.PruningCallback(pruning_fn="l1_unstructured", amount=0.5))
        if SETTINGS.quantization_required:
            callbacks.append(dl.QuantizationCallback(logdir=logdir))
        if engine is None or not isinstance(engine,
                                            dl.DistributedDataParallelEngine):
            callbacks.append(
                dl.TracingCallback(logdir=logdir, input_key="features"))
        # model training
        runner.train(
            engine=engine or dl.DeviceEngine(device),
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            num_epochs=1,
            callbacks=callbacks,
            logdir=logdir,
            valid_loader="valid",
            valid_metric="loss",
            minimize_valid_metric=True,
            verbose=False,
            load_best_on_end=True,
            timeit=False,
            check=False,
            overfit=False,
            fp16=False,
            ddp=False,
        )
        # model inference
        for prediction in runner.predict_loader(loader=loaders["valid"]):
            assert prediction["logits"].detach().cpu().numpy().shape[-1] == 10
        # model post-processing
        features_batch = next(iter(loaders["valid"]))[0]
        # model stochastic weight averaging
        model.load_state_dict(
            utils.get_averaged_weights_by_path_mask(logdir=logdir,
                                                    path_mask="*.pth"))
        # model onnx export
        if SETTINGS.onnx_required:
            utils.onnx_export(
                model=runner.model,
                batch=runner.engine.sync_device(features_batch),
                file="./mnist.onnx",
                verbose=False,
            )
        # model quantization
        if SETTINGS.quantization_required:
            utils.quantize_model(model=runner.model)
        # model pruning
        if SETTINGS.pruning_required:
            utils.prune_model(model=runner.model,
                              pruning_fn="l1_unstructured",
                              amount=0.8)
        # model tracing
        utils.trace_model(model=runner.model, batch=features_batch)
예제 #2
0
def main(args):
    train_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=True, transform=ToTensor())
    )
    val_dataset = TorchvisionDatasetWrapper(
        MNIST(root="./", download=True, train=False, transform=ToTensor())
    )

    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64)
    loaders = {"train": train_dataloader, "valid": val_dataloader}
    utils.set_global_seed(args.seed)
    net = nn.Sequential(
        Flatten(),
        nn.Linear(28 * 28, 300),
        nn.ReLU(),
        nn.Linear(300, 100),
        nn.ReLU(),
        nn.Linear(100, 10),
    )
    initial_state_dict = net.state_dict()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters())
    if args.device is not None:
        engine = dl.DeviceEngine(args.device)
    else:
        engine = None
    if args.vanilla_pruning:
        runner = dl.SupervisedRunner(engine=engine)

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=[
                dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            ],
            logdir="./logdir",
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )
        pruning_fn = partial(
            utils.pruning.prune_model,
            pruning_fn=args.pruning_method,
            amount=args.amount,
            keys_to_prune=["weights"],
            dim=args.dim,
            l_norm=args.n,
        )
        acc, amount = validate_model(
            runner, pruning_fn=pruning_fn, loader=loaders["valid"], num_sessions=args.num_sessions
        )
        torch.save(acc, "accuracy.pth")
        torch.save(amount, "amount.pth")

    else:
        runner = PruneRunner(num_sessions=args.num_sessions, engine=engine)
        callbacks = [
            dl.AccuracyCallback(input_key="logits", target_key="targets", num_classes=10),
            dl.PruningCallback(
                args.pruning_method,
                keys_to_prune=["weight"],
                amount=args.amount,
                remove_reparametrization_on_stage_end=False,
            ),
            dl.CriterionCallback(input_key="logits", target_key="targets", metric_key="loss"),
            dl.OptimizerCallback(metric_key="loss"),
        ]
        if args.lottery_ticket:
            callbacks.append(LotteryTicketCallback(initial_state_dict=initial_state_dict))
        if args.kd:
            net.load_state_dict(torch.load(args.state_dict))
            callbacks.append(
                PrepareForFinePruningCallback(probability_shift=args.probability_shift)
            )
            callbacks.append(KLDivCallback(temperature=4, student_logits_key="logits"))
            callbacks.append(
                MetricAggregationCallback(
                    prefix="loss", metrics={"loss": 0.1, "kl_div_loss": 0.9}, mode="weighted_sum"
                )
            )

        runner.train(
            model=net,
            criterion=criterion,
            optimizer=optimizer,
            loaders=loaders,
            callbacks=callbacks,
            logdir=args.logdir,
            num_epochs=args.num_epochs,
            load_best_on_end=True,
            valid_metric="accuracy01",
            minimize_valid_metric=False,
            valid_loader="valid",
        )