コード例 #1
0
ファイル: main.py プロジェクト: liuyixin-louis/DCP_py3
    def _set_checkpoint(self):
        """
        Load pre-trained model or resume checkpoint
        """

        assert self.pruned_model is not None, "please create model first"

        self.checkpoint = AuxCheckPoint(self.settings.save_path, self.logger)
        self._load_pretrained()
コード例 #2
0
ファイル: main.py プロジェクト: qdmy/DCP-1
class Experiment(object):
    """
    Run experiments with pre-defined pipeline
    """
    def __init__(self, options=None, conf_path=None):
        self.settings = options or Option(conf_path)
        self.checkpoint = None
        self.train_loader = None
        self.val_loader = None
        self.model = None

        self.trainer = None
        self.seg_opt_state = None
        self.fc_opt_state = None
        self.aux_fc_state = None

        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu

        self.settings.set_save_path()
        write_settings(self.settings)
        self.logger = get_logger(self.settings.save_path, "auxnet")
        self.tensorboard_logger = TensorboardLogger(self.settings.save_path)
        self.settings.copy_code(self.logger,
                                src=os.path.abspath('./'),
                                dst=os.path.join(self.settings.save_path,
                                                 'code'))
        self.logger.info("|===>Result will be saved at {}".format(
            self.settings.save_path))

        self.epoch = 0

        self.prepare()

    def prepare(self):
        """
        Prepare experiments
        """

        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._cal_pivot()
        self._set_checkpoint()
        self._set_trainier()

    def _set_gpu(self):
        """
        Initialize the seed of random number generator
        """

        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.seed)
        torch.cuda.manual_seed(self.settings.seed)
        torch.cuda.set_device(0)
        cudnn.benchmark = True

    def _set_dataloader(self):
        """
        Create train loader and validation loader for auxnet
        """

        if 'cifar' in self.settings.dataset:
            self.train_loader, self.val_loader = get_cifar_dataloader(
                self.settings.dataset, self.settings.batch_size,
                self.settings.n_threads, self.settings.data_path, self.logger)
        elif self.settings.dataset == 'imagenet':
            self.train_loader, self.val_loader = get_imagenet_dataloader(
                self.settings.dataset, self.settings.batch_size,
                self.settings.n_threads, self.settings.data_path, self.logger)

    def _set_model(self):
        """
        Available model
        cifar:
            preresnet
        imagenet:
            resnet
        """

        self.model, _ = get_model(self.settings.dataset,
                                  self.settings.net_type, self.settings.depth,
                                  self.settings.n_classes)

    def _set_checkpoint(self):
        """
        Load pre-trained model or resume checkpoint
        """

        assert self.model is not None, "please create model first"

        self.checkpoint = AuxCheckPoint(self.settings.save_path, self.logger)
        self._load_pretrained()
        self._load_resume()

    def _load_pretrained(self):
        """
        Load pretrained model
        the pretrained model contains state_dicts of model
        """

        if self.settings.pretrained is not None:
            check_point_params = torch.load(self.settings.pretrained)
            model_state = check_point_params
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.logger.info("|===>load restrain file: {}".format(
                self.settings.pretrained))

    def _load_resume(self):
        """
        Load resume checkpoint
        the checkpoint contains state_dicts of model and optimizer, as well as training epoch
        """

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            model_state = check_point_params["model"]
            self.aux_fc_state = check_point_params["aux_fc"]
            self.seg_opt_state = check_point_params["seg_opt"]
            self.fc_opt_state = check_point_params["fc_opt"]
            self.epoch = check_point_params["epoch"]
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.logger.info("|===>load resume file: {}".format(
                self.settings.resume))

    def _set_trainier(self):
        """
        Initialize trainer for AuxNet
        """

        self.trainer = AuxTrainer(model=self.model,
                                  train_loader=self.train_loader,
                                  val_loader=self.val_loader,
                                  settings=self.settings,
                                  logger=self.logger,
                                  tensorboard_logger=self.tensorboard_logger)
        if self.seg_opt_state is not None:
            self.trainer.update_aux_fc(aux_fc_state=self.aux_fc_state,
                                       aux_fc_opt_state=self.fc_opt_state,
                                       seg_opt_state=self.seg_opt_state)
        if self.aux_fc_state is not None:
            self.trainer.update_aux_fc(aux_fc_state=self.aux_fc_state,
                                       aux_fc_opt_state=self.fc_opt_state,
                                       seg_opt_state=self.seg_opt_state)

    def _cal_pivot(self):
        """
        Calculate the block index for additional loss
        """

        self.num_segments, self.settings.pivot_set = cal_pivot(
            self.settings.n_losses, self.settings.net_type,
            self.settings.depth, self.logger)

    def run(self):
        """
        Learn the parameters of the additional classifier and
        fine tune model with the additional losses and the final loss
        """

        best_top1 = 100
        best_top5 = 100

        start_epoch = 0
        # if load resume checkpoint
        if self.epoch != 0:
            start_epoch = self.epoch + 1
            self.epoch = 0

        self.trainer.val(0)

        for epoch in range(start_epoch, self.settings.n_epochs):
            train_error, train_loss, train5_error = self.trainer.train(epoch)
            val_error, val_loss, val5_error = self.trainer.val(epoch)

            # write log
            log_str = "{:d}\t".format(epoch)
            for i in range(len(train_error)):
                log_str += "{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t".format(
                    train_error[i], train_loss[i], val_error[i], val_loss[i],
                    train5_error[i], val5_error[i])
            write_log(self.settings.save_path, 'log.txt', log_str)

            # save model and checkpoint
            best_flag = False
            if best_top1 >= val_error[-1]:
                best_top1 = val_error[-1]
                best_top5 = val5_error[-1]
                best_flag = True

            if best_flag:
                self.checkpoint.save_aux_model(self.trainer.model,
                                               self.trainer.aux_fc)

            self.logger.info(
                "|===>Best Result is: Top1 Error: {:f}, Top5 Error: {:f}\n".
                format(best_top1, best_top5))
            self.logger.info(
                "|==>Best Result is: Top1 Accuracy: {:f}, Top5 Accuracy: {:f}\n"
                .format(100 - best_top1, 100 - best_top5))

            if self.settings.dataset in ["imagenet"]:
                self.checkpoint.save_aux_checkpoint(self.trainer.model,
                                                    self.trainer.seg_optimizer,
                                                    self.trainer.fc_optimizer,
                                                    self.trainer.aux_fc, epoch,
                                                    epoch + 1)
            else:
                self.checkpoint.save_aux_checkpoint(self.trainer.model,
                                                    self.trainer.seg_optimizer,
                                                    self.trainer.fc_optimizer,
                                                    self.trainer.aux_fc, epoch)
