def __init__(self, opts, device_ids, load_model=False): self.opts = opts self.epoch_step = 0 self.model_num = 0 self.network = Unnet.UNet(opts) # self.network = Resnet.resnet152() # self.network = framelet.Framelets() # self.network = googlenet.GoogLeNet() # self.network = densenet.densenet161() if torch.cuda.device_count() > 1 and opts.max_gpus > 1: if len(device_ids) <= opts.max_gpus: self.network = torch.nn.DataParallel( self.network) #, device_ids=device_ids[0] else: self.network = torch.nn.DataParallel( self.network, device_ids=device_ids[0:opts.max_gpus - 1]) self.network.cuda() # Create two sets of loss functions self.loss_func_l1 = torch.nn.L1Loss() self.loss_func_MSE = torch.nn.MSELoss() self.MedGanloss = MedGanloss() self.mssim_loss = MSSSIM(window_size=9, size_average=True) self.loss_func_poss = torch.nn.PoissonNLLLoss() self.loss_func_KLDiv = torch.nn.KLDivLoss() self.loss_func_Smoothl1 = torch.nn.SmoothL1Loss() self.loss_func_part = torch.nn.L1Loss() self.test_loss = torch.nn.MSELoss(reduction='none') self.averagepool = torch.nn.AvgPool2d(3, stride=2) self.optim_count = 0 #TODO2: Change the load model dict if self.opts.load_model == True or load_model: print("Restoring model") try: if load_model: self.network.load_state_dict( torch.load( '/home/liang/Desktop/output/model/model_dict_0')) else: self.network.load_state_dict( torch.load( os.path.join(self.opts.output_path, 'model', 'model_dict'))) except: # original saved file with DataParallel state_dict = torch.load( os.path.join(self.opts.output_path, 'model', 'model_dict')) # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): # name = k[7:] # remove `module.` name = 'module.' + k new_state_dict[name] = v # load params self.network.load_state_dict(new_state_dict)
def __init__(self, shape1, shape2): # self.network = Net(shape1, shape2) self.network = SETLayer(wl * wl, wl * wl) # self.network.cuda() self.loss_func = torch.nn.L1Loss(reduction='mean') self.mssim_loss = MSSSIM(window_size=11, size_average=True) self.loss_func_MSE = torch.nn.MSELoss() self.model_num = 0
def __init__(self, shape1, shape2): self.network = unet.UNet() #(shape1, shape2) self.network.cuda() self.loss_func = torch.nn.L1Loss(reduction='mean') self.mssim_loss = MSSSIM(window_size=3, size_average=True) self.crossengropy = torch.nn.CrossEntropyLoss() self.loss_func_MSE = torch.nn.MSELoss() self.model_num = 0
def __init__(self, shape1, shape2): # self.network = Net(shape1, shape2) # self.network.cuda() self.loss_func = torch.nn.L1Loss(reduction='sum') self.mssim_loss = MSSSIM(window_size=11, size_average=True) self.loss_func_MSE = torch.nn.MSELoss() self.model_num = 0 self.weight = torch.randn( shape2[0] * shape2[1], shape1[0] * shape1[1], device='cuda').to_sparse().requires_grad_(True) self.learning_rate = 1e-3 self.count = 0
def main(): with torch.cuda.device(1): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = AS_Net(in_channels=args.in_channels).to(device) batch_time = AverageMeter() train_ssim_meter = AverageMeter() train_psnr_meter = AverageMeter() test_ssim_meter = AverageMeter() test_psnr_meter = AverageMeter() vis = Visualizer(env=args.vis_env) train_dataset = multichanneldata.ReconDataset0526(args.dataset_pathr, train=True) test_dataset = multichanneldata.ReconDataset0526(args.dataset_pathr, train=False) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.test_batch, shuffle=False) smooth_L1 = nn.SmoothL1Loss() msssim = MSSSIM(channel=1) optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.loadcp: checkpoint = torch.load(args.save_path + 'latest_' + args.file_name) start_epoch = checkpoint['epoch'] print('%s%d' % ('training from epoch:', start_epoch)) model = checkpoint['model'] optimizer = checkpoint['optimizer'] args.learning_rate = checkpoint['curr_lr'] cudnn.benchmark = True total_step = len(train_loader) best_metric = {'test_epoch': 0, 'test_ssim': 0, 'test_psnr': 0} log.info('train image num: {}'.format(train_dataset.__len__())) log.info('val image num: {}'.format(test_dataset.__len__())) end = time.time() for epoch in range(args.start_epoch, args.num_epochs): for batch_idx, (rawdata, reimage, bfimg) in enumerate(tqdm(train_loader)): rawdata = rawdata.to(device) reimage = reimage.to(device) bfimg = bfimg.to(device) fake_img, bf_feature, side = model(rawdata, bfimg) loss_pe = smooth_L1(fake_img, reimage) bf_loss = smooth_L1(bf_feature, reimage) loss = 5 * loss_pe + bf_loss # Backward and optimize optimizer.zero_grad() loss.backward() optimizer.step() ssim = compare_ssim(np.array(reimage[0, 0, :, :].cpu().detach()), np.array(fake_img[0, 0, :, :].cpu().detach())) train_ssim_meter.update(ssim) psnr = compare_psnr(np.array(reimage[0, 0, :, :].cpu().detach()), np.array(fake_img[0, 0, :, :].cpu().detach()), data_range=1) train_psnr_meter.update(psnr) # visualization and evaluation if (batch_idx + 1) % 5 == 0: reimage = reimage.detach() bfimg = bfimg.detach() bf_feature = bf_feature.detach() side = side.detach() fake_img = fake_img.detach() vis.img(name='ground truth', img_=255 * reimage[0]) vis.img(name='DAS image', img_=255 * bfimg[0]) vis.img(name='textural map', img_=255 * bf_feature[0]) vis.img(name='side_output', img_=255 * side[0]) vis.img(name='output', img_=255 * fake_img[0]) batch_time.update(time.time() - end) end = time.time() log.info( 'Epoch [{}], Start [{}], Step [{}/{}], Loss: {:.4f}, Time [{batch_time.val:.3f}({batch_time.avg:.3f})]' .format(epoch + 1, args.start_epoch, batch_idx + 1, total_step, loss.item(), batch_time=batch_time)) vis.plot_multi_win( dict( bfloss=bf_loss.item(), loss_mse=loss_pe.item(), total_loss=loss.item(), )) vis.plot_multi_win(dict(train_ssim=train_ssim_meter.avg, train_psnr=train_psnr_meter.avg)) log.info('tain_ssim: {}, train_psnr: {}'.format(train_ssim_meter.avg, train_psnr_meter.avg)) # Validata if epoch % 5 == 0: with torch.no_grad(): for batch_idx, (rawdata, reimage, bfimg) in enumerate(tqdm(test_loader)): rawdata = rawdata.to(device) reimage = reimage.to(device) bfimg = bfimg.to(device) outputs, bf_feature, side_test = model(rawdata, bfimg) test_ms_ssim = msssim(outputs, reimage) ssim = compare_ssim(np.array(reimage.cpu().squeeze()), np.array(outputs.cpu().squeeze())) test_ssim_meter.update(ssim) psnr = compare_psnr(np.array(reimage.cpu().squeeze()), np.array(outputs.cpu().squeeze()), data_range=1) test_psnr_meter.update(psnr) if (batch_idx + 1) % 2 == 0: reimage = reimage.detach() bf_feature = bf_feature.detach() outputs = outputs.detach() side_test = side_test.detach() bfimg = bfimg.detach() vis.img(name='Test: ground truth', img_=255 * reimage[0]) vis.img(name='Test: DASimage', img_=255 * bfimg[0]) vis.img(name='Test: textural map', img_=255 * bf_feature[0]) vis.img(name='Test: output', img_=255 * outputs[0]) vis.img(name='Test: side_output', img_=255 * side_test[0]) vis.plot_multi_win(dict( test_ssim=test_ssim_meter.avg, test_psnr=test_psnr_meter.avg, test_msssim=test_ms_ssim.item() )) log.info('test_ssim: {}, test_psnr: {}'.format(test_ssim_meter.avg, test_psnr_meter.avg)) # Decay learning rate if (epoch + 1) % 50 == 0: args.learning_rate /= 5 update_lr(optimizer, args.learning_rate) torch.save({'epoch': epoch, 'model': model, 'optimizer': optimizer, 'curr_lr': args.learning_rate, }, args.save_path + 'latest_' + args.file_name ) if best_metric['test_ssim'] < test_ssim_meter.avg: torch.save({'epoch': epoch, 'model': model, 'optimizer': optimizer, 'curr_lr': args.learning_rate, }, args.save_path + 'best_' + args.file_name ) best_metric['test_epoch'] = epoch best_metric['test_ssim'] = test_ssim_meter.avg best_metric['test_psnr'] = test_psnr_meter.avg log.info('best_epoch: {}, best_ssim: {}, best_psnr: {}'.format(best_metric['test_epoch'], best_metric['test_ssim'], best_metric['test_psnr']))
class Sino_repair_net(): def __init__(self, opts, device_ids, load_model=False): self.opts = opts self.epoch_step = 0 self.model_num = 0 self.network = Unnet.UNet(opts) # self.network = Resnet.resnet152() # self.network = framelet.Framelets() # self.network = googlenet.GoogLeNet() # self.network = densenet.densenet161() if torch.cuda.device_count() > 1 and opts.max_gpus > 1: if len(device_ids) <= opts.max_gpus: self.network = torch.nn.DataParallel( self.network) #, device_ids=device_ids[0] else: self.network = torch.nn.DataParallel( self.network, device_ids=device_ids[0:opts.max_gpus - 1]) self.network.cuda() # Create two sets of loss functions self.loss_func_l1 = torch.nn.L1Loss() self.loss_func_MSE = torch.nn.MSELoss() self.MedGanloss = MedGanloss() self.mssim_loss = MSSSIM(window_size=9, size_average=True) self.loss_func_poss = torch.nn.PoissonNLLLoss() self.loss_func_KLDiv = torch.nn.KLDivLoss() self.loss_func_Smoothl1 = torch.nn.SmoothL1Loss() self.loss_func_part = torch.nn.L1Loss() self.test_loss = torch.nn.MSELoss(reduction='none') self.averagepool = torch.nn.AvgPool2d(3, stride=2) self.optim_count = 0 #TODO2: Change the load model dict if self.opts.load_model == True or load_model: print("Restoring model") try: if load_model: self.network.load_state_dict( torch.load( '/home/liang/Desktop/output/model/model_dict_0')) else: self.network.load_state_dict( torch.load( os.path.join(self.opts.output_path, 'model', 'model_dict'))) except: # original saved file with DataParallel state_dict = torch.load( os.path.join(self.opts.output_path, 'model', 'model_dict')) # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): # name = k[7:] # remove `module.` name = 'module.' + k new_state_dict[name] = v # load params self.network.load_state_dict(new_state_dict) def set_optimizer(self, optimizer): self.optimizer = optimizer def train_batch(self, input_img, target_img, valid=None): if valid is None: output = self.network.forward(input_img) loss, loss2 = self.optimize(output, target_img) return output, loss, loss2 else: final = input_img.clone() mask = torch.tensor([i for i, n in enumerate(valid) if n == 1]).cuda() if len(mask) > 0: traininput = torch.index_select(input_img, 0, mask) trainoutput = self.network.forward(traininput) loss, loss2 = self.optimize( trainoutput, torch.index_select(target_img, 0, mask)) final[mask] = trainoutput else: loss, loss2 = self.loss_func_l1(final, target_img), ( 1 - self.mssim_loss.forward(final, target_img)) / 2 return final, loss, loss2 def test(self, x, y, valid=None): if valid is None: output = self.network.forward(x) loss = self.test_loss(output, y).detach() return output, loss else: final = x.clone() mask = torch.tensor([i for i, n in enumerate(valid) if n == 1]).cuda() if len(mask) > 0: traininput = torch.index_select(x, 0, mask) trainoutput = self.network.forward(traininput) final[mask] = trainoutput loss = self.test_loss(final, y).detach() return final, loss def optimize(self, output, target_img): #TODO: can add other loss terms if needed #TODO: need to step though this code to make sure it works correctly input1 = output #torch.floor(output) #/ (output.max() + 1e-8) input2 = target_img #/ (target_img.max() + 1e-8) # # Including l1 loss # mask = ((input_img * -1.0) + 1.0) >= 0.8 loss1 = self.loss_func_l1(input1, input2) l1 = loss1.detach() # Including a consistency loss loss2 = self.loss_func_MSE(input1, input2) l2 = loss2.detach() loss3 = 0.0001 * (1 - self.mssim_loss.forward(input1, input2)) l3 = loss3.detach() # # loss3 = self.MedGanloss(output, target_img) # l3 = loss3.detach() # loss3 = self.loss_func_l1(self.averagepool(input1), self.averagepool(input2)) # loss3 = abs(torch.floor(input1).mean()- input2.mean()) # loss3 = self.loss_func_MSE(output, target_img) # # if self.OPT_count == 0: # self.alpha = torch.tensor(0.5).cuda() # self.lossl1 = [] # self.lossmssim = [] # # self.lossl1.append(l1.item()) # self.lossmssim.append(l2.item()) # # if self.OPT_count >= 20: # self.alpha = torch.FloatTensor(self.lossl1).mean().cuda()/(torch.FloatTensor(self.lossl1).mean() + torch.FloatTensor(self.lossmssim).mean()).cuda() # self.OPT_count = 0 # # # loss = self.alpha * loss1+(1-self.alpha) *loss2 # vx = output - torch.mean(output) # vy = target_img - torch.mean(target_img) # loss_pearson_correlation = 1 - torch.sum(vx * vy) / ( # torch.rsqrt(torch.sum(vx ** 2)) * torch.rsqrt(torch.sum(vy ** 2))) # use Pearson correlation loss = loss1 + loss2 + loss3 # print(loss1, loss2, loss3) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.optim_count += 1 return l1, l2 def save_network(self): print("saving network parameters") folder_path = os.path.join(self.opts.output_path, 'model') if not os.path.exists(folder_path): os.makedirs(folder_path) torch.save( self.network.state_dict(), os.path.join(folder_path, "model_dict_{}".format(self.model_num))) self.model_num += 1 if self.model_num >= 5: self.model_num = 0
def train(netG1, netG2, netD1, netD2, vgg, optG1, optG2, optD1, optD2, dataloader, bs, l, savename): real_label = torch.ones(bs).long() fake_label = torch.zeros(bs).long() # crit_data = nn.L1Loss() # criterion = nn.BCEWithLogitsLoss() # Segmentation losses crit_discriminator = nn.CrossEntropyLoss(reduction='sum') #NLLLoss() crit_generator = nn.CrossEntropyLoss(reduction='mean') #reduction='sum') # loss weights # loss weights tv_w = 2000 content_w = 10 color_w = 0.5 tv_msssim = 500 texture_w = 1 NUM_PATCHES = 50 // bs # DPED losses color_criterion = nn.MSELoss(reduction='sum') #nn.MSELoss(reduction='sum') texture_criterion = nn.BCELoss(reduction='mean') content_criterion = nn.MSELoss(reduction='sum') msssim_criterion = MSSSIM() tv_criterion = TVLoss() blur = GaussianBlur() gray = GrayLayer() if torch.cuda.is_available(): real_label = real_label.cuda() fake_label = fake_label.cuda() tv_criterion = tv_criterion.cuda() blur = blur.cuda() gray = gray.cuda() for epoch in range(40): begin = time.time() for i, (x, y, r) in enumerate(dataloader): if i % 10 == 0: print(epoch, i, end='\r') if torch.cuda.is_available(): x = x.cuda() y = y.cuda() r = r.cuda() # Image Segmentation # generator 1 # Generate classes mask image and calculate loss with ground truth image mask optG1.zero_grad() optG2.zero_grad() optD1.zero_grad() optD2.zero_grad() y_G1 = netG1(x) loss_G1_data = crit_generator(y_G1, y) _, y_G1 = torch.max(y_G1, 1) y_G1 = (y_G1.float() / 2.0 - 0.5) * 2 x_y_G1 = torch.cat((x, y_G1.float().unsqueeze(1)), 1) x_y = torch.cat((x, y.float().unsqueeze(1)), 1) loss_G1_adv = crit_discriminator(netD1(x_y_G1), real_label.long()) loss_G1 = loss_G1_data + loss_G1_adv # discriminator 1 y_D1_fake = netD1(x_y_G1.detach()) y_D1_real = netD1(x_y) loss_D1_fake = crit_discriminator(y_D1_fake, fake_label.long()) loss_D1_real = crit_discriminator(y_D1_real, real_label.long()) loss_D1 = l[1] * (loss_D1_real + loss_D1_fake) # Image Cropping x_y_G1_crop, x_y_crop, r_crop = x_y_G1.detach(), x_y, r texture_loss = 0 content_loss = 0 color_loss = 0 tv_loss = 0 loss_D2_fake = 0 loss_D2_real = 0 msssim_loss = 0 for _ in range(NUM_PATCHES): x_y_G1, x_y, r = CropImage(x_y_G1_crop, x_y_crop, r_crop) # Image Enhancement # DPED # train generator 2 y_G2 = netG2(x_y_G1) # texture loss y_G2_gray = gray(y_G2) x_y_G1_G2 = torch.cat((x_y_G1, y_G2), 1) y_G2_pred = netD2(x_y_G1_G2) texture_loss += -texture_criterion(y_G2_pred, fake_label.float()) * bs # content loss vgg_y_G2 = vgg(y_G2) vgg_r = vgg(r).detach() _, c1, h1, w1 = y_G2.size() chw1 = c1 * h1 * w1 content_loss += 1.0 / (2 * bs * chw1) * content_criterion( vgg_y_G2, vgg_r) # color loss y_G2_blur = blur(y_G2) r_blur = blur(r).detach() color_loss += color_criterion(y_G2_blur, r_blur) / (2 * bs) # color_loss += color_criterion(y_G2, r) / (2 * bs) # msssim_loss += msssim_criterion(r, y_G2) # total variation loss tv_loss += tv_criterion(y_G2) # discriminator 2 x_y_r = torch.cat((x_y, r), 1) r_pred = netD2(x_y_r.detach()) y_G2_pred = netD2(x_y_G1_G2.detach()) loss_D2_fake += texture_criterion(y_G2_pred, fake_label.float()) * bs loss_D2_real += texture_criterion(r_pred, real_label.float()) * bs texture_loss /= NUM_PATCHES content_loss /= NUM_PATCHES color_loss /= NUM_PATCHES tv_loss /= NUM_PATCHES loss_D2_fake /= NUM_PATCHES loss_D2_real /= NUM_PATCHES # msssim_loss /= NUM_PATCHES # msssim_loss = 1 - msssim_loss # Total losses loss_G2 = texture_w * texture_loss + content_w * content_loss + color_w * color_loss + tv_w * tv_loss #+ tv_msssim * msssim_loss loss_D2 = texture_w * (loss_D2_real + loss_D2_fake) loss_D1.backward() loss_D2.backward() optD1.step() optD2.step() loss_G1.backward() loss_G2.backward() optG1.step() optG2.step() if i % 100 == 0: losses = { 'content': content_loss.item(), 'color': color_loss.item(), # 'msssim': msssim_loss.item(), 'tv': tv_loss.item(), 'gen_texture_loss': texture_loss.item(), 'disc_fake_loss': loss_D2_fake.item() / bs, 'disc_real_loss': loss_D2_real.item() / bs, } torch.save(netG1.state_dict(), './models/netG1' + savename + '.pth') torch.save(netG2.state_dict(), './models/netG2' + savename + '.pth') torch.save(netD1.state_dict(), './models/netD1' + savename + '.pth') torch.save(netD2.state_dict(), './models/netD2' + savename + '.pth') print('epoch {:}'.format(epoch), 'iter {:}'.format(i), 'iter time {:.2f}'.format(time.time() - begin), 'loss_D1: {:.4f}'.format(loss_D1.item()), 'loss_D2: {:.4f}'.format(loss_D2.item()), 'loss_G1: {:.4f}'.format(loss_G1.item()), 'loss_G2: {:.4f}'.format(loss_G2.item()), file=open( '/home/jupyter/STGAN/results_stgan' + savename + '.txt', 'a+')) for k, v in losses.items(): print(k, '{:.3f}'.format(v), end=', ', file=open( '/home/jupyter/STGAN/results_stgan' + savename + '.txt', 'a+')) print('', file=open( '/home/jupyter/STGAN/results_stgan' + savename + '.txt', 'a+')) torch.save(netG1.state_dict(), './models/netG1' + savename + '.pth') torch.save(netG2.state_dict(), './models/netG2' + savename + '.pth') torch.save(netD1.state_dict(), './models/netD1' + savename + '.pth') torch.save(netD2.state_dict(), './models/netD2' + savename + '.pth')
def mssdim(output, target): return (1. - MSSSIM(output, target)) / 2.
D = model.Discriminator(6) G = model.VGG_VAE(5) D.apply(weights_init) G.apply(weights_init) D.cuda() G.cuda() print(D) print(G) D_criterion = torch.nn.BCEWithLogitsLoss().cuda() D_optimizer = torch.optim.SGD(D.parameters(), lr=1e-3) G_criterion = torch.nn.BCEWithLogitsLoss().cuda() G_l1 = torch.nn.L1Loss().cuda() G_msssim = MSSSIM().cuda() G_ssim = SSIM().cuda() G_optimizer = torch.optim.Adam(G.parameters(), lr=1e-3) pathlib.Path(sample_output).mkdir(parents=True, exist_ok=True) pathlib.Path(os.path.join(sample_output, "images")).mkdir(parents=True, exist_ok=True) d_loss = 0 g_loss = 0 d_to_g_threshold = 0.5 g_to_d_threshold = 0.3 train_d = True train_g = True conditional_training = False
def l1loss(x, y): return torch.nn.functional.l1_loss(x, y, reduction='mean') def l2loss(x, y): return torch.pow(x - y, 2).mean() def psnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return (10 * torch.log10(x.shape[-2] * x.shape[-1] / (x - y).pow(2).sum(dim=(2, 3)))).mean(dim=1) ssim = SSIM(data_range=1.0) msssim = MSSSIM(data_range=1.0) def gaussian(x, sigma=1.0): return np.exp(-(x**2) / (2 * (sigma**2))) def build_gauss_kernel(size=5, sigma=1.0, n_channels=1, device=None): """Construct the convolution kernel for a gaussian blur See https://en.wikipedia.org/wiki/Gaussian_blur for a definition. Overall I first generate a NxNx2 matrix of indices, and then use those to calculate the gaussian function on each element. The two dimensional Gaussian function is then the product along axis=2. Also, in_channels == out_channels == n_channels """ if size % 2 != 1:
def __init__(self, opts, device_ids, load_model=False): self.opts = opts self.epoch_step = 0 self.model_num = 0 # self.network = UNet(opts) # self.network = test2.Net(opts) # self.network = VGG.vgg11_bn() self.network = GoogleNet.GoogLeNet() # self.network = framelet.Framelets() # import VGG # self.network = VGG.vgg19_bn() # self.network = Nets.AlexNet(channelnumber=1, num_classes=64*520) #Logic to make training on a GPU cluster easier if torch.cuda.device_count() > 1 and opts.max_gpus > 1: if len(device_ids) <= opts.max_gpus: self.network = torch.nn.DataParallel( self.network) #, device_ids=device_ids[0] else: self.network = torch.nn.DataParallel( self.network, device_ids=device_ids[0:opts.max_gpus - 1]) self.network.cuda() # Create two sets of loss functions self.loss_func_l1 = torch.nn.L1Loss() self.loss_func_MSE = torch.nn.MSELoss() self.mssim_loss = MSSSIM(window_size=11, size_average=True) self.loss_func_poss = torch.nn.PoissonNLLLoss() self.loss_func_KLDiv = torch.nn.KLDivLoss() self.loss_func_Smoothl1 = torch.nn.SmoothL1Loss() self.loss_func_part = torch.nn.L1Loss() self.test_loss = torch.nn.MSELoss(reduction='none') self.OPT_count = 0 #TODO2: Change the load model dict if self.opts.load_model == True or load_model: print("Restoring model") try: if load_model: self.network.load_state_dict( torch.load( '/home/liang/Desktop/output/model/model_dict_0')) else: self.network.load_state_dict( torch.load( os.path.join(self.opts.output_path, 'model', 'model_dict'))) except: # original saved file with DataParallel state_dict = torch.load( os.path.join(self.opts.output_path, 'model', 'model_dict')) # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): # name = k[7:] # remove `module.` name = 'module.' + k new_state_dict[name] = v # load params self.network.load_state_dict(new_state_dict)