def _set_checkpoint(self):
        """
        load pre-trained model or resume checkpoint
        """

        assert self.pruned_model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path, self.logger)
        self._load_pretrained()
        self._load_resume()
示例#2
0
    def __init__(self, index, conf):
        self._nodes = conf['nodes']
        self._node_cnt = len(self._nodes)
        self._index = index
        # Number of faults tolerant.
        self._f = (self._node_cnt - 1) // 3
        # leader
        if self._index == 0:
            self._is_leader = True
        else:
            self._is_leader = False
        self._leader = 0

        self._view = View(0, self._node_cnt)
        # The largest view either promised or accepted
        self._follow_view = View(0, self._node_cnt)

        self._next_propose_slot = 0

        # tracks if commit_decisions had been commited to blockchain
        self.committed_to_blockchain = False
        # Checkpoint

        # Network simulation
        self._loss_rate = conf['loss%'] / 100

        # Time configuration
        self._network_timeout = conf['misc']['network_timeout']
        # After finishing committing self._checkpoint_interval slots,
        # trigger to propose new checkpoint.
        self._checkpoint_interval = conf['ckpt_interval']
        self._ckpt = CheckPoint(self._checkpoint_interval, self._nodes,
                                self._f, self._index, self._loss_rate,
                                self._network_timeout)
        # Commit
        self._last_commit_slot = -1

        self._dump_interval = conf['dump_interval']

        # Restore the votes number and information for each view number
        self._view_change_votes_by_view_number = {}

        # Record all the status of the given slot
        # To adjust json key, slot is string integer.
        self._status_by_slot = {}

        self._sync_interval = conf['sync_interval']

        self._blockchain = Blockchain()

        self._session = None
        self._log = logging.getLogger(__name__)
    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path, self.logger)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint(
                self.settings.resume)
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.start_epoch = epoch
            self.optimizer_state = optimizer_state
示例#4
0
    def create_mtcnn_net(self, p_model_path=None, r_model_path=None, o_model_path=None, use_cuda=True):
        dirname, _ = os.path.split(p_model_path)
        checkpoint = CheckPoint(dirname)

        pnet, rnet, onet = None, None, None
        self.device = torch.device(
            "cuda:0" if use_cuda and torch.cuda.is_available() else "cpu")

        if p_model_path is not None:
            pnet = PNet()
            pnet_model_state = checkpoint.load_model(p_model_path)
            pnet = checkpoint.load_state(pnet, pnet_model_state)
            if (use_cuda):
                pnet.to(self.device)
            pnet.eval()

        if r_model_path is not None:
            rnet = RNet()
            rnet_model_state = checkpoint.load_model(r_model_path)
            rnet = checkpoint.load_state(rnet, rnet_model_state)
            if (use_cuda):
                rnet.to(self.device)
            rnet.eval()

        if o_model_path is not None:
            onet = ONet()
            onet_model_state = checkpoint.load_model(o_model_path)
            onet = checkpoint.load_state(onet, onet_model_state)
            if (use_cuda):
                onet.to(self.device)
            onet.eval()

        return pnet, rnet, onet
示例#5
0
    def __load(self, filename):
        try:
            with open(filename, 'rb') as file:
                pgctl_data = file.read(PG_CONTROL_SIZE)

                self.systemidentifier, = struct.unpack_from('Q', pgctl_data, memoffset['ControlFileData.system_identifier'])
                self.catalog_version_no, = struct.unpack_from('I', pgctl_data, memoffset['ControlFileData.catalog_version_no'])
                self.dbstate, = struct.unpack_from(pgstruct.get_type_format('DBState'), pgctl_data, memoffset['ControlFileData.state'])
                self.checkPoint, = struct.unpack_from('Q', pgctl_data, memoffset['ControlFileData.checkPoint'])
                self.minRecoveryPoint, = struct.unpack_from('Q', pgctl_data, memoffset['ControlFileData.minRecoveryPoint'])
                self.minRecoveryPointTLI, = struct.unpack_from('I', pgctl_data, memoffset['ControlFileData.minRecoveryPointTLI'])
                self.checkPointCopy = CheckPoint(pgctl_data, memoffset['ControlFileData.checkPointCopy'])
                self.xlog_blcksz, = struct.unpack_from('I', pgctl_data, memoffset['ControlFileData.xlog_blcksz'])
                self.xlog_seg_size, = struct.unpack_from('I', pgctl_data, memoffset['ControlFileData.xlog_seg_size'])
                self.blcksz, = struct.unpack_from('I', pgctl_data, memoffset['ControlFileData.blcksz'])
                self.relseg_size, = struct.unpack_from('I', pgctl_data, memoffset['ControlFileData.relseg_size'])
                
                self.nameDataLen, = struct.unpack_from('I', pgctl_data, memoffset['ControlFileData.nameDataLen'])
                
                self.float8ByVal, = struct.unpack_from(pgstruct.get_type_format('bool'), pgctl_data, memoffset['ControlFileData.float8ByVal'])
                self.float4ByVal, = struct.unpack_from(pgstruct.get_type_format('bool'), pgctl_data, memoffset['ControlFileData.float4ByVal'])
                
                self.crc, = struct.unpack_from('I', pgctl_data, memoffset['ControlFileData.crc'])
                
                crcdatas = []
                crcdatas.append((pgctl_data, memoffset['ControlFileData.crc']));
                 
                if not upgcrc.crceq(crcdatas, self.crc):
                    logger.error('pg_control has invalid CRC')
                    raise UPgException('pg_control has invalid CRC')
        except IOError:
            logger.error("Error in reading control file!")
            raise
示例#6
0
    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            model_state = check_point_params["model"]
            self.seg_opt_state = check_point_params["seg_opt"]
            self.fc_opt_state = check_point_params["fc_opt"]
            self.aux_fc_state = check_point_params["aux_fc"]
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.start_epoch = 90
    async def register(self):
        """
        post { 'host':xx, 'port': xx }to ca to register, get index and current nodes. then broadcast join_request
        """     
        if not self._session:
            timeout = aiohttp.ClientTimeout(self._network_timeout)
            self._session = aiohttp.ClientSession(timeout=timeout)
        resp = await self._session.post(self.make_url(self._ca, MessageType.REGISTER), json=self._node)
        resp = await resp.json()
        self._index = resp['index']
        self._nodes = resp['nodes']
        self._log.info('register to ca, get index %d, current nodes: %s', self._index, self._nodes)

        self._node_cnt = len(self._nodes)
        self._f = (self._node_cnt - 1) // 3
        self._is_leader = False
        self._ckpt = CheckPoint(self._checkpoint_interval, self._nodes, 
            self._f, self._index, self._loss_rate, self._network_timeout)

        await self.join_request()
示例#8
0
    def StartProgram(self):

        # ---Get the User Input and make it globally accessible---#

        cg.SampleRate = float(
            self.sample_rate.get())  # sample rate for experiment in seconds

        if cg.method == "Continuous Scan":
            cg.numFiles = int(self.numfiles.get())  # file limit
        elif cg.method == "Frequency Map":
            cg.numFiles = 1

        cg.q = Queue()

        if cg.delimiter == 1:
            cg.delimiter = " "
        elif cg.delimiter == 2:
            cg.delimiter = "\t"
        elif cg.delimiter == 3:
            cg.delimiter = ","

        if cg.extension == 1:
            cg.extension = ".txt"
        elif cg.extension == 2:
            cg.extension = ".csv"
        elif cg.extension == 3:
            cg.extension = ".DTA"

        cg.InjectionPoint = (
            None  # None variable if user has not selected an injection point
        )
        cg.InitializedNormalization = False  # tracks if the data has been normalized
        # to the starting normalization point
        cg.RatioMetricCheck = False  # tracks changes to high and low frequencies
        cg.NormWarningExists = (
            False  # tracks if a warning label for the normalization has been created
        )

        cg.NormalizationPoint = 3
        cg.starting_file = 1

        cg.SaveVar = self.SaveVar.get(
        )  # tracks if text file export has been activated
        cg.InjectionVar = self.InjectionVar.get(
        )  # tracks if injection was selected
        cg.resize_interval = int(self.resize_entry.get()
                                 )  # interval at which xaxis of plots resizes
        cg.handle_variable = (self.ImportFileEntry.get()
                              )  # string handle used for the input file

        # --- Y Limit Adjustment Parameters ---#
        cg.min_norm = float(self.norm_data_min.get())  # normalization y limits
        cg.max_norm = float(self.norm_data_max.get())
        cg.min_raw = float(
            self.raw_data_min.get())  # raw data y limit adjustment variables
        cg.max_raw = float(self.raw_data_max.get())
        cg.min_data = float(
            self.data_min.get())  # raw data y limit adjustment variables
        cg.max_data = float(self.data_max.get())
        cg.ratio_min = float(self.KDM_min.get())  # KDM min and max
        cg.ratio_max = float(self.KDM_max.get())

        #############################################################
        # Interval at which the program searches for files (ms) ###
        #############################################################
        cg.Interval = self.Interval.get()

        # set the resizeability of the container ##
        # frame to handle PlotContainer resize   ##
        cg.container.columnconfigure(1, weight=1)

        # --- High and Low Frequency Selection for Drift Correction (KDM) ---#
        cg.HighFrequency = max(cg.frequency_list)
        cg.LowFrequency = min(cg.frequency_list)
        cg.HighLowList["High"] = cg.HighFrequency
        cg.HighLowList["Low"] = cg.LowFrequency

        # --- Create a timevault for normalization variables if the chosen
        # normalization point has not yet been analyzed ---#
        cg.NormalizationVault = []  # timevault for Normalization Points
        cg.NormalizationVault.append(
            cg.NormalizationPoint)  # append the starting normalization point

        ################################################################
        # If all checkpoints have been met, initialize the program ###
        ################################################################
        if not self.NoSelection:
            if cg.FoundFilePath:

                _ = CheckPoint(self.parent, self.controller)
示例#9
0
文件: train.py 项目: HaoKun-Li/HCCR
                                               shuffle=True,
                                               **kwargs)

    #Set valid_data
    valid_data = torchvision.datasets.MNIST(root=config.dataPath, train=False)
    valid_x = torch.unsqueeze(valid_data.test_data, dim=1).type(
        torch.FloatTensor
    )[:2000] / 255.  # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
    valid_y = valid_data.test_labels[:2000]

    #Set model
    model = LeNet_5()
    model = model.to(device)

    #Set checkpoint
    checkpoint = CheckPoint(config.save_path)

    #Set optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
    scheduler = None

    #Set trainer
    logger = Logger(config.save_path)
    trainer = LeNet_5Trainer(config.lr, train_loader, valid_x, valid_y, model,
                             optimizer, scheduler, logger, device)

    epoch_dict = 1
    model_dict, optimizer_dict, epoch_dict = checkpoint.load_checkpoint(
        os.path.join(checkpoint.save_path, 'checkpoint_005.pth'))
    model.load_state_dict(model_dict)
    optimizer.load_state_dict(optimizer_dict)
