Ejemplo n.º 1
0
    def __init__(self, config):
        super(TestSetEvaluator, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(
            cpu_only, config.get("gpu_ids"), logger=self.logger
        )

        # Instantiate Datasets and Dataloaders
        # self.test_set, self.test_loader = init_data(config.get("test_data"))
        base_transforms = [
            ToTensor(),
            Normalize(
                [0.663295328617096, 0.6501832604408264, 0.6542291045188904],
                [0.19360290467739105, 0.22194330394268036, 0.23059576749801636],
            ),
            Lambda(expand_multires),
        ]
        self.test_set = ImageFolder(
            "data/c617a1/test", transform=Compose(base_transforms)
        )
        self.test_loader = DataLoader(
            self.test_set,
            num_workers=config.get("num_workers", 0),
            batch_size=config.get("batch_size", 1),
            shuffle=True,
        )

        # Instantiate Model
        base_model = init_class(config.get("model"))
        # manually convert pretrained model into a binary classification problem
        base_out_size = 300
        base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(1280, base_out_size)
        )
        self.model = MultiResolutionModelWrapper(base_model, base_out_size=base_out_size)

        self.logger.info("Test Dataset: %s", self.test_set)
        self.checkpoints = config.get("checkpoints")
        self.map_location = None if self.use_cuda else torch.device("cpu")
        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
Ejemplo n.º 2
0
    def __init__(self, config):
        super(ModelProfilerAgent, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)

        # Instantiate Datasets and Dataloaders
        self.eval_set, self.eval_loader = init_data(config.get("eval_data"))

        # Instantiate Models
        self.model = init_class(config.get("model"))
        self.model.eval()
        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Log the classification experiment details
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Model: %s", self.model)
        self.logger.info("Batch Size: %d", self.eval_loader.batch_size)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum(
            [p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {
                "params": num_params,
                "lrn_params": num_lrn_p
            },
        )

        t_log_fpath = os.path.join(config["out_dir"], "profiler.out")
        self.t_log = TabLogger(t_log_fpath)
        self.t_log.set_names([
            "Batch Size", "Self CPU Time", "CPU Time Total", "CUDA Time Total"
        ])
        self.logger.info("Storing tab log output at: %s", t_log_fpath)
Ejemplo n.º 3
0
    def __init__(self, config):
        super(MaskablePackingAgent, self).__init__(config)

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)
        map_location = None if self.use_cuda else torch.device("cpu")

        self.model = init_class(config.get("model"))
        self.apply_sigmoid = config.get("apply_sigmoid", True)
        pack_checkpoint = config["pack"]
        self.logger.info("Packing model from checkpoint: %s", pack_checkpoint)
        pack_checkpoint = torch.load(pack_checkpoint,
                                     map_location=map_location)
        self.model.load_state_dict(pack_checkpoint["state_dict"])
Ejemplo n.º 4
0
    def __init__(self, config):
        super(TestSetEvaluator, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)

        # Instantiate Datasets and Dataloaders
        self.test_set, self.test_loader = init_data(config.get("test_data"))

        # Instantiate Model
        self.model = init_class(config.get("model"))
        # manually convert pretrained model into a binary classification problem
        self.model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(1280, 2))

        self.logger.info("Test Dataset: %s", self.test_set)
        self.checkpoints = config.get("checkpoints")
        self.map_location = None if self.use_cuda else torch.device("cpu")
        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
