def __init__(self, config, train, val): validate_config(config) self.config = config self.model = get_net(config["model"]) self.train_dataset = train self.val_dataset = val self.warmup_epochs = config.get("warmup_num", 0) # build experiment directory name self.experiment_base_dir = "_".join( map( str, [ "experiment", self.config["experiment"]["folder"], self.config["experiment"]["name"], self.config["train"]["transform"]["size"], self.config["model"]["arch"], self.config["num_epochs"], self.config["batch_size"], str(int(datetime.timestamp(datetime.now()))), ], )) # create experiment directory os.makedirs(self.experiment_base_dir, exist_ok=True) self.metric_counter = get_metric_counter(config, self.experiment_base_dir) self.steps_per_epoch = config.get("steps_per_epoch", len(self.train_dataset)) self.validation_steps = config.get("validation_steps", len(self.val_dataset)) self.sample_input_shape = config.get("sample_input_shape", (1, 3, 32, 32)) self.criterion = get_loss(self.config["model"]["loss"]) self.optimizer = get_optimizer( config=config, params=filter(lambda p: p.requires_grad, self.model.parameters()), ) self.scheduler = get_scheduler(config, self.optimizer) self.early_stopping = EarlyStopping( patience=self.config["early_stopping"]) self.model_adapter = get_model_adapter(self.config["model"]) self.monitor = TrainingMonitor(config["training_monitor"]["method"], config["training_monitor"]["interval"]) # setup logging to log file logger.addHandler( logging.FileHandler(self.experiment_base_dir + "/experiment.log")) # copy config file to experiment directory with open(osp.join(self.experiment_base_dir, "config.yaml"), "w") as outfile: yaml.dump(self.config, outfile, default_flow_style=False)
def test_dice_loss(tensor1, tensor2, tensor3, tensor4): dice_loss = loss.get_loss({"name": "dice_loss"}) assert dice_loss.forward(tensor1, tensor1).data.numpy() == pytest.approx( 0.71844566, standard_approx) assert dice_loss.forward(tensor2, tensor2).data.numpy() == pytest.approx( 0.71844566, standard_approx) assert dice_loss.forward(tensor1, tensor2).data.numpy() == pytest.approx( 0.54923755, standard_approx) assert dice_loss.forward(tensor3, tensor4).data.numpy() == pytest.approx( 0.57556236, standard_approx)
def test_mixed_loss(tensor1, tensor2, tensor3, tensor4): mixed_loss = loss.get_loss({"name": "mixed_loss"}) assert mixed_loss.forward(tensor1, tensor1).data.numpy() == pytest.approx( 1.3103894, standard_approx) assert mixed_loss.forward(tensor2, tensor2).data.numpy() == pytest.approx( 1.3103894, standard_approx) assert mixed_loss.forward(tensor1, tensor2).data.numpy() == pytest.approx( 4.975, standard_approx) assert mixed_loss.forward(tensor3, tensor4).data.numpy() == pytest.approx( 3.230158, standard_approx)
def test_iou_loss(tensor1, tensor2, tensor3, tensor4): iou_loss = loss.get_loss({"name": "iou_loss"}) assert iou_loss.iou_metric(tensor1, tensor1) == 1.0 assert iou_loss.iou_metric(tensor2, tensor2) == 1.0 # almost zero iou_result = iou_loss.iou_metric(tensor1, tensor2) assert abs(iou_result) < standard_approx assert iou_loss.iou_metric(tensor3, tensor4).data.numpy() == pytest.approx( 0.3333333, standard_approx)
def _init_params(self): """ Initializes Trainer Initialized attributes: - criterion: loss to be used during training - optimizer: model optimizer - scheduler: scheduler for optimizer - early_stopping: Early Stopping technique, which stops training if validation loss doesn't improve after a given patience - model_adapter: adapter for a given model """ self.criterion = get_loss(self.config['model']['loss']) self.optimizer = self._get_optim( filter(lambda p: p.requires_grad, self.model.parameters())) self.scheduler = self._get_scheduler(self.optimizer) self.early_stopping = EarlyStopping( patience=self.config['early_stopping']) self.model_adapter = get_model_adapter(self.config) os.makedirs(osp.join(self.config['experiment']['folder'], self.config['experiment']['name']), exist_ok=True)
def test_getting_incorrect_loss(): with pytest.raises(ValueError): assert loss.get_loss( {"name": "this_is_some_incorrect_name_for_loss_function"})
def test_getting_correct_loss(): assert loss.get_loss({"name": "lovasz"})