def build(self): # custom model if self.cf.pretrained_model.lower() == 'custom' and not self.cf.load_weight_only: self.net = self.restore_model() return self.net # segmentation networks if self.cf.model_type.lower() == 'densenetfcn': self.net = FCDenseNet(self.cf, nb_layers_per_block=self.cf.model_layers, growth_rate=self.cf.model_growth, nb_dense_block=self.cf.model_blocks, n_channel_start=48, n_classes=self.cf.num_classes, drop_rate=0, bottle_neck=False).cuda() elif self.cf.model_type.lower() == 'fcn8': self.net = FCN8(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() elif self.cf.model_type.lower() == 'fcn8atonce': self.net = FCN8AtOnce(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() elif self.cf.model_type.lower() == 'deeplabv3plus': self.net = DeepLabv3_plus(self.cf, n_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() elif self.cf.model_type.lower() == 'deeplabv3xception': self.net = DeepLabv3_xception(self.cf, n_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() elif self.cf.model_type.lower() == 'deeplabv2': self.net = MS_Deeplab(self.cf, n_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() # object detection networks # elif self.cf.model_type.lower() == 'rpn': # self.net = RPN(self.cf, 512) elif self.cf.model_type.lower() == 'ssd320': self.net = SSD300(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() self.box_coder = SSDBoxCoder(self.net) elif self.cf.model_type.lower() == 'ssd512': self.net = SSD512(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() self.box_coder = SSDBoxCoder(self.net) # classification networks elif self.cf.model_type.lower() == 'vgg16': self.net = VGG16(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() else: raise ValueError('Unknown model') # print(self.cf.resume_experiment) # print((self.cf.pretrained_model.lower() == 'custom' and self.cf.load_weight_only)) if self.cf.resume_experiment or (self.cf.pretrained_model.lower() == 'custom' and self.cf.load_weight_only): self.net.restore_weights(os.path.join(self.cf.input_model_path)) if self.cf.resume_experiment: self.load_statistics() elif self.net.pretrained: self.net.load_basic_weights() else: self.net.initialize_weights() # Loss definition if self.loss is None: self.loss = Loss_Builder(self.cf).build().cuda() # Optimizer definition self.optimizer = Optimizer_builder().build(self.cf, self.net) # Learning rate scheduler self.scheduler = scheduler_builder().build(self.cf, self.optimizer)
class Model_builder(): def __init__(self, cf): self.cf = cf self.net = None self.loss = None self.optimizer = None self.scheduler = None self.best_stats = Statistics() def build(self): # custom model if self.cf.pretrained_model.lower() == 'custom' and not self.cf.load_weight_only: self.net = self.restore_model() return self.net # segmentation networks if self.cf.model_type.lower() == 'densenetfcn': self.net = FCDenseNet(self.cf, nb_layers_per_block=self.cf.model_layers, growth_rate=self.cf.model_growth, nb_dense_block=self.cf.model_blocks, n_channel_start=48, n_classes=self.cf.num_classes, drop_rate=0, bottle_neck=False).cuda() elif self.cf.model_type.lower() == 'fcn8': self.net = FCN8(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() elif self.cf.model_type.lower() == 'fcn8atonce': self.net = FCN8AtOnce(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() elif self.cf.model_type.lower() == 'deeplabv3plus': self.net = DeepLabv3_plus(self.cf, n_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() elif self.cf.model_type.lower() == 'deeplabv3xception': self.net = DeepLabv3_xception(self.cf, n_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() elif self.cf.model_type.lower() == 'deeplabv2': self.net = MS_Deeplab(self.cf, n_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() # object detection networks # elif self.cf.model_type.lower() == 'rpn': # self.net = RPN(self.cf, 512) elif self.cf.model_type.lower() == 'ssd320': self.net = SSD300(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() self.box_coder = SSDBoxCoder(self.net) elif self.cf.model_type.lower() == 'ssd512': self.net = SSD512(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() self.box_coder = SSDBoxCoder(self.net) # classification networks elif self.cf.model_type.lower() == 'vgg16': self.net = VGG16(self.cf, num_classes=self.cf.num_classes, pretrained=self.cf.basic_pretrained_model).cuda() else: raise ValueError('Unknown model') # print(self.cf.resume_experiment) # print((self.cf.pretrained_model.lower() == 'custom' and self.cf.load_weight_only)) if self.cf.resume_experiment or (self.cf.pretrained_model.lower() == 'custom' and self.cf.load_weight_only): self.net.restore_weights(os.path.join(self.cf.input_model_path)) if self.cf.resume_experiment: self.load_statistics() elif self.net.pretrained: self.net.load_basic_weights() else: self.net.initialize_weights() # Loss definition if self.loss is None: self.loss = Loss_Builder(self.cf).build().cuda() # Optimizer definition self.optimizer = Optimizer_builder().build(self.cf, self.net) # Learning rate scheduler self.scheduler = scheduler_builder().build(self.cf, self.optimizer) def save_model(self): if self.cf.save_weight_only: torch.save(self.net.state_dict(), os.path.join(self.cf.output_model_path, self.cf.model_name + '.pth')) else: torch.save(self, os.path.join(self.cf.exp_folder, self.cf.model_name + '.pth')) def save(self, stats): if self.cf.save_condition == 'always': save = True else: save = self.check_stat(stats) if save: self.save_model() self.best_stats = copy.deepcopy(stats) return save def check_stat(self, stats): check = False if self.cf.save_condition.lower() == 'train_loss': if stats.train.loss < self.best_stats.train.loss: check = True elif self.cf.save_condition.lower() == 'valid_loss': if stats.val.loss < self.best_stats.val.loss: check = True elif self.cf.save_condition.lower() == 'valid_miou': if stats.val.mIoU > self.best_stats.val.mIoU: check = True elif self.cf.save_condition.lower() == 'valid_macc': if stats.val.acc > self.best_stats.val.acc: check = True elif self.cf.save_condition.lower() == 'precision': if stats.val.precision > self.best_stats.val.precision: check = True elif self.cf.save_condition.lower() == 'recall': if stats.val.recall > self.best_stats.val.recall: check = True elif self.cf.save_condition.lower() == 'f1_score': if stats.val.f1score > self.best_stats.val.f1score: check = True return check def restore_model(self): print('\t Restoring weight from ' + self.cf.input_model_path + self.cf.model_name) net = torch.load(os.path.join(self.cf.input_model_path, self.cf.model_name + '.pth')) return net def load_statistics(self): if os.path.exists(self.cf.best_json_file): with open(self.cf.best_json_file) as json_file: json_data = json.load(json_file) self.best_stats.epoch = json_data[0]['epoch'] self.best_stats.train = self.fill_statistics(json_data[0],self.best_stats.train) self.best_stats.val = self.fill_statistics(json_data[1], self.best_stats.val) def fill_statistics(self, dict_stats, stats): stats.loss = dict_stats['loss'] stats.mIoU = dict_stats['mIoU'] stats.acc = dict_stats['acc'] stats.precision = dict_stats['precision'] stats.recall = dict_stats['recall'] stats.f1score = dict_stats['f1score'] stats.conf_m = dict_stats['conf_m'] stats.mIoU_perclass = dict_stats['mIoU_perclass'] stats.acc_perclass = dict_stats['acc_perclass'] stats.precision_perclass = dict_stats['precision_perclass'] stats.recall_perclass = dict_stats['recall_perclass'] stats.f1score_perclass = dict_stats['f1score_perclass'] return stats