Esempio n. 1
0
 def save_model_structure(self, model):
     model = model.module if isinstance(model, nn.DataParallel) else model
     model_name = cls_name_trans(model.__class__.__name__)
     checkpoint = self.controller.checkpoint
     model_struct_file = checkpoint.work_dir.joinpath(f'{model_name}_model_structure')
     with model_struct_file.open('w') as fd:
         fd.write(str(model))
Esempio n. 2
0
 def load_cls(cls, sub_cls, sub_cls_name):
     if sub_cls_name is None:
         return None
     if any([char.isupper() for char in sub_cls_name]):
         sub_cls_name = cls_name_trans(sub_cls_name)
     if sub_cls.root_cls() is None:
         return cls.cls_map['PtBase'].get(sub_cls_name, None)
     return cls.cls_map[sub_cls.root_cls().__name__].get(sub_cls_name, None)
Esempio n. 3
0
 def default_build_optimizers(self, model):
     params = filter(lambda p: p.requires_grad, model.parameters())
     optim_kwargs = copy.deepcopy(self.optim_kwargs)
     optim_kwargs.update({'params': params})
     optimizer = self.optim_cls(**optim_kwargs)
     optimizers = [{
         'name': cls_name_trans(model.__class__.__name__),
         'optimizer': optimizer,
         'valid_fn': None
     }]
     return {'optimizers': optimizers, 'schedulers': None}
Esempio n. 4
0
 def save_checkpoint(self, epoch_idx, model: nn.Module, val_acc=None):
     if not self.need_save(epoch_idx, val_acc):
         return
     if hasattr(model, 'save_checkpoint'):
         return model.save_checkpoint(epoch_idx, self.save_dir, val_acc)
     model_name = cls_name_trans(model.__class__.__name__)
     if self.only_best:
         self.delete_state_files(model_name)
     torch.save(
         model.state_dict(),
         self.save_dir.joinpath(
             self.file_format.format(epoch_idx=epoch_idx,
                                     model_name=model_name)))
Esempio n. 5
0
 def register_cls(cls, sub_cls_name, sub_cls):
     sub_cls_name = cls_name_trans(sub_cls_name)
     root_cls = sub_cls.root_cls()
     if root_cls is None:
         assert sub_cls.__name__ == 'PtBase'
         root_cls = sub_cls
     root_cls_name = root_cls.__name__
     if sub_cls_name in cls.cls_map[root_cls_name]:
         raise KeyError(f'Class name {sub_cls_name} has been registered')
         # return
     cls.cls_map[root_cls_name][sub_cls_name] = sub_cls
     if sub_cls.is_root_cls():
         cls.cls_map['PtBase'][sub_cls_name] = sub_cls