示例#10
0
class ExperimentDesign:
    def __init__(self, options=None):
        self.settings = options or Option()
        self.checkpoint = None
        self.data_loader = None
        self.model = None

        self.optimizer_state = None
        self.trainer = None
        self.start_epoch = 0

        self.model_analyse = None

        self.visualize = vs.Visualization(self.settings.save_path)
        self.logger = vs.Logger(self.settings.save_path)
        self.test_input = None
        self.lr_master = None
        self.prepare()

    def prepare(self):
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_parallel()
        self._set_lr_policy()
        self._set_trainer()
        self._draw_net()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        print("|===>Set GPU done!")

    def _set_dataloader(self):
        # create data loader
        self.data_loader = DataLoader(dataset=self.settings.dataset,
                                      batch_size=self.settings.batchSize,
                                      data_path=self.settings.dataPath,
                                      n_threads=self.settings.nThreads,
                                      ten_crop=self.settings.tenCrop)
        print("|===>Set data loader done!")

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint(
                self.settings.resume)
            self.model = self.checkpoint.load_state(self.model, model_state)
            # self.start_epoch = epoch
            # self.optimizer_state = optimizer_state
        print("|===>Set checkpoint done!")

    def _set_model(self):
        if self.settings.dataset == "sphere":
            if self.settings.netType == "SphereNet":
                self.model = md.SphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereNIN":
                self.model = md.SphereNIN(
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "wcSphereNet":
                self.model = md.wcSphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim,
                    rate=self.settings.rate)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupport data set: " + self.settings.dataset
        print("|===>Set model done!")

    def _set_parallel(self):
        self.model = utils.data_parallel(self.model, self.settings.nGPU,
                                         self.settings.GPU)

    def _set_lr_policy(self):
        self.lr_master = utils.LRPolicy(self.settings.lr, self.settings.nIters,
                                        self.settings.lrPolicy)
        params_dict = {
            'gamma': self.settings.gamma,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        self.lr_master.set_params(params_dict=params_dict)

    def _set_trainer(self):

        train_loader, test_loader = self.data_loader.getloader()
        self.trainer = Trainer(model=self.model,
                               lr_master=self.lr_master,
                               n_epochs=self.settings.nEpochs,
                               n_iters=self.settings.nIters,
                               train_loader=train_loader,
                               test_loader=test_loader,
                               feature_dim=self.settings.featureDim,
                               momentum=self.settings.momentum,
                               weight_decay=self.settings.weightDecay,
                               optimizer_state=self.optimizer_state,
                               logger=self.logger)

    def _draw_net(self):
        # visualize model

        if self.settings.dataset == "sphere":
            rand_input = torch.randn(1, 3, 112, 96)
        else:
            assert False, "invalid data set"
        rand_input = Variable(rand_input.cuda())
        self.test_input = rand_input

        if self.settings.drawNetwork:
            rand_output, _ = self.trainer.forward(rand_input)
            self.visualize.save_network(rand_output)
            print("|===>Draw network done!")

        self.visualize.write_settings(self.settings)

    def pruning(self, run_count=0):
        net_type = None
        if self.settings.dataset == "sphere":
            if self.settings.netType == "wcSphereNet":
                net_type = "SphereNet"

        assert net_type is not None, "net_type for prune is NoneType"

        self.trainer.test()

        if isinstance(self.model, nn.DataParallel):
            model = self.model.module
        else:
            model = self.model

        if net_type == "SphereNet":
            model_prune = prune.SpherePrune(model)
        model_prune.run()
        self.trainer.reset_model(model_prune.model)
        self.model = self.trainer.model

        self.trainer.test()
        self.checkpoint.save_model(self.trainer.model,
                                   index=run_count,
                                   tag="pruning")

        # analyse model
        self.model_analyse = utils.ModelAnalyse(self.trainer.model,
                                                self.visualize)
        params_num = self.model_analyse.params_count()
        self.model_analyse.flops_compute(self.test_input)

    def fine_tuning(self, run_count=0):
        # set lr
        self.settings.lr = 0.01  #  0.1
        self.settings.nIters = 12000  # 28000
        self.settings.lrPolicy = "multi_step"
        self.settings.decayRate = 0.1
        self.settings.step = [6000]  #  [16000, 24000]

        self._set_lr_policy()
        self.trainer.reset_lr(self.lr_master, self.settings.nIters)

        # run fine-tuning
        self.training(run_count, tag="fine-tuning")

    def retrain(self, run_count=0):
        self.settings.lr = 0.1
        self.settings.nIters = 28000
        self.settings.lrPolicy = "multi_step"
        self.settings.decayRate = 0.1
        self.settings.step = [16000, 24000]
        self._set_lr_policy()
        self.trainer.reset_lr(self.lr_master, self.settings.nIters)

        # run retrain
        self.training(run_count, tag="training")

    def run(self, run_count=0):
        """
        if run_count == 0:
            print "|===> training"
            self.retrain(run_count)
        else:     
            print "|===> fine-tuning"
            self.fine_tuning(run_count)
        """
        if run_count >= 1:
            print("|===> training")
            self.retrain(run_count)

            self.trainer.reset_model(self.model)
            print("|===> fine-tuning")
            self.fine_tuning(run_count)

        print("|===> pruning")
        self.pruning(run_count)

        # keep margin_linear static
        layer_count = 0
        for layer in self.model.modules():
            if isinstance(layer, md.MarginLinear):
                layer.iteration.fill_(0)
                layer.margin_type.data.fill_(1)
                layer.weight.requires_grad = False

            elif isinstance(layer, nn.Linear):
                layer.weight.requires_grad = False
                layer.bias.requires_grad = False

            elif isinstance(layer, nn.Conv2d):
                if layer.bias is not None:
                    bias_flag = True
                else:
                    bias_flag = False
                new_layer = prune.wcConv2d(layer.weight.size(1),
                                           layer.weight.size(0),
                                           kernel_size=layer.kernel_size,
                                           stride=layer.stride,
                                           padding=layer.padding,
                                           bias=bias_flag,
                                           rate=self.settings.rate)
                new_layer.weight.data.copy_(layer.weight.data)
                if layer.bias is not None:
                    new_layer.bias.data.copy_(layer.bias.data)
                if layer_count == 1:
                    self.model.conv2 = new_layer
                elif layer_count == 2:
                    self.model.conv3 = new_layer
                elif layer_count == 3:
                    self.model.conv4 = new_layer
                layer_count += 1

        print(self.model)
        self.trainer.reset_model(self.model)
        # assert False

    def training(self, run_count=0, tag="training"):
        best_top1 = 100
        # start_time = time.time()
        self.trainer.test()
        # assert False
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            if self.trainer.iteration >= self.trainer.n_iters:
                break
            start_epoch = 0
            # training and testing
            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            acc_mean, acc_std, acc_all = self.trainer.test()

            test_error_mean = 100 - acc_mean * 100
            # write and print result
            log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                epoch, train_error, train_loss, train5_error, acc_mean,
                acc_std)
            for acc in acc_all:
                log_str += "%.4f\t" % acc

            self.visualize.write_log(log_str)
            best_flag = False
            if best_top1 >= test_error_mean:
                best_top1 = test_error_mean
                best_flag = True
                print(
                    colored(
                        "# %d ==>Best Result is: Top1 Error: %f\n" %
                        (run_count, best_top1), "red"))
            else:
                print(
                    colored(
                        "# %d ==>Best Result is: Top1 Error: %f\n" %
                        (run_count, best_top1), "blue"))

            self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer,
                                            epoch)

            if best_flag:
                self.checkpoint.save_model(self.model,
                                           best_flag=best_flag,
                                           tag="%s_%d" % (tag, run_count))

            if (epoch + 1) % self.settings.drawInterval == 0:
                self.visualize.draw_curves()

            for name, value in self.model.named_parameters():
                if 'weight' in name:
                    name = name.replace('.', '/')
                    self.logger.histo_summary(
                        name,
                        value.data.cpu().numpy(),
                        run_count * self.settings.nEpochs + epoch + 1)
                    if value.grad is not None:
                        self.logger.histo_summary(
                            name + "/grad",
                            value.grad.data.cpu().numpy(),
                            run_count * self.settings.nEpochs + epoch + 1)

        # end_time = time.time()

        if isinstance(self.model, nn.DataParallel):
            self.model = self.model.module
        # draw experimental curves
        self.visualize.draw_curves()

        # compute cost time
        # time_interval = end_time - start_time
        # t_string = "Running Time is: " + \
        #    str(datetime.timedelta(seconds=time_interval)) + "\n"
        # print(t_string)
        # write cost time to file
        # self.visualize.write_readme(t_string)

        # save experimental results
        self.model_analyse = utils.ModelAnalyse(self.trainer.model,
                                                self.visualize)
        self.visualize.write_readme("Best Result of all is: Top1 Error: %f\n" %
                                    best_top1)

        # analyse model
        params_num = self.model_analyse.params_count()

        # save analyse result to file
        self.visualize.write_readme("Number of parameters is: %d" %
                                    (params_num))
        self.model_analyse.prune_rate()

        self.model_analyse.flops_compute(self.test_input)

        return best_top1
示例#11
0
文件: main.py 项目: L-Zhe/SDISS
from SDISS.Optim import WarmUpOpt, LabelSmoothing
from checkpoint import CheckPoint, saveOutput, loadModel
from run import fit
from eval import eval_score
from generate import generator
import os

if __name__ == '__main__':
    checkpoint = CheckPoint(trainSrcFilePath=Parameter.trainSrcFilePath,
                            validSrcFilePath=Parameter.validSrcFilePath,
                            testSrcFilePath=Parameter.testSrcFilePath,
                            trainTgtFilePath=Parameter.trainTgtFilePath,
                            validTgtFilePath=Parameter.validTgtFilePath,
                            testTgtFilePath=Parameter.testTgtFilePath,
                            trainGraph=Parameter.trainGraph,
                            validGraph=Parameter.validGraph,
                            testGraph=Parameter.testGraph,
                            min_freq=Parameter.min_freq,
                            BATCH_SIZE=Parameter.BATCH_SIZE,
                            dataPath=Parameter.dataPath,
                            dataFile=Parameter.dataFile,
                            checkpointPath=Parameter.checkpointPath,
                            checkpointFile=Parameter.checkpointFile,
                            mode=Parameter.mode)

    trainDataSet, validDataSet, testDataSet, index2word, gword2index = checkpoint.LoadData(
    )
    model = createModel(len(index2word), len(gword2index)).to(device)
    if Parameter.mode == 'train':
        criterion = LabelSmoothing(smoothing=Parameter.smoothing,
                                   lamda=Parameter.lamda).to(device)
示例#12
0
class PBFTHandler:
    def __init__(self, index, conf):
        self._nodes = conf['nodes']
        self._node_cnt = len(self._nodes)
        self._index = index
        # Number of faults tolerant.
        self._f = (self._node_cnt - 1) // 3
        # leader
        if self._index == 0:
            self._is_leader = True
        else:
            self._is_leader = False
        self._leader = 0

        self._view = View(0, self._node_cnt)
        # The largest view either promised or accepted
        self._follow_view = View(0, self._node_cnt)

        self._next_propose_slot = 0

        # tracks if commit_decisions had been commited to blockchain
        self.committed_to_blockchain = False
        # Checkpoint

        # Network simulation
        self._loss_rate = conf['loss%'] / 100

        # Time configuration
        self._network_timeout = conf['misc']['network_timeout']
        # After finishing committing self._checkpoint_interval slots,
        # trigger to propose new checkpoint.
        self._checkpoint_interval = conf['ckpt_interval']
        self._ckpt = CheckPoint(self._checkpoint_interval, self._nodes,
                                self._f, self._index, self._loss_rate,
                                self._network_timeout)
        # Commit
        self._last_commit_slot = -1

        self._dump_interval = conf['dump_interval']

        # Restore the votes number and information for each view number
        self._view_change_votes_by_view_number = {}

        # Record all the status of the given slot
        # To adjust json key, slot is string integer.
        self._status_by_slot = {}

        self._sync_interval = conf['sync_interval']

        self._blockchain = Blockchain()

        self._session = None
        self._log = logging.getLogger(__name__)

    @staticmethod
    def make_url(node, command):
        '''
        input: 
            node: dictionary with key of host(url) and port
            command: action
        output:
            The url to send with given node and action.
        '''
        return "http://{}:{}/{}".format(node['host'], node['port'], command)

    async def _make_requests(self, nodes, command, json_data):
        '''
        Send json data:

        input:
            nodes: list of dictionary with key: host, port
            command: Command to execute.
            json_data: Json data.
        output:
            list of tuple: (node_index, response)

        '''
        resp_list = []
        for i, node in enumerate(nodes):
            if random() > self._loss_rate:
                if not self._session:
                    timeout = aiohttp.ClientTimeout(self._network_timeout)
                    self._session = aiohttp.ClientSession(timeout=timeout)
                self._log.debug("make request to %d, %s", i, command)
                try:
                    resp = await self._session.post(self.make_url(
                        node, command),
                                                    json=json_data)
                    resp_list.append((i, resp))

                except Exception as e:
                    #resp_list.append((i, e))
                    self._log.error(e)
                    pass
        return resp_list

    async def _make_response(self, resp):
        '''
        Drop response by chance, via sleep for sometime.
        '''
        if random() < self._loss_rate:
            await asyncio.sleep(self._network_timeout)
        return resp

    async def _post(self, nodes, command, json_data):
        '''
        Broadcast json_data to all node in nodes with given command.
        input:
            nodes: list of nodes
            command: action
            json_data: Data in json format.
        '''
        if not self._session:
            timeout = aiohttp.ClientTimeout(self._network_timeout)
            self._session = aiohttp.ClientSession(timeout=timeout)
        for i, node in enumerate(nodes):
            if random() > self._loss_rate:
                self._log.info("make request to %s, %s",
                               node['host'] + ':' + str(node['port']),
                               json_data['type'])
                try:
                    _ = await self._session.post(self.make_url(node, command),
                                                 json=json_data)
                except Exception as e:
                    #resp_list.append((i, e))
                    self._log.error(" %s while making request %s. %s", str(e),
                                    self.make_url(node, command),
                                    str(json_data))
                    pass

    def _legal_slot(self, slot):
        '''
        the slot is legal only when it's between upperbound and the lowerbound.
        input:
            slot: string integer direct get from the json_data proposal key.
        output:
            boolean to express the result.
        '''
        # special case: before commit, a node has recevied 2f + 1 ckpt_vote and update its ckpt
        #  so the msg to be commit becomes illegal

        # TODO: restore after fix ckpt_sync
        # if int(slot) < self._ckpt.next_slot - 1 or int(slot) >= self._ckpt.get_commit_upperbound():
        #     return False
        # else:
        #     return True
        return True

    async def get_request(self, request):
        '''
        Handle the request from client if leader, otherwise 
        redirect to the leader.
        '''
        self._log.info("%d: on request", self._index)

        if not self._is_leader:
            if self._leader != None:
                raise web.HTTPTemporaryRedirect(
                    self.make_url(self._nodes[self._leader],
                                  MessageType.REQUEST))
            else:
                raise web.HTTPServiceUnavailable()
        else:

            # print(request.headers)
            # print(request.__dict__)

            json_data = await request.json()

            # print("\t\t--->node"+str(self._index)+": on request :")
            # print(json_data)

            await self.preprepare(json_data)
            return web.Response()

    async def preprepare(self, json_data):
        '''
        Prepare: Deal with request from the client and broadcast to other replicas.
        input:
            json_data: Json-transformed web request from client
                {
                    id: (client_id, client_seq),
                    client_url: "url string"
                    timestamp:"time"
                    data: "string"
                }

        '''
        this_slot = str(self._next_propose_slot)
        self._next_propose_slot = int(this_slot) + 1

        self._log.info("%d: on preprepare, propose at slot: %d, %s",
                       self._index, int(this_slot), json_data)

        if this_slot not in self._status_by_slot:
            self._status_by_slot[this_slot] = Status(self._f)
        self._status_by_slot[this_slot].request = json_data

        preprepare_msg = {
            'leader': self._index,
            'view': self._view.get_view(),
            'proposal': {
                this_slot: json_data
            },
            'type': 'preprepare'
        }

        await self._post(self._nodes, MessageType.PREPARE, preprepare_msg)
        # # require replicas to feedback instead of prepare
        # await self._post(self._nodes, MessageType.FEEDBACK, preprepare_msg)

    async def prepare(self, request):
        '''
        Once receive preprepare message from leader, broadcast 
        prepare message to all replicas.

        input: 
            request: preprepare message from preprepare:
                preprepare_msg = {
                    'leader': self._index,
                    'view': self._view.get_view(),
                    'proposal': {
                        this_slot: json_data
                    }
                    'type': 'preprepare'
                }

        '''
        json_data = await request.json()

        if json_data['view'] < self._follow_view.get_view():
            # when receive message with view < follow_view, do nothing
            return web.Response()

        # self._log.info("%d: on prepare", self._index)
        self._log.info("%d: receive preprepare msg from %d", self._index,
                       json_data['leader'])

        for slot in json_data['proposal']:

            if not self._legal_slot(slot):
                continue

            if slot not in self._status_by_slot:
                self._status_by_slot[slot] = Status(self._f)

            prepare_msg = {
                'index': self._index,
                'view': json_data['view'],
                'proposal': {
                    slot: json_data['proposal'][slot]
                },
                'type': MessageType.PREPARE
            }
            await self._post(self._nodes, MessageType.COMMIT, prepare_msg)
        return web.Response()

    async def commit(self, request):
        '''
        Once receive more than 2f + 1 prepare message,
        send the commit message.
        input:
            request: prepare message from prepare:
                prepare_msg = {
                    'index': self._index,
                    'view': self._n,
                    'proposal': {
                        this_slot: json_data
                    }
                    'type': 'prepare'
                }
        '''
        json_data = await request.json()
        # self._log.info("%d: on commit", self._index)
        self._log.info("%d: receive prepare msg from %d", self._index,
                       json_data['index'])

        # print("\t--->node "+str(self._index)+": receive prepare msg from node "+str(json_data['index']))
        # print(json_data)

        if json_data['view'] < self._follow_view.get_view():
            # when receive message with view < follow_view, do nothing
            return web.Response()

        for slot in json_data['proposal']:
            if not self._legal_slot(slot):
                continue

            if slot not in self._status_by_slot:
                self._status_by_slot[slot] = Status(self._f)
            status = self._status_by_slot[slot]

            view = View(json_data['view'], self._node_cnt)

            status._update_sequence(json_data['type'], view,
                                    json_data['proposal'][slot],
                                    json_data['index'])

            if status._check_majority(json_data['type']):
                status.prepare_certificate = Status.Certificate(
                    view, json_data['proposal'][slot])
                commit_msg = {
                    'index': self._index,
                    'view': json_data['view'],
                    'proposal': {
                        slot: json_data['proposal'][slot]
                    },
                    'type': MessageType.COMMIT
                }
                await self._post(self._nodes, MessageType.REPLY, commit_msg)
        return web.Response()

    async def reply(self, request):
        '''
        Once receive more than 2f + 1 commit message, append the commit 
        certificate and cannot change anymore. In addition, if there is 
        no bubbles ahead, commit the given slots and update the last_commit_slot.
        input:
            request: commit message from commit:
                preprepare_msg = {
                    'index': self._index,
                    'n': self._n,
                    'proposal': {
                        this_slot: json_data
                    }
                    'type': 'commit'
                }
        '''

        json_data = await request.json()
        # self._log.info(" %d: on reply", self._index)
        # print("\t--->node "+str(self._index)+": on reply ")

        if json_data['view'] < self._follow_view.get_view():
            # when receive message with view < follow_view, do nothing
            return web.Response()

        self._log.info(" %d: receive commit msg from %d", self._index,
                       json_data['index'])

        for slot in json_data['proposal']:
            if not self._legal_slot(slot):
                self._log.error("%d: message %s not in valid slot",
                                self._index, json_data)
                continue

            if slot not in self._status_by_slot:
                self._status_by_slot[slot] = Status(self._f)
            status = self._status_by_slot[slot]

            view = View(json_data['view'], self._node_cnt)

            status._update_sequence(json_data['type'], view,
                                    json_data['proposal'][slot],
                                    json_data['index'])

            # Commit only when no commit certificate and got more than 2f + 1
            if not status.commit_certificate and status._check_majority(
                    json_data['type']):
                status.commit_certificate = Status.Certificate(
                    view, json_data['proposal'][slot])

                self._log.debug("Add commit certifiacte to slot %d", int(slot))

                # Reply only once and only when no bubble ahead
                if self._last_commit_slot == int(
                        slot) - 1 and not status.is_committed:

                    status.is_committed = True
                    self._last_commit_slot += 1
                    if not self._is_leader:
                        self._next_propose_slot += 1

                    #    When commit messages fill the next checkpoint, propose a new checkpoint.
                    if (self._last_commit_slot +
                            1) % self._checkpoint_interval == 0:
                        self._log.info(
                            "%d: Propose checkpoint with last slot: %d. "
                            "In addition, current checkpoint's next_slot is: %d",
                            self._index, self._last_commit_slot,
                            self._ckpt.next_slot)
                        await self._ckpt.propose_vote(
                            self.get_commit_decisions())

                    if (self._last_commit_slot + 1) % self._dump_interval == 0:
                        await self.dump_to_file()

                    reply_msg = {
                        'index': self._index,
                        'view': json_data['view'],
                        'proposal': json_data['proposal'][slot],
                        'type': MessageType.REPLY
                    }
                    try:
                        await self._session.post(
                            json_data['proposal'][slot]['client_url'],
                            json=reply_msg)
                    except:
                        self._log.error(
                            "Send message failed to %s",
                            json_data['proposal'][slot]['client_url'])
                        pass
                    else:
                        self._log.info(
                            "%d reply to %s successfully!!", self._index,
                            json_data['proposal'][slot]['client_url'])

        return web.Response()

    def get_commit_decisions(self):
        '''
        Get the commit decision between the next slot of the current ckpt until last commit slot.
        output:
            commit_decisions: list of tuple: [((client_index, client_seq), data), ... ]
        '''
        commit_decisions = []
        # print(self._ckpt.next_slot, self._last_commit_slot + 1)
        for i in range(self._ckpt.next_slot, self._last_commit_slot + 1):
            status = self._status_by_slot[str(i)]
            proposal = status.commit_certificate._proposal

            commit_decisions.append((str(proposal['id']), proposal['data']))
        return commit_decisions

    async def dump_to_file(self):
        '''
        Dump the current commit decisions to disk.
        '''
        try:
            transactions = []
            self._log.debug("ready to dump, last_commit_slot = %d",
                            self._last_commit_slot)
            for i in range(self._last_commit_slot - self._dump_interval + 1,
                           self._last_commit_slot + 1):
                status = self._status_by_slot[str(i)]
                proposal = status.commit_certificate._proposal
                transactions.append((str(proposal['id']), proposal['data']))
            self._log.debug("collect %d transactions", len(transactions))
            try:
                timestamp = time.asctime(time.localtime(proposal['timestamp']))
            except Exception as e:
                self._log.error(
                    "received invalid timestamp. replacing with current timestamp"
                )
                timestamp = time.asctime(time.localtime(time.time()))

            new_block = Block(self._blockchain.length, transactions, timestamp,
                              self._blockchain.last_block_hash())
            self._blockchain.add_block(new_block)

        except Exception as e:
            traceback.print_exc()
            print(e)

        with open("~$node_{}.blockchain".format(self._index), 'a') as f:
            self._log.debug("write block from %d to %d",
                            self._blockchain.commit_counter,
                            self._blockchain.length)
            for i in range(self._blockchain.commit_counter,
                           self._blockchain.length):
                f.write(
                    str(self._blockchain.chain[i].get_json()) +
                    '\n------------\n')
                self._blockchain.update_commit_counter()

    async def receive_ckpt_vote(self, request):
        '''
        Receive the message sent from CheckPoint.propose_vote()
        '''
        self._log.info("%d: receive checkpoint vote.", self._index)
        json_data = await request.json()

        await self._ckpt.receive_vote(json_data)
        return web.Response()

    async def get_prepare_certificates(self):
        '''
        For view change, get all prepare certificates in the valid commit interval.
        output:
            prepare_certificate_by_slot: dictionary which contains the mapping between
            each slot and its prepare_certificate if exists.

        '''
        prepare_certificate_by_slot = {}
        for i in range(self._ckpt.next_slot,
                       self._ckpt.get_commit_upperbound()):
            slot = str(i)
            if slot in self._status_by_slot:
                status = self._status_by_slot[slot]
                if status.prepare_certificate:
                    prepare_certificate_by_slot[slot] = (
                        status.prepare_certificate.to_dict())
        return prepare_certificate_by_slot

    async def _post_view_change_vote(self):
        '''
        Broadcast the view change vote messages to all the nodes. 
        View change vote messages contain current node index, 
        proposed new view number, checkpoint info, and all the 
        prepare certificate between valid slots.
        '''
        view_change_vote = {
            "index": self._index,
            "view_number": self._follow_view.get_view(),
            "checkpoint": self._ckpt.get_ckpt_info(),
            "prepare_certificates": await self.get_prepare_certificates(),
            'type': MessageType.VIEW_CHANGE_VOTE
        }
        await self._post(self._nodes, MessageType.VIEW_CHANGE_VOTE,
                         view_change_vote)

    async def get_view_change_request(self, request):
        '''
        Get view change request from client. Broadcast the view change vote and 
        all the information needed for view change(checkpoint, prepared_certificate)
        to every replicas.
        input:
            request: view change request messages from client.
                json_data{
                    "type" : "view_change_request"
                }
        '''

        self._log.info("%d: receive view change request from client.",
                       self._index)
        json_data = await request.json()
        # Make sure the message is valid.
        if json_data['type'] != MessageType.VIEW_CHANGE_REQUEST:
            return web.Response()
        # Update view number by 1 and change the followed leader. In addition,
        # if receive view update message within update interval, do nothing.
        if not self._follow_view.set_view(self._follow_view.get_view() + 1):
            return web.Response()

        self._leader = self._follow_view.get_leader()
        if self._is_leader:
            self._log.info("%d is not leader anymore. View number: %d",
                           self._index, self._follow_view.get_view())
            self._is_leader = False

        self._log.debug("%d: vote for view change to %d.", self._index,
                        self._follow_view.get_view())

        await self._post_view_change_vote()

        return web.Response()

    async def receive_view_change_vote(self, request):
        '''
        Receive the vote message for view change. (1) Update the checkpoint 
        if receive messages has larger checkpoint. (2) Update votes message 
        (Node comes from and prepare-certificate). (3) View change if receive
        f + 1 votes (4) if receive more than 2f + 1 node and is the leader 
        of the current view, become leader and preprepare the valid slot.

        input: 
            request. After transform to json:
                json_data = {
                    "index": self._index,
                    "view_number": self._follow_view.get_view(),
                    "checkpoint":self._ckpt.get_ckpt_info(),
                    "prepared_certificates":self.get_prepare_certificates(),
                    "type": view_change_vote
                }
        '''

        json_data = await request.json()
        self._log.info("%d receive view change vote from %d", self._index,
                       json_data['index'])
        view_number = json_data['view_number']
        if view_number not in self._view_change_votes_by_view_number:
            self._view_change_votes_by_view_number[view_number] = (
                ViewChangeVotes(self._index, self._node_cnt))

        self._ckpt.update_checkpoint(json_data['checkpoint'])
        self._last_commit_slot = max(self._last_commit_slot,
                                     self._ckpt.next_slot - 1)

        votes = self._view_change_votes_by_view_number[view_number]

        votes.receive_vote(json_data)

        # Receive more than 2f + 1 votes. If the node is the
        # charged leader for current view, become leader and
        # propose preprepare for all slots.

        if len(votes.from_nodes) >= 2 * self._f + 1:

            if self._follow_view.get_leader(
            ) == self._index and not self._is_leader:

                self._log.info("%d: Change to be leader!! view_number: %d",
                               self._index, self._follow_view.get_view())

                self._is_leader = True
                self._view.set_view(self._follow_view.get_view())

                last_certificate_slot = max(
                    [int(slot)
                     for slot in votes.prepare_certificate_by_slot] + [-1])

                # Update the next_slot!!
                # case 1 : prepare_certificate != [], _next_propose_slot = last_preapare + 1
                if last_certificate_slot != -1:
                    self._next_propose_slot = last_certificate_slot + 1
                # case 2: prepare__certificate = [], _next_propose_slot = ckpt.next_slot
                else:
                    self._next_propose_slot = self._ckpt.next_slot
                self._log.debug("change next_propose_slot to %d",
                                self._next_propose_slot)
                proposal_by_slot = {}
                for i in range(self._ckpt.next_slot,
                               last_certificate_slot + 1):
                    slot = str(i)
                    if slot not in votes.prepare_certificate_by_slot:

                        self._log.debug("%d decide no_op for slot %d",
                                        self._index, int(slot))

                        proposal = {
                            'id': (-1, -1),
                            'client_url': "no_op",
                            'timestamp': "no_op",
                            'data': MessageType.NO_OP
                        }
                        proposal_by_slot[slot] = proposal
                    elif not self._status_by_slot[slot].commit_certificate:
                        proposal = votes.prepare_certificate_by_slot[
                            slot].get_proposal()
                        proposal_by_slot[slot] = proposal
                await self.fill_bubbles(proposal_by_slot)
        return web.Response()

    async def fill_bubbles(self, proposal_by_slot):
        '''
        Fill the bubble during view change. Basically, it's a 
        preprepare that assign the proposed slot instead of using 
        new slot.

        input: 
            proposal_by_slot: dictionary that keyed by slot and 
            the values are the preprepared proposals
        '''
        self._log.info("%d: on fill bubbles.", self._index)
        self._log.debug("Number of bubbles: %d", len(proposal_by_slot))

        bubbles = {
            'leader': self._index,
            'view': self._view.get_view(),
            'proposal': proposal_by_slot,
            'type': 'preprepare'
        }

        await self._post(self._nodes, MessageType.PREPARE, bubbles)

    async def garbage_collection(self):
        '''
        Delete those status in self._status_by_slot if its 
        slot smaller than next slot of the checkpoint.
        '''
        await asyncio.sleep(self._sync_interval)
        delete_slots = []
        for slot in self._status_by_slot:
            if int(slot) < self._ckpt.next_slot:
                delete_slots.append(slot)
        for slot in delete_slots:
            del self._status_by_slot[slot]

        # Garbage collection for cjeckpoint.
        await self._ckpt.garbage_collection()

    async def show_blockchain(request):
        name = request.match_info.get("Anonymous")
        text = "show blockchain here "
        print('Node ' + str(self._index) + ' anything')
        return web.Response(text=text)

    async def receive_sync(self, request):
        '''
        Update the checkpoint and fill the bubble when receive sync messages.
        input:
            request: {
                'checkpoint': json_data = {
                    'next_slot': self._next_slot
                    'ckpt': json.dumps(ckpt)
                }
                'commit_certificates':commit_certificates
                    (Elements are commit_certificate.to_dict())
            }
        '''
        self._log.info("%d: on receive sync stage.", self._index)
        json_data = await request.json()

        try:
            # print(len(self._status_by_slot))
            # print(self._ckpt.next_slot, self._last_commit_slot + 1)
            # # print(len(json_data['checkpoint']))
            # print('node :' + str(self._index) +' > '+str(self._blockchain.commit_counter)+' : '+str(self._blockchain.length))
            # print()
            # print()
            self.committed_to_blockchain = False
        except Exception as e:
            traceback.print_exc()
            print('for i = ' + str(i))
            print(e)

        self._ckpt.update_checkpoint(json_data['checkpoint'])
        self._last_commit_slot = max(self._last_commit_slot,
                                     self._ckpt.next_slot - 1)
        # TODO: Only check bubble instead of all slots between lowerbound
        # and upperbound of the commit.

        for slot in json_data['commit_certificates']:
            # Skip those slot not qualified for update.
            if int(slot) >= self._ckpt.get_commit_upperbound() or (
                    int(slot) < self._ckpt.next_slot):
                continue

            certificate = json_data['commit_certificates'][slot]
            if slot not in self._status_by_slot:
                self._status_by_slot[slot] = Status(self._f)
                commit_certificate = Status.Certificate(View(
                    0, self._node_cnt))
                commit_certificate.dumps_from_dict(certificate)
                self._status_by_slot[
                    slot].commit_certificate = commit_certificate
            elif not self._status_by_slot[slot].commit_certificate:
                commit_certificate = Status.Certificate(View(
                    0, self._node_cnt))
                commit_certificate.dumps_from_dict(certificate)
                self._status_by_slot[
                    slot].commit_certificate = commit_certificate

        # Commit once the next slot of the last_commit_slot get commit certificate
        while (str(self._last_commit_slot + 1) in self._status_by_slot
               and self._status_by_slot[str(self._last_commit_slot +
                                            1)].commit_certificate):
            self._last_commit_slot += 1

            # When commit messages fill the next checkpoint,
            # propose a new checkpoint.
            if (self._last_commit_slot + 1) % self._checkpoint_interval == 0:
                await self._ckpt.propose_vote(self.get_commit_decisions())

                self._log.info(
                    "%d: During rev_sync, Propose checkpoint with l "
                    "ast slot: %d. In addition, current checkpoint's next_slot is: %d",
                    self._index, self._last_commit_slot, self._ckpt.next_slot)

        await self.dump_to_file()

        return web.Response()

    async def synchronize(self):
        '''
        Broadcast current checkpoint and all the commit certificate 
        between next slot of the checkpoint and commit upperbound.

        output:
            json_data = {
                'checkpoint': json_data = {
                    'next_slot': self._next_slot
                    'ckpt': json.dumps(ckpt)
                }
                'commit_certificates':commit_certificates
                    (Elements are commit_certificate.to_dict())
            }
        '''
        # TODO: Only send bubble slot message instead of all.
        while 1:
            await asyncio.sleep(self._sync_interval)
            commit_certificates = {}
            for i in range(self._ckpt.next_slot,
                           self._ckpt.get_commit_upperbound()):
                slot = str(i)
                if (slot in self._status_by_slot) and (
                        self._status_by_slot[slot].commit_certificate):
                    status = self._status_by_slot[slot]
                    commit_certificates[
                        slot] = status.commit_certificate.to_dict()
            json_data = {
                'checkpoint': self._ckpt.get_ckpt_info(),
                'commit_certificates': commit_certificates
            }
            await self._post(self._nodes, MessageType.RECEIVE_SYNC, json_data)
示例#13
0
class ExperimentDesign(object):
    """
    run experiments with pre-defined pipeline
    """
    def __init__(self):
        self.settings = Option()
        self.checkpoint = None
        self.data_loader = None
        self.model = None

        self.trainer = None
        self.seg_opt_state = None
        self.fc_opt_state = None
        self.aux_fc_state = None
        self.start_epoch = 0

        self.model_analyse = None

        self.visualize = vs.Visualization(self.settings.save_path)
        self.logger = vs.Logger(self.settings.save_path)

        self.prepare()

    def prepare(self):
        """
        preparing experiments
        """
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_trainer()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        cudnn.benchmark = True

    def _set_dataloader(self):
        # create data loader
        self.data_loader = DataLoader(dataset=self.settings.dataset,
                                      batch_size=self.settings.batchSize,
                                      data_path=self.settings.dataPath,
                                      n_threads=self.settings.nThreads,
                                      ten_crop=self.settings.tenCrop)

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            model_state = check_point_params["model"]
            self.seg_opt_state = check_point_params["seg_opt"]
            self.fc_opt_state = check_point_params["fc_opt"]
            self.aux_fc_state = check_point_params["aux_fc"]
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.start_epoch = 90

    def _set_model(self):
        print("netType:", self.settings.netType)
        if self.settings.dataset in ["cifar10", "cifar100"]:
            if self.settings.netType == "DARTSNet":
                genotype = md.genotypes.DARTS
                self.model = md.DARTSNet(self.settings.init_channels,
                                         self.settings.nClasses,
                                         self.settings.layers,
                                         self.settings.auxiliary, genotype)
            elif self.settings.netType == "PreResNet":
                self.model = md.PreResNet(depth=self.settings.depth,
                                          num_classes=self.settings.nClasses,
                                          wide_factor=self.settings.wideFactor)
            elif self.settings.netType == "CifarResNeXt":
                self.model = md.CifarResNeXt(self.settings.cardinality,
                                             self.settings.depth,
                                             self.settings.nClasses,
                                             self.settings.base_width,
                                             self.settings.widen_factor)

            elif self.settings.netType == "ResNet":
                self.model = md.ResNet(self.settings.depth,
                                       self.settings.nClasses)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupported data set: " + self.settings.dataset

    def _set_trainer(self):
        # set lr master
        lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs,
                                   self.settings.lrPolicy)
        params_dict = {
            'power': self.settings.power,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        lr_master.set_params(params_dict=params_dict)
        # set trainer
        train_loader, test_loader = self.data_loader.getloader()
        self.trainer = Trainer(model=self.model,
                               train_loader=train_loader,
                               test_loader=test_loader,
                               lr_master=lr_master,
                               settings=self.settings,
                               logger=self.logger)
        if self.seg_opt_state is not None:
            self.trainer.resume(aux_fc_state=self.aux_fc_state,
                                seg_opt_state=self.seg_opt_state,
                                fc_opt_state=self.fc_opt_state)

    def _save_best_model(self, model, aux_fc):

        check_point_params = {}
        if isinstance(model, nn.DataParallel):
            check_point_params["model"] = model.module.state_dict()
        else:
            check_point_params["model"] = model.state_dict()
        aux_fc_state = []
        for i in range(len(aux_fc)):
            if isinstance(aux_fc[i], nn.DataParallel):
                aux_fc_state.append(aux_fc[i].module.state_dict())
            else:
                aux_fc_state.append(aux_fc[i].state_dict())
        check_point_params["aux_fc"] = aux_fc_state
        torch.save(
            check_point_params,
            os.path.join(self.checkpoint.save_path,
                         "best_model_with_AuxFC.pth"))

    def _save_checkpoint(self,
                         model,
                         seg_optimizers,
                         fc_optimizers,
                         aux_fc,
                         index=0):
        check_point_params = {}
        if isinstance(model, nn.DataParallel):
            check_point_params["model"] = model.module.state_dict()
        else:
            check_point_params["model"] = model.state_dict()
        seg_opt_state = []
        fc_opt_state = []
        aux_fc_state = []
        for i in range(len(seg_optimizers)):
            seg_opt_state.append(seg_optimizers[i].state_dict())
        for i in range(len(fc_optimizers)):
            fc_opt_state.append(fc_optimizers[i].state_dict())
            if isinstance(aux_fc[i], nn.DataParallel):
                aux_fc_state.append(aux_fc[i].module.state_dict())
            else:
                aux_fc_state.append(aux_fc[i].state_dict())

        check_point_params["seg_opt"] = seg_opt_state
        check_point_params["fc_opt"] = fc_opt_state
        check_point_params["aux_fc"] = aux_fc_state

        torch.save(
            check_point_params,
            os.path.join(self.checkpoint.save_path,
                         "checkpoint_%03d.pth" % index))

    def run(self, run_count=0):
        """
        training and testing
        """

        best_top1 = 100
        best_top5 = 100
        start_time = time.time()
        if self.settings.resume is not None or self.settings.retrain is not None:
            self.trainer.test(0)
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            self.start_epoch = 0

            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            test_error, test_loss, test5_error = self.trainer.test(epoch=epoch)

            # write and print result
            if isinstance(train_error, np.ndarray):
                log_str = "%d\t" % (epoch)
                for i in range(len(train_error)):
                    log_str += "%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                        train_error[i],
                        train_loss[i],
                        test_error[i],
                        test_loss[i],
                        train5_error[i],
                        test5_error[i],
                    )

                best_flag = False
                if best_top1 >= test_error[-1]:
                    best_top1 = test_error[-1]
                    best_top5 = test5_error[-1]
                    best_flag = True

            else:
                log_str = "%d\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t%.4f\t" % (
                    epoch, train_error, train_loss, test_error, test_loss,
                    train5_error, test5_error)
                best_flag = False
                if best_top1 >= test_error:
                    best_top1 = test_error
                    best_top5 = test5_error
                    best_flag = True

            self.visualize.write_log(log_str)

            if best_flag:
                self.checkpoint.save_model(self.trainer.model,
                                           best_flag=best_flag)
                self._save_best_model(self.trainer.model, self.trainer.auxfc)

                print colored(
                    "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n"
                    % (run_count, best_top1, best_top5), "red")
            else:
                print colored(
                    "# %d ==>Best Result is: Top1 Error: %f, Top5 Error: %f\n"
                    % (run_count, best_top1, best_top5), "blue")

            self._save_checkpoint(self.trainer.model,
                                  self.trainer.seg_optimizer,
                                  self.trainer.fc_optimizer,
                                  self.trainer.auxfc)

        end_time = time.time()

        # compute cost time
        time_interval = end_time - start_time
        t_string = "Running Time is: " + \
            str(datetime.timedelta(seconds=time_interval)) + "\n"
        print t_string
        # write cost time to file
        self.visualize.write_readme(t_string)

        # save experimental results
        self.visualize.write_readme(
            "Best Result of all is: Top1 Error: %f, Top5 Error: %f\n" %
            (best_top1, best_top5))
示例#14
0
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_loader = torch.utils.data.DataLoader(FaceDataset(train_config.annoPath,
                                                       transform=transform,
                                                       is_train=True),
                                           batch_size=train_config.batchSize,
                                           shuffle=True,
                                           **kwargs)

# Set model
model = ONet(config.NUM_LANDMARKS)
model = model.to(device)

# Set checkpoint
checkpoint = CheckPoint(train_config.save_path)

# Set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=train_config.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                 milestones=train_config.step,
                                                 gamma=0.1)

# Set trainer
logger = Logger(train_config.save_path)
trainer = ONetTrainer(train_config.lr, train_loader, model, optimizer,
                      scheduler, logger, device)

