Esempio n. 1
0
    def load_parameters(self, params):
        """
        Override BaseModel.load_parameters

        Write base64 encoded PyTorch model state dict to temp file and then read it back with torch.load.
        The other persistent hyperparameters are recovered by setting model's private property
        """
        # Load model parameters
        h5_model_base64 = params['h5_model_base64']
        self._image_size = params['image_size']
        self._normalize_mean = np.array(json.loads(params['normalize_mean']))
        self._normalize_std = np.array(json.loads(params['normalize_std']))
        self._num_classes = params['num_classes']
        self.label_mapper = json.loads(params['label_mapper'])

        with tempfile.NamedTemporaryFile() as tmp:
            # Convert back to bytes & write to temp file
            h5_model_bytes = base64.b64decode(h5_model_base64.encode('utf-8'))
            with open(tmp.name, 'wb') as f:
                f.write(h5_model_bytes)

            # Load model from temp file
            self._model = self._create_model(
                scratch=self._knobs.get("scratch"),
                num_classes=self._num_classes)
            if self._knobs.get("enable_mc_dropout"):
                self._model = update_model(self._model)
            if self._knobs.get("enable_model_slicing"):
                self._model = upgrade_dynamic_layers(
                    model=self._model,
                    num_groups=self._knobs.get("model_slicing_groups"),
                    sr_in_list=[0.5, 0.75, 1.0])
            if torch.cuda.is_available() == False:
                print(
                    'GPU is not available. Model parameters storages are mapped to CPU'
                )
                self._model.load_state_dict(
                    torch.load(tmp.name, map_location=torch.device('cpu')))
            else:
                print(
                    'GPU is available. Model parameters storages are mapped to GPU'
                )
                self._model.load_state_dict(torch.load(tmp.name))

        if self._knobs.get("enable_label_adaptation"):
            self._label_drift_adapter = LabelDriftAdapter(
                model=self._model, num_classes=self._num_classes)
            self._label_drift_adapter.load_parameters(
                params=params[self._label_drift_adapter.get_mod_name()])
