def main(): random.seed(9999) torch.manual_seed(9999) cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, default='F:\\sintel', help='base folder of datasets') parser.add_argument('--split', type=str, default='SceneSplit') parser.add_argument('--mode', type=str, default='train') parser.add_argument( '--save_path', type=str, default= 'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_SceneSplit\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--save_model', type=bool, default=True, help='whether to save model or not') parser.add_argument('--num_epochs', type=int, default=240) parser.add_argument('--batch_size', type=int, default=16) parser.add_argument('--checkpoint', type=bool, default=False) parser.add_argument('--state_dict', type=str, default='composer_state.t7') args = parser.parse_args() check_folder(args.save_path) # pylint: disable=E1101 device = torch.device("cuda: 1" if torch.cuda.is_available() else 'cpu') # pylint: disable=E1101 shader = RIN.Shader(output_ch=3) reflection = RIN.Decomposer() composer = RIN.Composer(reflection, shader).to(device) # shader = RIN.Shader() # decomposer = RIN.Decomposer() # composer = RIN.Composer(decomposer, shader).to(device) # RIN.init_weights(composer, init_type='kaiming') MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt' MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt' MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt' if args.split == 'ImageSplit': train_txt = MPI_Image_Split_train_txt test_txt = MPI_Image_Split_test_txt print('Image split mode') else: train_txt = MPI_Scene_Split_train_txt test_txt = MPI_Scene_Split_test_txt print('Scene split mode') cur_epoch = 0 if args.checkpoint: composer.load_state_dict( torch.load(os.path.join(args.save_path, args.state_dict))) print('load checkpoint success!') cur_epoch = cur_epoch = int( args.state_dict.split('_')[-1].split('.')[0]) + 1 # train_transform = RIN_pipeline.MPI_Train_Agumentation() # train_set = RIN_pipeline.MPI_Dataset(args.data_path, mode=args.mode[0], transform=train_transform) # train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True) # test_transform = RIN_pipeline.MPI_Test_Agumentation() # test_set = RIN_pipeline.MPI_Dataset(args.data_path, mode=args.mode[1], transform=test_transform) # test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=True) train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True) test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=True) writer = SummaryWriter(log_dir=args.save_path) trainer = RIN_pipeline.MPI_TrainerOriginRemoveShape( composer, train_loader, args.lr, device, writer) best_average_loss = 9999 for epoch in range(cur_epoch, args.num_epochs): print('<Main> Epoch {}'.format(epoch)) step = trainer.train() if (epoch + 1) % 120 == 0: args.lr = args.lr * 0.75 trainer.update_lr(args.lr) if args.save_model: state = composer.state_dict() torch.save( state, os.path.join(args.save_path, 'composer_state_{}.t7'.format(epoch))) albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_remove_shape( composer, test_loader, device) average_loss = (albedo_test_loss + shading_test_loss) / 2 writer.add_scalar('A_mse', albedo_test_loss, epoch) writer.add_scalar('S_mse', shading_test_loss, epoch) writer.add_scalar('aver_mse', average_loss, epoch) with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f: f.write( 'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n' .format(epoch, average_loss, albedo_test_loss, shading_test_loss)) if average_loss < best_average_loss: best_average_loss = average_loss if args.save_model: state = composer.state_dict() torch.save( state, os.path.join(args.save_path, 'composer_state_{}.t7'.format(epoch))) #RIN_pipeline.visualize_MPI(composer, test_loader, device, os.path.join(args.save_path, 'image_{}.png'.format(epoch))) with open(os.path.join(args.save_path, 'loss.txt'), 'a+') as f: f.write( 'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n' .format(epoch, average_loss, albedo_test_loss, shading_test_loss))
def main(): cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--split', type=str, default='SceneSplit') parser.add_argument( '--save_path', type=str, default= 'MPI_log_paper\\GAN_RIID_updateLR3_epoch100_CosbfVGG_SceneSplit_refl-se-skip_shad-se-low_multi_new_shadSqueeze_256_Reduction2\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--refl_checkpoint', type=str, default='refl_checkpoint') parser.add_argument('--shad_checkpoint', type=str, default='shad_checkpoint') parser.add_argument('--state_dict_refl', type=str, default='composer_reflectance_state_25.t7') parser.add_argument('--state_dict_shad', type=str, default='composer_shading_state_60.t7') parser.add_argument('--refl_skip_se', type=StrToBool, default=True) parser.add_argument('--shad_skip_se', type=StrToBool, default=True) parser.add_argument('--refl_low_se', type=StrToBool, default=False) parser.add_argument('--shad_low_se', type=StrToBool, default=True) parser.add_argument('--refl_multi_size', type=StrToBool, default=True) parser.add_argument('--shad_multi_size', type=StrToBool, default=True) parser.add_argument('--refl_detach_flag', type=StrToBool, default=False) parser.add_argument('--shad_detach_flag', type=StrToBool, default=False) parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=True) parser.add_argument('--refl_reduction', type=StrToInt, default=2) parser.add_argument('--shad_reduction', type=StrToInt, default=2) parser.add_argument('--cuda', type=str, default='cuda') parser.add_argument('--fullsize', type=StrToBool, default=True) args = parser.parse_args() device = torch.device(args.cuda) reflectance = RIN.SEDecomposerSingle( multi_size=args.refl_multi_size, low_se=args.refl_low_se, skip_se=args.refl_skip_se, detach=args.refl_detach_flag, reduction=args.refl_reduction).to(device) shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size, low_se=args.shad_low_se, skip_se=args.shad_skip_se, se_squeeze=args.shad_squeeze_flag, reduction=args.shad_reduction, detach=args.shad_detach_flag).to(device) reflectance.load_state_dict( torch.load( os.path.join(args.save_path, args.refl_checkpoint, args.state_dict_refl))) shading.load_state_dict( torch.load( os.path.join(args.save_path, args.shad_checkpoint, args.state_dict_shad))) print('load checkpoint success!') composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size, args.shad_multi_size).to(device) if args.fullsize: print('test fullsize....') MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt' h, w = 436, 1024 pad_h, pad_w = clc_pad(h, w, 16) print(pad_h, pad_w) tmp_pad = nn.ReflectionPad2d((0, pad_w, 0, pad_h)) tmp_inversepad = nn.ReflectionPad2d((0, -pad_w, 0, -pad_h)) else: print('test size256....') MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt' if args.split == 'ImageSplit': test_txt = MPI_Image_Split_test_txt print('Image split mode') else: test_txt = MPI_Scene_Split_test_txt print('Scene split mode') test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=False) if args.fullsize: check_folder(os.path.join(args.save_path, "refl_target_fullsize")) check_folder(os.path.join(args.save_path, "refl_output_fullsize")) check_folder(os.path.join(args.save_path, "shad_target_fullsize")) check_folder(os.path.join(args.save_path, "shad_output_fullsize")) else: check_folder(os.path.join(args.save_path, "refl_target")) check_folder(os.path.join(args.save_path, "shad_target")) check_folder(os.path.join(args.save_path, "refl_output")) check_folder(os.path.join(args.save_path, "shad_output")) ToPIL = transforms.ToPILImage() composer.eval() with torch.no_grad(): for ind, tensors in enumerate(test_loader): print(ind) inp = [t.to(device) for t in tensors] input_g, albedo_g, shading_g, mask_g = inp if args.fullsize: input_g = tmp_pad(input_g) if args.refl_multi_size and args.shad_multi_size: albedo_fake, shading_fake, _, _ = composer.forward(input_g) elif args.refl_multi_size or args.shad_multi_size: albedo_fake, shading_fake, _ = composer.forward(input_g) else: albedo_fake, shading_fake = composer.forward(input_g) if args.fullsize: albedo_fake, shading_fake = tmp_inversepad( albedo_fake), tmp_inversepad(shading_fake) albedo_fake = albedo_fake * mask_g # lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1,2,0) # lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1,2,0) # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0) # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0) albedo_fake = albedo_fake.cpu().clamp(0, 1) shading_fake = shading_fake.cpu().clamp(0, 1) albedo_g = albedo_g.cpu().clamp(0, 1) shading_g = shading_g.cpu().clamp(0, 1) lab_refl_targ = ToPIL(albedo_g.squeeze()) lab_sha_targ = ToPIL(shading_g.squeeze()) refl_pred = ToPIL(albedo_fake.squeeze()) sha_pred = ToPIL(shading_fake.squeeze()) lab_refl_targ.save( os.path.join( args.save_path, "refl_target_fullsize" if args.fullsize else "refl_target", "{}.png".format(ind))) lab_sha_targ.save( os.path.join( args.save_path, "shad_target_fullsize" if args.fullsize else "shad_target", "{}.png".format(ind))) refl_pred.save( os.path.join( args.save_path, "refl_output_fullsize" if args.fullsize else "refl_output", "{}.png".format(ind))) sha_pred.save( os.path.join( args.save_path, "shad_output_fullsize" if args.fullsize else "shad_output", "{}.png".format(ind)))
def main(): random.seed(520) torch.manual_seed(520) torch.cuda.manual_seed(520) np.random.seed(520) cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, default='F:\\sintel', help='base folder of datasets') parser.add_argument('--split', type=str, default='SceneSplit') parser.add_argument('--mode', type=str, default='two') parser.add_argument('--save_path', type=str, default='MPI_logs_new\\RIID_new_RIN_updateLR1_epoch240_CosBF_VGG0.1_shading_SceneSplit_selayer1_reflmultiSize\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--save_model', type=bool, default=True, help='whether to save model or not') parser.add_argument('--num_epochs', type=int, default=120) parser.add_argument('--batch_size', type=int, default=20) parser.add_argument('--checkpoint', type=bool, default=False) parser.add_argument('--state_dict', type=str, default='composer_state.t7') parser.add_argument('--cuda', type=str, default='cuda') parser.add_argument('--choose', type=str, default='refl') parser.add_argument('--refl_skip_se', type=StrToBool, default=False) parser.add_argument('--shad_skip_se', type=StrToBool, default=False) parser.add_argument('--refl_low_se', type=StrToBool, default=False) parser.add_argument('--shad_low_se', type=StrToBool, default=False) parser.add_argument('--refl_multi_size', type=StrToBool, default=False) parser.add_argument('--shad_multi_size', type=StrToBool, default=False) parser.add_argument('--refl_vgg_flag', type=StrToBool, default=False) parser.add_argument('--shad_vgg_flag', type=StrToBool, default=False) parser.add_argument('--refl_bf_flag', type=StrToBool, default=False) parser.add_argument('--shad_bf_flag', type=StrToBool, default=False) parser.add_argument('--refl_cos_flag', type=StrToBool, default=False) parser.add_argument('--shad_cos_flag', type=StrToBool, default=False) parser.add_argument('--image_size', type=StrToInt, default=256) args = parser.parse_args() check_folder(args.save_path) # pylint: disable=E1101 device = torch.device(args.cuda) # pylint: disable=E1101 reflectance = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size, low_se=args.refl_low_se, skip_se=args.refl_skip_se).to(device) shading = RIN.SEDecomposerSingle(multi_size=args.shad_multi_size, low_se=args.shad_low_se, skip_se=args.shad_skip_se).to(device) composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size, args.shad_multi_size).to(device) MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt' MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt' MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt' if args.split == 'ImageSplit': train_txt = MPI_Image_Split_train_txt test_txt = MPI_Image_Split_test_txt print('Image split mode') else: train_txt = MPI_Scene_Split_train_txt test_txt = MPI_Scene_Split_test_txt print('Scene split mode') cur_epoch = 0 if args.checkpoint: composer.load_state_dict(torch.load(os.path.join(args.save_path, args.state_dict))) print('load checkpoint success!') train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True) test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False) writer = SummaryWriter(log_dir=args.save_path) trainer = RIN_pipeline.OctaveTrainer(composer, train_loader, device, writer, args) best_albedo_loss = 9999 best_shading_loss = 9999 for epoch in range(cur_epoch, args.num_epochs): print('<Main> Epoch {}'.format(epoch)) trainer.train() if (epoch + 1) % 40 == 0: args.lr = args.lr * 0.75 trainer.update_lr(args.lr) albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size) average_loss = (albedo_test_loss + shading_test_loss) / 2 writer.add_scalar('A_mse', albedo_test_loss, epoch) writer.add_scalar('S_mse', shading_test_loss, epoch) writer.add_scalar('aver_mse', average_loss, epoch) with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f: f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss)) if albedo_test_loss < best_albedo_loss: best_albedo_loss = albedo_test_loss if args.save_model: state = composer.reflectance.state_dict() torch.save(state, os.path.join(args.save_path, 'composer_reflectance_state.t7')) with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f: f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss)) if shading_test_loss < best_shading_loss: best_shading_loss = shading_test_loss if args.save_model: state = composer.shading.state_dict() torch.save(state, os.path.join(args.save_path, 'composer_shading_state.t7')) with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f: f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))
def main(): random.seed(9999) torch.manual_seed(9999) cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, default='F:\\MPI-sintel-PyrResNet\\Sintel\\images\\', help='base folder of datasets') parser.add_argument('--split', type=str, default='SceneSplit') parser.add_argument('--mode', type=str, default='train') parser.add_argument('--save_path', type=str, default='MPI_logs\\RIID_new_RIN_updateLR1_epoch240_CosBF_VGG0.1_shading_SceneSplit_GAN_selayer1_ReflMultiSize_DA\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--save_model', type=bool, default=True, help='whether to save model or not') parser.add_argument('--num_epochs', type=int, default=1000) parser.add_argument('--batch_size', type=int, default=20) parser.add_argument('--checkpoint', type=bool, default=False) parser.add_argument('--state_dict', type=str, default='composer_state.t7') parser.add_argument('--skip_se', type=StrToBool, default=False) parser.add_argument('--cuda', type=str, default='cuda:1') parser.add_argument('--dilation', type=StrToBool, default=False) parser.add_argument('--se_improved', type=StrToBool, default=False) parser.add_argument('--weight_decay', type=float, default=0.0001) parser.add_argument('--refl_multi_size', type=bool, default=False) parser.add_argument('--shad_multi_size', type=bool, default=False) parser.add_argument('--data_augmentation', type=bool, default=True) args = parser.parse_args() check_folder(args.save_path) # pylint: disable=E1101 device = torch.device(args.cuda) # pylint: disable=E1101 # shader = RIN.Shader(output_ch=3) print(args.skip_se) Generator_R = RIN.SESingleGenerator(multi_size=args.refl_multi_size).to(device) Generator_S = RIN.SESingleGenerator(multi_size=args.shad_multi_size).to(device) composer = RIN.SEComposerGenerater(Generator_R, Generator_S, args.refl_multi_size, args.shad_multi_size).to(device) Discriminator_R = RIN.SEUG_Discriminator().to(device) Discriminator_S = RIN.SEUG_Discriminator().to(device) # composer = RIN.Composer(reflection, shader).to(device) MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt' MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt' if args.data_augmentation: MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-train.txt' else: MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt' if args.split == 'ImageSplit': train_txt = MPI_Image_Split_train_txt test_txt = MPI_Image_Split_test_txt print('Image split mode') else: train_txt = MPI_Scene_Split_train_txt test_txt = MPI_Scene_Split_test_txt print('Scene split mode') cur_epoch = 0 if args.checkpoint: composer.load_state_dict(torch.load(os.path.join(args.save_path, args.state_dict))) print('load checkpoint success!') # cur_epoch = int(args.state_dict.split('_')[-1].split('.')[0]) + 1 if args.data_augmentation: train_transform = RIN_pipeline.MPI_Train_Agumentation() train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt, transform=train_transform if args.data_augmentation else None, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True) test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False) writer = SummaryWriter(log_dir=args.save_path) trainer = RIN_pipeline.SEUGTrainer(composer, Discriminator_R, Discriminator_S, train_loader, args.lr, device, writer, weight_decay=args.weight_decay, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size) best_albedo_loss = 9999 best_shading_loss = 9999 for epoch in range(cur_epoch, args.num_epochs): print('<Main> Epoch {}'.format(epoch)) step = trainer.train() if (epoch + 1) % 100 == 0: args.lr = args.lr * 0.75 trainer.update_lr(args.lr) albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet(composer, test_loader, device, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size) average_loss = (albedo_test_loss + shading_test_loss) / 2 writer.add_scalar('A_mse', albedo_test_loss, epoch) writer.add_scalar('S_mse', shading_test_loss, epoch) writer.add_scalar('aver_mse', average_loss, epoch) with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f: f.write('epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n'.format(epoch, average_loss, albedo_test_loss, shading_test_loss)) if albedo_test_loss < best_albedo_loss: best_albedo_loss = albedo_test_loss if args.save_model: state = composer.reflectance.state_dict() torch.save(state, os.path.join(args.save_path, 'composer_reflectance_state.t7')) with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f: f.write('epoch{} --- albedo_loss:{}\n'.format(epoch, albedo_test_loss)) if shading_test_loss < best_shading_loss: best_shading_loss = shading_test_loss if args.save_model: state = composer.shading.state_dict() torch.save(state, os.path.join(args.save_path, 'composer_shading_state.t7')) with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f: f.write('epoch{} --- shading_loss:{}\n'.format(epoch, shading_test_loss))
def main(): random.seed(9999) torch.manual_seed(9999) cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--data_path', type=str, default='F:\\sintel', help='base folder of datasets') parser.add_argument('--split', type=str, default='SceneSplit') parser.add_argument('--mode', type=str, default='one') parser.add_argument('--choose', type=str, default='refl') parser.add_argument( '--save_path', type=str, default= 'MPI_logs\\RIID_mish_ranger_Octave_one_bs10_updateLR1_epoch240_CosBF_VGG0.1_refl_SceneSplit\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--save_model', type=bool, default=True, help='whether to save model or not') parser.add_argument('--num_epochs', type=int, default=240) parser.add_argument('--batch_size', type=StrToInt, default=10) parser.add_argument('--checkpoint', type=bool, default=True) parser.add_argument('--state_dict', type=str, default='composer_state.t7') args = parser.parse_args() check_folder(args.save_path) # pylint: disable=E1101 device = torch.device("cuda: 1" if torch.cuda.is_available() else 'cpu') # pylint: disable=E1101 # shader = RIN.Shader(output_ch=3) composer = octave.get_mish_model_one_output().to(device) # composer = RIN.Composer(reflection, shader).to(device) MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt' MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt' MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt' if args.split == 'ImageSplit': train_txt = MPI_Image_Split_train_txt test_txt = MPI_Image_Split_test_txt print('Image split mode') else: train_txt = MPI_Scene_Split_train_txt test_txt = MPI_Scene_Split_test_txt print('Scene split mode') cur_epoch = 119 if args.checkpoint: composer.load_state_dict( torch.load(os.path.join(args.save_path, args.state_dict))) print('load checkpoint success!') # cur_epoch = int(args.state_dict.split('_')[-1].split('.')[0]) + 1 train_set = RIN_pipeline.MPI_Dataset_Revisit(train_txt) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True) test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=False) writer = SummaryWriter(log_dir=args.save_path) trainer = RIN_pipeline.OctaveTrainer(composer, train_loader, args.lr, device, writer, mode=args.mode, choose=args.choose) best_average_loss = 0.01970498738396499 for epoch in range(cur_epoch, args.num_epochs): print('<Main> Epoch {}'.format(epoch)) step = trainer.train() if (epoch + 1) % 120 == 0: args.lr = args.lr * 0.75 trainer.update_lr(args.lr) if args.mode == 'two': albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet( composer, test_loader, device) average_loss = (albedo_test_loss + shading_test_loss) / 2 writer.add_scalar('A_mse', albedo_test_loss, epoch) writer.add_scalar('S_mse', shading_test_loss, epoch) writer.add_scalar('aver_mse', average_loss, epoch) with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f: f.write( 'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n' .format(epoch, average_loss, albedo_test_loss, shading_test_loss)) else: average_loss = RIN_pipeline.MPI_test_unet_one(composer, test_loader, device, choose=args.choose) writer.add_scalar('test_mse', average_loss, epoch) with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f: f.write('epoch{} --- test_loss: {}\n'.format( epoch, average_loss)) if average_loss < best_average_loss: best_average_loss = average_loss if args.save_model: state = composer.state_dict() torch.save(state, os.path.join(args.save_path, 'composer_state.t7')) with open(os.path.join(args.save_path, 'loss.txt'), 'a+') as f: if args.mode == 'two': f.write( 'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n' .format(epoch, average_loss, albedo_test_loss, shading_test_loss)) else: f.write('epoch{} --- average_loss: {}\n'.format( epoch, average_loss))
def main(): random.seed(9999) torch.manual_seed(9999) cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--split', type=str, default='ImageSplit') parser.add_argument('--mode', type=str, default='test') parser.add_argument( '--save_path', type=str, default= 'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_epoch240_ImageSplit_size256_remove_shape\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--state_dict', type=str, default='composer_state_222.t7') args = parser.parse_args() # pylint: disable=E1101 device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') # pylint: disable=E1101 shader = RIN.Shader(output_ch=3) reflection = RIN.Decomposer() composer = RIN.Composer(reflection, shader).to(device) # RIN.init_weights(composer, init_type='kaiming') MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt' if args.split == 'ImageSplit': test_txt = MPI_Image_Split_test_txt print('Image split mode') else: test_txt = MPI_Scene_Split_test_txt print('Scene split mode') composer.load_state_dict( torch.load(os.path.join(args.save_path, args.state_dict))) print('load checkpoint success!') test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=False) check_folder(os.path.join(args.save_path, "refl_target")) check_folder(os.path.join(args.save_path, "shad_target")) check_folder(os.path.join(args.save_path, "refl_output")) check_folder(os.path.join(args.save_path, "shad_output")) # check_folder(os.path.join(args.save_path, "shape_output")) check_folder(os.path.join(args.save_path, "mask")) composer.eval() with torch.no_grad(): for ind, tensors in enumerate(test_loader): print(ind) inp = [t.to(device) for t in tensors] input_g, albedo_g, shading_g, mask_g = inp _, albedo_fake, shading_fake = composer.forward(input_g) albedo_fake = albedo_fake * mask_g lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1, 2, 0) lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1, 2, 0) mask = mask_g.squeeze().cpu().numpy().transpose(1, 2, 0) refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1, 2, 0) sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1, 2, 0) lab_refl_targ = np.clip(lab_refl_targ, 0, 1) lab_sha_targ = np.clip(lab_sha_targ, 0, 1) refl_pred = np.clip(refl_pred, 0, 1) sha_pred = np.clip(sha_pred, 0, 1) mask = np.clip(mask, 0, 1) scipy.misc.imsave( os.path.join(args.save_path, "refl_target", "{}.png".format(ind)), lab_refl_targ) scipy.misc.imsave( os.path.join(args.save_path, "shad_target", "{}.png".format(ind)), lab_sha_targ) scipy.misc.imsave( os.path.join(args.save_path, "mask", "{}.png".format(ind)), mask) scipy.misc.imsave( os.path.join(args.save_path, "refl_output", "{}.png".format(ind)), refl_pred) scipy.misc.imsave( os.path.join(args.save_path, "shad_output", "{}.png".format(ind)), sha_pred)
def main(): cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--split', type=str, default='SceneSplit') parser.add_argument( '--save_path', type=str, default= 'MPI_logs_new\\GAN_RIID_updateLR3_epoch160_CosbfVGG_SceneSplit_refl-se-skip_shad-se-low_multi_new_shadSqueeze_grad\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--refl_checkpoint', type=str, default='refl_checkpoint') parser.add_argument('--shad_checkpoint', type=str, default='shad_checkpoint') parser.add_argument('--state_dict_refl', type=str, default='composer_reflectance_state_81.t7') parser.add_argument('--state_dict_shad', type=str, default='composer_shading_state_81.t7') parser.add_argument('--refl_skip_se', type=StrToBool, default=False) parser.add_argument('--shad_skip_se', type=StrToBool, default=False) parser.add_argument('--refl_low_se', type=StrToBool, default=False) parser.add_argument('--shad_low_se', type=StrToBool, default=False) parser.add_argument('--refl_multi_size', type=StrToBool, default=False) parser.add_argument('--shad_multi_size', type=StrToBool, default=False) parser.add_argument('--refl_detach_flag', type=StrToBool, default=False) parser.add_argument('--shad_detach_flag', type=StrToBool, default=False) parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False) parser.add_argument('--refl_reduction', type=StrToInt, default=8) parser.add_argument('--shad_reduction', type=StrToInt, default=8) parser.add_argument('--cuda', type=str, default='cuda') parser.add_argument('--fullsize', type=StrToBool, default=True) parser.add_argument('--heatmap', type=StrToBool, default=False) args = parser.parse_args() device = torch.device(args.cuda) model = RIN.SEDecomposerSingle(multi_size=args.refl_multi_size, low_se=args.refl_low_se, skip_se=args.refl_skip_se, detach=args.refl_detach_flag, reduction=args.refl_reduction, heatmap=args.heatmap).to(device) model.load_state_dict( torch.load( os.path.join(args.save_path, args.refl_checkpoint, args.state_dict_refl))) print('load checkpoint success!') if args.fullsize: print('test fullsize....') MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt' h, w = 436, 1024 pad_h, pad_w = clc_pad(h, w, 16) print(pad_h, pad_w) tmp_pad = nn.ReflectionPad2d((0, pad_w, 0, pad_h)) tmp_inversepad = nn.ReflectionPad2d((0, -pad_w, 0, -pad_h)) tmp_inversepad_heatmap = nn.ReflectionPad2d((0, 0, 0, -3)) else: print('test size256....') MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt' if args.split == 'ImageSplit': test_txt = MPI_Image_Split_test_txt print('Image split mode') else: test_txt = MPI_Scene_Split_test_txt print('Scene split mode') test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=False) if args.fullsize: check_folder(os.path.join(args.save_path, "refl_heapmapin")) check_folder(os.path.join(args.save_path, "refl_heapmapout")) ToPIL = transforms.ToPILImage() model.eval() with torch.no_grad(): for ind, tensors in enumerate(test_loader): print(ind) inp = [t.to(device) for t in tensors] input_g, albedo_g, shading_g, mask_g = inp if args.fullsize: input_g = tmp_pad(input_g) if args.refl_multi_size: albedo_fake, _, heapmap = model.forward(input_g) else: albedo_fake = model.forward(input_g) if args.fullsize: input_g = tmp_inversepad(input_g) heapmap[0] = tmp_inversepad_heatmap(heapmap[0]) heapmap[1] = tmp_inversepad_heatmap(heapmap[1]) # albedo_fake = albedo_fake*mask_g # lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1,2,0) # lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1,2,0) # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0) # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0) print(heapmap[0].squeeze().size()) heapmap[0] = torch.sum(heapmap[0], dim=1, keepdim=True) heapmap[1] = torch.sum(heapmap[1], dim=1, keepdim=True) heapmapin = tensor2numpy(heapmap[0][0]) heapmapout = tensor2numpy(heapmap[1][0]) #heapmapout = torch.sum(heapmap[1].squeeze(), dim=0, keepdim=True).cpu().clamp(0,1).numpy().transpose(1,2,0) print(heapmapin.shape) heapmapin = cam(heapmapin) print(heapmapin.shape) heapmapout = cam(heapmapout) # heapmapin = heapmapin.transpose(2,0,1) # heapmapout = heapmapout.transpose(2,0,1) # input_g = input_g.squeeze().cpu().clamp(0, 1).numpy() # print(heapmapin.shape) # print(input_g.shape) # heapmapin = np.concatenate((heapmapin, input_g), 1).astype(np.float32) # heapmapout = np.concatenate((heapmapout, input_g), 1).astype(np.float32) # print(heapmapin.shape) # heapmapin = torch.from_numpy(heapmapin) # heapmapout = torch.from_numpy(heapmapout) # lab_refl_targ = ToPIL(input_g.squeeze()) # refl_pred = ToPIL(albedo_fake.squeeze()) # heapmapin = torch.cat([heapmapin, torch.zeros(2, h // 4, w // 4)]) # heapmapout = torch.cat([heapmapout, torch.zeros(2, h // 4, w // 4)]) # print(heapmapin.size) # print(heapmapout.size) cv2.imwrite( os.path.join(args.save_path, "refl_heapmapin", '{}.png'.format(ind)), heapmapin * 255.0) cv2.imwrite( os.path.join(args.save_path, "refl_heapmapout", '{}.png'.format(ind)), heapmapout * 255.0)
def main(): random.seed(520) torch.manual_seed(520) torch.cuda.manual_seed(520) np.random.seed(520) cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--split', type=str, default='SceneSplit') parser.add_argument('--mode', type=str, default='train') parser.add_argument( '--save_path', type=str, default='MPI_logs_vqvae\\vqvae_base_256x256\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--refl_checkpoint', type=str, default='refl_checkpoint') parser.add_argument('--shad_checkpoint', type=str, default='shad_checkpoint') parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--save_model', type=bool, default=True, help='whether to save model or not') parser.add_argument('--num_epochs', type=int, default=100) parser.add_argument('--batch_size', type=int, default=32) parser.add_argument('--checkpoint', type=StrToBool, default=False) parser.add_argument('--state_dict_refl', type=str, default='composer_reflectance_state.t7') parser.add_argument('--state_dict_shad', type=str, default='composer_shading_state.t7') parser.add_argument('--cur_epoch', type=StrToInt, default=0) parser.add_argument('--skip_se', type=StrToBool, default=False) parser.add_argument('--cuda', type=str, default='cuda') parser.add_argument('--dilation', type=StrToBool, default=False) parser.add_argument('--se_improved', type=StrToBool, default=False) parser.add_argument('--weight_decay', type=float, default=0.0001) parser.add_argument('--refl_skip_se', type=StrToBool, default=False) parser.add_argument('--shad_skip_se', type=StrToBool, default=False) parser.add_argument('--refl_low_se', type=StrToBool, default=False) parser.add_argument('--shad_low_se', type=StrToBool, default=False) parser.add_argument('--refl_multi_size', type=StrToBool, default=False) parser.add_argument('--shad_multi_size', type=StrToBool, default=False) parser.add_argument('--refl_vgg_flag', type=StrToBool, default=False) parser.add_argument('--shad_vgg_flag', type=StrToBool, default=False) parser.add_argument('--refl_bf_flag', type=StrToBool, default=False) parser.add_argument('--shad_bf_flag', type=StrToBool, default=False) parser.add_argument('--refl_cos_flag', type=StrToBool, default=False) parser.add_argument('--shad_cos_flag', type=StrToBool, default=False) parser.add_argument('--refl_grad_flag', type=StrToBool, default=False) parser.add_argument('--shad_grad_flag', type=StrToBool, default=False) parser.add_argument('--refl_detach_flag', type=StrToBool, default=False) parser.add_argument('--shad_detach_flag', type=StrToBool, default=False) parser.add_argument('--refl_D_weight_flag', type=StrToBool, default=False) parser.add_argument('--shad_D_weight_flag', type=StrToBool, default=False) parser.add_argument('--shad_squeeze_flag', type=StrToBool, default=False) parser.add_argument('--refl_reduction', type=StrToInt, default=8) parser.add_argument('--shad_reduction', type=StrToInt, default=8) parser.add_argument('--refl_bn', type=StrToBool, default=True) parser.add_argument('--shad_bn', type=StrToBool, default=True) parser.add_argument('--refl_act', type=str, default='relu') parser.add_argument('--shad_act', type=str, default='relu') # parser.add_argument('--refl_gan', type=StrToBool, default=False) # parser.add_argument('--shad_gan', type=StrToBool, default=False) parser.add_argument('--data_augmentation', type=StrToBool, default=False) parser.add_argument('--fullsize', type=StrToBool, default=False) parser.add_argument('--fullsize_test', type=StrToBool, default=False) parser.add_argument('--image_size', type=StrToInt, default=256) parser.add_argument('--ttur', type=StrToBool, default=False) parser.add_argument('--vae', type=StrToBool, default=True) parser.add_argument('--vq_flag', type=StrToBool, default=True) args = parser.parse_args() check_folder(args.save_path) check_folder(os.path.join(args.save_path, args.refl_checkpoint)) check_folder(os.path.join(args.save_path, args.shad_checkpoint)) # pylint: disable=E1101 device = torch.device(args.cuda) # pylint: disable=E1101 reflectance = RIN.VQVAE(vq_flag=args.vq_flag).to(device) shading = RIN.VQVAE(vq_flag=args.vq_flag).to(device) cur_epoch = 0 if args.checkpoint: reflectance.load_state_dict( torch.load( os.path.join(args.save_path, args.refl_checkpoint, args.state_dict_refl))) shading.load_state_dict( torch.load( os.path.join(args.save_path, args.shad_checkpoint, args.state_dict_shad))) cur_epoch = args.cur_epoch print('load checkpoint success!') composer = RIN.SEComposer(reflectance, shading, args.refl_multi_size, args.shad_multi_size).to(device) # if not args.ttur: # Discriminator_R = RIN.SEUG_Discriminator().to(device) # Discriminator_S = RIN.SEUG_Discriminator().to(device) # else: # Discriminator_R = RIN.SEUG_Discriminator_new().to(device) # Discriminator_S = RIN.SEUG_Discriminator_new().to(device) if args.data_augmentation: print('data_augmentation.....') if args.fullsize: MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-train.txt' MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-train.txt' else: MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-320-train.txt' MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-320-train.txt' else: MPI_Scene_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-train.txt' MPI_Image_Split_train_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-train.txt' if args.fullsize_test: MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt' else: MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-256-test.txt' if args.fullsize_test: MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt' else: MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-256-test.txt' if args.split == 'ImageSplit': train_txt = MPI_Image_Split_train_txt test_txt = MPI_Image_Split_test_txt print('Image split mode') else: train_txt = MPI_Scene_Split_train_txt test_txt = MPI_Scene_Split_test_txt print('Scene split mode') if args.data_augmentation: print('augmentation...') train_transform = RIN_pipeline.MPI_Train_Agumentation_fy2() train_set = RIN_pipeline.MPI_Dataset_Revisit( train_txt, transform=train_transform if args.data_augmentation else None, refl_multi_size=args.refl_multi_size, shad_multi_size=args.shad_multi_size, image_size=args.image_size) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, num_workers=args.loader_threads, shuffle=True) test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=False) if args.mode == 'test': print('test mode .....') albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet( composer, test_loader, device, args) print('albedo_test_loss: ', albedo_test_loss) print('shading_test_loss: ', shading_test_loss) return writer = SummaryWriter(log_dir=args.save_path) if not args.ttur: trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device, writer, args) else: trainer = RIN_pipeline.VQVAETrainer(composer, train_loader, device, writer, args) best_albedo_loss = 9999 best_shading_loss = 9999 for epoch in range(cur_epoch, args.num_epochs): print('<Main> Epoch {}'.format(epoch)) trainer.train() if (epoch + 1) % 40 == 0: args.lr = args.lr * 0.75 trainer.update_lr(args.lr) # if (epoch + 1) % 10 == 0: albedo_test_loss, shading_test_loss = RIN_pipeline.MPI_test_unet( composer, test_loader, device, args) average_loss = (albedo_test_loss + shading_test_loss) / 2 writer.add_scalar('A_mse', albedo_test_loss, epoch) writer.add_scalar('S_mse', shading_test_loss, epoch) writer.add_scalar('aver_mse', average_loss, epoch) with open(os.path.join(args.save_path, 'loss_every_epoch.txt'), 'a+') as f: f.write( 'epoch{} --- average_loss: {}, albedo_loss:{}, shading_loss:{}\n' .format(epoch, average_loss, albedo_test_loss, shading_test_loss)) if albedo_test_loss < best_albedo_loss: state = composer.reflectance.state_dict() torch.save( state, os.path.join(args.save_path, args.refl_checkpoint, 'composer_reflectance_state_{}.t7'.format(epoch))) best_albedo_loss = albedo_test_loss with open(os.path.join(args.save_path, 'reflectance_loss.txt'), 'a+') as f: f.write('epoch{} --- albedo_loss:{}\n'.format( epoch, albedo_test_loss)) if shading_test_loss < best_shading_loss: best_shading_loss = shading_test_loss state = composer.shading.state_dict() torch.save( state, os.path.join(args.save_path, args.shad_checkpoint, 'composer_shading_state_{}.t7'.format(epoch))) with open(os.path.join(args.save_path, 'shading_loss.txt'), 'a+') as f: f.write('epoch{} --- shading_loss:{}\n'.format( epoch, shading_test_loss))
def main(): random.seed(9999) torch.manual_seed(9999) cudnn.benchmark = True parser = argparse.ArgumentParser() parser.add_argument('--split', type=str, default='SceneSplit') parser.add_argument('--mode', type=str, default='test') parser.add_argument( '--save_path', type=str, default= 'MPI_logs\\RIID_origin_RIN_updateLR_CosBF_VGG0.1_shading_SceneSplit\\', help='save path of model, visualizations, and tensorboard') parser.add_argument('--loader_threads', type=float, default=8, help='number of parallel data-loading threads') parser.add_argument('--state_dict', type=str, default='composer_state_179.t7') args = parser.parse_args() # pylint: disable=E1101 device = torch.device("cuda" if torch.cuda.is_available() else 'cpu') # pylint: disable=E1101 shader = RIN.Shader(output_ch=3) reflection = RIN.Decomposer() composer = RIN.Composer(reflection, shader).to(device) MPI_Image_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_imageSplit-fullsize-ChenSplit-test.txt' MPI_Scene_Split_test_txt = 'D:\\fangyang\\intrinsic_by_fangyang\\MPI_TXT\\MPI_main_sceneSplit-fullsize-NoDefect-test.txt' if args.split == 'ImageSplit': test_txt = MPI_Image_Split_test_txt print('Image split mode') else: test_txt = MPI_Scene_Split_test_txt print('Scene split mode') composer.load_state_dict( torch.load(os.path.join(args.save_path, args.state_dict))) print('load checkpoint success!') test_set = RIN_pipeline.MPI_Dataset_Revisit(test_txt) test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=args.loader_threads, shuffle=False) check_folder(os.path.join(args.save_path, "refl_target_fullsize")) check_folder(os.path.join(args.save_path, "shad_target_fullsize")) check_folder(os.path.join(args.save_path, "refl_output_fullsize")) check_folder(os.path.join(args.save_path, "shad_output_fullsize")) check_folder(os.path.join(args.save_path, "shape_output_fullsize")) check_folder(os.path.join(args.save_path, "mask_fullsize")) composer.eval() with torch.no_grad(): for ind, tensors in enumerate(test_loader): print(ind) inp = [t.to(device) for t in tensors] input_g, albedo_g, shading_g, mask_g = inp lab_refl_pred = np.zeros( (input_g.size()[2], input_g.size()[3], input_g.size()[1])) lab_sha_pred = np.zeros_like(lab_refl_pred) lab_shape_pred = np.zeros_like(lab_refl_pred) for ind2 in range(8): if ind2 % 2 == 0: input_g_tmp = input_g[:, :, :256, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256] mask_g_tmp = mask_g[:, :, :256, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256] _, albedo_fake, shading_fake, shape_fake = composer.forward( input_g_tmp) albedo_fake = albedo_fake * mask_g_tmp lab_refl_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256, :] += albedo_fake.squeeze().cpu().numpy( ).transpose(1, 2, 0) lab_sha_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256, :] += shading_fake.squeeze().cpu().numpy( ).transpose(1, 2, 0) lab_shape_pred[:256, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256, :] += shape_fake.squeeze().cpu().numpy( ).transpose(1, 2, 0) else: input_g_tmp = input_g[:, :, 180:, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256] mask_g_tmp = mask_g[:, :, 180:, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256] _, albedo_fake, shading_fake, shape_fake = composer.forward( input_g_tmp) albedo_fake = albedo_fake * mask_g_tmp lab_refl_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256, :] += albedo_fake.squeeze().cpu().numpy( ).transpose(1, 2, 0) lab_sha_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256, :] += shading_fake.squeeze().cpu().numpy( ).transpose(1, 2, 0) lab_shape_pred[180:, (ind2 // 2) * 256:(ind2 // 2 + 1) * 256, :] += shape_fake.squeeze().cpu().numpy( ).transpose(1, 2, 0) lab_refl_pred[180:256, :, :] /= 2 lab_sha_pred[180:256, :, :] /= 2 lab_shape_pred[180:256, :, :] /= 2 lab_refl_targ = albedo_g.squeeze().cpu().numpy().transpose(1, 2, 0) lab_sha_targ = shading_g.squeeze().cpu().numpy().transpose(1, 2, 0) mask = mask_g.squeeze().cpu().numpy().transpose(1, 2, 0) # refl_pred = albedo_fake.squeeze().cpu().numpy().transpose(1,2,0) # sha_pred = shading_fake.squeeze().cpu().numpy().transpose(1,2,0) # shape_pred = shape_fake.squeeze().cpu().numpy().transpose(1,2,0) refl_pred = lab_refl_pred sha_pred = lab_sha_pred shape_pred = lab_shape_pred lab_refl_targ = np.clip(lab_refl_targ, 0, 1) lab_sha_targ = np.clip(lab_sha_targ, 0, 1) refl_pred = np.clip(refl_pred, 0, 1) sha_pred = np.clip(sha_pred, 0, 1) shape_pred = np.clip(shape_pred, 0, 1) mask = np.clip(mask, 0, 1) scipy.misc.imsave( os.path.join(args.save_path, "refl_target_fullsize", "{}.png".format(ind)), lab_refl_targ) scipy.misc.imsave( os.path.join(args.save_path, "shad_target_fullsize", "{}.png".format(ind)), lab_sha_targ) scipy.misc.imsave( os.path.join(args.save_path, "mask_fullsize", "{}.png".format(ind)), mask) scipy.misc.imsave( os.path.join(args.save_path, "refl_output_fullsize", "{}.png".format(ind)), refl_pred) scipy.misc.imsave( os.path.join(args.save_path, "shad_output_fullsize", "{}.png".format(ind)), sha_pred) scipy.misc.imsave( os.path.join(args.save_path, "shape_output_fullsize", "{}.png".format(ind)), shape_pred)