def load_parameters(self, params): """ Override BaseModel.load_parameters Write base64 encoded PyTorch model state dict to temp file and then read it back with torch.load. The other persistent hyperparameters are recovered by setting model's private property """ # Load model parameters h5_model_base64 = params['h5_model_base64'] self._image_size = params['image_size'] self._normalize_mean = np.array(json.loads(params['normalize_mean'])) self._normalize_std = np.array(json.loads(params['normalize_std'])) self._num_classes = params['num_classes'] self.label_mapper = json.loads(params['label_mapper']) with tempfile.NamedTemporaryFile() as tmp: # Convert back to bytes & write to temp file h5_model_bytes = base64.b64decode(h5_model_base64.encode('utf-8')) with open(tmp.name, 'wb') as f: f.write(h5_model_bytes) # Load model from temp file self._model = self._create_model( scratch=self._knobs.get("scratch"), num_classes=self._num_classes) if self._knobs.get("enable_mc_dropout"): self._model = update_model(self._model) if self._knobs.get("enable_model_slicing"): self._model = upgrade_dynamic_layers( model=self._model, num_groups=self._knobs.get("model_slicing_groups"), sr_in_list=[0.5, 0.75, 1.0]) if torch.cuda.is_available() == False: print( 'GPU is not available. Model parameters storages are mapped to CPU' ) self._model.load_state_dict( torch.load(tmp.name, map_location=torch.device('cpu'))) else: print( 'GPU is available. Model parameters storages are mapped to GPU' ) self._model.load_state_dict(torch.load(tmp.name)) if self._knobs.get("enable_label_adaptation"): self._label_drift_adapter = LabelDriftAdapter( model=self._model, num_classes=self._num_classes) self._label_drift_adapter.load_parameters( params=params[self._label_drift_adapter.get_mod_name()])
def train(self, dataset_path: str, shared_params: Optional[Params] = None, **train_args): """ Overide BaseModel.train() Train the model with given dataset_path parameters: dataset_path: path to dataset_path type: str **kwargs: optional arguments return: nothing """ torch.manual_seed(self._knobs.get("seed")) dataset = utils.dataset.load_dataset_of_image_files( dataset_path, min_image_size=32, max_image_size=self._knobs.get("max_image_size"), mode='RGB') self._normalize_mean, self._normalize_std = dataset.get_stat() # self._normalize_mean = [0.48233507, 0.48233507, 0.48233507] # self._normalize_std = [0.07271624, 0.07271624, 0.07271624] self._num_classes = dataset.classes self.label_mapper = dataset.label_mapper # construct the model self._model = self._create_model(scratch=self._knobs.get("scratch"), num_classes=self._num_classes) if self._knobs.get("enable_mc_dropout"): self._model = update_model(self._model) if self._knobs.get("enable_model_slicing"): self._model = upgrade_dynamic_layers( model=self._model, num_groups=self._knobs.get("model_slicing_groups"), sr_in_list=[0.5, 0.75, 1.0]) if self._knobs.get("enable_gm_prior_regularization"): self._gm_optimizer = GMOptimizer() for name, f in self._model.named_parameters(): self._gm_optimizer.gm_register( name, f.data.cpu().numpy(), model_name="PyVGG", hyperpara_list=[ self._knobs.get("gm_prior_regularization_a"), self._knobs.get("gm_prior_regularization_b"), self._knobs.get("gm_prior_regularization_alpha"), ], gm_num=self._knobs.get("gm_prior_regularization_num"), gm_lambda_ratio_value=self._knobs.get( "gm_prior_regularization_lambda"), uptfreq=[ self._knobs.get("gm_prior_regularization_upt_freq"), self._knobs.get( "gm_prior_regularization_param_upt_freq") ]) if self._knobs.get("enable_spl"): self._spl = SPL() train_dataset = TorchImageDataset(sa_dataset=dataset, image_scale_size=self._image_size, norm_mean=self._normalize_mean, norm_std=self._normalize_std, is_train=True) train_dataloader = DataLoader(train_dataset, batch_size=self._knobs.get("batch_size"), shuffle=True) #Setup Criterion # print("self._num_classes is : ", self._num_classes) self.train_criterion = nn.MultiLabelSoftMarginLoss( ) ### type(torch.FloatTensor) #Setup Optimizer if self._knobs.get("optimizer") == "adam": optimizer = optim.Adam( filter(lambda p: p.requires_grad, self._model.parameters()), lr=self._knobs.get("lr"), weight_decay=self._knobs.get("weight_decay")) elif self._knobs.get("optimizer") == "rmsprop": optimizer = optim.RMSprop( filter(lambda p: p.requires_grad, self._model.parameters()), lr=self._knobs.get("lr"), weight_decay=self._knobs.get("weight_decay")) elif self._knobs.get("optimizer") == "sgd": optimizer = optim.SGD(filter(lambda p: p.requires_grad, self._model.parameters()), lr=self._knobs.get("lr"), weight_decay=self._knobs.get("weight_decay")) else: raise NotImplementedError() #Setup Learning Rate Scheduler scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience=1, threshold=0.001, factor=0.1) self._model = self._model.to(self.device) self._model.train() if self._knobs.get("enable_model_slicing"): sr_scheduler = create_sr_scheduler( scheduler_type=self._knobs.get("model_slicing_scheduler_type"), sr_rand_num=self._knobs.get("model_slicing_randnum"), sr_list=[0.5, 0.75, 1.0], sr_prob=None) utils.logger.define_plot('Loss Over Epochs', ['loss', 'epoch_accuracy'], x_axis='epoch') utils.logger.log(loss=0.0, epoch_accuracy=0.0, epoch=0) for epoch in range(1, self._knobs.get("max_epochs") + 1): print("Epoch {}/{}".format(epoch, self._knobs.get("max_epochs"))) batch_accuracy = [] batch_losses = [] for batch_idx, (raw_indices, traindata, batch_classes) in enumerate(train_dataloader): print("Got batch_idx and batchdata", batch_idx) inputs, labels = self._transform_data(traindata, batch_classes, train=True) print("zero the optimizer") optimizer.zero_grad() if self._knobs.get("enable_model_slicing"): for sr_idx in next(sr_scheduler): self._model.update_sr_idx(sr_idx) outputs = self._model(inputs) trainloss = self.train_criterion(outputs, labels) trainloss.backward() else: # torch.Size([256, 3, 128, 128]) outputs = self._model(inputs) trainloss = self.train_criterion(outputs, labels) print("doing backward") trainloss.backward() if self._knobs.get("enable_gm_prior_regularization"): for name, f in self._model.named_parameters(): self._gm_optimizer.apply_GM_regularizer_constraint( labelnum=dataset.classes, trainnum=dataset.size, epoch=epoch, weight_decay=self._knobs.get("weight_decay"), f=f, name=name, step=batch_idx) if self._knobs.get("enable_spl"): train_dataset.update_sample_score( raw_indices, trainloss.detach().cpu().numpy()) optimizer.step() print("Epoch: {:d} Batch: {:d} Train Loss: {:.6f}".format( epoch, batch_idx, trainloss.item())) sys.stdout.flush() transfered_labels = torch.max(labels.data, 1) transfered_outpus = torch.max(torch.sigmoid(outputs), 1) batch_accuracy.append(transfered_labels[1].eq( transfered_outpus[1]).sum().item() / transfered_labels[1].size(0)) batch_losses.append(trainloss.item()) train_loss = np.mean(batch_losses) batch_accuracy_mean = np.mean(batch_accuracy) utils.logger.log(loss=train_loss, epoch_accuracy=batch_accuracy_mean, epoch=epoch) print("Training Loss: {:.6f}".format(train_loss)) if self._knobs.get("enable_spl"): train_dataset.update_score_threshold( threshold=self._spl.calculate_threshold_by_epoch( epoch=epoch, threshold_init=self._knobs.get("spl_threshold_init"), mu=self._knobs.get("spl_mu")))