Ejemplo n.º 5
0
    def __init__(self, config):
        super(ClassificationAgent, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(
            cpu_only, config.get("gpu_ids"), logger=self.logger
        )

        # Instantiate Datasets and Dataloaders
        self.train_set, self.train_loader = init_data(config.get("train_data"))
        self.eval_set, self.eval_loader = init_data(config.get("eval_data"))

        # Instantiate Model
        self.model = init_class(config.get("model"))
        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.optimizer = init_class(config.get("optimizer"), self.model.parameters())

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.schedule = config.get("schedule", [150, 225])
        self.gamma = config.get("gamma", 0.1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc1 = 0

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {"params": num_params, "lrn_params": num_lrn_p},
        )
        self.logger.info(
            "LR: %(lr)f decreasing by a factor of %(gamma)f at epochs %(schedule)s",
            {"lr": self.lr, "gamma": self.gamma, "schedule": self.schedule},
        )

        # Path to in progress checkpoint.pth.tar for resuming experiment
        resume = config.get("resume")
        t_log_fpath = os.path.join(config["out_dir"], "epoch.out")
        self.t_log = TabLogger(t_log_fpath, resume=bool(resume))  # tab logger
        self.t_log.set_names(
            [
                "Epoch",
                "Train Task Loss",
                "Train Acc",
                "Eval Task Loss",
                "Eval Acc",
                "LR",
            ]
        )
        if resume:
            self.logger.info("Resuming from checkpoint: %s", resume)
            res_chkpt = torch.load(resume)
            self.model.load_state_dict(res_chkpt["state_dict"])
            self.start_epoch = res_chkpt.get("epoch", 0)
            self.best_acc1 = res_chkpt.get("best_acc1", 0)
            optim_state_dict = res_chkpt.get("optim_state_dict")
            if optim_state_dict:
                self.optimizer.load_state_dict(optim_state_dict)
            self.logger.info(
                "Resumed at epoch %d, eval best_acc1 %.2f",
                self.start_epoch,
                self.best_acc1,
            )
            # fastforward LR to match current schedule
            for sched in self.schedule:
                if sched > self.start_epoch:
                    break
                new_lr = adjust_learning_rate(
                    self.optimizer,
                    sched,
                    lr=self.lr,
                    schedule=self.schedule,
                    gamma=self.gamma,
                )
                self.logger.info(
                    "LR fastforward from %(old)f to %(new)f at Epoch %(epoch)d",
                    {"old": self.lr, "new": new_lr, "epoch": sched},
                )
                self.lr = new_lr

        self.logger.info(
            "Training from Epoch %(start)d to %(end)d",
            {"start": self.start_epoch, "end": self.epochs},
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
    def __init__(self, config):
        super(JointKnowledgeDistillationPruningAgent, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(
            cpu_only, config.get("gpu_ids"), logger=self.logger
        )

        # Instantiate Datasets and Dataloaders
        self.train_set, self.train_loader = init_data(config.get("train_data"))
        self.eval_set, self.eval_loader = init_data(config.get("eval_data"))

        # Instantiate Models
        self.pretrained_model = init_class(config.get("pretrained_model"))
        self.model = init_class(config.get("model"))

        # Load the pretrained weights
        map_location = None if self.use_cuda else torch.device("cpu")
        prune_checkpoint = config["prune"]
        self.logger.info(
            "Loading pretrained model from checkpoint: %s", prune_checkpoint
        )
        prune_checkpoint = torch.load(prune_checkpoint, map_location=map_location)
        self.pretrained_model.load_state_dict(prune_checkpoint["state_dict"])
        self.pretrained_model.eval()
        modules_pretrained = list(self.pretrained_model.modules())
        modules_to_prune = list(self.model.modules())
        module_idx = 0
        self.mask_modules = []
        for module_to_prune in modules_to_prune:
            module_pretrained = modules_pretrained[module_idx]
            modstr = str(type(module_to_prune))
            # Skip the masking layers
            if "MaskSTE" in modstr:
                self.mask_modules.append(module_to_prune)
                continue
            if len(list(module_to_prune.children())) == 0:
                assert modstr == str(type(module_pretrained))
                # copy all parameters over
                param_lookup = dict(module_pretrained.named_parameters())
                for param_key, param_val in module_to_prune.named_parameters():
                    param_val.data.copy_(param_lookup[param_key].data)
                # BatchNorm layers are special and require copying of running_mean/running_var
                if "BatchNorm" in modstr:
                    module_to_prune.running_mean.copy_(module_pretrained.running_mean)
                    module_to_prune.running_var.copy_(module_pretrained.running_var)
            module_idx += 1

        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.mask_loss_fn = init_class(config.get("mask_loss"))
        self.optimizer = init_class(config.get("optimizer"), self.model.parameters())
        self.temperature = config.get("temperature", 4.0)
        self.task_loss_reg = config.get("task_loss_reg", 1.0)
        self.mask_loss_reg = config.get("mask_loss_reg", 1.0)
        self.kd_loss_reg = config.get("kd_loss_reg", 1.0)

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.schedule = config.get("schedule", [150, 225])
        self.gamma = config.get("gamma", 0.1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc_per_usage = {}

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {"params": num_params, "lrn_params": num_lrn_p},
        )
        self.logger.info(
            "LR: %(lr)f decreasing by a factor of %(gamma)f at epochs %(schedule)s",
            {"lr": self.lr, "gamma": self.gamma, "schedule": self.schedule},
        )

        # Path to in progress checkpoint.pth.tar for resuming experiment
        resume = config.get("resume")
        t_log_fpath = os.path.join(config["out_dir"], "epoch.out")
        self.t_log = TabLogger(t_log_fpath, resume=bool(resume))
        self.t_log.set_names(
            [
                "Epoch",
                "Train Task Loss",
                "Train KD Loss",
                "Train Mask Loss",
                "Train Acc",
                "Eval Task Loss",
                "Eval KD Loss",
                "Eval Mask Loss",
                "Eval Acc",
                "Num Parameters",
                "LR",
            ]
        )
        if resume:
            self.logger.info("Resuming from checkpoint: %s", resume)
            res_chkpt = torch.load(resume, map_location=map_location)
            self.start_epoch = res_chkpt["epoch"]
            self.model.load_state_dict(res_chkpt["state_dict"])
            eval_acc = res_chkpt["acc"]
            self.best_acc_per_usage = res_chkpt["best_acc_per_usage"]
            self.optimizer.load_state_dict(res_chkpt["optim_state_dict"])
            self.logger.info(
                "Resumed at epoch %d, eval acc %.2f", self.start_epoch, eval_acc
            )
            self.logger.info(pformat(self.best_acc_per_usage))
            # fastforward LR to match current schedule
            for sched in self.schedule:
                if sched > self.start_epoch:
                    break
                new_lr = adjust_learning_rate(
                    self.optimizer,
                    sched,
                    lr=self.lr,
                    schedule=self.schedule,
                    gamma=self.gamma,
                )
                self.logger.info(
                    "LR fastforward from %(old)f to %(new)f at Epoch %(epoch)d",
                    {"old": self.lr, "new": new_lr, "epoch": sched},
                )
                self.lr = new_lr

        self.logger.info(
            "Training from Epoch %(start)d to %(end)d",
            {"start": self.start_epoch, "end": self.epochs},
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.pretrained_model = torch.nn.DataParallel(
                    self.pretrained_model
                ).cuda()
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.pretrained_model = self.pretrained_model.cuda()
                self.model = self.model.cuda()
Ejemplo n.º 7
0
    def __init__(self, config):
        super(FineTuneClassifier, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)

        # Instantiate Datasets and Dataloaders
        # self.train_set, self.train_loader = init_data(config.get("train_eval_data"))
        train_eval_config = config.get("train_eval_data")
        ds_class = fetch_class(train_eval_config["name"])
        d_transform = list(
            map(init_class, train_eval_config.get("transform", [])))
        d_ttransform = list(
            map(init_class, train_eval_config.get("target_transform", [])))
        ds = ds_class(
            *train_eval_config.get("args", []),
            **train_eval_config.get("kwargs", {}),
            transform=Compose(d_transform) if d_transform else None,
            target_transform=Compose(d_ttransform) if d_ttransform else None)
        train_set_ratio, eval_set_ratio = train_eval_config.get(
            "train_eval_split_ratio", [0.85, 0.15])
        train_len = ceil(len(ds) * train_set_ratio)
        eval_len = len(ds) - train_len
        self.train_set, self.eval_set = random_split(ds, [train_len, eval_len])
        self.train_loader = DataLoader(
            self.train_set, **train_eval_config.get("dataloader_kwargs", {}))
        self.eval_loader = DataLoader(
            self.eval_set, **train_eval_config.get("dataloader_kwargs", {}))

        # Instantiate Model
        self.model = init_class(config.get("model"))
        # Freeze all of the parameters (except for final classification layer which we add afterwards)
        for param in self.model.parameters():
            param.requires_grad = False
        # manually convert pretrained model into a binary classification problem
        self.model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(1280, 2))
        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.optimizer = init_class(config.get("optimizer"),
                                    self.model.parameters())

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.schedule = config.get("schedule", [150, 225])
        self.gamma = config.get("gamma", 0.1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc1 = 0

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum(
            [p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {
                "params": num_params,
                "lrn_params": num_lrn_p
            },
        )
        self.logger.info(
            "LR: %(lr)f decreasing by a factor of %(gamma)f at epochs %(schedule)s",
            {
                "lr": self.lr,
                "gamma": self.gamma,
                "schedule": self.schedule
            },
        )

        # Path to in progress checkpoint.pth.tar for resuming experiment
        self.map_location = None if self.use_cuda else torch.device("cpu")
        resume = config.get("resume")
        t_log_fpath = os.path.join(config["out_dir"], "epoch.out")
        self.t_log = TabLogger(t_log_fpath, resume=bool(resume))  # tab logger
        self.t_log.set_names([
            "Epoch",
            "Train Task Loss",
            "Train Acc",
            "Eval Task Loss",
            "Eval Acc",
            "LR",
        ])

        if resume:
            self.logger.info("Resuming from checkpoint: %s", resume)
            res_chkpt = torch.load(resume, map_location=self.map_location)
            self.model.load_state_dict(res_chkpt["state_dict"])
            self.start_epoch = res_chkpt.get("epoch", 0)
            self.best_acc1 = res_chkpt.get("best_acc1", 0)
            optim_state_dict = res_chkpt.get("optim_state_dict")
            if optim_state_dict:
                self.optimizer.load_state_dict(optim_state_dict)
            self.logger.info(
                "Resumed at epoch %d, eval best_acc1 %.2f",
                self.start_epoch,
                self.best_acc1,
            )
            # fastforward LR to match current schedule
            for sched in self.schedule:
                if sched > self.start_epoch:
                    break
                new_lr = adjust_learning_rate(
                    self.optimizer,
                    sched,
                    lr=self.lr,
                    schedule=self.schedule,
                    gamma=self.gamma,
                )
                self.logger.info(
                    "LR fastforward from %(old)f to %(new)f at Epoch %(epoch)d",
                    {
                        "old": self.lr,
                        "new": new_lr,
                        "epoch": sched
                    },
                )
                self.lr = new_lr

        self.logger.info(
            "Training from Epoch %(start)d to %(end)d",
            {
                "start": self.start_epoch,
                "end": self.epochs
            },
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
Ejemplo n.º 8
0
    def __init__(self, config):
        super(MultiResolutionFineTuneClassifier, self).__init__(config)
        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(
            cpu_only, config.get("gpu_ids"), logger=self.logger
        )

        # Instantiate Datasets and Dataloaders
        # self.train_set, self.train_loader = init_data(config.get("train_data"))
        # self.eval_set, self.eval_loader = init_data(config.get("eval_data"))
        ds = ImageFolder(
            "data/c617a1/train_eval",
            transform=Compose(
                [
                    RandomHorizontalFlip(),
                    ToTensor(),
                    Normalize(
                        [0.663295328617096, 0.6501832604408264, 0.6542291045188904],
                        [0.19360290467739105, 0.22194330394268036, 0.23059576749801636],
                    ),
                    Lambda(expand_multires),
                ]
            ),
        )
        train_set_ratio, eval_set_ratio = config.get(
            "train_eval_split_ratio", [0.85, 0.15]
        )
        train_len = ceil(len(ds) * train_set_ratio)
        eval_len = len(ds) - train_len
        self.train_set, self.eval_set = random_split(ds, [train_len, eval_len])
        self.train_loader = DataLoader(
            self.train_set,
            num_workers=config.get("num_workers", 0),
            batch_size=config.get("batch_size", 128),
            shuffle=True
        )
        self.eval_loader = DataLoader(
            self.eval_set, 
            num_workers=config.get("num_workers", 0),
            batch_size=config.get("batch_size", 128),
            shuffle=True
        )

        # Instantiate Model
        base_model = init_class(config.get("model"))
        # Freeze all of the parameters (except for final classification layer which we add afterwards)
        for param in base_model.parameters():
            param.requires_grad = False

        # manually convert pretrained model into a binary classification problem
        base_out_size = 300
        base_model.classifier = torch.nn.Sequential(
            torch.nn.Dropout(p=0.2, inplace=True), torch.nn.Linear(1280, base_out_size)
        )
        self.model = MultiResolutionModelWrapper(base_model, base_out_size=base_out_size)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.optimizer = init_class(config.get("optimizer"), self.model.parameters())

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.schedule = config.get("schedule", [150, 225])
        self.gamma = config.get("gamma", 0.1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc1 = 0

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum([p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {"params": num_params, "lrn_params": num_lrn_p},
        )
        self.logger.info(
            "LR: %(lr)f decreasing by a factor of %(gamma)f at epochs %(schedule)s",
            {"lr": self.lr, "gamma": self.gamma, "schedule": self.schedule},
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.model = self.model.cuda()
    def __init__(self, config):
        super(AdaptivePruningAgent, self).__init__(config)

        # TensorBoard Summary Writer
        self.tb_sw = SummaryWriter(log_dir=config["tb_dir"])
        self.tb_sw.add_text("config", str(config))

        # Configure seed, if provided
        seed = config.get("seed")
        set_seed(seed, logger=self.logger)
        self.tb_sw.add_text("seed", str(seed))

        # Setup CUDA
        cpu_only = config.get("cpu_only", False)
        self.use_cuda, self.gpu_ids = set_cuda_devices(cpu_only,
                                                       config.get("gpu_ids"),
                                                       logger=self.logger)

        # Instantiate Datasets and Dataloaders
        self.train_set, self.train_loader = init_data(config.get("train_data"))
        self.eval_set, self.eval_loader = init_data(config.get("eval_data"))

        # Instantiate Models
        self.pretrained_model = init_class(config.get("pretrained_model"))
        self.model = init_class(config.get("model"))

        # Load the pretrained weights
        map_location = None if self.use_cuda else torch.device("cpu")
        prune_checkpoint = config["prune"]
        self.logger.info("Loading pretrained model from checkpoint: %s",
                         prune_checkpoint)
        prune_checkpoint = torch.load(prune_checkpoint,
                                      map_location=map_location)
        self.pretrained_model.load_state_dict(prune_checkpoint["state_dict"])
        self.pretrained_model.eval()
        modules_pretrained = list(self.pretrained_model.modules())
        modules_to_prune = list(self.model.modules())
        module_idx = 0
        for module_to_prune in modules_to_prune:
            module_pretrained = modules_pretrained[module_idx]
            modstr = str(type(module_to_prune))
            # Skip the masking layers
            if type(module_to_prune) == MaskSTE:
                continue
            if len(list(module_to_prune.children())) == 0:
                assert modstr == str(type(module_pretrained))
                # copy all parameters over
                param_lookup = dict(module_pretrained.named_parameters())
                for param_key, param_val in module_to_prune.named_parameters():
                    param_val.data.copy_(param_lookup[param_key].data)
                # BatchNorm layers are special and require copying of running_mean/running_var
                if "BatchNorm" in modstr:
                    module_to_prune.running_mean.copy_(
                        module_pretrained.running_mean)
                    module_to_prune.running_var.copy_(
                        module_pretrained.running_var)
            module_idx += 1

        try:
            # Try to visualize tensorboard model graph structure
            model_input, _target = next(iter(self.eval_set))
            self.tb_sw.add_graph(self.model, model_input.unsqueeze(0))
        except Exception as e:
            self.logger.warn(e)

        # Instantiate task loss and optimizer
        self.task_loss_fn = init_class(config.get("task_loss"))
        self.mask_loss_fn = init_class(config.get("mask_loss"))
        self.optimizer = init_class(config.get("optimizer"),
                                    self.model.parameters())
        self.temperature = config.get("temperature", 4.0)
        self.task_loss_reg = config.get("task_loss_reg", 1.0)
        self.mask_loss_reg = config.get("mask_loss_reg", 1.0)
        self.kd_loss_reg = config.get("kd_loss_reg", 1.0)

        # Misc. Other classification hyperparameters
        self.epochs = config.get("epochs", 300)
        self.start_epoch = config.get("start_epoch", 0)
        self.gamma = config.get("gamma", 1)
        self.lr = self.optimizer.param_groups[0]["lr"]
        self.best_acc_per_usage = {}

        # Log the classification experiment details
        self.logger.info("Train Dataset: %s", self.train_set)
        self.logger.info("Eval Dataset: %s", self.eval_set)
        self.logger.info("Task Loss (Criterion): %s", self.task_loss_fn)
        self.logger.info("Model Optimizer: %s", self.optimizer)
        self.logger.info("Model: %s", self.model)
        num_params = sum([p.numel() for p in self.model.parameters()])
        num_lrn_p = sum(
            [p.numel() for p in self.model.parameters() if p.requires_grad])
        self.logger.info(
            "Num Parameters: %(params)d (%(lrn_params)d requires gradient)",
            {
                "params": num_params,
                "lrn_params": num_lrn_p
            },
        )

        self.budget = config.get("budget", 4300000)
        self.criteria = config.get("criteria", "parameters")
        if self.criteria == "parameters":
            self.og_usage = sum(self.calculate_model_parameters()).data.item()
        else:
            raise NotImplementedError("Unknown criteria: {}".format(
                self.criteria))
        self.logger.info("Pruning from {:.2e} {} to {:.2e} {}.".format(
            self.og_usage, self.criteria, self.budget, self.criteria))
        self.short_term_fine_tune_patience = config.get(
            "short_term_fine_tune_patience", 2)
        self.long_term_fine_tune_patience = config.get(
            "long_term_fine_tune_patience", 4)

        t_log_fpath = os.path.join(config["out_dir"], "epoch.out")
        self.t_log = TabLogger(t_log_fpath)
        self.t_log.set_names([
            "Epoch", "Train Task Loss", "Train KD Loss", "Train Mask Loss",
            "Train Acc", "Eval Task Loss", "Eval KD Loss", "Eval Mask Loss",
            "Eval Acc", "LR", "Parameters"
        ])

        self.logger.info(
            "Training from Epoch %(start)d to %(end)d",
            {
                "start": self.start_epoch,
                "end": self.epochs
            },
        )

        # Support multiple GPUs using DataParallel
        if self.use_cuda:
            if len(self.gpu_ids) > 1:
                self.pretrained_model = torch.nn.DataParallel(
                    self.pretrained_model).cuda()
                self.model = torch.nn.DataParallel(self.model).cuda()
            else:
                self.pretrained_model = self.pretrained_model.cuda()
                self.model = self.model.cuda()
    def run(self):
        self.exp_start = time()

        epoch_type = "Sparsity"  # FineTune
        best_tune_eval_acc = 0
        fine_tune_counter = 0
        budget_acheived = False
        lr_decreased = False

        for epoch in range(self.start_epoch, self.epochs):
            epoch_start = time()

            for p in self.model.parameters():
                p.requires_grad = True

            # Get current usage
            if epoch_type == "Sparsity":
                if self.criteria == "parameters":
                    pre_epoch_usage = sum(
                        self.calculate_model_parameters()).data.item()
                else:
                    raise NotImplementedError("Unknown criteria: {}".format(
                        self.criteria))

            # Joint Sparsity Training
            train_res = self.run_epoch_pass(epoch=epoch,
                                            train=True,
                                            epoch_type=epoch_type)
            with torch.no_grad():
                eval_res = self.run_epoch_pass(epoch=epoch,
                                               train=False,
                                               epoch_type=epoch_type)
            epoch_elapsed = time() - epoch_start
            self.log_epoch_info(epoch, train_res, eval_res, epoch_type,
                                epoch_elapsed)

            if epoch_type == "Sparsity":
                post_epoch_usage = sum(
                    self.calculate_model_parameters()).data.item()

                if post_epoch_usage < pre_epoch_usage:
                    self.logger.info(
                        "Masking %s reduced from %.2e to %.2e (diff: %d) Packing...",
                        self.criteria,
                        pre_epoch_usage,
                        post_epoch_usage,
                        pre_epoch_usage - post_epoch_usage,
                    )
                    # time to pack, transfer weights, and short term fine tune
                    modules = list(self.model.modules())
                    if type(modules[0]) == VGG:
                        binary_masks = [torch.tensor([1, 1, 1])]
                        make_layers_config, pack_model = MaskablePackingAgent.gen_vgg_make_layers(
                            modules, binary_masks, use_cuda=self.use_cuda)
                        self.logger.info("New VGG configuration: %s",
                                         make_layers_config)
                        MaskablePackingAgent.transfer_vgg_parameters(
                            self.model, pack_model, binary_masks)
                        self.make_layers_config = make_layers_config
                        self.model = pack_model
                        self.optimizer = init_class(
                            self.config.get("optimizer"),
                            self.model.parameters())
                        post_epoch_usage = sum(
                            [p.numel() for p in self.model.parameters()])
                        self.logger.info(
                            "Packed model contains %.2e %s!",
                            post_epoch_usage,
                            self.criteria,
                        )
                        epoch_type = "FineTune"
                        best_tune_eval_acc = eval_res["top1_acc"]
                        fine_tune_counter = 0
                    else:
                        raise NotImplementedError(
                            "Cannot pack sparse module: %s", modules[0])

                    # Special case, do not fine tune in between Sparsity epochs
                    if (self.short_term_fine_tune_patience == 0):
                        if post_epoch_usage <= self.budget:
                            self.logger.info(
                                "Sparsity constraint reached, beginning long term fine tuning..."
                            )
                        else:
                            self.logger.info(
                                "Skipping Fine Tuning Epoch, Re-masking model for sparisty pruning..."
                            )
                            self.model = MaskablePackingAgent.insert_masks_into_model(
                                self.model, use_cuda=self.use_cuda)
                            self.optimizer = init_class(
                                self.config.get("optimizer"),
                                self.model.parameters())
                            epoch_type = "Sparsity"

            elif epoch_type == "FineTune":
                cur_eval_acc = eval_res["top1_acc"]
                if cur_eval_acc > best_tune_eval_acc:
                    fine_tune_counter = 0
                    best_tune_eval_acc = cur_eval_acc
                    self.logger.info(
                        "Resetting Fine Tune Counter, New Best Tune Evaluation Acc: %.2f",
                        best_tune_eval_acc,
                    )
                else:
                    fine_tune_counter += 1
                    self.logger.info("Fine Tune Counter: %d",
                                     fine_tune_counter)

                if post_epoch_usage <= self.budget:
                    if fine_tune_counter >= self.long_term_fine_tune_patience:
                        if lr_decreased:
                            budget_acheived = True
                        else:
                            new_lr = self.lr * 0.1
                            for param_group in self.optimizer.param_groups:
                                param_group["lr"] = new_lr
                            self.logger.info(
                                "Long Term Fine Tuning, LR Decreased from %.1e to %.1e",
                                self.lr,
                                new_lr,
                            )
                            self.lr = new_lr
                            lr_decreased = True
                            fine_tune_counter = 0
                else:
                    if fine_tune_counter >= self.short_term_fine_tune_patience:
                        # Time to learn sparsity, add in the mask layers now
                        self.model = MaskablePackingAgent.insert_masks_into_model(
                            self.model, use_cuda=self.use_cuda)
                        self.optimizer = init_class(
                            self.config.get("optimizer"),
                            self.model.parameters())
                        epoch_type = "Sparsity"

            if budget_acheived:
                self.logger.info(
                    "Budget %.2e %s acheived at %.2e %s (%d less)",
                    self.budget,
                    self.criteria,
                    post_epoch_usage,
                    self.criteria,
                    self.budget - post_epoch_usage,
                )
                break