def load_model(path, input_nc, output_nc): nest_model = DenseFuse_net(input_nc, output_nc) nest_model.load_state_dict(torch.load(path)) para = sum([np.prod(list(p.size())) for p in nest_model.parameters()]) type_size = 4 print('Model {} : params: {:4f}M'.format(nest_model._get_name(), para * type_size / 1000 / 1000)) nest_model.eval() nest_model.cuda() return nest_model
def main_worker(gpu, ngpus_per_node, args): args.gpu = gpu if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.multiprocessing_distributed: args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, world_size=args.world_size, rank=args.rank) model = DenseFuse_net(input_nc=args.CHANNELS, output_nc=args.CHANNELS) optimizer = torch.optim.Adam(model.parameters(), args.lr) epoch = 0 if not torch.cuda.is_available(): print('using CPU, this will be slow') elif args.distributed: if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: model = torch.nn.DataParallel(model).cuda() if args.resume: if args.gpu is None: checkpoint = torch.load(args.resume) else: loc = 'cuda:{}'.format(args.gpu) checkpoint = torch.load(args.resume, map_location=loc) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] else: # print("Training from scratch") pass img_path_file = args.dataset assert args.CHANNELS == 1 or 3, "Input channels should be either 1 or 3" if args.CHANNELS == 1: custom_transform = transforms.Compose([transforms.Grayscale(num_output_channels=1), transforms.Resize((args.HEIGHT,args.WIDTH)), transforms.ToTensor()]) elif args.CHANNELS == 3: custom_transform = transforms.Compose([transforms.Resize((args.HEIGHT,args.WIDTH)), transforms.ToTensor()]) trainloader = DataLoader(MyTrainDataset(img_path_file, custom_transform=custom_transform), batch_size=args.batch_size, shuffle=False, num_workers=4) for ep in range(epoch, args.epochs): pbar = tqdm(trainloader) for inputs in pbar: # for inputs in trainloader: if args.gpu is not None: inputs = inputs.cuda(args.gpu, non_blocking=True) optimizer.zero_grad() en = model.encoder(inputs) predicts = model.decoder(en) loss = compute_loss(predicts, inputs, args.ssim_weight, w_idx=2) loss.backward() optimizer.step() if (ep + 1) % args.save_per_epoch == 0: # Save model torch.save({ 'epoch': ep, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss }, args.save_model_dir + 'ckpt_{}.pt'.format(ep)) print('Finished training')
def train(i, original_imgs_path): batch_size = args.batch_size # load network model, RGB in_c = 3 # 1 - gray; 3 - RGB if in_c == 1: img_model = 'L' else: img_model = 'RGB' input_nc = in_c output_nc = in_c densefuse_model = DenseFuse_net(input_nc, output_nc) if args.resume is not None: print('Resuming, initializing using weight from {}.'.format( args.resume)) densefuse_model.load_state_dict(torch.load(args.resume)) print(densefuse_model) optimizer = Adam(densefuse_model.parameters(), args.lr) mse_loss = torch.nn.MSELoss() ssim_loss = pytorch_msssim.msssim if args.cuda: densefuse_model.cuda() tbar = trange(args.epochs) print('Start training.....') # creating save path temp_path_model = os.path.join(args.save_model_dir, args.ssim_path[i]) if os.path.exists(temp_path_model) is False: os.mkdir(temp_path_model) temp_path_loss = os.path.join(args.save_loss_dir, args.ssim_path[i]) if os.path.exists(temp_path_loss) is False: os.mkdir(temp_path_loss) Loss_pixel = [] Loss_ssim = [] Loss_all = [] all_ssim_loss = 0. all_pixel_loss = 0. for e in tbar: print('Epoch %d.....' % e) # load training database image_set_ir, batches = utils.load_dataset(original_imgs_path, batch_size) densefuse_model.train() count = 0 for batch in range(batches): image_paths = image_set_ir[batch * batch_size:(batch * batch_size + batch_size)] img = utils.get_train_images_auto(image_paths, height=args.HEIGHT, width=args.WIDTH, mode=img_model) count += 1 optimizer.zero_grad() img = Variable(img, requires_grad=False) if args.cuda: img = img.cuda() # get fusion image # encoder en = densefuse_model.encoder(img) # decoder outputs = densefuse_model.decoder(en) # resolution loss x = Variable(img.data.clone(), requires_grad=False) ssim_loss_value = 0. pixel_loss_value = 0. for output in outputs: pixel_loss_temp = mse_loss(output, x) ssim_loss_temp = ssim_loss(output, x, normalize=True) ssim_loss_value += (1 - ssim_loss_temp) pixel_loss_value += pixel_loss_temp ssim_loss_value /= len(outputs) pixel_loss_value /= len(outputs) # total loss total_loss = pixel_loss_value + args.ssim_weight[ i] * ssim_loss_value total_loss.backward() optimizer.step() all_ssim_loss += ssim_loss_value.item() all_pixel_loss += pixel_loss_value.item() if (batch + 1) % args.log_interval == 0: mesg = "{}\tEpoch {}:\t[{}/{}]\t pixel loss: {:.6f}\t ssim loss: {:.6f}\t total: {:.6f}".format( time.ctime(), e + 1, count, batches, all_pixel_loss / args.log_interval, all_ssim_loss / args.log_interval, (args.ssim_weight[i] * all_ssim_loss + all_pixel_loss) / args.log_interval) tbar.set_description(mesg) Loss_pixel.append(all_pixel_loss / args.log_interval) Loss_ssim.append(all_ssim_loss / args.log_interval) Loss_all.append( (args.ssim_weight[i] * all_ssim_loss + all_pixel_loss) / args.log_interval) all_ssim_loss = 0. all_pixel_loss = 0. if (batch + 1) % (200 * args.log_interval) == 0: # save model densefuse_model.eval() densefuse_model.cpu() save_model_filename = args.ssim_path[i] + '/' + "Epoch_" + str(e) + "_iters_" + str(count) + "_" + \ str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + args.ssim_path[ i] + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(densefuse_model.state_dict(), save_model_path) # save loss data # pixel loss loss_data_pixel = np.array(Loss_pixel) loss_filename_path = args.ssim_path[i] + '/' + "loss_pixel_epoch_" + str( args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \ args.ssim_path[i] + ".mat" save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path) scio.savemat(save_loss_path, {'loss_pixel': loss_data_pixel}) # SSIM loss loss_data_ssim = np.array(Loss_ssim) loss_filename_path = args.ssim_path[i] + '/' + "loss_ssim_epoch_" + str( args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \ args.ssim_path[i] + ".mat" save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path) scio.savemat(save_loss_path, {'loss_ssim': loss_data_ssim}) # all loss loss_data_total = np.array(Loss_all) loss_filename_path = args.ssim_path[i] + '/' + "loss_total_epoch_" + str( args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \ args.ssim_path[i] + ".mat" save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path) scio.savemat(save_loss_path, {'loss_total': loss_data_total}) densefuse_model.train() densefuse_model.cuda() tbar.set_description("\nCheckpoint, trained model saved at", save_model_path) # pixel loss loss_data_pixel = np.array(Loss_pixel) loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_pixel_epoch_" + str( args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':','_') + "_" + \ args.ssim_path[i] + ".mat" save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path) scio.savemat(save_loss_path, {'loss_pixel': loss_data_pixel}) # SSIM loss loss_data_ssim = np.array(Loss_ssim) loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_ssim_epoch_" + str( args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \ args.ssim_path[i] + ".mat" save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path) scio.savemat(save_loss_path, {'loss_ssim': loss_data_ssim}) # all loss loss_data_total = np.array(Loss_all) loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_total_epoch_" + str( args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \ args.ssim_path[i] + ".mat" save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path) scio.savemat(save_loss_path, {'loss_total': loss_data_total}) # save model densefuse_model.eval() densefuse_model.cpu() save_model_filename = args.ssim_path[i] + '/' "Final_epoch_" + str(args.epochs) + "_" + \ str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + args.ssim_path[i] + ".model" save_model_path = os.path.join(args.save_model_dir, save_model_filename) torch.save(densefuse_model.state_dict(), save_model_path) print("\nDone, trained model saved at", save_model_path)
img = Image.open(self.img_list[index]).convert('RGB') ir = Image.open(self.ir_list[index]).convert('RGB') if self.transform: img = self.transform(img) ir = self.transform(ir) return img, ir def __len__(self): return len(self.img_list) if __name__ == '__main__': model = DenseFuse_net() checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['model_state_dict']) strategy_type = args.strategy_type img_path_file = args.test_img ir_path_file = args.test_ir testloader = DataLoader(MyTestDataset(img_path_file, ir_path_file), batch_size=1, shuffle=False, num_workers=1) if is_cuda: model.cuda() for i, (img, ir) in enumerate(tqdm(testloader)): if is_cuda: