def load_model(path, input_nc, output_nc):

	nest_model = DenseFuse_net(input_nc, output_nc)
	nest_model.load_state_dict(torch.load(path))

	para = sum([np.prod(list(p.size())) for p in nest_model.parameters()])
	type_size = 4
	print('Model {} : params: {:4f}M'.format(nest_model._get_name(), para * type_size / 1000 / 1000))

	nest_model.eval()
	nest_model.cuda()

	return nest_model
Esempio n. 2
0
def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.distributed:
        if args.multiprocessing_distributed:
            args.rank = args.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=args.dist_backend, world_size=args.world_size, rank=args.rank)

    model = DenseFuse_net(input_nc=args.CHANNELS, output_nc=args.CHANNELS)
    optimizer = torch.optim.Adam(model.parameters(), args.lr)

    epoch = 0

    if not torch.cuda.is_available():
        print('using CPU, this will be slow')
    elif args.distributed:
        if args.gpu is not None:
            torch.cuda.set_device(args.gpu)
            model.cuda(args.gpu)
            args.batch_size = int(args.batch_size / ngpus_per_node)
            args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        else:
            model.cuda()
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model = model.cuda(args.gpu)
    else:
        model = torch.nn.DataParallel(model).cuda()

    if args.resume:
        if args.gpu is None:
            checkpoint = torch.load(args.resume)
        else:
            loc = 'cuda:{}'.format(args.gpu)
            checkpoint = torch.load(args.resume, map_location=loc)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']
    else:
        # print("Training from scratch")
        pass
    
    img_path_file = args.dataset

    assert args.CHANNELS == 1 or 3, "Input channels should be either 1 or 3"
    if args.CHANNELS == 1:
        custom_transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                            transforms.Resize((args.HEIGHT,args.WIDTH)),
                                            transforms.ToTensor()])
    elif args.CHANNELS == 3:
        custom_transform = transforms.Compose([transforms.Resize((args.HEIGHT,args.WIDTH)),
                                            transforms.ToTensor()])

    trainloader = DataLoader(MyTrainDataset(img_path_file, custom_transform=custom_transform), batch_size=args.batch_size, shuffle=False, num_workers=4)

    for ep in range(epoch, args.epochs):

        pbar = tqdm(trainloader)

        for inputs in pbar:
        # for inputs in trainloader:
            
            if args.gpu is not None:
                inputs = inputs.cuda(args.gpu, non_blocking=True)

            optimizer.zero_grad()
            
            en = model.encoder(inputs)
            predicts = model.decoder(en)

            loss = compute_loss(predicts, inputs, args.ssim_weight, w_idx=2)
            loss.backward()
            optimizer.step()

        if (ep + 1) % args.save_per_epoch == 0:
            # Save model
            torch.save({
                        'epoch': ep,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss
                    }, args.save_model_dir + 'ckpt_{}.pt'.format(ep))

    print('Finished training')