Esempio n. 2
0
    def train(self,
              dataset_path: str,
              shared_params: Optional[Params] = None,
              **train_args):
        """
        Overide BaseModel.train()
        Train the model with given dataset_path

        parameters:
            dataset_path: path to dataset_path
                type: str
            **kwargs:
                optional arguments

        return:
            nothing
        """
        torch.manual_seed(self._knobs.get("seed"))
        dataset = utils.dataset.load_dataset_of_image_files(
            dataset_path,
            min_image_size=32,
            max_image_size=self._knobs.get("max_image_size"),
            mode='RGB')
        self._normalize_mean, self._normalize_std = dataset.get_stat()
        # self._normalize_mean = [0.48233507, 0.48233507, 0.48233507]
        # self._normalize_std = [0.07271624, 0.07271624, 0.07271624]

        self._num_classes = dataset.classes
        self.label_mapper = dataset.label_mapper

        # construct the model
        self._model = self._create_model(scratch=self._knobs.get("scratch"),
                                         num_classes=self._num_classes)
        if self._knobs.get("enable_mc_dropout"):
            self._model = update_model(self._model)

        if self._knobs.get("enable_model_slicing"):
            self._model = upgrade_dynamic_layers(
                model=self._model,
                num_groups=self._knobs.get("model_slicing_groups"),
                sr_in_list=[0.5, 0.75, 1.0])

        if self._knobs.get("enable_gm_prior_regularization"):
            self._gm_optimizer = GMOptimizer()
            for name, f in self._model.named_parameters():
                self._gm_optimizer.gm_register(
                    name,
                    f.data.cpu().numpy(),
                    model_name="PyVGG",
                    hyperpara_list=[
                        self._knobs.get("gm_prior_regularization_a"),
                        self._knobs.get("gm_prior_regularization_b"),
                        self._knobs.get("gm_prior_regularization_alpha"),
                    ],
                    gm_num=self._knobs.get("gm_prior_regularization_num"),
                    gm_lambda_ratio_value=self._knobs.get(
                        "gm_prior_regularization_lambda"),
                    uptfreq=[
                        self._knobs.get("gm_prior_regularization_upt_freq"),
                        self._knobs.get(
                            "gm_prior_regularization_param_upt_freq")
                    ])

        if self._knobs.get("enable_spl"):
            self._spl = SPL()

        train_dataset = TorchImageDataset(sa_dataset=dataset,
                                          image_scale_size=self._image_size,
                                          norm_mean=self._normalize_mean,
                                          norm_std=self._normalize_std,
                                          is_train=True)
        train_dataloader = DataLoader(train_dataset,
                                      batch_size=self._knobs.get("batch_size"),
                                      shuffle=True)

        #Setup Criterion
        # print("self._num_classes is :   ", self._num_classes)

        self.train_criterion = nn.MultiLabelSoftMarginLoss(
        )  ### type(torch.FloatTensor)

        #Setup Optimizer
        if self._knobs.get("optimizer") == "adam":
            optimizer = optim.Adam(
                filter(lambda p: p.requires_grad, self._model.parameters()),
                lr=self._knobs.get("lr"),
                weight_decay=self._knobs.get("weight_decay"))
        elif self._knobs.get("optimizer") == "rmsprop":
            optimizer = optim.RMSprop(
                filter(lambda p: p.requires_grad, self._model.parameters()),
                lr=self._knobs.get("lr"),
                weight_decay=self._knobs.get("weight_decay"))
        elif self._knobs.get("optimizer") == "sgd":
            optimizer = optim.SGD(filter(lambda p: p.requires_grad,
                                         self._model.parameters()),
                                  lr=self._knobs.get("lr"),
                                  weight_decay=self._knobs.get("weight_decay"))
        else:
            raise NotImplementedError()

        #Setup Learning Rate Scheduler
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   patience=1,
                                                   threshold=0.001,
                                                   factor=0.1)

        self._model = self._model.to(self.device)

        self._model.train()

        if self._knobs.get("enable_model_slicing"):
            sr_scheduler = create_sr_scheduler(
                scheduler_type=self._knobs.get("model_slicing_scheduler_type"),
                sr_rand_num=self._knobs.get("model_slicing_randnum"),
                sr_list=[0.5, 0.75, 1.0],
                sr_prob=None)
        utils.logger.define_plot('Loss Over Epochs',
                                 ['loss', 'epoch_accuracy'],
                                 x_axis='epoch')
        utils.logger.log(loss=0.0, epoch_accuracy=0.0, epoch=0)
        for epoch in range(1, self._knobs.get("max_epochs") + 1):
            print("Epoch {}/{}".format(epoch, self._knobs.get("max_epochs")))
            batch_accuracy = []
            batch_losses = []
            for batch_idx, (raw_indices, traindata,
                            batch_classes) in enumerate(train_dataloader):
                print("Got batch_idx and batchdata", batch_idx)
                inputs, labels = self._transform_data(traindata,
                                                      batch_classes,
                                                      train=True)
                print("zero the optimizer")
                optimizer.zero_grad()
                if self._knobs.get("enable_model_slicing"):
                    for sr_idx in next(sr_scheduler):
                        self._model.update_sr_idx(sr_idx)
                        outputs = self._model(inputs)
                        trainloss = self.train_criterion(outputs, labels)
                        trainloss.backward()
                else:
                    # torch.Size([256, 3, 128, 128])
                    outputs = self._model(inputs)
                    trainloss = self.train_criterion(outputs, labels)
                    print("doing backward")
                    trainloss.backward()
                if self._knobs.get("enable_gm_prior_regularization"):
                    for name, f in self._model.named_parameters():
                        self._gm_optimizer.apply_GM_regularizer_constraint(
                            labelnum=dataset.classes,
                            trainnum=dataset.size,
                            epoch=epoch,
                            weight_decay=self._knobs.get("weight_decay"),
                            f=f,
                            name=name,
                            step=batch_idx)

                if self._knobs.get("enable_spl"):
                    train_dataset.update_sample_score(
                        raw_indices,
                        trainloss.detach().cpu().numpy())
                optimizer.step()
                print("Epoch: {:d} Batch: {:d} Train Loss: {:.6f}".format(
                    epoch, batch_idx, trainloss.item()))
                sys.stdout.flush()

                transfered_labels = torch.max(labels.data, 1)
                transfered_outpus = torch.max(torch.sigmoid(outputs), 1)
                batch_accuracy.append(transfered_labels[1].eq(
                    transfered_outpus[1]).sum().item() /
                                      transfered_labels[1].size(0))
                batch_losses.append(trainloss.item())
            train_loss = np.mean(batch_losses)
            batch_accuracy_mean = np.mean(batch_accuracy)
            utils.logger.log(loss=train_loss,
                             epoch_accuracy=batch_accuracy_mean,
                             epoch=epoch)
            print("Training Loss: {:.6f}".format(train_loss))
            if self._knobs.get("enable_spl"):
                train_dataset.update_score_threshold(
                    threshold=self._spl.calculate_threshold_by_epoch(
                        epoch=epoch,
                        threshold_init=self._knobs.get("spl_threshold_init"),
                        mu=self._knobs.get("spl_mu")))