コード例 #3
0
ファイル: main.py プロジェクト: liuyixin-louis/DCP_py3
class Experiment(object):
    """
    Run experiments with pre-defined pipeline
    """
    def __init__(self, options=None, conf_path=None):
        self.settings = options or Option(conf_path)
        self.checkpoint = None
        self.train_loader = None
        self.val_loader = None
        self.pruned_model = None
        self.network_wise_trainer = None
        self.optimizer_state = None
        self.aux_fc_state = None

        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu

        self.settings.set_save_path()
        write_settings(self.settings)
        self.logger = get_logger(self.settings.save_path, "finetune")
        self.tensorboard_logger = TensorboardLogger(self.settings.save_path)
        self.logger.info("|===>Result will be saved at {}".format(
            self.settings.save_path))
        self.epoch = 0
        self.test_input = None

        self.prepare()

    def write_settings(self):
        """
        Save expriment settings to a file
        """

        with open(os.path.join(self.settings.save_path, "settings.log"),
                  "w") as f:
            for k, v in self.settings.__dict__.items():
                f.write(str(k) + ": " + str(v) + "\n")

    def prepare(self):
        """
        Preparing experiments
        """

        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_trainier()

    def _set_gpu(self):
        """
        Initialize the seed of random number generator
        """

        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.seed)
        torch.cuda.manual_seed(self.settings.seed)
        torch.cuda.set_device(0)
        cudnn.benchmark = True

    def _set_dataloader(self):
        """
        Create train loader and validation loader for channel pruning
        """

        if 'cifar' in self.settings.dataset:
            self.train_loader, self.val_loader = get_cifar_dataloader(
                self.settings.dataset, self.settings.batch_size,
                self.settings.n_threads, self.settings.data_path, self.logger)
        elif self.settings.dataset == 'imagenet':
            self.train_loader, self.val_loader = get_imagenet_dataloader(
                self.settings.dataset, self.settings.batch_size,
                self.settings.n_threads, self.settings.data_path, self.logger)

    def _set_model(self):
        """
        Get model
        """

        if self.settings.dataset in ["cifar10", "cifar100"]:
            if self.settings.net_type == "preresnet":
                self.pruned_model = PrunedPreResNet(
                    depth=self.settings.depth,
                    pruning_rate=self.settings.pruning_rate,
                    num_classes=self.settings.n_classes)
            elif self.settings.net_type == "vgg":
                self.pruned_model = PrunedVGGCIFAR(
                    pruning_rate=self.settings.pruning_rate,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        elif self.settings.dataset in ["imagenet", "imagenet_mio"]:
            if self.settings.net_type == "resnet":
                self.pruned_model = PrunedResNet(
                    depth=self.settings.depth,
                    pruning_rate=self.settings.pruning_rate,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        else:
            assert False, "unsupported data set: {}".format(
                self.settings.dataset)

    def _set_checkpoint(self):
        """
        Load pre-trained model or resume checkpoint
        """

        assert self.pruned_model is not None, "please create model first"

        self.checkpoint = AuxCheckPoint(self.settings.save_path, self.logger)
        self._load_pretrained()

    def _load_pretrained(self):
        """
        Load pre-trained model
        """

        if self.settings.pretrained is not None:
            check_point_params = torch.load(self.settings.pretrained)
            model_state = check_point_params
            # self.aux_fc_state = check_point_params["aux_fc"]
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, model_state)
            self.logger.info("|===>load restrain file: {}".format(
                self.settings.pretrained))

    def _set_trainier(self):
        """
        Initialize network-wise trainer
        """

        self.trainer = Trainer(model=self.pruned_model,
                               train_loader=self.train_loader,
                               val_loader=self.val_loader,
                               settings=self.settings,
                               logger=self.logger,
                               tensorboard_logger=self.tensorboard_logger,
                               run_count=self.epoch)

    def evaluation(self):
        """
        Conduct network-wise fine-tuning after channel selection
        """

        self.trainer.val(0)
コード例 #4
0
class Experiment(object):
    """
    Run experiments with pre-defined pipeline
    """
    def __init__(self, options=None, conf_path=None):
        self.settings = options or Option(conf_path)
        self.checkpoint = None
        self.train_loader = None
        self.val_loader = None
        self.pruned_model = None
        self.network_wise_trainer = None
        self.optimizer_state = None
        self.aux_fc_state = None

        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu

        self.settings.set_save_path()
        write_settings(self.settings)
        self.logger = get_logger(self.settings.save_path, "finetune")
        self.tensorboard_logger = TensorboardLogger(self.settings.save_path)
        self.logger.info("|===>Result will be saved at {}".format(
            self.settings.save_path))
        self.epoch = 0
        self.test_input = None

        self.prepare()

    def write_settings(self):
        """
        Save expriment settings to a file
        """

        with open(os.path.join(self.settings.save_path, "settings.log"),
                  "w") as f:
            for k, v in self.settings.__dict__.items():
                f.write(str(k) + ": " + str(v) + "\n")

    def prepare(self):
        """
        Preparing experiments
        """

        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._cal_pivot()
        self._set_checkpoint()
        self._set_trainier()

    def _set_gpu(self):
        """
        Initialize the seed of random number generator
        """

        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.seed)
        torch.cuda.manual_seed(self.settings.seed)
        torch.cuda.set_device(0)
        cudnn.benchmark = True

    def _set_dataloader(self):
        """
        Create train loader and validation loader for channel pruning
        """

        if 'cifar' in self.settings.dataset:
            self.train_loader, self.val_loader = get_cifar_dataloader(
                self.settings.dataset, self.settings.batch_size,
                self.settings.n_threads, self.settings.data_path, self.logger)
        elif self.settings.dataset == 'imagenet':
            self.train_loader, self.val_loader = get_imagenet_dataloader(
                self.settings.dataset, self.settings.batch_size,
                self.settings.n_threads, self.settings.data_path, self.logger)

    def replace_layer_with_mask_conv_resnet(self):
        """
        Replace the conv layer in resnet with mask_conv for ResNet
        """

        for module in self.pruned_model.modules():
            if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)):
                # replace conv2
                temp_conv = MaskConv2d(in_channels=module.conv2.in_channels,
                                       out_channels=module.conv2.out_channels,
                                       kernel_size=module.conv2.kernel_size,
                                       stride=module.conv2.stride,
                                       padding=module.conv2.padding,
                                       bias=(module.conv2.bias is not None))

                temp_conv.weight.data.copy_(module.conv2.weight.data)
                if module.conv2.bias is not None:
                    temp_conv.bias.data.copy_(module.conv2.bias.data)
                module.conv2 = temp_conv

                if isinstance(module, Bottleneck):
                    # replace conv3
                    temp_conv = MaskConv2d(
                        in_channels=module.conv3.in_channels,
                        out_channels=module.conv3.out_channels,
                        kernel_size=module.conv3.kernel_size,
                        stride=module.conv3.stride,
                        padding=module.conv3.padding,
                        bias=(module.conv3.bias is not None))

                    temp_conv.weight.data.copy_(module.conv3.weight.data)
                    if module.conv3.bias is not None:
                        temp_conv.bias.data.copy_(module.conv3.bias.data)
                    module.conv3 = temp_conv

    def replace_layer_with_mask_conv(self):
        """
        Replace the conv layer in resnet with mask_conv
        """

        if self.settings.net_type in ["preresnet", "resnet"]:
            self.replace_layer_with_mask_conv_resnet()

    def _set_model(self):
        """
        Get model
        """

        self.pruned_model, self.test_input = get_model(self.settings.dataset,
                                                       self.settings.net_type,
                                                       self.settings.depth,
                                                       self.settings.n_classes)
        self.replace_layer_with_mask_conv()

    def _set_checkpoint(self):
        """
        Load pre-trained model or resume checkpoint
        """

        assert self.pruned_model is not None, "please create model first"

        self.checkpoint = AuxCheckPoint(self.settings.save_path, self.logger)
        self._load_pretrained()
        # self._load_resume()

    def _load_pretrained(self):
        """
        Load pre-trained model
        """

        if self.settings.pretrained is not None:
            check_point_params = torch.load(self.settings.pretrained)
            model_state = check_point_params["pruned_model"]
            self.aux_fc_state = check_point_params["aux_fc"]
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, model_state)
            self.logger.info("|===>load restrain file: {}".format(
                self.settings.pretrained))

    def _load_resume(self):
        """
        Load resume checkpoint
        """

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            pruned_model_state = check_point_params["model"]
            # self.optimizer_state = check_point_params["optimizer"]
            self.epoch = check_point_params["epoch"]
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, pruned_model_state)
            self.logger.info("|===>load resume file: {}".format(
                self.settings.resume))
            self.network_wise_trainer.run_count = self.epoch

    def _set_trainier(self):
        """
        Initialize network-wise trainer
        """

        self.network_wise_trainer = AuxTrainer(
            model=self.pruned_model,
            train_loader=self.train_loader,
            val_loader=self.val_loader,
            settings=self.settings,
            logger=self.logger,
            tensorboard_logger=self.tensorboard_logger,
            run_count=self.epoch)
        if self.aux_fc_state is not None:
            self.network_wise_trainer.update_aux_fc(
                aux_fc_state=self.aux_fc_state)

    def _cal_pivot(self):
        """
        Calculate the block index for additional loss
        """

        self.num_segments, self.settings.pivot_set = cal_pivot(
            self.settings.n_losses, self.settings.net_type,
            self.settings.depth, self.logger)

    def pruning(self):
        """
        Prune channels
        """
        self.logger.info("Before pruning:")
        self.logger.info(self.pruned_model)
        # self.network_wise_trainer.val(0)
        model_analyse = ModelAnalyse(self.pruned_model, self.logger)
        params_num = model_analyse.params_count()
        zero_num = model_analyse.zero_count()
        zero_rate = zero_num * 1.0 / params_num
        self.logger.info("zero rate is: {}".format(zero_rate))
        model_analyse.madds_compute(self.test_input)

        if self.settings.net_type in ["preresnet", "resnet"]:
            model_prune = ResModelPrune(model=self.pruned_model,
                                        net_type=self.settings.net_type,
                                        depth=self.settings.depth)
        else:
            assert False, "unsupport net_type: {}".format(
                self.settings.net_type)

        model_prune.run()
        self.network_wise_trainer.update_model(model_prune.model,
                                               self.optimizer_state)

        self.logger.info("After pruning:")
        self.logger.info(self.pruned_model)
        # self.network_wise_trainer.val(0)
        model_analyse = ModelAnalyse(self.pruned_model, self.logger)
        params_num = model_analyse.params_count()
        zero_num = model_analyse.zero_count()
        zero_rate = zero_num * 1.0 / params_num
        self.logger.info("zero rate is: {}".format(zero_rate))
        model_analyse.madds_compute(self.test_input)

    def fine_tuning(self):
        """
        Conduct network-wise fine-tuning after channel selection
        """

        best_top1 = 100
        best_top5 = 100

        start_epoch = 0
        if self.epoch != 0:
            start_epoch = self.epoch + 1
            self.epoch = 0

        for epoch in range(start_epoch, self.settings.n_epochs):
            train_error, train_loss, train5_error = self.network_wise_trainer.train(
                epoch)
            val_error, val_loss, val5_error = self.network_wise_trainer.val(
                epoch)

            # write log
            log_str = "{:d}\t".format(epoch)
            for i in range(len(train_error)):
                log_str += "{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t".format(
                    train_error[i], train_loss[i], val_error[i], val_loss[i],
                    train5_error[i], val5_error[i])
            write_log(self.settings.save_path, 'log.txt', log_str)

            # save model and checkpoint
            best_flag = False
            if best_top1 >= val_error[-1]:
                best_top1 = val_error[-1]
                best_top5 = val5_error[-1]
                best_flag = True

            if best_flag:
                self.checkpoint.save_aux_model(
                    self.network_wise_trainer.model,
                    self.network_wise_trainer.aux_fc)

            self.logger.info(
                "|===>Best Result is: Top1 Error: {:f}, Top5 Error: {:f}\n".
                format(best_top1, best_top5))
            self.logger.info(
                "|==>Best Result is: Top1 Accuracy: {:f}, Top5 Accuracy: {:f}\n"
                .format(100 - best_top1, 100 - best_top5))

            if "imagenet" in self.settings.dataset:
                self.checkpoint.save_aux_checkpoint(
                    self.network_wise_trainer.model,
                    self.network_wise_trainer.seg_optimizer,
                    self.network_wise_trainer.fc_optimizer,
                    self.network_wise_trainer.aux_fc, epoch, epoch + 1)
            else:
                self.checkpoint.save_aux_checkpoint(
                    self.network_wise_trainer.model,
                    self.network_wise_trainer.seg_optimizer,
                    self.network_wise_trainer.fc_optimizer,
                    self.network_wise_trainer.aux_fc, epoch)