Esempio n. 3
0
def train(i, original_imgs_path):

    batch_size = args.batch_size

    # load network model, RGB
    in_c = 3  # 1 - gray; 3 - RGB
    if in_c == 1:
        img_model = 'L'
    else:
        img_model = 'RGB'
    input_nc = in_c
    output_nc = in_c
    densefuse_model = DenseFuse_net(input_nc, output_nc)

    if args.resume is not None:
        print('Resuming, initializing using weight from {}.'.format(
            args.resume))
        densefuse_model.load_state_dict(torch.load(args.resume))
    print(densefuse_model)
    optimizer = Adam(densefuse_model.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()
    ssim_loss = pytorch_msssim.msssim

    if args.cuda:
        densefuse_model.cuda()

    tbar = trange(args.epochs)
    print('Start training.....')

    # creating save path
    temp_path_model = os.path.join(args.save_model_dir, args.ssim_path[i])
    if os.path.exists(temp_path_model) is False:
        os.mkdir(temp_path_model)

    temp_path_loss = os.path.join(args.save_loss_dir, args.ssim_path[i])
    if os.path.exists(temp_path_loss) is False:
        os.mkdir(temp_path_loss)

    Loss_pixel = []
    Loss_ssim = []
    Loss_all = []
    all_ssim_loss = 0.
    all_pixel_loss = 0.
    for e in tbar:
        print('Epoch %d.....' % e)
        # load training database
        image_set_ir, batches = utils.load_dataset(original_imgs_path,
                                                   batch_size)
        densefuse_model.train()
        count = 0
        for batch in range(batches):
            image_paths = image_set_ir[batch * batch_size:(batch * batch_size +
                                                           batch_size)]
            img = utils.get_train_images_auto(image_paths,
                                              height=args.HEIGHT,
                                              width=args.WIDTH,
                                              mode=img_model)

            count += 1
            optimizer.zero_grad()
            img = Variable(img, requires_grad=False)

            if args.cuda:
                img = img.cuda()
            # get fusion image
            # encoder
            en = densefuse_model.encoder(img)
            # decoder
            outputs = densefuse_model.decoder(en)
            # resolution loss
            x = Variable(img.data.clone(), requires_grad=False)

            ssim_loss_value = 0.
            pixel_loss_value = 0.
            for output in outputs:
                pixel_loss_temp = mse_loss(output, x)
                ssim_loss_temp = ssim_loss(output, x, normalize=True)
                ssim_loss_value += (1 - ssim_loss_temp)
                pixel_loss_value += pixel_loss_temp
            ssim_loss_value /= len(outputs)
            pixel_loss_value /= len(outputs)

            # total loss
            total_loss = pixel_loss_value + args.ssim_weight[
                i] * ssim_loss_value
            total_loss.backward()
            optimizer.step()

            all_ssim_loss += ssim_loss_value.item()
            all_pixel_loss += pixel_loss_value.item()
            if (batch + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\t pixel loss: {:.6f}\t ssim loss: {:.6f}\t total: {:.6f}".format(
                    time.ctime(), e + 1, count, batches,
                    all_pixel_loss / args.log_interval,
                    all_ssim_loss / args.log_interval,
                    (args.ssim_weight[i] * all_ssim_loss + all_pixel_loss) /
                    args.log_interval)
                tbar.set_description(mesg)
                Loss_pixel.append(all_pixel_loss / args.log_interval)
                Loss_ssim.append(all_ssim_loss / args.log_interval)
                Loss_all.append(
                    (args.ssim_weight[i] * all_ssim_loss + all_pixel_loss) /
                    args.log_interval)

                all_ssim_loss = 0.
                all_pixel_loss = 0.

            if (batch + 1) % (200 * args.log_interval) == 0:
                # save model
                densefuse_model.eval()
                densefuse_model.cpu()
                save_model_filename = args.ssim_path[i] + '/' + "Epoch_" + str(e) + "_iters_" + str(count) + "_" + \
                       str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + args.ssim_path[
                        i] + ".model"
                save_model_path = os.path.join(args.save_model_dir,
                                               save_model_filename)
                torch.save(densefuse_model.state_dict(), save_model_path)
                # save loss data
                # pixel loss
                loss_data_pixel = np.array(Loss_pixel)
                loss_filename_path = args.ssim_path[i] + '/' + "loss_pixel_epoch_" + str(
                 args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
                      args.ssim_path[i] + ".mat"
                save_loss_path = os.path.join(args.save_loss_dir,
                                              loss_filename_path)
                scio.savemat(save_loss_path, {'loss_pixel': loss_data_pixel})
                # SSIM loss
                loss_data_ssim = np.array(Loss_ssim)
                loss_filename_path = args.ssim_path[i] + '/' + "loss_ssim_epoch_" + str(
                 args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
                      args.ssim_path[i] + ".mat"
                save_loss_path = os.path.join(args.save_loss_dir,
                                              loss_filename_path)
                scio.savemat(save_loss_path, {'loss_ssim': loss_data_ssim})
                # all loss
                loss_data_total = np.array(Loss_all)
                loss_filename_path = args.ssim_path[i] + '/' + "loss_total_epoch_" + str(
                 args.epochs) + "_iters_" + str(count) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
                      args.ssim_path[i] + ".mat"
                save_loss_path = os.path.join(args.save_loss_dir,
                                              loss_filename_path)
                scio.savemat(save_loss_path, {'loss_total': loss_data_total})

                densefuse_model.train()
                densefuse_model.cuda()
                tbar.set_description("\nCheckpoint, trained model saved at",
                                     save_model_path)

    # pixel loss
    loss_data_pixel = np.array(Loss_pixel)
    loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_pixel_epoch_" + str(
     args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':','_') + "_" + \
          args.ssim_path[i] + ".mat"
    save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
    scio.savemat(save_loss_path, {'loss_pixel': loss_data_pixel})
    # SSIM loss
    loss_data_ssim = np.array(Loss_ssim)
    loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_ssim_epoch_" + str(
     args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
          args.ssim_path[i] + ".mat"
    save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
    scio.savemat(save_loss_path, {'loss_ssim': loss_data_ssim})
    # all loss
    loss_data_total = np.array(Loss_all)
    loss_filename_path = args.ssim_path[i] + '/' + "Final_loss_total_epoch_" + str(
     args.epochs) + "_" + str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + \
          args.ssim_path[i] + ".mat"
    save_loss_path = os.path.join(args.save_loss_dir, loss_filename_path)
    scio.savemat(save_loss_path, {'loss_total': loss_data_total})
    # save model
    densefuse_model.eval()
    densefuse_model.cpu()
    save_model_filename = args.ssim_path[i] + '/' "Final_epoch_" + str(args.epochs) + "_" + \
           str(time.ctime()).replace(' ', '_').replace(':', '_') + "_" + args.ssim_path[i] + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(densefuse_model.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Esempio n. 4
0
        img = Image.open(self.img_list[index]).convert('RGB')
        ir = Image.open(self.ir_list[index]).convert('RGB')

        if self.transform:
            img = self.transform(img)
            ir = self.transform(ir)
        
        return img, ir

    def __len__(self):

        return len(self.img_list)

if __name__ == '__main__':

    model = DenseFuse_net()

    checkpoint = torch.load(args.resume)
    model.load_state_dict(checkpoint['model_state_dict'])
    strategy_type = args.strategy_type

    img_path_file = args.test_img
    ir_path_file = args.test_ir
    testloader = DataLoader(MyTestDataset(img_path_file, ir_path_file), batch_size=1, shuffle=False, num_workers=1)

    if is_cuda:
        model.cuda()

    for i, (img, ir) in enumerate(tqdm(testloader)):

        if is_cuda: