def __init__(self): opt = TestOptions().parse() # get test options # init pygame pygame.init() self.size = (256, 256) self.screen = pygame.display.set_mode(self.size) self.font = pygame.font.SysFont(pygame.font.get_fonts()[0], 64) self.time = pygame.time.get_ticks() #self.surface_test = pygame.surfarray.make_surface() self.screen.fill(pygame.Color(255, 255, 255)) pygame.display.flip() self.model = CycleGANModel(opt) self.model.setup(opt) #norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) #net = ResnetGenerator(256, 256, 64, norm_layer=norm_layer, use_dropout=False, n_blocks=9) #self.net = init_net(net, 'normal', 0.02, []) impath = os.getcwd() + "/datasets/bird/testA/514.png" image = pygame.image.load(impath)
def main(): opt = ArchTrainOptions().parse() torch.cuda.manual_seed(12345) opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name) dataset = create_dataset( opt) # create a dataset given opt.dataset_mode and other options print('The number of training images = %d' % len(dataset)) cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) cycle_gan.set_arch(opt.arch, opt.n_resnet - 1) writer_dict = { "writer": SummaryWriter(opt.path_helper['log_path']), 'train_steps': 0 } # for i, data in tqdm(enumerate(dataset)): # cycle_gan.set_input(data) # cycle_gan.forward() # cycle_gan.compute_visuals() # save_current_results(opt, cycle_gan.get_current_visuals(), i) cyclgan_train(opt, cycle_gan, dataset, writer_dict)
def gather_options(self): """Initialize our parser with basic options(only once). Add additional model-specific and dataset-specific options. These options are defined in the <modify_commandline_options> function in model and dataset classes. """ if not self.initialized: # check if it has been initialized parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = self.initialize(parser) # get the basic options parser.parse_known_args() # modify model-related parser options parser = CycleGANModel.modify_commandline_options(parser, self.isTrain) parser.parse_known_args() # parse again with new defaults # save and return the parser self.parser = parser return parser.parse_args()
def cyclgan_train(opt, cycle_gan: CycleGANModel, train_loader, writer_dict): cycle_gan.train() writer = writer_dict['writer'] total_iters = 0 t_data = 0.0 for epoch in trange(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 train_steps = writer_dict['train_steps'] for i, data in enumerate(train_loader): iter_start_time = time.time() if total_iters % opt.print_freq == 0: t_data = iter_start_time - iter_data_time total_iters += opt.batch_size epoch_iter += opt.batch_size cycle_gan.set_input(data) cycle_gan.optimize_parameters() if (i + 1) % opt.print_freq == 0: losses = cycle_gan.get_current_losses() t_comp = (time.time() - iter_start_time) message = "GAN: [Ep: %d/%d]" % (epoch, opt.n_epochs + opt.n_epochs_decay) message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % ( epoch_iter, len(train_loader), t_comp, t_data) for k, v in losses.items(): message += '[%s: %.3f]' % (k, v) tqdm.write(message) if (total_iters + 1) % opt.display_freq == 0: cycle_gan.compute_visuals() save_current_results(opt, cycle_gan.get_current_visuals(), train_steps) if (total_iters + 1) % opt.save_latest_freq == 0: tqdm.write( 'saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) save_suffix = 'latest' cycle_gan.save_networks(save_suffix) iter_data_time = time.time() if (epoch + 1) % opt.save_epoch_freq == 0: cycle_gan.save_networks('latest') cycle_gan.save_networks(epoch) tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) writer.add_scalars('Train/discriminator', { "A": float(cycle_gan.loss_D_A), "B": float(cycle_gan.loss_D_B), }, train_steps) writer.add_scalars('Train/generator', { "A": float(cycle_gan.loss_G_A), "B": float(cycle_gan.loss_G_B), }, train_steps) writer.add_scalars( 'Train/cycle', { "A": float(cycle_gan.loss_cycle_A), "B": float(cycle_gan.loss_cycle_B), }, train_steps) writer.add_scalars('Train/idt', { "A": float(cycle_gan.loss_idt_A), "B": float(cycle_gan.loss_idt_B), }, train_steps) writer_dict['train_steps'] += 1 cycle_gan.update_learning_rate()
class DrawingCanvas: def __init__(self): opt = TestOptions().parse() # get test options # init pygame pygame.init() self.size = (256, 256) self.screen = pygame.display.set_mode(self.size) self.font = pygame.font.SysFont(pygame.font.get_fonts()[0], 64) self.time = pygame.time.get_ticks() #self.surface_test = pygame.surfarray.make_surface() self.screen.fill(pygame.Color(255, 255, 255)) pygame.display.flip() self.model = CycleGANModel(opt) self.model.setup(opt) #norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) #net = ResnetGenerator(256, 256, 64, norm_layer=norm_layer, use_dropout=False, n_blocks=9) #self.net = init_net(net, 'normal', 0.02, []) impath = os.getcwd() + "/datasets/bird/testA/514.png" image = pygame.image.load(impath) #self.screen.blit(image, (0, 0)) """ Method 'game_loop' will be executed every frame to drive the display and handling of events in the background. In Processing this is done behind the screen. Don't change this, unless you know what you are doing. """ def game_loop(self): current_time = pygame.time.get_ticks() delta_time = current_time - self.time self.time = current_time self.handle_events() self.update_game(delta_time) self.draw_components() """ Method 'update_game' is there to update the state of variables and objects from frame to frame. """ def update_game(self, dt): pass """ Method 'draw_components' is similar is meant to contain everything that draws one frame. It is similar to method void draw() in Processing. Put all draw calls here. Leave all updates in method 'update' """ def draw_components(self): #self.screen.fill([255, 255, 255]) #pygame.display.flip() pass def reset(self): pass """ Method 'handle_event' loop over all the event types and handles them accordingly. In Processing this is done behind the screen. Don't change this, unless you know what you are doing. """ def handle_events(self): for event in pygame.event.get(): if event.type == pygame.QUIT: sys.exit() if event.type == pygame.KEYDOWN: self.handle_key_down(event) if event.type == pygame.KEYUP: self.handle_key_up(event) if event.type == pygame.MOUSEMOTION: self.handle_mouse_motion(event) if event.type == pygame.MOUSEBUTTONDOWN: self.handle_mouse_pressed(event) if event.type == pygame.MOUSEBUTTONUP: self.handle_mouse_released(event) """ This method will store a currently pressed buttons in list 'keyboard_handler.pressed'. """ def handle_key_down(self, event): pass """ This method will remove a released button from list 'keyboard_handler.pressed'. """ def handle_key_up(self, event): pass """ Similar to void mouseMoved() in Processing """ def handle_mouse_motion(self, event): #print("test: ",pygame.mouse.get_pressed()[0]) if pygame.mouse.get_pressed()[0]: pos = pygame.mouse.get_pos() pygame.display.update( pygame.draw.ellipse(self.screen, (0, 0, 0), [pos, [5, 5]])) #print(pos) self.screen.blit(self.screen, (0, 0)) """ Similar to void mousePressed() in Processing """ def handle_mouse_pressed(self, event): pos = pygame.mouse.get_pos() pygame.display.update( pygame.draw.rect(self.screen, (0, 0, 0), [pos, [5, 5]])) #(pos) self.screen.blit(self.screen, (0, 0)) """ Similar to void mouseReleased() in Processing """ def handle_mouse_released(self, event): #pygame.display.flip() test = pygame.surfarray.array3d(self.screen) print(test.shape) #test = test.T test = test.transpose(1, 0, 2) print(test.shape) #string_image = pygame.image.tostring(self.screen, 'RGBA') #temp_surf = pygame.image.fromstring(string_image, (512, 512), 'RGB') #tmp_arr = pygame.surfarray.array2d(temp_surf) compose = transforms.Compose([ transforms.ToPILImage(), #transforms.Resize(256, interpolation=Image.CUBIC), transforms.ToTensor() ]) test_tensor = compose(test).unsqueeze(0) #plt.figure() #plt.imshow(test) #plt.show() print(test_tensor.size()) #test = compose(test) #self.net.set_input(test) self.model.set_input(test_tensor) result = self.model.forward() result = self.model.get_generated() print("Result", result) resultT = result.squeeze(0) resultT[resultT < 0] = 0 im = transforms.ToPILImage()(resultT).convert("RGB") test = result.squeeze(0) print(test.size()) result = result.detach().numpy() result = np.squeeze(result, axis=0) result = result.transpose(1, 2, 0) print(result.shape) results = result[:] * 255 #results[result < 0] print(im) print(im.size) plt.imshow(im) plt.show() def lab2rgb(self, L, AB): """Convert an Lab tensor image to a RGB numpy output Parameters: L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array) AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array) Returns: rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array) """ AB2 = AB * 110.0 L2 = (L + 1.0) * 50.0 Lab = torch.cat([L2, AB2], dim=1) Lab = Lab[0].data.cpu().float().numpy() Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) rgb = color.lab2rgb(Lab) * 255 return rgb
from models.cycle_gan_with_distillation import CycleGANModelWithDistillation from models.cycle_gan_model import CycleGANModel if __name__ == '__main__': opt = TrainOptions().parse() # get test options # hard-code some parameters for test opt.num_threads = 0 # test code only supports num_threads = 1 opt.batch_size = 1 # test code only supports batch_size = 1 opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. opt.no_flip = True # no flip; comment this line if results on flipped images are needed. opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. dataset = create_dataset( opt) # create a dataset given opt.dataset_mode and other options opt2 = Namespace(**vars(opt)) opt2.name = 'monet2photo_pretrained' opt2.isTrain = False teacher = CycleGANModel(opt2) teacher.isTeacher = True opt.continue_train = True opt2.continue_train = True teacher.setup( opt2) # regular setup: load and print networks; create schedulers opt.netG = 'resnet_3blocks' opt.results_dir = 'results' model = CycleGANModelWithDistillation( opt, teacher) # create a model given opt.model and other options model.setup( opt) # regular setup: load and print networks; create schedulers # create a website web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format( opt.phase, opt.epoch)) # define the website directory if opt.load_iter > 0: # load_iter is 0 by default
import numpy as np from visdom import Visdom viz = Visdom() assert viz.check_connection() viz.close() opt = TrainOptions().parse() save_opt(opt) data_loader = DataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) model = CycleGANModel() model.initialize(opt) visualizer = Visualizer(opt) if __name__ == '__main__': total_steps = 0 sparse_c_loss_points, sparse_c_loss_avr_points = [], [] win_sparse_C = viz.line(X=torch.zeros((1, )), Y=torch.zeros((1, )), name="win_sparse_C") for epoch in range(1, opt.epoch + 1): epoch_start_time = time.time() epoch_iter = 0
from utilSet import html from models.cycle_gan_model import CycleGANModel from utilSet.visualizer import Visualizer from config import TestOptions from data.dataset import DataLoader import ntpath opt = TestOptions().parse() opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip data_loader = DataLoader(opt) dataset = data_loader.load_data() model = CycleGANModel() model.initialize(opt) visualizer = Visualizer(opt) if __name__ == '__main__': root_dir = os.path.join(opt.result_root_dir, opt.variable) web_dir = os.path.join(root_dir, opt.variable_value, opt.phase) webpage = html.HTML(web_dir, 'Experiment = GAN2C, Phase = test, Epoch = latest') # test for i, data in enumerate(dataset): model.set_input(data) model.test() visuals = model.get_current_visuals() img_path = model.get_image_paths()
def main(): opt = SearchOptions().parse() torch.cuda.manual_seed(12345) _init_inception(MODEL_DIR) inception_path = check_or_download_inception(None) create_inception_graph(inception_path) start_search_iter = 0 cur_stage = 1 delta_grow_steps = [int(opt.grow_step ** i) for i in range(1, opt.max_skip_num)] + \ [int(opt.grow_step ** 3) for _ in range(1, opt.n_resnet - opt.max_skip_num + 1)] opt.max_search_iter = sum(delta_grow_steps) grow_steps = [ sum(delta_grow_steps[:i]) for i in range(len(delta_grow_steps)) ][1:] grow_ctrler = GrowCtrler(opt.grow_step, steps=grow_steps) if opt.load_path: print(f'=> resuming from {opt.load_path}') assert os.path.exists(opt.load_path) checkpoint_file = os.path.join(opt.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file, map_location={'cuda:0': 'cpu'}) # set controller && its optimizer cur_stage = checkpoint['cur_stage'] start_search_iter = checkpoint["search_iter"] opt.path_helper = checkpoint['path_helper'] cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan) cycle_gan.load_from_state(checkpoint["cycle_gan"]) cycle_controller.load_from_state(checkpoint["cycle_controller"]) else: opt.path_helper = set_log_dir(opt.checkpoints_dir, opt.name) cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) cycle_controller = CycleControllerModel(opt, cur_stage=cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan) dataset = create_dataset( opt) # create a dataset given opt.dataset_mode and other options print('The number of training images = %d' % len(dataset)) writer_dict = { "writer": SummaryWriter(opt.path_helper['log_path']), 'controller_steps': start_search_iter * opt.ctrl_step, 'train_steps': start_search_iter * opt.shared_epoch } g_loss_history = RunningStats(opt.dynamic_reset_window) d_loss_history = RunningStats(opt.dynamic_reset_window) dynamic_reset = None for search_iter in tqdm( range(int(start_search_iter), int(opt.max_search_iter))): tqdm.write(f"<start search iteration {search_iter}>") cycle_controller.reset() if search_iter in grow_steps: cur_stage = grow_ctrler.cur_stage(search_iter) + 1 tqdm.write(f'=> grow to stage {cur_stage}') prev_archs_A, prev_hiddens_A = cycle_controller.get_topk_arch_hidden_A( ) prev_archs_B, prev_hiddens_B = cycle_controller.get_topk_arch_hidden_B( ) del cycle_controller cycle_controller = CycleControllerModel(opt, cur_stage) cycle_controller.setup(opt) cycle_controller.set(cycle_gan, prev_hiddens_A, prev_hiddens_B, prev_archs_A, prev_archs_B) dynamic_reset = cyclgan_train(opt, cycle_gan, cycle_controller, dataset, g_loss_history, d_loss_history, writer_dict) controller_train(opt, cycle_gan, cycle_controller, writer_dict) if dynamic_reset: tqdm.write('re-initialize share GAN') del cycle_gan cycle_gan = CycleGANModel(opt) cycle_gan.setup(opt) save_checkpoint( { 'cur_stage': cur_stage, 'search_iter': search_iter + 1, 'cycle_gan': cycle_gan.save_networks(epoch=search_iter), 'cycle_controller': cycle_controller.save_networks(epoch=search_iter), 'path_helper': opt.path_helper }, False, opt.path_helper['ckpt_path']) final_archs_A, _ = cycle_controller.get_topk_arch_hidden_A() final_archs_B, _ = cycle_controller.get_topk_arch_hidden_B() print(f"discovered archs: {final_archs_A}") print(f"discovered archs: {final_archs_B}")
def cyclgan_train(opt, cycle_gan: CycleGANModel, cycle_controller: CycleControllerModel, train_loader, g_loss_history: RunningStats, d_loss_history: RunningStats, writer_dict): cycle_gan.train() cycle_controller.eval() dynamic_reset = False writer = writer_dict['writer'] total_iters = 0 t_data = 0.0 for epoch in range(opt.shared_epoch): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 train_steps = writer_dict['train_steps'] for i, data in enumerate(train_loader): iter_start_time = time.time() if total_iters % opt.print_freq == 0: t_data = iter_start_time - iter_data_time total_iters += opt.batch_size epoch_iter += opt.batch_size cycle_controller.forward() cycle_gan.set_input(data) cycle_gan.optimize_parameters() g_loss_history.push(cycle_gan.loss_G.item()) d_loss_history.push(cycle_gan.loss_D_A.item() + cycle_gan.loss_D_B.item()) if (i + 1) % opt.print_freq == 0: losses = cycle_gan.get_current_losses() t_comp = (time.time() - iter_start_time) message = "GAN: [Ep: %d/%d]" % (epoch, opt.shared_epoch) message += "[Batch: %d/%d][time: %.3f][data: %.3f]" % ( epoch_iter, len(train_loader), t_comp, t_data) for k, v in losses.items(): message += '[%s: %.3f]' % (k, v) tqdm.write(message) if (total_iters + 1) % opt.display_freq == 0: cycle_gan.compute_visuals() save_current_results(opt, cycle_gan.get_current_visuals(), train_steps) if g_loss_history.is_full(): if g_loss_history.get_var() < opt.dynamic_reset_threshold \ or d_loss_history.get_var() < opt.dynamic_reset_threshold: dynamic_reset = True tqdm.write("=> dynamic resetting triggered") g_loss_history.clear() d_loss_history.clear() return dynamic_reset if ( total_iters + 1 ) % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations tqdm.write( 'saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) save_suffix = 'latest' # cycle_gan.save_networks(train_steps) iter_data_time = time.time() if (epoch + 1) % opt.save_epoch_freq == 0: cycle_gan.save_networks('latest') # cycle_gan.save_networks(train_steps) tqdm.write('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) writer.add_scalars('Train/discriminator', { "A": float(cycle_gan.loss_D_A), "B": float(cycle_gan.loss_D_B), }, train_steps) writer.add_scalars('Train/generator', { "A": float(cycle_gan.loss_G_A), "B": float(cycle_gan.loss_G_B), }, train_steps) writer.add_scalars( 'Train/cycle', { "A": float(cycle_gan.loss_cycle_A), "B": float(cycle_gan.loss_cycle_B), }, train_steps) writer.add_scalars('Train/idt', { "A": float(cycle_gan.loss_idt_A), "B": float(cycle_gan.loss_idt_B), }, train_steps) writer_dict['train_steps'] += 1 return dynamic_reset
def controller_train(opt, cycle_gan: CycleGANModel, cycle_controller: CycleControllerModel, writer_dict): writer = writer_dict['writer'] # train mode cycle_controller.train() # eval mode cycle_gan.eval() iter_start_time = time.time() for i in range(0, opt.ctrl_step): controller_step = writer_dict['controller_steps'] cycle_controller.step_A() cycle_controller.step_B() if (i + 1) % opt.print_freq_controller == 0: losses = cycle_controller.get_current_losses() t_comp = (time.time() - iter_start_time) iter_start_time = time.time() message = "Cont: [Ep: %d/%d]" % ( i, opt.ctrl_step) + "[{}][{}]".format(cycle_controller.arch_A, cycle_controller.arch_B) message += "[time: %.3f]" % (t_comp) for k, v in losses.items(): message += '[%s: %.3f]' % (k, v) tqdm.write(message) # write writer.add_scalars( 'Controller/loss', { "A": cycle_controller.loss_A.item(), "B": cycle_controller.loss_B.item() }, controller_step) writer.add_scalars( 'Controller/discriminator', { "A": cycle_controller.loss_D_A.item(), "B": cycle_controller.loss_D_B.item() }, controller_step) writer.add_scalars( 'Controller/inception_score', { "A": cycle_controller.loss_IS_A.item(), "B": cycle_controller.loss_IS_B.item() }, controller_step) writer.add_scalars('Controller/adv', { "A": cycle_controller.loss_adv_A, "B": cycle_controller.loss_adv_B }, controller_step) writer.add_scalars( 'Controller/entropy', { "A": cycle_controller.loss_entropy_A, "B": cycle_controller.loss_entropy_B }, controller_step) writer.add_scalars( 'Controller/reward', { "A": cycle_controller.loss_reward_A, "B": cycle_controller.loss_reward_B }, controller_step) writer_dict['controller_steps'] = controller_step + 1
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md """ import time from options.train_options import TrainOptions from data import create_dataset from models.cycle_gan_model import CycleGANModel from util.util import create_log_txt, print_current_losses if __name__ == '__main__': opt = TrainOptions().parse() # get training options dataset = create_dataset( opt) # create a dataset given opt.dataset_mode and other options dataset_size = len(dataset) # get the number of images in the dataset. model = CycleGANModel( opt) # create a model given opt.model and other options model.setup( opt) # regular setup: load and print networks; create schedulers log_file = create_log_txt(opt) # create a log file for training progress total_iters = 0 # the total number of training iterations for epoch in range( opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1 ): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq> epoch_start_time = time.time() # timer for entire epoch iter_data_time = time.time() # timer for data loading per iteration epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch model.update_learning_rate( ) # update learning rates in the beginning of every epoch. for i, data in enumerate(dataset): # inner loop within one epoch iter_start_time = time.time(
ngf=64, no_dropout=True, no_flip=True, norm="instance", ntest=float("inf"), num_test=100, num_threads=0, output_nc=3, phase="test", preprocess="no_preprocessing", results_dir="./results/", serial_batches=True, suffix="", verbose=False, ) model = CycleGANModel(opt).netG_A model.load_state_dict(torch.load(model_fp)) preprocess = get_transform(opt) class SingleImageDataset(torch.utils.data.Dataset): def __init__(self, *args, **kwargs): img = kwargs.pop("img") super().__init__(*args, **kwargs) img = preprocess(img) self.img = img def __getitem__(self, i): return self.img
# -*- coding:utf-8 -*- import time from config import TrainOptions from models.cycle_gan_model import CycleGANModel from utilSet.visualizer import Visualizer, save_opt from data.dataset import DataLoader opt = TrainOptions().parse() save_opt(opt) data_loader = DataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) model = CycleGANModel() model.initialize(opt) visualizer = Visualizer(opt) if __name__ == '__main__': total_steps = 0 for epoch in range(1, opt.epoch + 1): epoch_start_time = time.time() epoch_iter = 0 for i, data in enumerate(dataset): iter_start_time = time.time() visualizer.reset() total_steps += 1 epoch_iter += 1 model.set_input(data) model.optimize_parameters()
from options.test_options import TestOptions from data import create_dataset from models.cycle_gan_model import CycleGANModel from util.util import save_images, create_results_dir from util import html if __name__ == '__main__': opt = TestOptions().parse() # get test options # hard-code some parameters for test opt.num_threads = 0 # test code only supports num_threads = 0 opt.batch_size = 1 # test code only supports batch_size = 1 opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. opt.no_flip = True # no flip; comment this line if results on flipped images are needed. dataset = create_dataset( opt) # create a dataset given opt.dataset_mode and other options model = CycleGANModel( opt) # create a model given opt.model and other options model.setup( opt) # regular setup: load and print networks; create schedulers # create results dir image_dir = create_results_dir(opt) # test with eval mode. This only affects layers like batchnorm and dropout. # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. if opt.eval: model.eval() for i, data in enumerate(dataset): if i >= opt.num_test: # only apply our model to opt.num_test images. break model.set_input(data) # unpack data from data loader model.test() # run inference visuals = model.get_current_visuals() # get image results img_path = model.get_image_paths() # get image paths