import layer #############################第一次预训练seg网络 criterion_L1 = torch.nn.L1Loss() criterion_MSE = torch.nn.MSELoss() criterion_BCE = torch.nn.BCEWithLogitsLoss() criterion_CE = criterion.crossentry() criterion_ncc = criterion.NCC().loss criterion_grad = criterion.Grad('l2', 2).loss criterion_dice = criterion.DiceMeanLoss() device = torch.device("cuda:0") data = data.train_data dataloder = Datas.DataLoader(dataset=data, batch_size=1, shuffle=True) Segnet = Network.DenseBiasNet(n_channels=1, n_classes=4).to(device) # Flownet = Network.VXm(2).to(device) opt_seg = torch.optim.Adam(Segnet.parameters(), lr=0.0001) ## # pretrained_dict = torch.load('./pkl/net_epoch_100-Flow-Network.pkl') # model_dict = Flownet.state_dict() # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # model_dict.update(pretrained_dict) # Flownet.load_state_dict(model_dict) pretrained_dict = torch.load('./pkl/net_epoch_99-Seg-Network.pkl') model_dict = Segnet.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) Segnet.load_state_dict(model_dict)