for epoch in range(1, train_config.nEpochs + 1):
    trainer.train(epoch)
    checkpoint.save_model(model,
class ExperimentDesign:
    def __init__(self, options=None):
        self.settings = options or Option()
        self.checkpoint = None
        self.train_loader = None
        self.test_loader = None
        self.model = None

        self.optimizer_state = None
        self.trainer = None
        self.start_epoch = 0
        self.test_input = None
        self.model_analyse = None

        os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.visible_devices

        self.settings.set_save_path()
        self.logger = self.set_logger()
        self.settings.paramscheck(self.logger)
        self.visualize = vs.Visualization(self.settings.save_path, self.logger)
        self.tensorboard_logger = vs.Logger(self.settings.save_path)

        self.prepare()

    def set_logger(self):
        logger = logging.getLogger('sphereface')
        file_formatter = logging.Formatter(
            '%(asctime)s %(levelname)s: %(message)s')
        console_formatter = logging.Formatter('%(message)s')
        # file log
        file_handler = logging.FileHandler(
            os.path.join(self.settings.save_path, "train_test.log"))
        file_handler.setFormatter(file_formatter)

        # console log
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(console_formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        logger.setLevel(logging.INFO)
        return logger

    def prepare(self):
        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_trainer()
        self._draw_net()

    def _set_gpu(self):
        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.manualSeed)
        torch.cuda.manual_seed(self.settings.manualSeed)
        assert self.settings.GPU <= torch.cuda.device_count(
        ) - 1, "Invalid GPU ID"
        torch.cuda.set_device(self.settings.GPU)
        cudnn.benchmark = True

    def _set_dataloader(self):
        # create data loader
        data_loader = DataLoader(dataset=self.settings.dataset,
                                 batch_size=self.settings.batchSize,
                                 data_path=self.settings.dataPath,
                                 n_threads=self.settings.nThreads,
                                 ten_crop=self.settings.tenCrop,
                                 logger=self.logger)
        self.train_loader, self.test_loader = data_loader.getloader()

    def _set_checkpoint(self):
        assert self.model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path, self.logger)
        if self.settings.retrain is not None:
            model_state = self.checkpoint.load_model(self.settings.retrain)
            self.model = self.checkpoint.load_state(self.model, model_state)

        if self.settings.resume is not None:
            model_state, optimizer_state, epoch = self.checkpoint.load_checkpoint(
                self.settings.resume)
            self.model = self.checkpoint.load_state(self.model, model_state)
            self.start_epoch = epoch
            self.optimizer_state = optimizer_state

    def _set_model(self):
        if self.settings.dataset in ["sphere", "sphere_large"]:
            self.test_input = Variable(torch.randn(1, 3, 112, 96).cuda())
            if self.settings.netType == "SphereNet":
                self.model = md.SphereNet(
                    depth=self.settings.depth,
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereNIN":
                self.model = md.SphereNIN(
                    num_features=self.settings.featureDim)

            elif self.settings.netType == "SphereMobileNet_v2":
                self.model = md.SphereMobleNet_v2(
                    num_features=self.settings.featureDim)
            else:
                assert False, "use %s data while network is %s" % (
                    self.settings.dataset, self.settings.netType)
        else:
            assert False, "unsupport data set: " + self.settings.dataset

    def _set_trainer(self):
        lr_master = utils.LRPolicy(self.settings.lr, self.settings.nEpochs,
                                   self.settings.lrPolicy)
        params_dict = {
            'power': self.settings.power,
            'step': self.settings.step,
            'end_lr': self.settings.endlr,
            'decay_rate': self.settings.decayRate
        }

        lr_master.set_params(params_dict=params_dict)

        self.trainer = Trainer(model=self.model,
                               train_loader=self.train_loader,
                               test_loader=self.test_loader,
                               lr_master=lr_master,
                               settings=self.settings,
                               logger=self.logger,
                               tensorboard_logger=self.tensorboard_logger,
                               optimizer_state=self.optimizer_state)

    def _draw_net(self):
        # visualize model
        if self.settings.drawNetwork:
            rand_output, _ = self.trainer.forward(self.test_input)
            self.visualize.save_network(rand_output)
            self.logger.info("|===>Draw network done!")

        self.visualize.write_settings(self.settings)

    def _model_analyse(self, model):
        # analyse model
        model_analyse = utils.ModelAnalyse(model, self.visualize)
        params_num = model_analyse.params_count()
        zero_num = model_analyse.zero_count()
        zero_rate = zero_num * 1.0 / params_num
        self.logger.info("zero rate is: {}".format(zero_rate))

        # save analyse result to file
        self.visualize.write_readme(
            "Number of parameters is: %d, number of zeros is: %d, zero rate is: %f"
            % (params_num, zero_num, zero_rate))

        # model_analyse.flops_compute(self.test_input)
        model_analyse.madds_compute(self.test_input)

    def run(self, run_count=0):
        self.logger.info("|===>Start training")
        best_top1 = 100
        start_time = time.time()
        # self._model_analyse(self.model)
        # assert False
        self.trainer.test()
        # assert False
        for epoch in range(self.start_epoch, self.settings.nEpochs):
            if self.trainer.iteration >= self.settings.nIters:
                break
            self.start_epoch = 0
            # training and testing
            train_error, train_loss, train5_error = self.trainer.train(
                epoch=epoch)
            acc_mean, acc_std, acc_all = self.trainer.test()

            test_error_mean = 100 - acc_mean * 100
            # write and print result
            log_str = "{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t".format(
                epoch, train_error, train_loss, train5_error, acc_mean,
                acc_std)
            for acc in acc_all:
                log_str += "%.4f\t" % acc

            self.visualize.write_log(log_str)
            best_flag = False
            if best_top1 >= test_error_mean:
                best_top1 = test_error_mean
                best_flag = True
                self.logger.info(
                    "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format(
                        run_count, best_top1))
            else:
                self.logger.info(
                    "# {:d} ==>Best Result is: Top1 Error: {:f}\n".format(
                        run_count, best_top1))

            self.checkpoint.save_checkpoint(self.model, self.trainer.optimizer,
                                            epoch)

            if best_flag:
                self.checkpoint.save_model(self.model, best_flag=best_flag)

            if (epoch + 1) % self.settings.drawInterval == 0:
                self.visualize.draw_curves()

        end_time = time.time()

        if isinstance(self.model, nn.DataParallel):
            self.model = self.model.module
        # draw experimental curves
        self.visualize.draw_curves()

        # compute cost time
        time_interval = end_time - start_time
        t_string = "Running Time is: " + \
                   str(datetime.timedelta(seconds=time_interval)) + "\n"
        self.logger.info(t_string)
        # write cost time to file
        self.visualize.write_readme(t_string)
        # analyse model
        self._model_analyse(self.model)

        return best_top1
示例#16
0
     auc_last = 0
 lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
     optimizer_model,
     mode="min",
     factor=0.1,
     patience=0,
     verbose=True,
     min_lr=1e-8,
     threshold=0.0001,
     threshold_mode='abs')
 path_ckpt = '{}/ckpt/{}'.format(ROOT_DIR, exp_name)
 # learning checkpointer
 ckpter = CheckPoint(model=model,
                     optimizer=optimizer_model,
                     path=path_ckpt,
                     prefix=run_name,
                     interval=1,
                     save_num=n_save_epoch,
                     loss0=loss0)
 ckpter_lr = CheckPoint(model=logisticReg,
                        optimizer=optimizer_model,
                        path=path_ckpt,
                        prefix=run_name + '_lr',
                        interval=1,
                        save_num=n_save_epoch,
                        loss0=loss0)
 ckpter_auc = CheckPoint(model=model,
                         optimizer=optimizer_model,
                         path=path_ckpt,
                         prefix=run_name,
                         interval=1,
class Experiment(object):
    """
    run experiments with pre-defined pipeline
    """
    def __init__(self, options=None, conf_path=None):
        self.settings = options or Option(conf_path)
        self.checkpoint = None
        self.train_loader = None
        self.val_loader = None
        self.pruned_model = None
        self.network_wise_trainer = None
        self.optimizer_state = None

        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu

        self.settings.set_save_path()
        self.write_settings()
        self.logger = self.set_logger()
        self.tensorboard_logger = TensorboardLogger(self.settings.save_path)

        self.epoch = 0

        self.prepare()

    def write_settings(self):
        """
        save expriment settings to a file
        """

        with open(os.path.join(self.settings.save_path, "settings.log"),
                  "w") as f:
            for k, v in self.settings.__dict__.items():
                f.write(str(k) + ": " + str(v) + "\n")

    def set_logger(self):
        """
        initialize logger
        """

        logger = logging.getLogger('channel_selection')
        file_formatter = logging.Formatter(
            '%(asctime)s %(levelname)s: %(message)s')
        console_formatter = logging.Formatter('%(message)s')
        # file log
        file_handler = logging.FileHandler(
            os.path.join(self.settings.save_path, "train_test.log"))
        file_handler.setFormatter(file_formatter)

        # console log
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(console_formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        logger.setLevel(logging.INFO)
        return logger

    def prepare(self):
        """
        preparing experiments
        """

        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._set_checkpoint()
        self._set_trainier()

    def _set_gpu(self):
        """
        initialize the seed of random number generator
        """

        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.seed)
        torch.cuda.manual_seed(self.settings.seed)
        torch.cuda.set_device(0)
        cudnn.benchmark = True

    def _set_dataloader(self):
        """
        create train loader and validation loader for channel pruning
        """

        if self.settings.dataset == 'cifar10':
            data_root = os.path.join(self.settings.data_path, "cifar")

            norm_mean = [0.49139968, 0.48215827, 0.44653124]
            norm_std = [0.24703233, 0.24348505, 0.26158768]
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(norm_mean, norm_std)
            ])
            val_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(norm_mean, norm_std)
            ])

            train_dataset = datasets.CIFAR10(root=data_root,
                                             train=True,
                                             transform=train_transform,
                                             download=True)
            val_dataset = datasets.CIFAR10(root=data_root,
                                           train=False,
                                           transform=val_transform)

            self.train_loader = torch.utils.data.DataLoader(
                dataset=train_dataset,
                batch_size=self.settings.batch_size,
                shuffle=True,
                pin_memory=True,
                num_workers=self.settings.n_threads)
            self.val_loader = torch.utils.data.DataLoader(
                dataset=val_dataset,
                batch_size=self.settings.batch_size,
                shuffle=False,
                pin_memory=True,
                num_workers=self.settings.n_threads)
        elif self.settings.dataset == 'imagenet':
            dataset_path = os.path.join(self.settings.data_path, "imagenet")
            traindir = os.path.join(dataset_path, "train")
            valdir = os.path.join(dataset_path, 'val')
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])

            self.train_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    traindir,
                    transforms.Compose([
                        transforms.RandomResizedCrop(224),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=self.settings.batch_size,
                shuffle=True,
                num_workers=self.settings.n_threads,
                pin_memory=True)

            self.val_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    valdir,
                    transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=self.settings.batch_size,
                shuffle=False,
                num_workers=self.settings.n_threads,
                pin_memory=True)

    def _set_model(self):
        """
        get model
        """

        if self.settings.dataset in ["cifar10", "cifar100"]:
            if self.settings.net_type == "preresnet":
                self.pruned_model = md.PreResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        elif self.settings.dataset in ["imagenet", "imagenet_mio"]:
            if self.settings.net_type == "resnet":
                self.pruned_model = md.ResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        else:
            assert False, "unsupported data set: {}".format(
                self.settings.dataset)

        # replace the conv layer in resnet with mask_conv
        if self.settings.net_type in ["preresnet", "resnet"]:
            for module in self.pruned_model.modules():
                if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)):
                    # replace conv2
                    temp_conv = MaskConv2d(
                        in_channels=module.conv2.in_channels,
                        out_channels=module.conv2.out_channels,
                        kernel_size=module.conv2.kernel_size,
                        stride=module.conv2.stride,
                        padding=module.conv2.padding,
                        bias=(module.conv2.bias is not None))

                    temp_conv.weight.data.copy_(module.conv2.weight.data)
                    if module.conv2.bias is not None:
                        temp_conv.bias.data.copy_(module.conv2.bias.data)
                    module.conv2 = temp_conv

                    if isinstance(module, (Bottleneck)):
                        # replace conv3
                        temp_conv = MaskConv2d(
                            in_channels=module.conv3.in_channels,
                            out_channels=module.conv3.out_channels,
                            kernel_size=module.conv3.kernel_size,
                            stride=module.conv3.stride,
                            padding=module.conv3.padding,
                            bias=(module.conv3.bias is not None))

                        temp_conv.weight.data.copy_(module.conv3.weight.data)
                        if module.conv3.bias is not None:
                            temp_conv.bias.data.copy_(module.conv3.bias.data)
                        module.conv3 = temp_conv

    def _set_checkpoint(self):
        """
        load pre-trained model or resume checkpoint
        """

        assert self.pruned_model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path, self.logger)
        self._load_pretrained()
        self._load_resume()

    def _load_pretrained(self):
        """
        load pre-trained model
        """

        if self.settings.retrain is not None:
            check_point_params = torch.load(self.settings.retrain)
            model_state = check_point_params["pruned_model"]
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, model_state)
            self.logger.info("|===>load restrain file: {}".format(
                self.settings.retrain))

    def _load_resume(self):
        """
        load resume checkpoint
        """

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            pruned_model_state = check_point_params["pruned_model"]
            self.optimizer_state = check_point_params["optimizer_state"]
            self.epoch = check_point_params["epoch"]
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, pruned_model_state)
            self.logger.info("|===>load resume file: {}".format(
                self.settings.resume))

    def _set_trainier(self):
        """
        initialize network-wise trainer
        """

        self.network_wise_trainer = NetworkWiseTrainer(
            pruned_model=self.pruned_model,
            train_loader=self.train_loader,
            val_loader=self.val_loader,
            settings=self.settings,
            logger=self.logger,
            tensorboard_logger=self.tensorboard_logger)

    def pruning(self):
        """
        prune channels
        """

        self.logger.info(self.pruned_model)
        self.network_wise_trainer.val(0)

        if self.settings.net_type in ["preresnet", "resnet"]:
            model_prune = ResModelPrune(model=self.pruned_model,
                                        net_type=self.settings.net_type,
                                        depth=self.settings.depth)
        else:
            assert False, "unsupport net_type: {}".format(
                self.settings.net_type)

        model_prune.run()
        self.network_wise_trainer.update_model(model_prune.model,
                                               self.optimizer_state)

        self.network_wise_trainer.val(0)
        self.logger.info(self.pruned_model)

    def fine_tuning(self):
        """
        conduct network-wise fine-tuning
        """

        best_top1 = 100
        best_top5 = 100

        start_epoch = 0
        if self.epoch != 0:
            start_epoch = self.epoch + 1
            self.epoch = 0

        self.network_wise_trainer.val(0)

        for epoch in range(start_epoch, self.settings.network_wise_n_epochs):
            train_error, train_loss, train5_error = self.network_wise_trainer.train(
                epoch)
            val_error, val_loss, val5_error = self.network_wise_trainer.val(
                epoch)

            for module in self.pruned_model.modules():
                if isinstance(module, MaskConv2d):
                    print(module.pruned_weight[0, :, 0, 0].eq(0).sum())

            # write and print result
            best_flag = False
            if best_top1 >= val_error:
                best_top1 = val_error
                best_top5 = val5_error
                best_flag = True

            if best_flag:
                self.checkpoint.save_network_wise_fine_tune_model(
                    self.pruned_model, best_flag)

            self.logger.info(
                "|===>Best Result is: Top1 Error: {:f}, Top5 Error: {:f}\n".
                format(best_top1, best_top5))
            self.logger.info(
                "|==>Best Result is: Top1 Accuracy: {:f}, Top5 Accuracy: {:f}\n"
                .format(100 - best_top1, 100 - best_top5))

            if self.settings.dataset in ["imagenet"]:
                self.checkpoint.save_network_wise_fine_tune_checkpoint(
                    self.pruned_model, self.network_wise_trainer.optimizer,
                    epoch, epoch)
            else:
                self.checkpoint.save_network_wise_fine_tune_checkpoint(
                    self.pruned_model, self.network_wise_trainer.optimizer,
                    epoch)
