Exemple #1
0
    def load_parameters(self, params):
        """
        Override BaseModel.load_parameters

        Write base64 encoded PyTorch model state dict to temp file and then read it back with torch.load.
        The other persistent hyperparameters are recovered by setting model's private property
        """
        # Load model parameters
        h5_model_base64 = params['h5_model_base64']
        self._image_size = params['image_size']
        self._normalize_mean = np.array(json.loads(params['normalize_mean']))
        self._normalize_std = np.array(json.loads(params['normalize_std']))
        self._num_classes = params['num_classes']
        self.label_mapper = json.loads(params['label_mapper'])

        with tempfile.NamedTemporaryFile() as tmp:
            # Convert back to bytes & write to temp file
            h5_model_bytes = base64.b64decode(h5_model_base64.encode('utf-8'))
            with open(tmp.name, 'wb') as f:
                f.write(h5_model_bytes)

            # Load model from temp file
            self._model = self._create_model(
                scratch=self._knobs.get("scratch"),
                num_classes=self._num_classes)
            if self._knobs.get("enable_mc_dropout"):
                self._model = update_model(self._model)
            if self._knobs.get("enable_model_slicing"):
                self._model = upgrade_dynamic_layers(
                    model=self._model,
                    num_groups=self._knobs.get("model_slicing_groups"),
                    sr_in_list=[0.5, 0.75, 1.0])
            if torch.cuda.is_available() == False:
                print(
                    'GPU is not available. Model parameters storages are mapped to CPU'
                )
                self._model.load_state_dict(
                    torch.load(tmp.name, map_location=torch.device('cpu')))
            else:
                print(
                    'GPU is available. Model parameters storages are mapped to GPU'
                )
                self._model.load_state_dict(torch.load(tmp.name))

        if self._knobs.get("enable_label_adaptation"):
            self._label_drift_adapter = LabelDriftAdapter(
                model=self._model, num_classes=self._num_classes)
            self._label_drift_adapter.load_parameters(
                params=params[self._label_drift_adapter.get_mod_name()])
