예제 #1
0
    def __init__(self, config):
        super(TestSetEvaluator, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(
            cpu_only, config.get("gpu_ids"), logger=self.logger
        )

        # Instantiate Datasets and Dataloaders
        # self.test_set, self.test_loader = init_data(config.get("test_data"))
        base_transforms = [
            ToTensor(),
            Normalize(
                [0.663295328617096, 0.6501832604408264, 0.6542291045188904],
                [0.19360290467739105, 0.22194330394268036, 0.23059576749801636],
            ),
            Lambda(expand_multires),
        ]
        self.test_set = ImageFolder(
            "data/c617a1/test", transform=Compose(base_transforms)
        )
        self.test_loader = DataLoader(
            self.test_set,
            num_workers=config.get("num_workers", 0),
            batch_size=config.get("batch_size", 1),
            shuffle=True,
        )

        # Instantiate Model
        base_model = init_class(config.get("model"))
        # manually convert pretrained model into a binary classification problem
        base_out_size = 300
        base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(1280, base_out_size)
        )
        self.model = MultiResolutionModelWrapper(base_model, base_out_size=base_out_size)

        self.logger.info("Test Dataset: %s", self.test_set)
        self.checkpoints = config.get("checkpoints")
        self.map_location = None if self.use_cuda else torch.device("cpu")
        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
예제 #2
0
    def __init__(self, config):
        super(ModelProfilerAgent, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)

        # Instantiate Datasets and Dataloaders
        self.eval_set, self.eval_loader = init_data(config.get("eval_data"))

        # Instantiate Models
        self.model = init_class(config.get("model"))
        self.model.eval()
        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Log the classification experiment details
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Model: %s", self.model)
        self.logger.info("Batch Size: %d", self.eval_loader.batch_size)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum(
            [p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {
                "params": num_params,
                "lrn_params": num_lrn_p
            },
        )

        t_log_fpath = os.path.join(config["out_dir"], "profiler.out")
        self.t_log = TabLogger(t_log_fpath)
        self.t_log.set_names([
            "Batch Size", "Self CPU Time", "CPU Time Total", "CUDA Time Total"
        ])
        self.logger.info("Storing tab log output at: %s", t_log_fpath)
예제 #3
0
    def __init__(self, config):
        super(TestSetEvaluator, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)

        # Instantiate Datasets and Dataloaders
        self.test_set, self.test_loader = init_data(config.get("test_data"))

        # Instantiate Model
        self.model = init_class(config.get("model"))
        # manually convert pretrained model into a binary classification problem
        self.model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(1280, 2))

        self.logger.info("Test Dataset: %s", self.test_set)
        self.checkpoints = config.get("checkpoints")
        self.map_location = None if self.use_cuda else torch.device("cpu")
        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
예제 #4
0
    def __init__(self, config):
        super(ClassificationAgent, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(
            cpu_only, config.get("gpu_ids"), logger=self.logger
        )

        # Instantiate Datasets and Dataloaders
        self.train_set, self.train_loader = init_data(config.get("train_data"))
        self.eval_set, self.eval_loader = init_data(config.get("eval_data"))

        # Instantiate Model
        self.model = init_class(config.get("model"))
        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.optimizer = init_class(config.get("optimizer"), self.model.parameters())

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.schedule = config.get("schedule", [150, 225])
        self.gamma = config.get("gamma", 0.1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc1 = 0

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {"params": num_params, "lrn_params": num_lrn_p},
        )
        self.logger.info(
            "LR: %(lr)f decreasing by a factor of %(gamma)f at epochs %(schedule)s",
            {"lr": self.lr, "gamma": self.gamma, "schedule": self.schedule},
        )

        # Path to in progress checkpoint.pth.tar for resuming experiment
        resume = config.get("resume")
        t_log_fpath = os.path.join(config["out_dir"], "epoch.out")
        self.t_log = TabLogger(t_log_fpath, resume=bool(resume))  # tab logger
        self.t_log.set_names(
            [
                "Epoch",
                "Train Task Loss",
                "Train Acc",
                "Eval Task Loss",
                "Eval Acc",
                "LR",
            ]
        )
        if resume:
            self.logger.info("Resuming from checkpoint: %s", resume)
            res_chkpt = torch.load(resume)
            self.model.load_state_dict(res_chkpt["state_dict"])
            self.start_epoch = res_chkpt.get("epoch", 0)
            self.best_acc1 = res_chkpt.get("best_acc1", 0)
            optim_state_dict = res_chkpt.get("optim_state_dict")
            if optim_state_dict:
                self.optimizer.load_state_dict(optim_state_dict)
            self.logger.info(
                "Resumed at epoch %d, eval best_acc1 %.2f",
                self.start_epoch,
                self.best_acc1,
            )
            # fastforward LR to match current schedule
            for sched in self.schedule:
                if sched > self.start_epoch:
                    break
                new_lr = adjust_learning_rate(
                    self.optimizer,
                    sched,
                    lr=self.lr,
                    schedule=self.schedule,
                    gamma=self.gamma,
                )
                self.logger.info(
                    "LR fastforward from %(old)f to %(new)f at Epoch %(epoch)d",
                    {"old": self.lr, "new": new_lr, "epoch": sched},
                )
                self.lr = new_lr

        self.logger.info(
            "Training from Epoch %(start)d to %(end)d",
            {"start": self.start_epoch, "end": self.epochs},
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
    def __init__(self, config):
        super(JointKnowledgeDistillationPruningAgent, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(
            cpu_only, config.get("gpu_ids"), logger=self.logger
        )

        # Instantiate Datasets and Dataloaders
        self.train_set, self.train_loader = init_data(config.get("train_data"))
        self.eval_set, self.eval_loader = init_data(config.get("eval_data"))

        # Instantiate Models
        self.pretrained_model = init_class(config.get("pretrained_model"))
        self.model = init_class(config.get("model"))

        # Load the pretrained weights
        map_location = None if self.use_cuda else torch.device("cpu")
        prune_checkpoint = config["prune"]
        self.logger.info(
            "Loading pretrained model from checkpoint: %s", prune_checkpoint
        )
        prune_checkpoint = torch.load(prune_checkpoint, map_location=map_location)
        self.pretrained_model.load_state_dict(prune_checkpoint["state_dict"])
        self.pretrained_model.eval()
        modules_pretrained = list(self.pretrained_model.modules())
        modules_to_prune = list(self.model.modules())
        module_idx = 0
        self.mask_modules = []
        for module_to_prune in modules_to_prune:
            module_pretrained = modules_pretrained[module_idx]
            modstr = str(type(module_to_prune))
            # Skip the masking layers
            if "MaskSTE" in modstr:
                self.mask_modules.append(module_to_prune)
                continue
            if len(list(module_to_prune.children())) == 0:
                assert modstr == str(type(module_pretrained))
                # copy all parameters over
                param_lookup = dict(module_pretrained.named_parameters())
                for param_key, param_val in module_to_prune.named_parameters():
                    param_val.data.copy_(param_lookup[param_key].data)
                # BatchNorm layers are special and require copying of running_mean/running_var
                if "BatchNorm" in modstr:
                    module_to_prune.running_mean.copy_(module_pretrained.running_mean)
                    module_to_prune.running_var.copy_(module_pretrained.running_var)
            module_idx += 1

        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.mask_loss_fn = init_class(config.get("mask_loss"))
        self.optimizer = init_class(config.get("optimizer"), self.model.parameters())
        self.temperature = config.get("temperature", 4.0)
        self.task_loss_reg = config.get("task_loss_reg", 1.0)
        self.mask_loss_reg = config.get("mask_loss_reg", 1.0)
        self.kd_loss_reg = config.get("kd_loss_reg", 1.0)

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.schedule = config.get("schedule", [150, 225])
        self.gamma = config.get("gamma", 0.1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc_per_usage = {}

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {"params": num_params, "lrn_params": num_lrn_p},
        )
        self.logger.info(
            "LR: %(lr)f decreasing by a factor of %(gamma)f at epochs %(schedule)s",
            {"lr": self.lr, "gamma": self.gamma, "schedule": self.schedule},
        )

        # Path to in progress checkpoint.pth.tar for resuming experiment
        resume = config.get("resume")
        t_log_fpath = os.path.join(config["out_dir"], "epoch.out")
        self.t_log = TabLogger(t_log_fpath, resume=bool(resume))
        self.t_log.set_names(
            [
                "Epoch",
                "Train Task Loss",
                "Train KD Loss",
                "Train Mask Loss",
                "Train Acc",
                "Eval Task Loss",
                "Eval KD Loss",
                "Eval Mask Loss",
                "Eval Acc",
                "Num Parameters",
                "LR",
            ]
        )
        if resume:
            self.logger.info("Resuming from checkpoint: %s", resume)
            res_chkpt = torch.load(resume, map_location=map_location)
            self.start_epoch = res_chkpt["epoch"]
            self.model.load_state_dict(res_chkpt["state_dict"])
            eval_acc = res_chkpt["acc"]
            self.best_acc_per_usage = res_chkpt["best_acc_per_usage"]
            self.optimizer.load_state_dict(res_chkpt["optim_state_dict"])
            self.logger.info(
                "Resumed at epoch %d, eval acc %.2f", self.start_epoch, eval_acc
            )
            self.logger.info(pformat(self.best_acc_per_usage))
            # fastforward LR to match current schedule
            for sched in self.schedule:
                if sched > self.start_epoch:
                    break
                new_lr = adjust_learning_rate(
                    self.optimizer,
                    sched,
                    lr=self.lr,
                    schedule=self.schedule,
                    gamma=self.gamma,
                )
                self.logger.info(
                    "LR fastforward from %(old)f to %(new)f at Epoch %(epoch)d",
                    {"old": self.lr, "new": new_lr, "epoch": sched},
                )
                self.lr = new_lr

        self.logger.info(
            "Training from Epoch %(start)d to %(end)d",
            {"start": self.start_epoch, "end": self.epochs},
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.pretrained_model = torch.nn.DataParallel(
                    self.pretrained_model
                ).cuda()
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.pretrained_model = self.pretrained_model.cuda()
                self.model = self.model.cuda()
예제 #6
0
    def __init__(self, config):
        super(FineTuneClassifier, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)

        # Instantiate Datasets and Dataloaders
        # self.train_set, self.train_loader = init_data(config.get("train_eval_data"))
        train_eval_config = config.get("train_eval_data")
        ds_class = fetch_class(train_eval_config["name"])
        d_transform = list(
            map(init_class, train_eval_config.get("transform", [])))
        d_ttransform = list(
            map(init_class, train_eval_config.get("target_transform", [])))
        ds = ds_class(
            *train_eval_config.get("args", []),
            **train_eval_config.get("kwargs", {}),
            transform=Compose(d_transform) if d_transform else None,
            target_transform=Compose(d_ttransform) if d_ttransform else None)
        train_set_ratio, eval_set_ratio = train_eval_config.get(
            "train_eval_split_ratio", [0.85, 0.15])
        train_len = ceil(len(ds) * train_set_ratio)
        eval_len = len(ds) - train_len
        self.train_set, self.eval_set = random_split(ds, [train_len, eval_len])
        self.train_loader = DataLoader(
            self.train_set, **train_eval_config.get("dataloader_kwargs", {}))
        self.eval_loader = DataLoader(
            self.eval_set, **train_eval_config.get("dataloader_kwargs", {}))

        # Instantiate Model
        self.model = init_class(config.get("model"))
        # Freeze all of the parameters (except for final classification layer which we add afterwards)
        for param in self.model.parameters():
            param.requires_grad = False
        # manually convert pretrained model into a binary classification problem
        self.model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(1280, 2))
        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.optimizer = init_class(config.get("optimizer"),
                                    self.model.parameters())

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.schedule = config.get("schedule", [150, 225])
        self.gamma = config.get("gamma", 0.1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc1 = 0

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum(
            [p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {
                "params": num_params,
                "lrn_params": num_lrn_p
            },
        )
        self.logger.info(
            "LR: %(lr)f decreasing by a factor of %(gamma)f at epochs %(schedule)s",
            {
                "lr": self.lr,
                "gamma": self.gamma,
                "schedule": self.schedule
            },
        )

        # Path to in progress checkpoint.pth.tar for resuming experiment
        self.map_location = None if self.use_cuda else torch.device("cpu")
        resume = config.get("resume")
        t_log_fpath = os.path.join(config["out_dir"], "epoch.out")
        self.t_log = TabLogger(t_log_fpath, resume=bool(resume))  # tab logger
        self.t_log.set_names([
            "Epoch",
            "Train Task Loss",
            "Train Acc",
            "Eval Task Loss",
            "Eval Acc",
            "LR",
        ])

        if resume:
            self.logger.info("Resuming from checkpoint: %s", resume)
            res_chkpt = torch.load(resume, map_location=self.map_location)
            self.model.load_state_dict(res_chkpt["state_dict"])
            self.start_epoch = res_chkpt.get("epoch", 0)
            self.best_acc1 = res_chkpt.get("best_acc1", 0)
            optim_state_dict = res_chkpt.get("optim_state_dict")
            if optim_state_dict:
                self.optimizer.load_state_dict(optim_state_dict)
            self.logger.info(
                "Resumed at epoch %d, eval best_acc1 %.2f",
                self.start_epoch,
                self.best_acc1,
            )
            # fastforward LR to match current schedule
            for sched in self.schedule:
                if sched > self.start_epoch:
                    break
                new_lr = adjust_learning_rate(
                    self.optimizer,
                    sched,
                    lr=self.lr,
                    schedule=self.schedule,
                    gamma=self.gamma,
                )
                self.logger.info(
                    "LR fastforward from %(old)f to %(new)f at Epoch %(epoch)d",
                    {
                        "old": self.lr,
                        "new": new_lr,
                        "epoch": sched
                    },
                )
                self.lr = new_lr

        self.logger.info(
            "Training from Epoch %(start)d to %(end)d",
            {
                "start": self.start_epoch,
                "end": self.epochs
            },
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
예제 #7
0
    def __init__(self, config):
        super(MultiResolutionFineTuneClassifier, self).__init__(config)
        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(
            cpu_only, config.get("gpu_ids"), logger=self.logger
        )

        # Instantiate Datasets and Dataloaders
        # self.train_set, self.train_loader = init_data(config.get("train_data"))
        # self.eval_set, self.eval_loader = init_data(config.get("eval_data"))
        ds = ImageFolder(
            "data/c617a1/train_eval",
            transform=Compose(
                [
                    RandomHorizontalFlip(),
                    ToTensor(),
                    Normalize(
                        [0.663295328617096, 0.6501832604408264, 0.6542291045188904],
                        [0.19360290467739105, 0.22194330394268036, 0.23059576749801636],
                    ),
                    Lambda(expand_multires),
                ]
            ),
        )
        train_set_ratio, eval_set_ratio = config.get(
            "train_eval_split_ratio", [0.85, 0.15]
        )
        train_len = ceil(len(ds) * train_set_ratio)
        eval_len = len(ds) - train_len
        self.train_set, self.eval_set = random_split(ds, [train_len, eval_len])
        self.train_loader = DataLoader(
            self.train_set,
            num_workers=config.get("num_workers", 0),
            batch_size=config.get("batch_size", 128),
            shuffle=True
        )
        self.eval_loader = DataLoader(
            self.eval_set, 
            num_workers=config.get("num_workers", 0),
            batch_size=config.get("batch_size", 128),
            shuffle=True
        )

        # Instantiate Model
        base_model = init_class(config.get("model"))
        # Freeze all of the parameters (except for final classification layer which we add afterwards)
        for param in base_model.parameters():
            param.requires_grad = False

        # manually convert pretrained model into a binary classification problem
        base_out_size = 300
        base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(1280, base_out_size)
        )
        self.model = MultiResolutionModelWrapper(base_model, base_out_size=base_out_size)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.optimizer = init_class(config.get("optimizer"), self.model.parameters())

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.schedule = config.get("schedule", [150, 225])
        self.gamma = config.get("gamma", 0.1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc1 = 0

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {"params": num_params, "lrn_params": num_lrn_p},
        )
        self.logger.info(
            "LR: %(lr)f decreasing by a factor of %(gamma)f at epochs %(schedule)s",
            {"lr": self.lr, "gamma": self.gamma, "schedule": self.schedule},
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
    def __init__(self, config):
        super(AdaptivePruningAgent, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)

        # Instantiate Datasets and Dataloaders
        self.train_set, self.train_loader = init_data(config.get("train_data"))
        self.eval_set, self.eval_loader = init_data(config.get("eval_data"))

        # Instantiate Models
        self.pretrained_model = init_class(config.get("pretrained_model"))
        self.model = init_class(config.get("model"))

        # Load the pretrained weights
        map_location = None if self.use_cuda else torch.device("cpu")
        prune_checkpoint = config["prune"]
        self.logger.info("Loading pretrained model from checkpoint: %s",
                         prune_checkpoint)
        prune_checkpoint = torch.load(prune_checkpoint,
                                      map_location=map_location)
        self.pretrained_model.load_state_dict(prune_checkpoint["state_dict"])
        self.pretrained_model.eval()
        modules_pretrained = list(self.pretrained_model.modules())
        modules_to_prune = list(self.model.modules())
        module_idx = 0
        for module_to_prune in modules_to_prune:
            module_pretrained = modules_pretrained[module_idx]
            modstr = str(type(module_to_prune))
            # Skip the masking layers
            if type(module_to_prune) == MaskSTE:
                continue
            if len(list(module_to_prune.children())) == 0:
                assert modstr == str(type(module_pretrained))
                # copy all parameters over
                param_lookup = dict(module_pretrained.named_parameters())
                for param_key, param_val in module_to_prune.named_parameters():
                    param_val.data.copy_(param_lookup[param_key].data)
                # BatchNorm layers are special and require copying of running_mean/running_var
                if "BatchNorm" in modstr:
                    module_to_prune.running_mean.copy_(
                        module_pretrained.running_mean)
                    module_to_prune.running_var.copy_(
                        module_pretrained.running_var)
            module_idx += 1

        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.mask_loss_fn = init_class(config.get("mask_loss"))
        self.optimizer = init_class(config.get("optimizer"),
                                    self.model.parameters())
        self.temperature = config.get("temperature", 4.0)
        self.task_loss_reg = config.get("task_loss_reg", 1.0)
        self.mask_loss_reg = config.get("mask_loss_reg", 1.0)
        self.kd_loss_reg = config.get("kd_loss_reg", 1.0)

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.gamma = config.get("gamma", 1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc_per_usage = {}

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum(
            [p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {
                "params": num_params,
                "lrn_params": num_lrn_p
            },
        )

        self.budget = config.get("budget", 4300000)
        self.criteria = config.get("criteria", "parameters")
        if self.criteria == "parameters":
            self.og_usage = sum(self.calculate_model_parameters()).data.item()
        else:
            raise NotImplementedError("Unknown criteria: {}".format(
                self.criteria))
        self.logger.info("Pruning from {:.2e} {} to {:.2e} {}.".format(
            self.og_usage, self.criteria, self.budget, self.criteria))
        self.short_term_fine_tune_patience = config.get(
            "short_term_fine_tune_patience", 2)
        self.long_term_fine_tune_patience = config.get(
            "long_term_fine_tune_patience", 4)

        t_log_fpath = os.path.join(config["out_dir"], "epoch.out")
        self.t_log = TabLogger(t_log_fpath)
        self.t_log.set_names([
            "Epoch", "Train Task Loss", "Train KD Loss", "Train Mask Loss",
            "Train Acc", "Eval Task Loss", "Eval KD Loss", "Eval Mask Loss",
            "Eval Acc", "LR", "Parameters"
        ])

        self.logger.info(
            "Training from Epoch %(start)d to %(end)d",
            {
                "start": self.start_epoch,
                "end": self.epochs
            },
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.pretrained_model = torch.nn.DataParallel(
                    self.pretrained_model).cuda()
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.pretrained_model = self.pretrained_model.cuda()
                self.model = self.model.cuda()