示例#18
0
文件: train.py 项目: HaoKun-Li/HCCR
                                               shuffle=True,
                                               drop_last=True,
                                               **kwargs)

    print(len(train_loader.dataset))
    print(len(valid_loader.dataset))

    #Set model
    model = FastCNN()
    para = sum([np.prod(list(p.size())) for p in model.parameters()])
    print('Model {} : params: {:4f}M'.format(model._get_name(),
                                             para * 4 / 1000 / 1000))
    model = model.to(device)

    # Set checkpoint
    checkpoint = CheckPoint(config.save_path)

    # Set optimizer
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=config.lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=config.step,
                                                     gamma=0.1)

    # Set trainer
    logger = Logger(config.save_path)
    trainer = FastCNNTrainer(config.lr, train_loader, valid_loader, model,
                             optimizer, scheduler, logger, device)

    print(model)
示例#19
0
        last_epoch = int(reporter.select_last(run=run_name).last_epoch)
        loss0 = reporter.select_last(run=run_name).last_loss
        loss0 = float(loss0[:-4])
        model.load_state_dict(
            torch.load(last_model_filename)['model_state_dict'])
    else:
        last_epoch = -1
        loss0 = 0

    optimizer_model = torch.optim.Adam(model.parameters(), lr=lr)
    path_ckpt = '{}/ckpt/{}'.format(ROOT_DIR, triplet_method)
    # learning embedding checkpointer.
    ckpter = CheckPoint(model=model,
                        optimizer=optimizer_model,
                        path=path_ckpt,
                        prefix=run_name,
                        interval=1,
                        save_num=n_save_epoch,
                        loss0=loss0)
    ckpter_v2 = CheckPoint(model=model,
                           optimizer=optimizer_model,
                           path=path_ckpt,
                           prefix='X' + run_name,
                           interval=1,
                           save_num=n_save_epoch,
                           loss0=loss0)
    train_hist = History(name='train_hist' + run_name)
    validation_hist = History(name='validation_hist' + run_name)
    #  --------------------------------------------------------------------------------------
    # Computing metrics on validation set before starting training
    #  --------------------------------------------------------------------------------------
示例#20
0
device = torch.device("cuda" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

# Set dataloader
kwargs = {'num_workers': config.nThreads, 'pin_memory': True} if use_cuda else {}
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_loader = torch.utils.data.DataLoader(
    FaceDataset(config.annoPath, transform=transform, is_train=True), batch_size=config.batchSize, shuffle=True, **kwargs)

# Set model
model = ONet()
model = model.to(device)

# Set checkpoint
checkpoint = CheckPoint(config.save_path)

# Set optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=config.step, gamma=0.1)

# Set trainer
logger = Logger(config.save_path)
trainer = ONetTrainer(config.lr, train_loader, model, optimizer, scheduler, logger, device)

for epoch in range(1, config.nEpochs + 1):
    trainer.train(epoch)
    checkpoint.save_model(model, index=epoch)
示例#21
0
import os
import logging
import numpy as np
from dataloader import create
from checkpoint import CheckPoint, update_history
from models.init import setup
from train import Trainer
from option import args
import utility

if __name__ == "__main__":
    torch.manual_seed(0)
    np.random.seed(0)

    loaders = create(args)
    check_p, optim_state = CheckPoint.latest(args)
    model = setup(args, check_p)
    trainer = Trainer(model, args, optim_state)

    start_epoch = check_p['epoch'] if check_p else args.start_epoch

    if args.val_only:
        results = trainer.test(0, loaders[1], 'val')
        exit(0)

    for epoch in range(start_epoch, args.n_epochs):
        train_loss = trainer.train(epoch, loaders[0], 'train')
        update_history(args, epoch + 1, train_loss, 'train')

        if (epoch + 1) % args.save_interval == 0:
            print('\n\n===== Epoch {} saving checkpoint ====='.format(epoch +
示例#22
0
文件: train.py 项目: HaoKun-Li/HCCR
    para_1 = sum([np.prod(list(p.size())) for p in model_1.parameters()])
    para_2 = sum([np.prod(list(p.size())) for p in model_2.parameters()])
    para_3 = sum([np.prod(list(p.size())) for p in model_3.parameters()])
    print('Model_1 {} : params: {:4f}M'.format(model_1._get_name(),
                                               para_1 * 4 / 1000 / 1000))
    print('Model_2 {} : params: {:4f}M'.format(model_2._get_name(),
                                               para_2 * 4 / 1000 / 1000))
    print('Model_3 {} : params: {:4f}M'.format(model_3._get_name(),
                                               para_3 * 4 / 1000 / 1000))
    model_1 = model_1.to(device)
    model_2 = model_2.to(device)
    model_3 = model_3.to(device)

    # Set checkpoint
    checkpoint = CheckPoint(config.save_path)

    # Set optimizer
    optimizer_1 = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                          model_1.parameters()),
                                   lr=0.0001)
    scheduler_1 = torch.optim.lr_scheduler.MultiStepLR(optimizer_1,
                                                       milestones=config.step,
                                                       gamma=0.1)

    optimizer_2 = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                          model_2.parameters()),
                                   lr=0.01)
    scheduler_2 = torch.optim.lr_scheduler.MultiStepLR(optimizer_2,
                                                       milestones=config.step,
                                                       gamma=0.1)
示例#23
0
                                               batch_size=config.batchSize,
                                               shuffle=True,
                                               **kwargs)

    print(len(train_loader.dataset))
    print(len(valid_loader.dataset))

    #Set model
    model = YangNet()
    para = sum([np.prod(list(p.size())) for p in model.parameters()])
    print('Model {} : params: {:4f}M'.format(model._get_name(),
                                             para * 4 / 1000 / 1000))
    model = model.to(device)

    # Set checkpoint
    checkpoint = CheckPoint(config.save_path)

    # Set optimizer
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=config.lr)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=config.step,
                                                     gamma=0.1)

    # Set trainer
    logger = Logger(config.save_path)
    trainer = AlexNetTrainer(config.lr, train_loader, valid_loader, model,
                             optimizer, scheduler, logger, device)

    print(model)