Exemple #2
0
    def train(self,
              dataset_path: str,
              shared_params: Optional[Params] = None,
              **train_args):
        """
        Overide BaseModel.train()
        Train the model with given dataset_path

        parameters:
            dataset_path: path to dataset_path
                type: str
            **kwargs:
                optional arguments

        return:
            nothing
        """
        torch.manual_seed(self._knobs.get("seed"))
        dataset = utils.dataset.load_dataset_of_image_files(
            dataset_path,
            min_image_size=32,
            max_image_size=self._knobs.get("max_image_size"),
            mode='RGB')
        self._normalize_mean, self._normalize_std = dataset.get_stat()
        # self._normalize_mean = [0.48233507, 0.48233507, 0.48233507]
        # self._normalize_std = [0.07271624, 0.07271624, 0.07271624]

        self._num_classes = dataset.classes
        self.label_mapper = dataset.label_mapper

        # construct the model
        self._model = self._create_model(scratch=self._knobs.get("scratch"),
                                         num_classes=self._num_classes)
        if self._knobs.get("enable_mc_dropout"):
            self._model = update_model(self._model)

        if self._knobs.get("enable_model_slicing"):
            self._model = upgrade_dynamic_layers(
                model=self._model,
                num_groups=self._knobs.get("model_slicing_groups"),
                sr_in_list=[0.5, 0.75, 1.0])

        if self._knobs.get("enable_gm_prior_regularization"):
            self._gm_optimizer = GMOptimizer()
            for name, f in self._model.named_parameters():
                self._gm_optimizer.gm_register(
                    name,
                    f.data.cpu().numpy(),
                    model_name="PyVGG",
                    hyperpara_list=[
                        self._knobs.get("gm_prior_regularization_a"),
                        self._knobs.get("gm_prior_regularization_b"),
                        self._knobs.get("gm_prior_regularization_alpha"),
                    ],
                    gm_num=self._knobs.get("gm_prior_regularization_num"),
                    gm_lambda_ratio_value=self._knobs.get(
                        "gm_prior_regularization_lambda"),
                    uptfreq=[
                        self._knobs.get("gm_prior_regularization_upt_freq"),
                        self._knobs.get(
                            "gm_prior_regularization_param_upt_freq")
                    ])

        if self._knobs.get("enable_spl"):
            self._spl = SPL()

        train_dataset = TorchImageDataset(sa_dataset=dataset,
                                          image_scale_size=self._image_size,
                                          norm_mean=self._normalize_mean,
                                          norm_std=self._normalize_std,
                                          is_train=True)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=self._knobs.get("batch_size"),
                                      shuffle=True)

        #Setup Criterion
        # print("self._num_classes is :   ", self._num_classes)

        self.train_criterion = nn.MultiLabelSoftMarginLoss(
        )  ### type(torch.FloatTensor)

        #Setup Optimizer
        if self._knobs.get("optimizer") == "adam":
            optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, self._model.parameters()),
                lr=self._knobs.get("lr"),
                weight_decay=self._knobs.get("weight_decay"))
        elif self._knobs.get("optimizer") == "rmsprop":
            optimizer = optim.RMSprop(
                filter(lambda p: p.requires_grad, self._model.parameters()),
                lr=self._knobs.get("lr"),
                weight_decay=self._knobs.get("weight_decay"))
        elif self._knobs.get("optimizer") == "sgd":
            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         self._model.parameters()),
                                  lr=self._knobs.get("lr"),
                                  weight_decay=self._knobs.get("weight_decay"))
        else:
            raise NotImplementedError()

        #Setup Learning Rate Scheduler
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   patience=1,
                                                   threshold=0.001,
                                                   factor=0.1)

        self._model = self._model.to(self.device)

        self._model.train()

        if self._knobs.get("enable_model_slicing"):
            sr_scheduler = create_sr_scheduler(
                scheduler_type=self._knobs.get("model_slicing_scheduler_type"),
                sr_rand_num=self._knobs.get("model_slicing_randnum"),
                sr_list=[0.5, 0.75, 1.0],
                sr_prob=None)
        utils.logger.define_plot('Loss Over Epochs',
                                 ['loss', 'epoch_accuracy'],
                                 x_axis='epoch')
        utils.logger.log(loss=0.0, epoch_accuracy=0.0, epoch=0)
        for epoch in range(1, self._knobs.get("max_epochs") + 1):
            print("Epoch {}/{}".format(epoch, self._knobs.get("max_epochs")))
            batch_accuracy = []
            batch_losses = []
            for batch_idx, (raw_indices, traindata,
                            batch_classes) in enumerate(train_dataloader):
                print("Got batch_idx and batchdata", batch_idx)
                inputs, labels = self._transform_data(traindata,
                                                      batch_classes,
                                                      train=True)
                print("zero the optimizer")
                optimizer.zero_grad()
                if self._knobs.get("enable_model_slicing"):
                    for sr_idx in next(sr_scheduler):
                        self._model.update_sr_idx(sr_idx)
                        outputs = self._model(inputs)
                        trainloss = self.train_criterion(outputs, labels)
                        trainloss.backward()
                else:
                    # torch.Size([256, 3, 128, 128])
                    outputs = self._model(inputs)
                    trainloss = self.train_criterion(outputs, labels)
                    print("doing backward")
                    trainloss.backward()
                if self._knobs.get("enable_gm_prior_regularization"):
                    for name, f in self._model.named_parameters():
                        self._gm_optimizer.apply_GM_regularizer_constraint(
                            labelnum=dataset.classes,
                            trainnum=dataset.size,
                            epoch=epoch,
                            weight_decay=self._knobs.get("weight_decay"),
                            f=f,
                            name=name,
                            step=batch_idx)

                if self._knobs.get("enable_spl"):
                    train_dataset.update_sample_score(
                        raw_indices,
                        trainloss.detach().cpu().numpy())
                optimizer.step()
                print("Epoch: {:d} Batch: {:d} Train Loss: {:.6f}".format(
                    epoch, batch_idx, trainloss.item()))
                sys.stdout.flush()

                transfered_labels = torch.max(labels.data, 1)
                transfered_outpus = torch.max(torch.sigmoid(outputs), 1)
                batch_accuracy.append(transfered_labels[1].eq(
                    transfered_outpus[1]).sum().item() /
                                      transfered_labels[1].size(0))
                batch_losses.append(trainloss.item())
            train_loss = np.mean(batch_losses)
            batch_accuracy_mean = np.mean(batch_accuracy)
            utils.logger.log(loss=train_loss,
                             epoch_accuracy=batch_accuracy_mean,
                             epoch=epoch)
            print("Training Loss: {:.6f}".format(train_loss))
            if self._knobs.get("enable_spl"):
                train_dataset.update_score_threshold(
                    threshold=self._spl.calculate_threshold_by_epoch(
                        epoch=epoch,
                        threshold_init=self._knobs.get("spl_threshold_init"),
                        mu=self._knobs.get("spl_mu")))
    def train(self,
              dataset_path: str,
              shared_params: Optional[Params] = None,
              **train_args):
        """
        Overide BaseModel.train()
        Train the model with given dataset_path

        parameters:
            dataset_path: path to dataset_path
                type: str
            **kwargs:
                optional arguments

        return:
            nothing
        """
        dataset = utils.dataset.load_dataset_of_image_files(
            dataset_path,
            min_image_size=32,
            max_image_size=self._knobs.get("max_image_size"),
            mode='RGB',
            lazy_load=True)
        self._normalize_mean, self._normalize_std = dataset.get_stat()
        # self._normalize_mean = [0.48233507, 0.48233507, 0.48233507]
        # self._normalize_std = [0.07271624, 0.07271624, 0.07271624]

        self._num_classes = dataset.classes
        print('num_class', dataset.classes)

        # construct the model
        self._model = self._create_model(scratch=self._knobs.get("scratch"),
                                         num_classes=self._num_classes)

        if self._knobs.get("enable_model_slicing"):
            self._model = upgrade_dynamic_layers(
                model=self._model,
                num_groups=self._knobs.get("model_slicing_groups"),
                sr_in_list=[0.5, 0.75, 1.0])

        if self._knobs.get("enable_gm_prior_regularization"):
            self._gm_optimizer = GMOptimizer()
            for name, f in self._model.named_parameters():
                self._gm_optimizer.gm_register(
                    name,
                    f.data.cpu().numpy(),
                    model_name="PyVGG",
                    hyperpara_list=[
                        self._knobs.get("gm_prior_regularization_a"),
                        self._knobs.get("gm_prior_regularization_b"),
                        self._knobs.get("gm_prior_regularization_alpha"),
                    ],
                    gm_num=self._knobs.get("gm_prior_regularization_num"),
                    gm_lambda_ratio_value=self._knobs.get(
                        "gm_prior_regularization_lambda"),
                    uptfreq=[
                        self._knobs.get("gm_prior_regularization_upt_freq"),
                        self._knobs.get(
                            "gm_prior_regularization_param_upt_freq")
                    ])

        if self._knobs.get("enable_spl"):
            self._spl = SPL()

        train_dataset = TorchImageDataset(sa_dataset=dataset,
                                          image_scale_size=128,
                                          norm_mean=self._normalize_mean,
                                          norm_std=self._normalize_std,
                                          is_train=True)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=self._knobs.get("batch_size"),
                                      shuffle=True)

        #Setup Criterion
        if self._num_classes == 2:
            self.train_criterion = nn.CrossEntropyLoss(
            )  # type(torch.LongTensor)
            # add selectionhead loss
            self.selectionhead_criterion = nn.CrossEntropyLoss()
        else:

            self.train_criterion = nn.MultiLabelSoftMarginLoss(
            )  # type(torch.FloatTensor)
            # add selectionhead loss
            self.selectionhead_criterion = nn.MultiLabelSoftMarginLoss()

        #Setup Optimizer
        if self._knobs.get("optimizer") == "adam":
            optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                          self._model.parameters()),
                                   lr=self._knobs.get("lr"),
                                   weight_decay=self._knobs.get("weight_decay"))
        elif self._knobs.get("optimizer") == "rmsprop":
            optimizer = optim.RMSprop(
                filter(lambda p: p.requires_grad, self._model.parameters()),
                lr=self._knobs.get("lr"),
                weight_decay=self._knobs.get("weight_decay"))
        else:
            raise NotImplementedError()

        #Setup Learning Rate Scheduler
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   patience=1,
                                                   threshold=0.001,
                                                   factor=0.1)

        if self._use_gpu:
            self._model = self._model.cuda()

        self._model.train()

        if self._knobs.get("enable_model_slicing"):
            sr_scheduler = create_sr_scheduler(
                scheduler_type=self._knobs.get("model_slicing_scheduler_type"),
                sr_rand_num=self._knobs.get("model_slicing_randnum"),
                sr_list=[0.5, 0.75, 1.0],
                sr_prob=None)

        # SelectiveNet params
        lamda = self._knobs.get("lamda")
        selectionheadloss_weight = self._knobs.get("selectionheadloss_weight")
        target_coverage = self._knobs.get("target_coverage")

        for epoch in range(1, self._knobs.get("max_epochs") + 1):
            print("Epoch {}/{}".format(epoch, self._knobs.get("max_epochs")))
            batch_losses = []
            for batch_idx, (raw_indices, traindata,
                            batch_classes) in enumerate(train_dataloader):
                inputs, labels = self._transform_data(traindata,
                                                      batch_classes,
                                                      train=True)
                optimizer.zero_grad()

                if self._knobs.get("enable_model_slicing"):
                    for sr_idx in next(sr_scheduler):
                        self._model.update_sr_idx(sr_idx)
                        # add selection head outputs, selectionhead be a column
                        (outputs, selectionhead) = self._model(inputs)
                        predloss = self.train_criterion(outputs, labels)
                        # apply the Interoir Point Method on labels # same as selectionhead.view(-1, 1).repeat(1,self._num_classes).view(selectionhead.shape[0],-1) * labels
                        interior_point_of_labels = selectionhead * labels
                        auxiliaryhead = outputs
                        empirical_coverage = selectionhead.type(
                            torch.float64).mean()
                        selectionheadloss = self.selectionhead_criterion(
                            interior_point_of_labels, auxiliaryhead) + lamda * (
                                target_coverage -
                                empirical_coverage).clamp(min=0)**2
                        selectionheadloss = torch.tensor(
                            selectionheadloss, dtype=torch.float).cuda()
                        trainloss = selectionheadloss * selectionheadloss_weight + predloss * (
                            1 - selectionheadloss_weight)
                        trainloss.backward()
                else:
                    # add selection head outputs, selectionhead be a column
                    (outputs, selectionhead) = self._model(inputs)
                    predloss = self.train_criterion(outputs, labels)
                    # apply the Interoir Point Method on labels # same as selectionhead.view(-1, 1).repeat(1,self._num_classes).view(selectionhead.shape[0],-1) * labels
                    interior_point_of_labels = selectionhead * labels
                    auxiliaryhead = outputs
                    empirical_coverage = selectionhead.type(
                        torch.float64).mean()
                    selectionheadloss = self.selectionhead_criterion(
                        interior_point_of_labels, auxiliaryhead) + lamda * (
                            target_coverage -
                            empirical_coverage).clamp(min=0)**2
                    selectionheadloss = torch.tensor(selectionheadloss,
                                                     dtype=torch.float).cuda()
                    trainloss = selectionheadloss * selectionheadloss_weight + predloss * (
                        1 - selectionheadloss_weight)
                    trainloss.backward()

                if self._knobs.get("enable_gm_prior_regularization"):
                    for name, f in self._model.named_parameters():
                        self._gm_optimizer.apply_GM_regularizer_constraint(
                            labelnum=1,
                            trainnum=0,
                            epoch=epoch,
                            weight_decay=self._knobs.get("weight_decay"),
                            f=f,
                            name=name,
                            step=batch_idx)

                if self._knobs.get("enable_spl"):
                    train_dataset.update_sample_score(
                        raw_indices,
                        trainloss.detach().cpu().numpy())

                optimizer.step()
                print("Epoch: {:d} Batch: {:d} Train Loss: {:.6f}".format(
                    epoch, batch_idx, trainloss.item()))
                sys.stdout.flush()
                batch_losses.append(trainloss.item())

            train_loss = np.mean(batch_losses)
            print("Training Loss: {:.6f}".format(train_loss))
            if self._knobs.get("enable_spl"):
                train_dataset.update_score_threshold(
                    threshold=self._spl.calculate_threshold_by_epoch(
                        epoch=epoch,
                        threshold_init=self._knobs.get("spl_threshold_init"),
                        mu=self._knobs.get("spl_mu")))