Exemple #1
0
import layer

#############################第一次预训练seg网络
criterion_L1 = torch.nn.L1Loss()
criterion_MSE = torch.nn.MSELoss()
criterion_BCE = torch.nn.BCEWithLogitsLoss()
criterion_CE = criterion.crossentry()
criterion_ncc = criterion.NCC().loss
criterion_grad = criterion.Grad('l2', 2).loss
criterion_dice = criterion.DiceMeanLoss()

device = torch.device("cuda:0")

data = data.train_data
dataloder = Datas.DataLoader(dataset=data, batch_size=1, shuffle=True)
Segnet = Network.DenseBiasNet(n_channels=1, n_classes=4).to(device)
# Flownet = Network.VXm(2).to(device)

opt_seg = torch.optim.Adam(Segnet.parameters(), lr=0.0001)
##
# pretrained_dict = torch.load('./pkl/net_epoch_100-Flow-Network.pkl')
# model_dict = Flownet.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# Flownet.load_state_dict(model_dict)

pretrained_dict = torch.load('./pkl/net_epoch_99-Seg-Network.pkl')
model_dict = Segnet.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
Segnet.load_state_dict(model_dict)