def load_model(config): model = choice_model(config.model.name, config.model.num_classes) if torch.cuda.is_available(): model = model.cuda() ch = torch.load(config.model.checkpoint) model.load_state_dict(ch['state_dict']) return model
def main(): checkpoint = '/home/du/Desktop/my_project/pytorch_classification/logs/11-13_17:32_efficientnet-b4_c254/50.pth' model_name = 'efficientnet-b4' n_class = 254 # model model = choice_model(model_name, n_class) net = model = resume_custom(checkpoint, model) net.to('cuda:0').eval() print('model: ', net) tmp = torch.ones(1, 3, 224, 224).to('cuda:0') print('input: ', tmp) out = net(tmp) print('output:', out) summary(net, (3, 224, 224)) # return f = open("eff-b4.wts", 'w') f.write("{}\n".format(len(net.state_dict().keys()))) for k, v in net.state_dict().items(): print('key: ', k) print('value: ', v.shape) vr = v.reshape(-1).cpu().numpy() f.write("{} {}".format(k, len(vr))) for vv in vr: f.write(" ") f.write(struct.pack(">f", float(vv)).hex()) f.write("\n")
def load_model(): model = choice_model('efficientnet-b7', 15) if torch.cuda.is_available(): model = model.cuda() ch = torch.load('../logs/2020-07-28 10:40:25/23.pth') model.load_state_dict(ch['state_dict']) model.eval() return model
def get_model(config): model = choice_model(config.model.name, config.model.num_classes) if config.model.custom_pretrain: model = resume_custom(config, model) # print info if get_rank(): get_model_info(model, config.model.name) # to cuda device = torch.device(config.device) model.to(device) return model
# ans.pop('model.fc.weight') # ans.pop('model.fc.bias') model_dict = model.state_dict() model_dict.update(ans) model.load_state_dict(model_dict) return model if __name__ == "__main__": checkpoint = '/home/du/Desktop/my_project/pytorch_classification/logs/resnest50_c15/18.pth' model_name = 'resnest50' onnx_model_name = 'resnest50.onnx' n_class = 15 model = choice_model(model_name, n_class) model = resume_custom(checkpoint, model) if torch.cuda.is_available(): model = model.cuda() model.eval() x = torch.randn(1, 3, 224, 224, requires_grad=True).cuda() # Export the model torch.onnx.export(model, # model being run x, # model input (or a tuple for multiple inputs) onnx_model_name, # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=11, # the ONNX version to export the model to