class Experiment(object):
    """
    run experiments with pre-defined pipeline
    """
    def __init__(self, options=None, conf_path=None):
        self.settings = options or Option(conf_path)
        self.checkpoint = None
        self.train_loader = None
        self.val_loader = None
        self.ori_model = None
        self.pruned_model = None
        self.segment_wise_trainer = None

        self.aux_fc_state = None
        self.aux_fc_opt_state = None
        self.seg_opt_state = None
        self.current_pivot_index = None
        self.is_segment_wise_finetune = False
        self.is_channel_selection = False

        self.epoch = 0

        self.feature_cache_origin = {}
        self.feature_cache_pruned = {}

        os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.gpu

        self.settings.set_save_path()
        self.write_settings()
        self.logger = self.set_logger()
        self.tensorboard_logger = TensorboardLogger(self.settings.save_path)

        self.prepare()

    def write_settings(self):
        """
        save experimental settings to a file
        """

        with open(os.path.join(self.settings.save_path, "settings.log"),
                  "w") as f:
            for k, v in self.settings.__dict__.items():
                f.write(str(k) + ": " + str(v) + "\n")

    def set_logger(self):
        """
        initialize logger
        """
        logger = logging.getLogger('channel_selection')
        file_formatter = logging.Formatter(
            '%(asctime)s %(levelname)s: %(message)s')
        console_formatter = logging.Formatter('%(message)s')
        # file log
        file_handler = logging.FileHandler(
            os.path.join(self.settings.save_path, "train_test.log"))
        file_handler.setFormatter(file_formatter)

        # console log
        console_handler = logging.StreamHandler(sys.stdout)
        console_handler.setFormatter(console_formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

        logger.setLevel(logging.INFO)
        return logger

    def prepare(self):
        """
        preparing experiments
        """

        self._set_gpu()
        self._set_dataloader()
        self._set_model()
        self._cal_pivot()
        self._set_checkpoint()
        self._set_trainier()

    def _set_gpu(self):
        """
        initialize the seed of random number generator
        """

        # set torch seed
        # init random seed
        torch.manual_seed(self.settings.seed)
        torch.cuda.manual_seed(self.settings.seed)
        torch.cuda.set_device(0)
        cudnn.benchmark = True

    def _set_dataloader(self):
        """
        create train loader and validation loader for channel pruning
        """

        if self.settings.dataset == 'cifar10':
            data_root = os.path.join(self.settings.data_path, "cifar")

            norm_mean = [0.49139968, 0.48215827, 0.44653124]
            norm_std = [0.24703233, 0.24348505, 0.26158768]
            train_transform = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(norm_mean, norm_std)
            ])
            val_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(norm_mean, norm_std)
            ])

            train_dataset = datasets.CIFAR10(root=data_root,
                                             train=True,
                                             transform=train_transform,
                                             download=True)
            val_dataset = datasets.CIFAR10(root=data_root,
                                           train=False,
                                           transform=val_transform)

            self.train_loader = torch.utils.data.DataLoader(
                dataset=train_dataset,
                batch_size=self.settings.batch_size,
                shuffle=True,
                pin_memory=True,
                num_workers=self.settings.n_threads)
            self.val_loader = torch.utils.data.DataLoader(
                dataset=val_dataset,
                batch_size=self.settings.batch_size,
                shuffle=False,
                pin_memory=True,
                num_workers=self.settings.n_threads)
        elif self.settings.dataset == 'imagenet':
            dataset_path = os.path.join(self.settings.data_path, "imagenet")
            traindir = os.path.join(dataset_path, "train")
            valdir = os.path.join(dataset_path, 'val')
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])

            self.train_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    traindir,
                    transforms.Compose([
                        transforms.RandomResizedCrop(224),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=self.settings.batch_size,
                shuffle=True,
                num_workers=self.settings.n_threads,
                pin_memory=True)

            self.val_loader = torch.utils.data.DataLoader(
                datasets.ImageFolder(
                    valdir,
                    transforms.Compose([
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ])),
                batch_size=self.settings.batch_size,
                shuffle=False,
                num_workers=self.settings.n_threads,
                pin_memory=True)

    def _set_trainier(self):
        """
        initialize segment-wise trainer trainer
        """

        # initialize segment-wise trainer
        self.segment_wise_trainer = SegmentWiseTrainer(
            ori_model=self.ori_model,
            pruned_model=self.pruned_model,
            train_loader=self.train_loader,
            val_loader=self.val_loader,
            settings=self.settings,
            logger=self.logger,
            tensorboard_logger=self.tensorboard_logger)
        if self.aux_fc_state is not None:
            self.segment_wise_trainer.update_aux_fc(self.aux_fc_state,
                                                    self.aux_fc_opt_state,
                                                    self.seg_opt_state)

    def _set_model(self):
        """
        get model
        """

        if self.settings.dataset in ["cifar10", "cifar100"]:
            if self.settings.net_type == "preresnet":
                self.ori_model = md.PreResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
                self.pruned_model = md.PreResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        elif self.settings.dataset in ["imagenet"]:
            if self.settings.net_type == "resnet":
                self.ori_model = md.ResNet(depth=self.settings.depth,
                                           num_classes=self.settings.n_classes)
                self.pruned_model = md.ResNet(
                    depth=self.settings.depth,
                    num_classes=self.settings.n_classes)
            else:
                assert False, "use {} data while network is {}".format(
                    self.settings.dataset, self.settings.net_type)

        else:
            assert False, "unsupported data set: {}".format(
                self.settings.dataset)

    def _set_checkpoint(self):
        """
        load pre-trained model or resume checkpoint
        """

        assert self.ori_model is not None and self.pruned_model is not None, "please create model first"

        self.checkpoint = CheckPoint(self.settings.save_path, self.logger)
        self._load_retrain()
        self._load_resume()

    def _load_retrain(self):
        """
        load pre-trained model
        """

        if self.settings.retrain is not None:
            check_point_params = torch.load(self.settings.retrain)
            if "ori_model" not in check_point_params:
                model_state = check_point_params
                self.ori_model = self.checkpoint.load_state(
                    self.ori_model, model_state)
                self.pruned_model = self.checkpoint.load_state(
                    self.pruned_model, model_state)
                self.logger.info("|===>load restrain file: {}".format(
                    self.settings.retrain))
            else:
                ori_model_state = check_point_params["ori_model"]
                pruned_model_state = check_point_params["pruned_model"]
                # self.current_block_count = check_point_params["current_pivot"]
                self.aux_fc_state = check_point_params["aux_fc"]
                # self.replace_layer()
                self.ori_model = self.checkpoint.load_state(
                    self.ori_model, ori_model_state)
                self.pruned_model = self.checkpoint.load_state(
                    self.pruned_model, pruned_model_state)
                self.logger.info("|===>load pre-trained model: {}".format(
                    self.settings.retrain))

    def _load_resume(self):
        """
        load resume checkpoint
        """

        if self.settings.resume is not None:
            check_point_params = torch.load(self.settings.resume)
            ori_model_state = check_point_params["ori_model"]
            pruned_model_state = check_point_params["pruned_model"]
            self.aux_fc_state = check_point_params["aux_fc"]
            self.aux_fc_opt_state = check_point_params["aux_fc_opt"]
            self.seg_opt_state = check_point_params["seg_opt"]
            self.current_pivot_index = check_point_params["current_pivot"]
            self.is_segment_wise_finetune = check_point_params[
                "segment_wise_finetune"]
            self.is_channel_selection = check_point_params["channel_selection"]
            self.epoch = check_point_params["epoch"]
            self.epoch = self.settings.segment_wise_n_epochs
            self.current_block_count = check_point_params[
                "current_block_count"]

            if self.is_channel_selection or \
                    (self.is_segment_wise_finetune and self.current_pivot_index > self.settings.pivot_set[0]):
                self.replace_layer()
            self.ori_model = self.checkpoint.load_state(
                self.ori_model, ori_model_state)
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, pruned_model_state)
            self.logger.info("|===>load resume file: {}".format(
                self.settings.resume))

    def _cal_pivot(self):
        """
        calculate the inserted layer for additional loss
        """

        self.num_segments = self.settings.n_losses + 1
        num_block_per_segment = (
            block_num[self.settings.net_type + str(self.settings.depth)] //
            self.num_segments) + 1
        pivot_set = []
        for i in range(self.num_segments - 1):
            pivot_set.append(num_block_per_segment * (i + 1))
        self.settings.pivot_set = pivot_set
        self.logger.info("pivot set: {}".format(pivot_set))

    def segment_wise_fine_tune(self, index):
        """
        conduct segment-wise fine-tuning
        :param index: segment index
        """

        best_top1 = 100
        best_top5 = 100

        start_epoch = 0
        if self.is_segment_wise_finetune and self.epoch != 0:
            start_epoch = self.epoch + 1
            self.epoch = 0
        for epoch in range(start_epoch, self.settings.segment_wise_n_epochs):
            train_error, train_loss, train5_error = self.segment_wise_trainer.train(
                epoch, index)
            val_error, val_loss, val5_error = self.segment_wise_trainer.val(
                epoch)

            # write and print result
            if isinstance(train_error, list):
                best_flag = False
                if best_top1 >= val_error[-1]:
                    best_top1 = val_error[-1]
                    best_top5 = val5_error[-1]
                    best_flag = True

            else:
                best_flag = False
                if best_top1 >= val_error:
                    best_top1 = val_error
                    best_top5 = val5_error
                    best_flag = True

            if best_flag:
                self.checkpoint.save_model(
                    ori_model=self.ori_model,
                    pruned_model=self.pruned_model,
                    aux_fc=self.segment_wise_trainer.aux_fc,
                    current_pivot=self.current_pivot_index,
                    segment_wise_finetune=True,
                    index=index)

            self.logger.info(
                "|===>Best Result is: Top1 Error: {:f}, Top5 Error: {:f}\n".
                format(best_top1, best_top5))
            self.logger.info(
                "|===>Best Result is: Top1 Accuracy: {:f}, Top5 Accuracy: {:f}\n"
                .format(100 - best_top1, 100 - best_top5))

            if self.settings.dataset in ["imagenet"]:
                self.checkpoint.save_checkpoint(
                    ori_model=self.ori_model,
                    pruned_model=self.pruned_model,
                    aux_fc=self.segment_wise_trainer.aux_fc,
                    aux_fc_opt=self.segment_wise_trainer.fc_optimizer,
                    seg_opt=self.segment_wise_trainer.seg_optimizer,
                    current_pivot=self.current_pivot_index,
                    segment_wise_finetune=True,
                    index=index,
                    epoch=epoch)
            else:
                self.checkpoint.save_checkpoint(
                    ori_model=self.ori_model,
                    pruned_model=self.pruned_model,
                    aux_fc=self.segment_wise_trainer.aux_fc,
                    aux_fc_opt=self.segment_wise_trainer.fc_optimizer,
                    seg_opt=self.segment_wise_trainer.seg_optimizer,
                    current_pivot=self.current_pivot_index,
                    segment_wise_finetune=True,
                    index=index)

    def replace_layer(self):
        """
        Replace the convolutional layer to mask convolutional layer
        """

        block_count = 0
        if self.settings.net_type in ["preresnet", "resnet"]:
            for module in self.pruned_model.modules():
                if isinstance(module, (PreBasicBlock, BasicBlock, Bottleneck)):
                    block_count += 1
                    layer = module.conv2
                    if block_count <= self.current_block_count and not isinstance(
                            layer, MaskConv2d):
                        temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                               out_channels=layer.out_channels,
                                               kernel_size=layer.kernel_size,
                                               stride=layer.stride,
                                               padding=layer.padding,
                                               bias=(layer.bias is not None))
                        temp_conv.weight.data.copy_(layer.weight.data)

                        if layer.bias is not None:
                            temp_conv.bias.data.copy_(layer.bias.data)
                        module.conv2 = temp_conv

                    if isinstance(module, Bottleneck):
                        layer = module.conv3
                        if block_count <= self.current_block_count and not isinstance(
                                layer, MaskConv2d):
                            temp_conv = MaskConv2d(
                                in_channels=layer.in_channels,
                                out_channels=layer.out_channels,
                                kernel_size=layer.kernel_size,
                                stride=layer.stride,
                                padding=layer.padding,
                                bias=(layer.bias is not None))
                            temp_conv.weight.data.copy_(layer.weight.data)

                            if layer.bias is not None:
                                temp_conv.bias.data.copy_(layer.bias.data)
                            module.conv3 = temp_conv

    def channel_selection(self):
        """
        conduct channel selection
        """

        # get testing error
        self.segment_wise_trainer.val(0)
        time_start = time.time()

        restart_index = None
        # find restart segment index
        if self.current_pivot_index:
            if self.current_pivot_index in self.settings.pivot_set:
                restart_index = self.settings.pivot_set.index(
                    self.current_pivot_index)
            else:
                restart_index = len(self.settings.pivot_set)

        for index in range(self.num_segments):
            if restart_index is not None:
                if index < restart_index:
                    continue
                elif index == restart_index:
                    if self.is_channel_selection and self.current_block_count == self.current_pivot_index:
                        self.is_channel_selection = False
                        continue

            if index == self.num_segments - 1:
                self.current_pivot_index = self.segment_wise_trainer.final_block_count
            else:
                self.current_pivot_index = self.settings.pivot_set[index]

            # fine tune the network with additional loss and final loss
            if (not self.is_segment_wise_finetune and not self.is_channel_selection) or \
                    (self.is_segment_wise_finetune and self.epoch != self.settings.segment_wise_n_epochs - 1):
                self.segment_wise_fine_tune(index)
            else:
                self.is_segment_wise_finetune = False

            # load best model
            best_model_path = os.path.join(
                self.checkpoint.save_path,
                'model_{:0>3d}_swft.pth'.format(index))
            check_point_params = torch.load(best_model_path)
            ori_model_state = check_point_params["ori_model"]
            pruned_model_state = check_point_params["pruned_model"]
            aux_fc_state = check_point_params["aux_fc"]
            self.ori_model = self.checkpoint.load_state(
                self.ori_model, ori_model_state)
            self.pruned_model = self.checkpoint.load_state(
                self.pruned_model, pruned_model_state)
            self.segment_wise_trainer.update_model(self.ori_model,
                                                   self.pruned_model,
                                                   aux_fc_state)

            # replace the baseline model
            if index == 0:
                if self.settings.net_type in ['preresnet']:
                    self.ori_model.conv = copy.deepcopy(self.pruned_model.conv)
                    for ori_module, pruned_module in zip(
                            self.ori_model.modules(),
                            self.pruned_model.modules()):
                        if isinstance(ori_module, PreBasicBlock):
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                    self.ori_model.bn = copy.deepcopy(self.pruned_model.bn)
                    self.ori_model.fc = copy.deepcopy(self.pruned_model.fc)
                elif self.settings.net_type in ['resnet']:
                    self.ori_model.conv1 = copy.deepcopy(
                        self.pruned_model.conv)
                    self.ori_model.bn1 = copy.deepcopy(self.pruned_model.bn1)
                    for ori_module, pruned_module in zip(
                            self.ori_model.modules(),
                            self.pruned_model.modules()):
                        if isinstance(ori_module, BasicBlock):
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                        if isinstance(ori_module, Bottleneck):
                            ori_module.conv1 = copy.deepcopy(
                                pruned_module.conv1)
                            ori_module.conv2 = copy.deepcopy(
                                pruned_module.conv2)
                            ori_module.conv3 = copy.deepcopy(
                                pruned_module.conv3)
                            ori_module.bn1 = copy.deepcopy(pruned_module.bn1)
                            ori_module.bn2 = copy.deepcopy(pruned_module.bn2)
                            ori_module.bn3 = copy.deepcopy(pruned_module.bn3)
                            if ori_module.downsample is not None:
                                ori_module.downsample = copy.deepcopy(
                                    pruned_module.downsample)
                    self.ori_model.fc = copy.deepcopy(self.pruned_model.fc)

                aux_fc_state = []
                for i in range(len(self.segment_wise_trainer.aux_fc)):
                    if isinstance(self.segment_wise_trainer.aux_fc[i],
                                  nn.DataParallel):
                        temp_state = self.segment_wise_trainer.aux_fc[
                            i].module.state_dict()
                    else:
                        temp_state = self.segment_wise_trainer.aux_fc[
                            i].state_dict()
                    aux_fc_state.append(temp_state)
                self.segment_wise_trainer.update_model(self.ori_model,
                                                       self.pruned_model,
                                                       aux_fc_state)
            self.segment_wise_trainer.val(0)

            # conduct channel selection
            # contains [0:index] segments
            net_origin_list = []
            net_pruned_list = []
            for j in range(index + 1):
                net_origin_list += utils.model2list(
                    self.segment_wise_trainer.ori_segments[j])
                net_pruned_list += utils.model2list(
                    self.segment_wise_trainer.pruned_segments[j])

            net_origin = nn.Sequential(*net_origin_list)
            net_pruned = nn.Sequential(*net_pruned_list)

            self._seg_channel_selection(
                net_origin=net_origin,
                net_pruned=net_pruned,
                aux_fc=self.segment_wise_trainer.aux_fc[index],
                pivot_index=self.current_pivot_index,
                index=index)

            # update optimizer
            aux_fc_state = []
            for i in range(len(self.segment_wise_trainer.aux_fc)):
                if isinstance(self.segment_wise_trainer.aux_fc[i],
                              nn.DataParallel):
                    temp_state = self.segment_wise_trainer.aux_fc[
                        i].module.state_dict()
                else:
                    temp_state = self.segment_wise_trainer.aux_fc[
                        i].state_dict()
                aux_fc_state.append(temp_state)

            self.segment_wise_trainer.update_model(self.ori_model,
                                                   self.pruned_model,
                                                   aux_fc_state)

            self.checkpoint.save_checkpoint(
                self.ori_model,
                self.pruned_model,
                self.segment_wise_trainer.aux_fc,
                self.segment_wise_trainer.fc_optimizer,
                self.segment_wise_trainer.seg_optimizer,
                self.current_pivot_index,
                channel_selection=True,
                index=index,
                block_count=self.current_pivot_index)

            self.logger.info(self.ori_model)
            self.logger.info(self.pruned_model)
            self.segment_wise_trainer.val(0)
            self.current_pivot_index = None

        self.checkpoint.save_model(self.ori_model,
                                   self.pruned_model,
                                   self.segment_wise_trainer.aux_fc,
                                   self.segment_wise_trainer.final_block_count,
                                   index=self.num_segments)
        time_interval = time.time() - time_start
        log_str = "cost time: {}".format(
            str(datetime.timedelta(seconds=time_interval)))
        self.logger.info(log_str)

    def _hook_origin_feature(self, module, input, output):
        gpu_id = str(output.get_device())
        self.feature_cache_origin[gpu_id] = output

    def _hook_pruned_feature(self, module, input, output):
        gpu_id = str(output.get_device())
        self.feature_cache_pruned[gpu_id] = output

    @staticmethod
    def _concat_gpu_data(data):
        data_cat = data["0"]
        for i in range(1, len(data)):
            data_cat = torch.cat((data_cat, data[str(i)].cuda(0)))
        return data_cat

    def _layer_channel_selection(self,
                                 net_origin,
                                 net_pruned,
                                 aux_fc,
                                 module,
                                 block_count,
                                 layer_name="conv2"):
        """
        conduct channel selection for module
        :param net_origin: original network segments
        :param net_pruned: pruned network segments
        :param aux_fc: auxiliary fully-connected layer
        :param module: the module need to be pruned
        :param block_count: current block no.
        :param layer_name: the name of layer need to be pruned
        """

        self.logger.info(
            "|===>layer-wise channel selection: block-{}-{}".format(
                block_count, layer_name))
        # layer-wise channel selection
        if layer_name == "conv2":
            layer = module.conv2
        elif layer_name == "conv3":
            layer = module.conv3
        else:
            assert False, "unsupport layer: {}".format(layer_name)

        if not isinstance(layer, MaskConv2d):
            temp_conv = MaskConv2d(in_channels=layer.in_channels,
                                   out_channels=layer.out_channels,
                                   kernel_size=layer.kernel_size,
                                   stride=layer.stride,
                                   padding=layer.padding,
                                   bias=(layer.bias is not None))
            temp_conv.weight.data.copy_(layer.weight.data)

            if layer.bias is not None:
                temp_conv.bias.data.copy_(layer.bias.data)
            temp_conv.pruned_weight.data.fill_(0)
            temp_conv.d.fill_(0)

            if layer_name == "conv2":
                module.conv2 = temp_conv
            elif layer_name == "conv3":
                module.conv3 = temp_conv
            layer = temp_conv

        # define criterion
        criterion_mse = nn.MSELoss().cuda()
        criterion_softmax = nn.CrossEntropyLoss().cuda()

        # register hook
        if layer_name == "conv2":
            hook_origin = net_origin[block_count].conv2.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv2.register_forward_hook(
                self._hook_pruned_feature)
        elif layer_name == "conv3":
            hook_origin = net_origin[block_count].conv3.register_forward_hook(
                self._hook_origin_feature)
            hook_pruned = module.conv3.register_forward_hook(
                self._hook_pruned_feature)

        net_origin_parallel = utils.data_parallel(net_origin,
                                                  self.settings.n_gpus)
        net_pruned_parallel = utils.data_parallel(net_pruned,
                                                  self.settings.n_gpus)

        # avoid computing the gradient
        for params in net_origin_parallel.parameters():
            params.requires_grad = False
        for params in net_pruned_parallel.parameters():
            params.requires_grad = False

        net_origin_parallel.eval()
        net_pruned_parallel.eval()

        layer.pruned_weight.requires_grad = True
        aux_fc.cuda()
        logger_counter = 0
        record_time = utils.AverageMeter()

        for channel in range(layer.in_channels):
            if layer.d.eq(0).sum() <= math.floor(
                    layer.in_channels * self.settings.pruning_rate):
                break

            time_start = time.time()
            cum_grad = None
            record_selection_mse_loss = utils.AverageMeter()
            record_selection_softmax_loss = utils.AverageMeter()
            record_selection_loss = utils.AverageMeter()
            img_count = 0
            for i, (images, labels) in enumerate(self.train_loader):
                images = images.cuda()
                labels = labels.cuda()
                net_origin_parallel(images)
                output = net_pruned_parallel(images)
                softmax_loss = criterion_softmax(aux_fc(output), labels)

                origin_feature = self._concat_gpu_data(
                    self.feature_cache_origin)
                self.feature_cache_origin = {}
                pruned_feature = self._concat_gpu_data(
                    self.feature_cache_pruned)
                self.feature_cache_pruned = {}
                mse_loss = criterion_mse(pruned_feature, origin_feature)

                loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                loss.backward()
                record_selection_loss.update(loss.item(), images.size(0))
                record_selection_mse_loss.update(mse_loss.item(),
                                                 images.size(0))
                record_selection_softmax_loss.update(softmax_loss.item(),
                                                     images.size(0))

                if cum_grad is None:
                    cum_grad = layer.pruned_weight.grad.data.clone()
                else:
                    cum_grad.add_(layer.pruned_weight.grad.data)
                    layer.pruned_weight.grad = None

                img_count += images.size(0)
                if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                    break

            # write tensorboard log
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_LossAll".format(block_count, layer_name),
                value=record_selection_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_selection_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="S-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_selection_softmax_loss.avg,
                step=logger_counter)
            cum_grad.abs_()
            # calculate gradient F norm
            grad_fnorm = cum_grad.mul(cum_grad).sum((2, 3)).sqrt().sum(0)

            # find grad_fnorm with maximum absolute gradient
            while True:
                _, max_index = torch.topk(grad_fnorm, 1)
                if layer.d[max_index[0]] == 0:
                    layer.d[max_index[0]] = 1
                    layer.pruned_weight.data[:, max_index[
                        0], :, :] = layer.weight[:,
                                                 max_index[0], :, :].data.clone(
                                                 )
                    break
                else:
                    grad_fnorm[max_index[0]] = -1

            # fine-tune average meter
            record_finetune_softmax_loss = utils.AverageMeter()
            record_finetune_mse_loss = utils.AverageMeter()
            record_finetune_loss = utils.AverageMeter()

            record_finetune_top1_error = utils.AverageMeter()
            record_finetune_top5_error = utils.AverageMeter()

            # define optimizer
            params_list = []
            params_list.append({
                "params": layer.pruned_weight,
                "lr": self.settings.layer_wise_lr
            })
            if layer.bias is not None:
                layer.bias.requires_grad = True
                params_list.append({"params": layer.bias, "lr": 0.001})
            optimizer = torch.optim.SGD(
                params=params_list,
                weight_decay=self.settings.weight_decay,
                momentum=self.settings.momentum,
                nesterov=True)
            img_count = 0
            for epoch in range(1):
                for i, (images, labels) in enumerate(self.train_loader):
                    images = images.cuda()
                    labels = labels.cuda()
                    features = net_pruned_parallel(images)
                    net_origin_parallel(images)
                    output = aux_fc(features)
                    softmax_loss = criterion_softmax(output, labels)

                    origin_feature = self._concat_gpu_data(
                        self.feature_cache_origin)
                    self.feature_cache_origin = {}
                    pruned_feature = self._concat_gpu_data(
                        self.feature_cache_pruned)
                    self.feature_cache_pruned = {}
                    mse_loss = criterion_mse(pruned_feature, origin_feature)

                    top1_error, _, top5_error = utils.compute_singlecrop(
                        outputs=output,
                        labels=labels,
                        loss=softmax_loss,
                        top5_flag=True,
                        mean_flag=True)

                    # update parameters
                    optimizer.zero_grad()
                    loss = mse_loss * self.settings.mse_weight + softmax_loss * self.settings.softmax_weight
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(layer.parameters(),
                                                   max_norm=10.0)
                    layer.pruned_weight.grad.data.mul_(
                        layer.d.unsqueeze(0).unsqueeze(2).unsqueeze(
                            3).expand_as(layer.pruned_weight))
                    optimizer.step()
                    # update record info
                    record_finetune_softmax_loss.update(
                        softmax_loss.item(), images.size(0))
                    record_finetune_mse_loss.update(mse_loss.item(),
                                                    images.size(0))
                    record_finetune_loss.update(loss.item(), images.size(0))
                    record_finetune_top1_error.update(top1_error,
                                                      images.size(0))
                    record_finetune_top5_error.update(top5_error,
                                                      images.size(0))

                    img_count += images.size(0)
                    if self.settings.max_samples != -1 and img_count >= self.settings.max_samples:
                        break

            layer.pruned_weight.grad = None
            if layer.bias is not None:
                layer.bias.requires_grad = False

            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_SoftmaxLoss".format(block_count,
                                                       layer_name),
                value=record_finetune_softmax_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Loss".format(block_count, layer_name),
                value=record_finetune_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_MSELoss".format(block_count, layer_name),
                value=record_finetune_mse_loss.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top1Error".format(block_count, layer_name),
                value=record_finetune_top1_error.avg,
                step=logger_counter)
            self.tensorboard_logger.scalar_summary(
                tag="F-block-{}_{}_Top5Error".format(block_count, layer_name),
                value=record_finetune_top5_error.avg,
                step=logger_counter)

            # write log information to file
            self._write_log(
                dir_name=os.path.join(self.settings.save_path, "log"),
                file_name="log_block-{:0>2d}_{}.txt".format(
                    block_count, layer_name),
                log_str=
                "{:d}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t{:f}\t\n".
                format(int(layer.d.sum()), record_selection_loss.avg,
                       record_selection_mse_loss.avg,
                       record_selection_softmax_loss.avg,
                       record_finetune_loss.avg, record_finetune_mse_loss.avg,
                       record_finetune_softmax_loss.avg,
                       record_finetune_top1_error.avg,
                       record_finetune_top5_error.avg))
            log_str = "Block-{:0>2d}-{}\t#channels: [{:0>4d}|{:0>4d}]\t".format(
                block_count, layer_name, int(layer.d.sum()), layer.d.size(0))
            log_str += "[selection]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_selection_loss.avg, record_selection_mse_loss.avg,
                record_selection_softmax_loss.avg)
            log_str += "[fine-tuning]loss: {:4f}\tmseloss: {:4f}\tsoftmaxloss: {:4f}\t".format(
                record_finetune_loss.avg, record_finetune_mse_loss.avg,
                record_finetune_softmax_loss.avg)
            log_str += "top1error: {:4f}\ttop5error: {:4f}".format(
                record_finetune_top1_error.avg, record_finetune_top5_error.avg)
            self.logger.info(log_str)

            logger_counter += 1
            time_interval = time.time() - time_start
            record_time.update(time_interval)

        for params in net_origin_parallel.parameters():
            params.requires_grad = True
        for params in net_pruned_parallel.parameters():
            params.requires_grad = True

        # remove hook
        hook_origin.remove()
        hook_pruned.remove()
        log_str = "|===>Select channel from block-{:d}_{}: time_total:{} time_avg: {}".format(
            block_count, layer_name,
            str(datetime.timedelta(seconds=record_time.sum)),
            str(datetime.timedelta(seconds=record_time.avg)))
        self.logger.info(log_str)
        log_str = "|===>fine-tuning result: loss: {:f}, mse_loss: {:f}, softmax_loss: {:f}, top1error: {:f} top5error: {:f}".format(
            record_finetune_loss.avg, record_finetune_mse_loss.avg,
            record_finetune_softmax_loss.avg, record_finetune_top1_error.avg,
            record_finetune_top5_error.avg)
        self.logger.info(log_str)

        self.logger.info("|===>remove hook")

    @staticmethod
    def _write_log(dir_name, file_name, log_str):
        """
        Write log to file
        :param dir_name:  the path of directory
        :param file_name: the name of the saved file
        :param log_str: the string that need to be saved
        """

        if not os.path.isdir(dir_name):
            os.mkdir(dir_name)
        with open(os.path.join(dir_name, file_name), "a+") as f:
            f.write(log_str)

    def _seg_channel_selection(self, net_origin, net_pruned, aux_fc,
                               pivot_index, index):
        """
        conduct segment channel selection
        :param net_origin: original network segments
        :param net_pruned: pruned network segments
        :param aux_fc: auxiliary fully-connected layer
        :param pivot_index: the layer index of the additional loss
        :param index: the index of segment
        :return:
        """
        block_count = 0
        if self.settings.net_type in ["preresnet", "resnet"]:
            for module in net_pruned.modules():
                if isinstance(module, (PreBasicBlock, BasicBlock)):
                    block_count += 1
                    # We will not prune the pruned blocks again
                    if not isinstance(module.conv2, MaskConv2d):
                        self._layer_channel_selection(net_origin=net_origin,
                                                      net_pruned=net_pruned,
                                                      aux_fc=aux_fc,
                                                      module=module,
                                                      block_count=block_count,
                                                      layer_name="conv2")
                        self.logger.info("|===>checking layer type: {}".format(
                            type(module.conv2)))

                        self.checkpoint.save_model(
                            self.ori_model,
                            self.pruned_model,
                            self.segment_wise_trainer.aux_fc,
                            pivot_index,
                            channel_selection=True,
                            index=index,
                            block_count=block_count)
                        self.checkpoint.save_checkpoint(
                            self.ori_model,
                            self.pruned_model,
                            self.segment_wise_trainer.aux_fc,
                            self.segment_wise_trainer.fc_optimizer,
                            self.segment_wise_trainer.seg_optimizer,
                            pivot_index,
                            channel_selection=True,
                            index=index,
                            block_count=block_count)

                elif isinstance(module, Bottleneck):
                    block_count += 1
                    if not isinstance(module.conv2, MaskConv2d):
                        self._layer_channel_selection(net_origin=net_origin,
                                                      net_pruned=net_pruned,
                                                      aux_fc=aux_fc,
                                                      module=module,
                                                      block_count=block_count,
                                                      layer_name="conv2")

                    if not isinstance(module.conv3, MaskConv2d):
                        self._layer_channel_selection(net_origin=net_origin,
                                                      net_pruned=net_pruned,
                                                      aux_fc=aux_fc,
                                                      module=module,
                                                      block_count=block_count,
                                                      layer_name="conv3")

                        self.checkpoint.save_model(
                            self.ori_model,
                            self.pruned_model,
                            self.segment_wise_trainer.aux_fc,
                            pivot_index,
                            channel_selection=True,
                            index=index,
                            block_count=block_count)
                        self.checkpoint.save_checkpoint(
                            self.ori_model,
                            self.pruned_model,
                            self.segment_wise_trainer.aux_fc,
                            self.segment_wise_trainer.fc_optimizer,
                            self.segment_wise_trainer.seg_optimizer,
                            pivot_index,
                            channel_selection=True,
                            index=index,
                            block_count=block_count)